diff --git a/.gitignore b/.gitignore index 9757054a50f9..d54d21b802be 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ *.iml *.iws *.pyc +*.pyo .idea/ .idea_modules/ build/*.jar @@ -62,6 +63,8 @@ ec2/lib/ rat-results.txt scalastyle.txt scalastyle-output.xml +R-unit-tests.log +R/unit-tests.out # For Hive metastore_db/ diff --git a/.rat-excludes b/.rat-excludes index 769defbac11b..8aca5a7f7a96 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -1,4 +1,5 @@ target +cache .gitignore .gitattributes .project @@ -18,6 +19,7 @@ fairscheduler.xml.template spark-defaults.conf.template log4j.properties log4j.properties.template +metrics.properties metrics.properties.template slaves slaves.template @@ -65,3 +67,5 @@ logs .*scalastyle-output.xml .*dependency-reduced-pom.xml known_translations +DESCRIPTION +NAMESPACE diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b6c6b050fa33..f10d7e277eea 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,12 +1,16 @@ ## Contributing to Spark -Contributions via GitHub pull requests are gladly accepted from their original -author. Along with any pull requests, please state that the contribution is -your original work and that you license the work to the project under the -project's open source license. Whether or not you state this explicitly, by -submitting any copyrighted material via pull request, email, or other means -you agree to license the material under the project's open source license and -warrant that you have the legal authority to do so. +*Before opening a pull request*, review the +[Contributing to Spark wiki](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark). +It lists steps that are required before creating a PR. In particular, consider: + +- Is the change important and ready enough to ask the community to spend time reviewing? +- Have you searched for existing, related JIRAs and pull requests? +- Is this a new feature that can stand alone as a package on http://spark-packages.org ? +- Is the change being proposed clearly explained and motivated? -Please see the [Contributing to Spark wiki page](https://cwiki.apache.org/SPARK/Contributing+to+Spark) -for more information. +When you contribute code, you affirm that the contribution is your original work and that you +license the work to the project under the project's open source license. Whether or not you +state this explicitly, by submitting any copyrighted material via pull request, email, or +other means you agree to license the material under the project's open source license and +warrant that you have the legal authority to do so. diff --git a/LICENSE b/LICENSE index 0a42d389e4c3..9b364a4d0007 100644 --- a/LICENSE +++ b/LICENSE @@ -771,6 +771,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. +======================================================================== +For TestTimSort (core/src/test/java/org/apache/spark/util/collection/TestTimSort.java): +======================================================================== +Copyright (C) 2015 Stijn de Gouw + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. ======================================================================== For LimitedInputStream diff --git a/R/.gitignore b/R/.gitignore new file mode 100644 index 000000000000..9a5889ba28b2 --- /dev/null +++ b/R/.gitignore @@ -0,0 +1,6 @@ +*.o +*.so +*.Rd +lib +pkg/man +pkg/html diff --git a/R/DOCUMENTATION.md b/R/DOCUMENTATION.md new file mode 100644 index 000000000000..931d01549b26 --- /dev/null +++ b/R/DOCUMENTATION.md @@ -0,0 +1,12 @@ +# SparkR Documentation + +SparkR documentation is generated using in-source comments annotated using using +`roxygen2`. After making changes to the documentation, to generate man pages, +you can run the following from an R console in the SparkR home directory + + library(devtools) + devtools::document(pkg="./pkg", roclets=c("rd")) + +You can verify if your changes are good by running + + R CMD check pkg/ diff --git a/R/README.md b/R/README.md new file mode 100644 index 000000000000..a6970e39b55f --- /dev/null +++ b/R/README.md @@ -0,0 +1,67 @@ +# R on Spark + +SparkR is an R package that provides a light-weight frontend to use Spark from R. + +### SparkR development + +#### Build Spark + +Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-PsparkR` profile to build the R package. For example to use the default Hadoop versions you can run +``` + build/mvn -DskipTests -Psparkr package +``` + +#### Running sparkR + +You can start using SparkR by launching the SparkR shell with + + ./bin/sparkR + +The `sparkR` script automatically creates a SparkContext with Spark by default in +local mode. To specify the Spark master of a cluster for the automatically created +SparkContext, you can run + + ./bin/sparkR --master "local[2]" + +To set other options like driver memory, executor memory etc. you can pass in the [spark-submit](http://spark.apache.org/docs/latest/submitting-applications.html) arguments to `./bin/sparkR` + +#### Using SparkR from RStudio + +If you wish to use SparkR from RStudio or other R frontends you will need to set some environment variables which point SparkR to your Spark installation. For example +``` +# Set this to where Spark is installed +Sys.setenv(SPARK_HOME="/Users/shivaram/spark") +# This line loads SparkR from the installed directory +.libPaths(c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib"), .libPaths())) +library(SparkR) +sc <- sparkR.init(master="local") +``` + +#### Making changes to SparkR + +The [instructions](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) for making contributions to Spark also apply to SparkR. +If you only make R file changes (i.e. no Scala changes) then you can just re-install the R package using `R/install-dev.sh` and test your changes. +Once you have made your changes, please include unit tests for them and run existing unit tests using the `run-tests.sh` script as described below. + +#### Generating documentation + +The SparkR documentation (Rd files and HTML files) are not a part of the source repository. To generate them you can run the script `R/create-docs.sh`. This script uses `devtools` and `knitr` to generate the docs and these packages need to be installed on the machine before using the script. + +### Examples, Unit tests + +SparkR comes with several sample programs in the `examples/src/main/r` directory. +To run one of them, use `./bin/sparkR `. For example: + + ./bin/sparkR examples/src/main/r/pi.R local[2] + +You can also run the unit-tests for SparkR by running (you need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first): + + R -e 'install.packages("testthat", repos="http://cran.us.r-project.org")' + ./R/run-tests.sh + +### Running on YARN +The `./bin/spark-submit` and `./bin/sparkR` can also be used to submit jobs to YARN clusters. You will need to set YARN conf dir before doing so. For example on CDH you can run +``` +export YARN_CONF_DIR=/etc/hadoop/conf +./bin/spark-submit --master yarn examples/src/main/r/pi.R 4 +``` diff --git a/R/WINDOWS.md b/R/WINDOWS.md new file mode 100644 index 000000000000..3f889c0ca3d1 --- /dev/null +++ b/R/WINDOWS.md @@ -0,0 +1,13 @@ +## Building SparkR on Windows + +To build SparkR on Windows, the following steps are required + +1. Install R (>= 3.1) and [Rtools](http://cran.r-project.org/bin/windows/Rtools/). Make sure to +include Rtools and R in `PATH`. +2. Install +[JDK7](http://www.oracle.com/technetwork/java/javase/downloads/jdk7-downloads-1880260.html) and set +`JAVA_HOME` in the system environment variables. +3. Download and install [Maven](http://maven.apache.org/download.html). Also include the `bin` +directory in Maven in `PATH`. +4. Set `MAVEN_OPTS` as described in [Building Spark](http://spark.apache.org/docs/latest/building-spark.html). +5. Open a command shell (`cmd`) in the Spark directory and run `mvn -DskipTests -Psparkr package` diff --git a/R/create-docs.sh b/R/create-docs.sh new file mode 100755 index 000000000000..4194172a2e11 --- /dev/null +++ b/R/create-docs.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Script to create API docs for SparkR +# This requires `devtools` and `knitr` to be installed on the machine. + +# After running this script the html docs can be found in +# $SPARK_HOME/R/pkg/html + +# Figure out where the script is +export FWDIR="$(cd "`dirname "$0"`"; pwd)" +pushd $FWDIR + +# Generate Rd file +Rscript -e 'library(devtools); devtools::document(pkg="./pkg", roclets=c("rd"))' + +# Install the package +./install-dev.sh + +# Now create HTML files + +# knit_rd puts html in current working directory +mkdir -p pkg/html +pushd pkg/html + +Rscript -e 'library(SparkR, lib.loc="../../lib"); library(knitr); knit_rd("SparkR")' + +popd + +popd diff --git a/R/install-dev.bat b/R/install-dev.bat new file mode 100644 index 000000000000..008a5c668bc4 --- /dev/null +++ b/R/install-dev.bat @@ -0,0 +1,27 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem Install development version of SparkR +rem + +set SPARK_HOME=%~dp0.. + +MKDIR %SPARK_HOME%\R\lib + +R.exe CMD INSTALL --library="%SPARK_HOME%\R\lib" %SPARK_HOME%\R\pkg\ diff --git a/R/install-dev.sh b/R/install-dev.sh new file mode 100755 index 000000000000..55ed6f4be1a4 --- /dev/null +++ b/R/install-dev.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This scripts packages the SparkR source files (R and C files) and +# creates a package that can be loaded in R. The package is by default installed to +# $FWDIR/lib and the package can be loaded by using the following command in R: +# +# library(SparkR, lib.loc="$FWDIR/lib") +# +# NOTE(shivaram): Right now we use $SPARK_HOME/R/lib to be the installation directory +# to load the SparkR package on the worker nodes. + + +FWDIR="$(cd `dirname $0`; pwd)" +LIB_DIR="$FWDIR/lib" + +mkdir -p $LIB_DIR + +# Install R +R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ diff --git a/R/log4j.properties b/R/log4j.properties new file mode 100644 index 000000000000..701adb2a3da1 --- /dev/null +++ b/R/log4j.properties @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=R-unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.eclipse.jetty=WARN +org.eclipse.jetty.LEVEL=WARN diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION new file mode 100644 index 000000000000..1c1779a763c7 --- /dev/null +++ b/R/pkg/DESCRIPTION @@ -0,0 +1,35 @@ +Package: SparkR +Type: Package +Title: R frontend for Spark +Version: 1.4.0 +Date: 2013-09-09 +Author: The Apache Software Foundation +Maintainer: Shivaram Venkataraman +Imports: + methods +Depends: + R (>= 3.0), + methods, +Suggests: + testthat +Description: R frontend for Spark +License: Apache License (== 2.0) +Collate: + 'generics.R' + 'jobj.R' + 'RDD.R' + 'pairRDD.R' + 'schema.R' + 'column.R' + 'group.R' + 'DataFrame.R' + 'SQLContext.R' + 'backend.R' + 'broadcast.R' + 'client.R' + 'context.R' + 'deserialize.R' + 'serialize.R' + 'sparkR.R' + 'utils.R' + 'zzz.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE new file mode 100644 index 000000000000..80283643861a --- /dev/null +++ b/R/pkg/NAMESPACE @@ -0,0 +1,196 @@ +#exportPattern("^[[:alpha:]]+") +exportClasses("RDD") +exportClasses("Broadcast") +exportMethods( + "aggregateByKey", + "aggregateRDD", + "cache", + "cartesian", + "checkpoint", + "coalesce", + "cogroup", + "collect", + "collectAsMap", + "collectPartition", + "combineByKey", + "count", + "countByKey", + "countByValue", + "distinct", + "Filter", + "filterRDD", + "first", + "flatMap", + "flatMapValues", + "fold", + "foldByKey", + "foreach", + "foreachPartition", + "fullOuterJoin", + "glom", + "groupByKey", + "intersection", + "join", + "keyBy", + "keys", + "length", + "lapply", + "lapplyPartition", + "lapplyPartitionsWithIndex", + "leftOuterJoin", + "lookup", + "map", + "mapPartitions", + "mapPartitionsWithIndex", + "mapValues", + "maximum", + "minimum", + "numPartitions", + "partitionBy", + "persist", + "pipeRDD", + "reduce", + "reduceByKey", + "reduceByKeyLocally", + "repartition", + "rightOuterJoin", + "sampleByKey", + "sampleRDD", + "saveAsTextFile", + "saveAsObjectFile", + "sortBy", + "sortByKey", + "subtract", + "subtractByKey", + "sumRDD", + "take", + "takeOrdered", + "takeSample", + "top", + "unionRDD", + "unpersist", + "value", + "values", + "zipRDD", + "zipWithIndex", + "zipWithUniqueId" + ) + +# S3 methods exported +export( + "textFile", + "objectFile", + "parallelize", + "hashCode", + "includePackage", + "broadcast", + "setBroadcastValue", + "setCheckpointDir" + ) +export("sparkR.init") +export("sparkR.stop") +export("print.jobj") +useDynLib(SparkR, stringHashCode) +importFrom(methods, setGeneric, setMethod, setOldClass) + +# SparkRSQL + +exportClasses("DataFrame") + +exportMethods("columns", + "distinct", + "dtypes", + "except", + "explain", + "filter", + "groupBy", + "head", + "insertInto", + "intersect", + "isLocal", + "limit", + "orderBy", + "names", + "printSchema", + "registerTempTable", + "repartition", + "sampleDF", + "saveAsParquetFile", + "saveAsTable", + "saveDF", + "schema", + "select", + "selectExpr", + "show", + "showDF", + "sortDF", + "toJSON", + "toRDD", + "unionAll", + "where", + "withColumn", + "withColumnRenamed") + +exportClasses("Column") + +exportMethods("abs", + "alias", + "approxCountDistinct", + "asc", + "avg", + "cast", + "contains", + "countDistinct", + "desc", + "endsWith", + "getField", + "getItem", + "isNotNull", + "isNull", + "last", + "like", + "lower", + "max", + "mean", + "min", + "rlike", + "sqrt", + "startsWith", + "substr", + "sum", + "sumDistinct", + "upper") + +exportClasses("GroupedData") +exportMethods("agg") + +export("sparkRSQL.init", + "sparkRHive.init") + +export("cacheTable", + "clearCache", + "createDataFrame", + "createExternalTable", + "dropTempTable", + "jsonFile", + "jsonRDD", + "loadDF", + "parquetFile", + "sql", + "table", + "tableNames", + "tables", + "toDF", + "uncacheTable") + +export("sparkRSQL.init", + "sparkRHive.init") + +export("structField", + "structField.jobj", + "structField.character", + "print.structField", + "structType", + "structType.jobj", + "structType.structField", + "print.structType") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R new file mode 100644 index 000000000000..b59b700af5dc --- /dev/null +++ b/R/pkg/R/DataFrame.R @@ -0,0 +1,1278 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# DataFrame.R - DataFrame class and methods implemented in S4 OO classes + +#' @include generics.R jobj.R schema.R RDD.R pairRDD.R column.R group.R +NULL + +setOldClass("jobj") + +#' @title S4 class that represents a DataFrame +#' @description DataFrames can be created using functions like +#' \code{jsonFile}, \code{table} etc. +#' @rdname DataFrame +#' @seealso jsonFile, table +#' +#' @param env An R environment that stores bookkeeping states of the DataFrame +#' @param sdf A Java object reference to the backing Scala DataFrame +#' @export +setClass("DataFrame", + slots = list(env = "environment", + sdf = "jobj")) + +setMethod("initialize", "DataFrame", function(.Object, sdf, isCached) { + .Object@env <- new.env() + .Object@env$isCached <- isCached + + .Object@sdf <- sdf + .Object +}) + +#' @rdname DataFrame +#' @export +dataFrame <- function(sdf, isCached = FALSE) { + new("DataFrame", sdf, isCached) +} + +############################ DataFrame Methods ############################################## + +#' Print Schema of a DataFrame +#' +#' Prints out the schema in tree format +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname printSchema +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' printSchema(df) +#'} +setMethod("printSchema", + signature(x = "DataFrame"), + function(x) { + schemaString <- callJMethod(schema(x)$jobj, "treeString") + cat(schemaString) + }) + +#' Get schema object +#' +#' Returns the schema of this DataFrame as a structType object. +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname schema +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' dfSchema <- schema(df) +#'} +setMethod("schema", + signature(x = "DataFrame"), + function(x) { + structType(callJMethod(x@sdf, "schema")) + }) + +#' Explain +#' +#' Print the logical and physical Catalyst plans to the console for debugging. +#' +#' @param x A SparkSQL DataFrame +#' @param extended Logical. If extended is False, explain() only prints the physical plan. +#' @rdname explain +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' explain(df, TRUE) +#'} +setMethod("explain", + signature(x = "DataFrame"), + function(x, extended = FALSE) { + queryExec <- callJMethod(x@sdf, "queryExecution") + if (extended) { + cat(callJMethod(queryExec, "toString")) + } else { + execPlan <- callJMethod(queryExec, "executedPlan") + cat(callJMethod(execPlan, "toString")) + } + }) + +#' isLocal +#' +#' Returns True if the `collect` and `take` methods can be run locally +#' (without any Spark executors). +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname isLocal +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' isLocal(df) +#'} +setMethod("isLocal", + signature(x = "DataFrame"), + function(x) { + callJMethod(x@sdf, "isLocal") + }) + +#' ShowDF +#' +#' Print the first numRows rows of a DataFrame +#' +#' @param x A SparkSQL DataFrame +#' @param numRows The number of rows to print. Defaults to 20. +#' +#' @rdname showDF +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' showDF(df) +#'} +setMethod("showDF", + signature(x = "DataFrame"), + function(x, numRows = 20) { + cat(callJMethod(x@sdf, "showString", numToInt(numRows)), "\n") + }) + +#' show +#' +#' Print the DataFrame column names and types +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname show +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' show(df) +#'} +setMethod("show", "DataFrame", + function(object) { + cols <- lapply(dtypes(object), function(l) { + paste(l, collapse = ":") + }) + s <- paste(cols, collapse = ", ") + cat(paste("DataFrame[", s, "]\n", sep = "")) + }) + +#' DataTypes +#' +#' Return all column names and their data types as a list +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname dtypes +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' dtypes(df) +#'} +setMethod("dtypes", + signature(x = "DataFrame"), + function(x) { + lapply(schema(x)$fields(), function(f) { + c(f$name(), f$dataType.simpleString()) + }) + }) + +#' Column names +#' +#' Return all column names as a list +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname columns +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' columns(df) +#'} +setMethod("columns", + signature(x = "DataFrame"), + function(x) { + sapply(schema(x)$fields(), function(f) { + f$name() + }) + }) + +#' @rdname columns +#' @export +setMethod("names", + signature(x = "DataFrame"), + function(x) { + columns(x) + }) + +#' Register Temporary Table +#' +#' Registers a DataFrame as a Temporary Table in the SQLContext +#' +#' @param x A SparkSQL DataFrame +#' @param tableName A character vector containing the name of the table +#' +#' @rdname registerTempTable +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "json_df") +#' new_df <- sql(sqlCtx, "SELECT * FROM json_df") +#'} +setMethod("registerTempTable", + signature(x = "DataFrame", tableName = "character"), + function(x, tableName) { + callJMethod(x@sdf, "registerTempTable", tableName) + }) + +#' insertInto +#' +#' Insert the contents of a DataFrame into a table registered in the current SQL Context. +#' +#' @param x A SparkSQL DataFrame +#' @param tableName A character vector containing the name of the table +#' @param overwrite A logical argument indicating whether or not to overwrite +#' the existing rows in the table. +#' +#' @rdname insertInto +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df <- loadDF(sqlCtx, path, "parquet") +#' df2 <- loadDF(sqlCtx, path2, "parquet") +#' registerTempTable(df, "table1") +#' insertInto(df2, "table1", overwrite = TRUE) +#'} +setMethod("insertInto", + signature(x = "DataFrame", tableName = "character"), + function(x, tableName, overwrite = FALSE) { + callJMethod(x@sdf, "insertInto", tableName, overwrite) + }) + +#' Cache +#' +#' Persist with the default storage level (MEMORY_ONLY). +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname cache-methods +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' cache(df) +#'} +setMethod("cache", + signature(x = "DataFrame"), + function(x) { + cached <- callJMethod(x@sdf, "cache") + x@env$isCached <- TRUE + x + }) + +#' Persist +#' +#' Persist this DataFrame with the specified storage level. For details of the +#' supported storage levels, refer to +#' http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence. +#' +#' @param x The DataFrame to persist +#' @rdname persist +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' persist(df, "MEMORY_AND_DISK") +#'} +setMethod("persist", + signature(x = "DataFrame", newLevel = "character"), + function(x, newLevel) { + callJMethod(x@sdf, "persist", getStorageLevel(newLevel)) + x@env$isCached <- TRUE + x + }) + +#' Unpersist +#' +#' Mark this DataFrame as non-persistent, and remove all blocks for it from memory and +#' disk. +#' +#' @param x The DataFrame to unpersist +#' @param blocking Whether to block until all blocks are deleted +#' @rdname unpersist-methods +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' persist(df, "MEMORY_AND_DISK") +#' unpersist(df) +#'} +setMethod("unpersist", + signature(x = "DataFrame"), + function(x, blocking = TRUE) { + callJMethod(x@sdf, "unpersist", blocking) + x@env$isCached <- FALSE + x + }) + +#' Repartition +#' +#' Return a new DataFrame that has exactly numPartitions partitions. +#' +#' @param x A SparkSQL DataFrame +#' @param numPartitions The number of partitions to use. +#' @rdname repartition +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' newDF <- repartition(df, 2L) +#'} +setMethod("repartition", + signature(x = "DataFrame", numPartitions = "numeric"), + function(x, numPartitions) { + sdf <- callJMethod(x@sdf, "repartition", numToInt(numPartitions)) + dataFrame(sdf) + }) + +#' toJSON +#' +#' Convert the rows of a DataFrame into JSON objects and return an RDD where +#' each element contains a JSON string. +#' +#' @param x A SparkSQL DataFrame +#' @return A StringRRDD of JSON objects +#' @rdname tojson +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' newRDD <- toJSON(df) +#'} +setMethod("toJSON", + signature(x = "DataFrame"), + function(x) { + rdd <- callJMethod(x@sdf, "toJSON") + jrdd <- callJMethod(rdd, "toJavaRDD") + RDD(jrdd, serializedMode = "string") + }) + +#' saveAsParquetFile +#' +#' Save the contents of a DataFrame as a Parquet file, preserving the schema. Files written out +#' with this method can be read back in as a DataFrame using parquetFile(). +#' +#' @param x A SparkSQL DataFrame +#' @param path The directory where the file is saved +#' @rdname saveAsParquetFile +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' saveAsParquetFile(df, "/tmp/sparkr-tmp/") +#'} +setMethod("saveAsParquetFile", + signature(x = "DataFrame", path = "character"), + function(x, path) { + invisible(callJMethod(x@sdf, "saveAsParquetFile", path)) + }) + +#' Distinct +#' +#' Return a new DataFrame containing the distinct rows in this DataFrame. +#' +#' @param x A SparkSQL DataFrame +#' @rdname distinct +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' distinctDF <- distinct(df) +#'} +setMethod("distinct", + signature(x = "DataFrame"), + function(x) { + sdf <- callJMethod(x@sdf, "distinct") + dataFrame(sdf) + }) + +#' SampleDF +#' +#' Return a sampled subset of this DataFrame using a random seed. +#' +#' @param x A SparkSQL DataFrame +#' @param withReplacement Sampling with replacement or not +#' @param fraction The (rough) sample target fraction +#' @rdname sampleDF +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' collect(sampleDF(df, FALSE, 0.5)) +#' collect(sampleDF(df, TRUE, 0.5)) +#'} +setMethod("sampleDF", + # TODO : Figure out how to send integer as java.lang.Long to JVM so + # we can send seed as an argument through callJMethod + signature(x = "DataFrame", withReplacement = "logical", + fraction = "numeric"), + function(x, withReplacement, fraction) { + if (fraction < 0.0) stop(cat("Negative fraction value:", fraction)) + sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction) + dataFrame(sdf) + }) + +#' Count +#' +#' Returns the number of rows in a DataFrame +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname count +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' count(df) +#' } +setMethod("count", + signature(x = "DataFrame"), + function(x) { + callJMethod(x@sdf, "count") + }) + +#' Collects all the elements of a Spark DataFrame and coerces them into an R data.frame. +#' +#' @param x A SparkSQL DataFrame +#' @param stringsAsFactors (Optional) A logical indicating whether or not string columns +#' should be converted to factors. FALSE by default. + +#' @rdname collect-methods +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' collected <- collect(df) +#' firstName <- collected[[1]]$name +#' } +setMethod("collect", + signature(x = "DataFrame"), + function(x, stringsAsFactors = FALSE) { + # listCols is a list of raw vectors, one per column + listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf) + cols <- lapply(listCols, function(col) { + objRaw <- rawConnection(col) + numRows <- readInt(objRaw) + col <- readCol(objRaw, numRows) + close(objRaw) + col + }) + names(cols) <- columns(x) + do.call(cbind.data.frame, list(cols, stringsAsFactors = stringsAsFactors)) + }) + +#' Limit +#' +#' Limit the resulting DataFrame to the number of rows specified. +#' +#' @param x A SparkSQL DataFrame +#' @param num The number of rows to return +#' @return A new DataFrame containing the number of rows specified. +#' +#' @rdname limit +#' @export +#' @examples +#' \dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' limitedDF <- limit(df, 10) +#' } +setMethod("limit", + signature(x = "DataFrame", num = "numeric"), + function(x, num) { + res <- callJMethod(x@sdf, "limit", as.integer(num)) + dataFrame(res) + }) + +# Take the first NUM rows of a DataFrame and return a the results as a data.frame + +#' @rdname take +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' take(df, 2) +#' } +setMethod("take", + signature(x = "DataFrame", num = "numeric"), + function(x, num) { + limited <- limit(x, num) + collect(limited) + }) + +#' Head +#' +#' Return the first NUM rows of a DataFrame as a data.frame. If NUM is NULL, +#' then head() returns the first 6 rows in keeping with the current data.frame +#' convention in R. +#' +#' @param x A SparkSQL DataFrame +#' @param num The number of rows to return. Default is 6. +#' @return A data.frame +#' +#' @rdname head +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' head(df) +#' } +setMethod("head", + signature(x = "DataFrame"), + function(x, num = 6L) { + # Default num is 6L in keeping with R's data.frame convention + take(x, num) + }) + +#' Return the first row of a DataFrame +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname first +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' first(df) +#' } +setMethod("first", + signature(x = "DataFrame"), + function(x) { + take(x, 1) + }) + +#' toRDD() +#' +#' Converts a Spark DataFrame to an RDD while preserving column names. +#' +#' @param x A Spark DataFrame +#' +#' @rdname DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' rdd <- toRDD(df) +#' } +setMethod("toRDD", + signature(x = "DataFrame"), + function(x) { + jrdd <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToRowRDD", x@sdf) + colNames <- callJMethod(x@sdf, "columns") + rdd <- RDD(jrdd, serializedMode = "row") + lapply(rdd, function(row) { + names(row) <- colNames + row + }) + }) + +#' GroupBy +#' +#' Groups the DataFrame using the specified columns, so we can run aggregation on them. +#' +#' @param x a DataFrame +#' @return a GroupedData +#' @seealso GroupedData +#' @rdname DataFrame +#' @export +#' @examples +#' \dontrun{ +#' # Compute the average for all numeric columns grouped by department. +#' avg(groupBy(df, "department")) +#' +#' # Compute the max age and average salary, grouped by department and gender. +#' agg(groupBy(df, "department", "gender"), salary="avg", "age" -> "max") +#' } +setMethod("groupBy", + signature(x = "DataFrame"), + function(x, ...) { + cols <- list(...) + if (length(cols) >= 1 && class(cols[[1]]) == "character") { + sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], listToSeq(cols[-1])) + } else { + jcol <- lapply(cols, function(c) { c@jc }) + sgd <- callJMethod(x@sdf, "groupBy", listToSeq(jcol)) + } + groupedData(sgd) + }) + +#' Agg +#' +#' Compute aggregates by specifying a list of columns +#' +#' @rdname DataFrame +#' @export +setMethod("agg", + signature(x = "DataFrame"), + function(x, ...) { + agg(groupBy(x), ...) + }) + + +############################## RDD Map Functions ################################## +# All of the following functions mirror the existing RDD map functions, # +# but allow for use with DataFrames by first converting to an RRDD before calling # +# the requested map function. # +################################################################################### + +#' @rdname lapply +setMethod("lapply", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + rdd <- toRDD(X) + lapply(rdd, FUN) + }) + +#' @rdname lapply +setMethod("map", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + lapply(X, FUN) + }) + +#' @rdname flatMap +setMethod("flatMap", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + rdd <- toRDD(X) + flatMap(rdd, FUN) + }) + +#' @rdname lapplyPartition +setMethod("lapplyPartition", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + rdd <- toRDD(X) + lapplyPartition(rdd, FUN) + }) + +#' @rdname lapplyPartition +setMethod("mapPartitions", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + lapplyPartition(X, FUN) + }) + +#' @rdname foreach +setMethod("foreach", + signature(x = "DataFrame", func = "function"), + function(x, func) { + rdd <- toRDD(x) + foreach(rdd, func) + }) + +#' @rdname foreach +setMethod("foreachPartition", + signature(x = "DataFrame", func = "function"), + function(x, func) { + rdd <- toRDD(x) + foreachPartition(rdd, func) + }) + + +############################## SELECT ################################## + +getColumn <- function(x, c) { + column(callJMethod(x@sdf, "col", c)) +} + +#' @rdname select +setMethod("$", signature(x = "DataFrame"), + function(x, name) { + getColumn(x, name) + }) + +setMethod("$<-", signature(x = "DataFrame"), + function(x, name, value) { + stopifnot(class(value) == "Column" || is.null(value)) + cols <- columns(x) + if (name %in% cols) { + if (is.null(value)) { + cols <- Filter(function(c) { c != name }, cols) + } + cols <- lapply(cols, function(c) { + if (c == name) { + alias(value, name) + } else { + col(c) + } + }) + nx <- select(x, cols) + } else { + if (is.null(value)) { + return(x) + } + nx <- withColumn(x, name, value) + } + x@sdf <- nx@sdf + x + }) + +#' @rdname select +setMethod("[[", signature(x = "DataFrame"), + function(x, i) { + if (is.numeric(i)) { + cols <- columns(x) + i <- cols[[i]] + } + getColumn(x, i) + }) + +#' @rdname select +setMethod("[", signature(x = "DataFrame", i = "missing"), + function(x, i, j, ...) { + if (is.numeric(j)) { + cols <- columns(x) + j <- cols[j] + } + if (length(j) > 1) { + j <- as.list(j) + } + select(x, j) + }) + +#' Select +#' +#' Selects a set of columns with names or Column expressions. +#' @param x A DataFrame +#' @param col A list of columns or single Column or name +#' @return A new DataFrame with selected columns +#' @export +#' @rdname select +#' @examples +#' \dontrun{ +#' select(df, "*") +#' select(df, "col1", "col2") +#' select(df, df$name, df$age + 1) +#' select(df, c("col1", "col2")) +#' select(df, list(df$name, df$age + 1)) +#' # Columns can also be selected using `[[` and `[` +#' df[[2]] == df[["age"]] +#' df[,2] == df[,"age"] +#' # Similar to R data frames columns can also be selected using `$` +#' df$age +#' } +setMethod("select", signature(x = "DataFrame", col = "character"), + function(x, col, ...) { + sdf <- callJMethod(x@sdf, "select", col, toSeq(...)) + dataFrame(sdf) + }) + +#' @rdname select +#' @export +setMethod("select", signature(x = "DataFrame", col = "Column"), + function(x, col, ...) { + jcols <- lapply(list(col, ...), function(c) { + c@jc + }) + sdf <- callJMethod(x@sdf, "select", listToSeq(jcols)) + dataFrame(sdf) + }) + +#' @rdname select +#' @export +setMethod("select", + signature(x = "DataFrame", col = "list"), + function(x, col) { + cols <- lapply(col, function(c) { + if (class(c)== "Column") { + c@jc + } else { + col(c)@jc + } + }) + sdf <- callJMethod(x@sdf, "select", listToSeq(cols)) + dataFrame(sdf) + }) + +#' SelectExpr +#' +#' Select from a DataFrame using a set of SQL expressions. +#' +#' @param x A DataFrame to be selected from. +#' @param expr A string containing a SQL expression +#' @param ... Additional expressions +#' @return A DataFrame +#' @rdname selectExpr +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' selectExpr(df, "col1", "(col2 * 5) as newCol") +#' } +setMethod("selectExpr", + signature(x = "DataFrame", expr = "character"), + function(x, expr, ...) { + exprList <- list(expr, ...) + sdf <- callJMethod(x@sdf, "selectExpr", listToSeq(exprList)) + dataFrame(sdf) + }) + +#' WithColumn +#' +#' Return a new DataFrame with the specified column added. +#' +#' @param x A DataFrame +#' @param colName A string containing the name of the new column. +#' @param col A Column expression. +#' @return A DataFrame with the new column added. +#' @rdname withColumn +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' newDF <- withColumn(df, "newCol", df$col1 * 5) +#' } +setMethod("withColumn", + signature(x = "DataFrame", colName = "character", col = "Column"), + function(x, colName, col) { + select(x, x$"*", alias(col, colName)) + }) + +#' WithColumnRenamed +#' +#' Rename an existing column in a DataFrame. +#' +#' @param x A DataFrame +#' @param existingCol The name of the column you want to change. +#' @param newCol The new column name. +#' @return A DataFrame with the column name changed. +#' @rdname withColumnRenamed +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' newDF <- withColumnRenamed(df, "col1", "newCol1") +#' } +setMethod("withColumnRenamed", + signature(x = "DataFrame", existingCol = "character", newCol = "character"), + function(x, existingCol, newCol) { + cols <- lapply(columns(x), function(c) { + if (c == existingCol) { + alias(col(c), newCol) + } else { + col(c) + } + }) + select(x, cols) + }) + +setClassUnion("characterOrColumn", c("character", "Column")) + +#' SortDF +#' +#' Sort a DataFrame by the specified column(s). +#' +#' @param x A DataFrame to be sorted. +#' @param col Either a Column object or character vector indicating the field to sort on +#' @param ... Additional sorting fields +#' @return A DataFrame where all elements are sorted. +#' @rdname sortDF +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' sortDF(df, df$col1) +#' sortDF(df, "col1") +#' sortDF(df, asc(df$col1), desc(abs(df$col2))) +#' } +setMethod("sortDF", + signature(x = "DataFrame", col = "characterOrColumn"), + function(x, col, ...) { + if (class(col) == "character") { + sdf <- callJMethod(x@sdf, "sort", col, toSeq(...)) + } else if (class(col) == "Column") { + jcols <- lapply(list(col, ...), function(c) { + c@jc + }) + sdf <- callJMethod(x@sdf, "sort", listToSeq(jcols)) + } + dataFrame(sdf) + }) + +#' @rdname sortDF +#' @export +setMethod("orderBy", + signature(x = "DataFrame", col = "characterOrColumn"), + function(x, col) { + sortDF(x, col) + }) + +#' Filter +#' +#' Filter the rows of a DataFrame according to a given condition. +#' +#' @param x A DataFrame to be sorted. +#' @param condition The condition to sort on. This may either be a Column expression +#' or a string containing a SQL statement +#' @return A DataFrame containing only the rows that meet the condition. +#' @rdname filter +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' filter(df, "col1 > 0") +#' filter(df, df$col2 != "abcdefg") +#' } +setMethod("filter", + signature(x = "DataFrame", condition = "characterOrColumn"), + function(x, condition) { + if (class(condition) == "Column") { + condition <- condition@jc + } + sdf <- callJMethod(x@sdf, "filter", condition) + dataFrame(sdf) + }) + +#' @rdname filter +#' @export +setMethod("where", + signature(x = "DataFrame", condition = "characterOrColumn"), + function(x, condition) { + filter(x, condition) + }) + +#' Join +#' +#' Join two DataFrames based on the given join expression. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @param joinExpr (Optional) The expression used to perform the join. joinExpr must be a +#' Column expression. If joinExpr is omitted, join() wil perform a Cartesian join +#' @param joinType The type of join to perform. The following join types are available: +#' 'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'. The default joinType is "inner". +#' @return A DataFrame containing the result of the join operation. +#' @rdname join +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlCtx, path) +#' df2 <- jsonFile(sqlCtx, path2) +#' join(df1, df2) # Performs a Cartesian +#' join(df1, df2, df1$col1 == df2$col2) # Performs an inner join based on expression +#' join(df1, df2, df1$col1 == df2$col2, "right_outer") +#' } +setMethod("join", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y, joinExpr = NULL, joinType = NULL) { + if (is.null(joinExpr)) { + sdf <- callJMethod(x@sdf, "join", y@sdf) + } else { + if (class(joinExpr) != "Column") stop("joinExpr must be a Column") + if (is.null(joinType)) { + sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc) + } else { + if (joinType %in% c("inner", "outer", "left_outer", "right_outer", "semijoin")) { + sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc, joinType) + } else { + stop("joinType must be one of the following types: ", + "'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'") + } + } + } + dataFrame(sdf) + }) + +#' UnionAll +#' +#' Return a new DataFrame containing the union of rows in this DataFrame +#' and another DataFrame. This is equivalent to `UNION ALL` in SQL. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @return A DataFrame containing the result of the union. +#' @rdname unionAll +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlCtx, path) +#' df2 <- jsonFile(sqlCtx, path2) +#' unioned <- unionAll(df, df2) +#' } +setMethod("unionAll", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y) { + unioned <- callJMethod(x@sdf, "unionAll", y@sdf) + dataFrame(unioned) + }) + +#' Intersect +#' +#' Return a new DataFrame containing rows only in both this DataFrame +#' and another DataFrame. This is equivalent to `INTERSECT` in SQL. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @return A DataFrame containing the result of the intersect. +#' @rdname intersect +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlCtx, path) +#' df2 <- jsonFile(sqlCtx, path2) +#' intersectDF <- intersect(df, df2) +#' } +setMethod("intersect", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y) { + intersected <- callJMethod(x@sdf, "intersect", y@sdf) + dataFrame(intersected) + }) + +#' except +#' +#' Return a new DataFrame containing rows in this DataFrame +#' but not in another DataFrame. This is equivalent to `EXCEPT` in SQL. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @return A DataFrame containing the result of the except operation. +#' @rdname except +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlCtx, path) +#' df2 <- jsonFile(sqlCtx, path2) +#' exceptDF <- except(df, df2) +#' } +#' @rdname except +#' @export +setMethod("except", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y) { + excepted <- callJMethod(x@sdf, "except", y@sdf) + dataFrame(excepted) + }) + +#' Save the contents of the DataFrame to a data source +#' +#' The data source is specified by the `source` and a set of options (...). +#' If `source` is not specified, the default data source configured by +#' spark.sql.sources.default will be used. +#' +#' Additionally, mode is used to specify the behavior of the save operation when +#' data already exists in the data source. There are four modes: +#' append: Contents of this DataFrame are expected to be appended to existing data. +#' overwrite: Existing data is expected to be overwritten by the contents of +# this DataFrame. +#' error: An exception is expected to be thrown. +#' ignore: The save operation is expected to not save the contents of the DataFrame +# and to not change the existing data. +#' +#' @param df A SparkSQL DataFrame +#' @param path A name for the table +#' @param source A name for external data source +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' +#' +#' @rdname saveAsTable +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' saveAsTable(df, "myfile") +#' } +setMethod("saveDF", + signature(df = "DataFrame", path = 'character', source = 'character', + mode = 'character'), + function(df, path = NULL, source = NULL, mode = "append", ...){ + if (is.null(source)) { + sqlCtx <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlCtx, "getConf", "spark.sql.sources.default", + "org.apache.spark.sql.parquet") + } + allModes <- c("append", "overwrite", "error", "ignore") + if (!(mode %in% allModes)) { + stop('mode should be one of "append", "overwrite", "error", "ignore"') + } + jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) + options <- varargsToEnv(...) + if (!is.null(path)) { + options[['path']] = path + } + callJMethod(df@sdf, "save", source, jmode, options) + }) + + +#' saveAsTable +#' +#' Save the contents of the DataFrame to a data source as a table +#' +#' The data source is specified by the `source` and a set of options (...). +#' If `source` is not specified, the default data source configured by +#' spark.sql.sources.default will be used. +#' +#' Additionally, mode is used to specify the behavior of the save operation when +#' data already exists in the data source. There are four modes: +#' append: Contents of this DataFrame are expected to be appended to existing data. +#' overwrite: Existing data is expected to be overwritten by the contents of +# this DataFrame. +#' error: An exception is expected to be thrown. +#' ignore: The save operation is expected to not save the contents of the DataFrame +# and to not change the existing data. +#' +#' @param df A SparkSQL DataFrame +#' @param tableName A name for the table +#' @param source A name for external data source +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' +#' +#' @rdname saveAsTable +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' saveAsTable(df, "myfile") +#' } +setMethod("saveAsTable", + signature(df = "DataFrame", tableName = 'character', source = 'character', + mode = 'character'), + function(df, tableName, source = NULL, mode="append", ...){ + if (is.null(source)) { + sqlCtx <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlCtx, "getConf", "spark.sql.sources.default", + "org.apache.spark.sql.parquet") + } + allModes <- c("append", "overwrite", "error", "ignore") + if (!(mode %in% allModes)) { + stop('mode should be one of "append", "overwrite", "error", "ignore"') + } + jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) + options <- varargsToEnv(...) + callJMethod(df@sdf, "saveAsTable", tableName, source, jmode, options) + }) + diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R new file mode 100644 index 000000000000..1662d6bb3b1a --- /dev/null +++ b/R/pkg/R/RDD.R @@ -0,0 +1,1592 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# RDD in R implemented in S4 OO system. + +setOldClass("jobj") + +#' @title S4 class that represents an RDD +#' @description RDD can be created using functions like +#' \code{parallelize}, \code{textFile} etc. +#' @rdname RDD +#' @seealso parallelize, textFile +#' +#' @slot env An R environment that stores bookkeeping states of the RDD +#' @slot jrdd Java object reference to the backing JavaRDD +#' to an RDD +#' @export +setClass("RDD", + slots = list(env = "environment", + jrdd = "jobj")) + +setClass("PipelinedRDD", + slots = list(prev = "RDD", + func = "function", + prev_jrdd = "jobj"), + contains = "RDD") + +setMethod("initialize", "RDD", function(.Object, jrdd, serializedMode, + isCached, isCheckpointed) { + # Check that RDD constructor is using the correct version of serializedMode + stopifnot(class(serializedMode) == "character") + stopifnot(serializedMode %in% c("byte", "string", "row")) + # RDD has three serialization types: + # byte: The RDD stores data serialized in R. + # string: The RDD stores data as strings. + # row: The RDD stores the serialized rows of a DataFrame. + + # We use an environment to store mutable states inside an RDD object. + # Note that R's call-by-value semantics makes modifying slots inside an + # object (passed as an argument into a function, such as cache()) difficult: + # i.e. one needs to make a copy of the RDD object and sets the new slot value + # there. + + # The slots are inheritable from superclass. Here, both `env' and `jrdd' are + # inherited from RDD, but only the former is used. + .Object@env <- new.env() + .Object@env$isCached <- isCached + .Object@env$isCheckpointed <- isCheckpointed + .Object@env$serializedMode <- serializedMode + + .Object@jrdd <- jrdd + .Object +}) + +setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) { + .Object@env <- new.env() + .Object@env$isCached <- FALSE + .Object@env$isCheckpointed <- FALSE + .Object@env$jrdd_val <- jrdd_val + if (!is.null(jrdd_val)) { + # This tracks the serialization mode for jrdd_val + .Object@env$serializedMode <- prev@env$serializedMode + } + + .Object@prev <- prev + + isPipelinable <- function(rdd) { + e <- rdd@env + !(e$isCached || e$isCheckpointed) + } + + if (!inherits(prev, "PipelinedRDD") || !isPipelinable(prev)) { + # This transformation is the first in its stage: + .Object@func <- cleanClosure(func) + .Object@prev_jrdd <- getJRDD(prev) + .Object@env$prev_serializedMode <- prev@env$serializedMode + # NOTE: We use prev_serializedMode to track the serialization mode of prev_JRDD + # prev_serializedMode is used during the delayed computation of JRDD in getJRDD + } else { + pipelinedFunc <- function(partIndex, part) { + func(partIndex, prev@func(partIndex, part)) + } + .Object@func <- cleanClosure(pipelinedFunc) + .Object@prev_jrdd <- prev@prev_jrdd # maintain the pipeline + # Get the serialization mode of the parent RDD + .Object@env$prev_serializedMode <- prev@env$prev_serializedMode + } + + .Object +}) + +#' @rdname RDD +#' @export +#' +#' @param jrdd Java object reference to the backing JavaRDD +#' @param serializedMode Use "byte" if the RDD stores data serialized in R, "string" if the RDD +#' stores strings, and "row" if the RDD stores the rows of a DataFrame +#' @param isCached TRUE if the RDD is cached +#' @param isCheckpointed TRUE if the RDD has been checkpointed +RDD <- function(jrdd, serializedMode = "byte", isCached = FALSE, + isCheckpointed = FALSE) { + new("RDD", jrdd, serializedMode, isCached, isCheckpointed) +} + +PipelinedRDD <- function(prev, func) { + new("PipelinedRDD", prev, func, NULL) +} + +# Return the serialization mode for an RDD. +setGeneric("getSerializedMode", function(rdd, ...) { standardGeneric("getSerializedMode") }) +# For normal RDDs we can directly read the serializedMode +setMethod("getSerializedMode", signature(rdd = "RDD"), function(rdd) rdd@env$serializedMode ) +# For pipelined RDDs if jrdd_val is set then serializedMode should exist +# if not we return the defaultSerialization mode of "byte" as we don't know the serialization +# mode at this point in time. +setMethod("getSerializedMode", signature(rdd = "PipelinedRDD"), + function(rdd) { + if (!is.null(rdd@env$jrdd_val)) { + return(rdd@env$serializedMode) + } else { + return("byte") + } + }) + +# The jrdd accessor function. +setMethod("getJRDD", signature(rdd = "RDD"), function(rdd) rdd@jrdd ) +setMethod("getJRDD", signature(rdd = "PipelinedRDD"), + function(rdd, serializedMode = "byte") { + if (!is.null(rdd@env$jrdd_val)) { + return(rdd@env$jrdd_val) + } + + packageNamesArr <- serialize(.sparkREnv[[".packages"]], + connection = NULL) + + broadcastArr <- lapply(ls(.broadcastNames), + function(name) { get(name, .broadcastNames) }) + + serializedFuncArr <- serialize(rdd@func, connection = NULL) + + prev_jrdd <- rdd@prev_jrdd + + if (serializedMode == "string") { + rddRef <- newJObject("org.apache.spark.api.r.StringRRDD", + callJMethod(prev_jrdd, "rdd"), + serializedFuncArr, + rdd@env$prev_serializedMode, + packageNamesArr, + as.character(.sparkREnv[["libname"]]), + broadcastArr, + callJMethod(prev_jrdd, "classTag")) + } else { + rddRef <- newJObject("org.apache.spark.api.r.RRDD", + callJMethod(prev_jrdd, "rdd"), + serializedFuncArr, + rdd@env$prev_serializedMode, + serializedMode, + packageNamesArr, + as.character(.sparkREnv[["libname"]]), + broadcastArr, + callJMethod(prev_jrdd, "classTag")) + } + # Save the serialization flag after we create a RRDD + rdd@env$serializedMode <- serializedMode + rdd@env$jrdd_val <- callJMethod(rddRef, "asJavaRDD") # rddRef$asJavaRDD() + rdd@env$jrdd_val + }) + +setValidity("RDD", + function(object) { + jrdd <- getJRDD(object) + cls <- callJMethod(jrdd, "getClass") + className <- callJMethod(cls, "getName") + if (grep("spark.api.java.*RDD*", className) == 1) { + TRUE + } else { + paste("Invalid RDD class ", className) + } + }) + + +############ Actions and Transformations ############ + +#' Persist an RDD +#' +#' Persist this RDD with the default storage level (MEMORY_ONLY). +#' +#' @param x The RDD to cache +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' cache(rdd) +#'} +#' @rdname cache-methods +#' @aliases cache,RDD-method +setMethod("cache", + signature(x = "RDD"), + function(x) { + callJMethod(getJRDD(x), "cache") + x@env$isCached <- TRUE + x + }) + +#' Persist an RDD +#' +#' Persist this RDD with the specified storage level. For details of the +#' supported storage levels, refer to +#' http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence. +#' +#' @param x The RDD to persist +#' @param newLevel The new storage level to be assigned +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' persist(rdd, "MEMORY_AND_DISK") +#'} +#' @rdname persist +#' @aliases persist,RDD-method +setMethod("persist", + signature(x = "RDD", newLevel = "character"), + function(x, newLevel) { + callJMethod(getJRDD(x), "persist", getStorageLevel(newLevel)) + x@env$isCached <- TRUE + x + }) + +#' Unpersist an RDD +#' +#' Mark the RDD as non-persistent, and remove all blocks for it from memory and +#' disk. +#' +#' @param x The RDD to unpersist +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' cache(rdd) # rdd@@env$isCached == TRUE +#' unpersist(rdd) # rdd@@env$isCached == FALSE +#'} +#' @rdname unpersist-methods +#' @aliases unpersist,RDD-method +setMethod("unpersist", + signature(x = "RDD"), + function(x) { + callJMethod(getJRDD(x), "unpersist") + x@env$isCached <- FALSE + x + }) + +#' Checkpoint an RDD +#' +#' Mark this RDD for checkpointing. It will be saved to a file inside the +#' checkpoint directory set with setCheckpointDir() and all references to its +#' parent RDDs will be removed. This function must be called before any job has +#' been executed on this RDD. It is strongly recommended that this RDD is +#' persisted in memory, otherwise saving it on a file will require recomputation. +#' +#' @param x The RDD to checkpoint +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' setCheckpointDir(sc, "checkpoint") +#' rdd <- parallelize(sc, 1:10, 2L) +#' checkpoint(rdd) +#'} +#' @rdname checkpoint-methods +#' @aliases checkpoint,RDD-method +setMethod("checkpoint", + signature(x = "RDD"), + function(x) { + jrdd <- getJRDD(x) + callJMethod(jrdd, "checkpoint") + x@env$isCheckpointed <- TRUE + x + }) + +#' Gets the number of partitions of an RDD +#' +#' @param x A RDD. +#' @return the number of partitions of rdd as an integer. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' numPartitions(rdd) # 2L +#'} +#' @rdname numPartitions +#' @aliases numPartitions,RDD-method +setMethod("numPartitions", + signature(x = "RDD"), + function(x) { + jrdd <- getJRDD(x) + partitions <- callJMethod(jrdd, "partitions") + callJMethod(partitions, "size") + }) + +#' Collect elements of an RDD +#' +#' @description +#' \code{collect} returns a list that contains all of the elements in this RDD. +#' +#' @param x The RDD to collect +#' @param ... Other optional arguments to collect +#' @param flatten FALSE if the list should not flattened +#' @return a list containing elements in the RDD +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' collect(rdd) # list from 1 to 10 +#' collectPartition(rdd, 0L) # list from 1 to 5 +#'} +#' @rdname collect-methods +#' @aliases collect,RDD-method +setMethod("collect", + signature(x = "RDD"), + function(x, flatten = TRUE) { + # Assumes a pairwise RDD is backed by a JavaPairRDD. + collected <- callJMethod(getJRDD(x), "collect") + convertJListToRList(collected, flatten, + serializedMode = getSerializedMode(x)) + }) + + +#' @description +#' \code{collectPartition} returns a list that contains all of the elements +#' in the specified partition of the RDD. +#' @param partitionId the partition to collect (starts from 0) +#' @rdname collect-methods +#' @aliases collectPartition,integer,RDD-method +setMethod("collectPartition", + signature(x = "RDD", partitionId = "integer"), + function(x, partitionId) { + jPartitionsList <- callJMethod(getJRDD(x), + "collectPartitions", + as.list(as.integer(partitionId))) + + jList <- jPartitionsList[[1]] + convertJListToRList(jList, flatten = TRUE, + serializedMode = getSerializedMode(x)) + }) + +#' @description +#' \code{collectAsMap} returns a named list as a map that contains all of the elements +#' in a key-value pair RDD. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4)), 2L) +#' collectAsMap(rdd) # list(`1` = 2, `3` = 4) +#'} +#' @rdname collect-methods +#' @aliases collectAsMap,RDD-method +setMethod("collectAsMap", + signature(x = "RDD"), + function(x) { + pairList <- collect(x) + map <- new.env() + lapply(pairList, function(i) { assign(as.character(i[[1]]), i[[2]], envir = map) }) + as.list(map) + }) + +#' Return the number of elements in the RDD. +#' +#' @param x The RDD to count +#' @return number of elements in the RDD. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' count(rdd) # 10 +#' length(rdd) # Same as count +#'} +#' @rdname count +#' @aliases count,RDD-method +setMethod("count", + signature(x = "RDD"), + function(x) { + countPartition <- function(part) { + as.integer(length(part)) + } + valsRDD <- lapplyPartition(x, countPartition) + vals <- collect(valsRDD) + sum(as.integer(vals)) + }) + +#' Return the number of elements in the RDD +#' @export +#' @rdname count +setMethod("length", + signature(x = "RDD"), + function(x) { + count(x) + }) + +#' Return the count of each unique value in this RDD as a list of +#' (value, count) pairs. +#' +#' Same as countByValue in Spark. +#' +#' @param x The RDD to count +#' @return list of (value, count) pairs, where count is number of each unique +#' value in rdd. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, c(1,2,3,2,1)) +#' countByValue(rdd) # (1,2L), (2,2L), (3,1L) +#'} +#' @rdname countByValue +#' @aliases countByValue,RDD-method +setMethod("countByValue", + signature(x = "RDD"), + function(x) { + ones <- lapply(x, function(item) { list(item, 1L) }) + collect(reduceByKey(ones, `+`, numPartitions(x))) + }) + +#' Apply a function to all elements +#' +#' This function creates a new RDD by applying the given transformation to all +#' elements of the given RDD +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each element +#' @return a new RDD created by the transformation. +#' @rdname lapply +#' @aliases lapply +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' multiplyByTwo <- lapply(rdd, function(x) { x * 2 }) +#' collect(multiplyByTwo) # 2,4,6... +#'} +setMethod("lapply", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + func <- function(partIndex, part) { + lapply(part, FUN) + } + lapplyPartitionsWithIndex(X, func) + }) + +#' @rdname lapply +#' @aliases map,RDD,function-method +setMethod("map", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapply(X, FUN) + }) + +#' Flatten results after apply a function to all elements +#' +#' This function return a new RDD by first applying a function to all +#' elements of this RDD, and then flattening the results. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each element +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' multiplyByTwo <- flatMap(rdd, function(x) { list(x*2, x*10) }) +#' collect(multiplyByTwo) # 2,20,4,40,6,60... +#'} +#' @rdname flatMap +#' @aliases flatMap,RDD,function-method +setMethod("flatMap", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + partitionFunc <- function(part) { + unlist( + lapply(part, FUN), + recursive = F + ) + } + lapplyPartition(X, partitionFunc) + }) + +#' Apply a function to each partition of an RDD +#' +#' Return a new RDD by applying a function to each partition of this RDD. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each partition. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' partitionSum <- lapplyPartition(rdd, function(part) { Reduce("+", part) }) +#' collect(partitionSum) # 15, 40 +#'} +#' @rdname lapplyPartition +#' @aliases lapplyPartition,RDD,function-method +setMethod("lapplyPartition", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapplyPartitionsWithIndex(X, function(s, part) { FUN(part) }) + }) + +#' mapPartitions is the same as lapplyPartition. +#' +#' @rdname lapplyPartition +#' @aliases mapPartitions,RDD,function-method +setMethod("mapPartitions", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapplyPartition(X, FUN) + }) + +#' Return a new RDD by applying a function to each partition of this RDD, while +#' tracking the index of the original partition. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each partition; takes the partition +#' index and a list of elements in the particular partition. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 5L) +#' prod <- lapplyPartitionsWithIndex(rdd, function(partIndex, part) { +#' partIndex * Reduce("+", part) }) +#' collect(prod, flatten = FALSE) # 0, 7, 22, 45, 76 +#'} +#' @rdname lapplyPartitionsWithIndex +#' @aliases lapplyPartitionsWithIndex,RDD,function-method +setMethod("lapplyPartitionsWithIndex", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + PipelinedRDD(X, FUN) + }) + +#' @rdname lapplyPartitionsWithIndex +#' @aliases mapPartitionsWithIndex,RDD,function-method +setMethod("mapPartitionsWithIndex", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapplyPartitionsWithIndex(X, FUN) + }) + +#' This function returns a new RDD containing only the elements that satisfy +#' a predicate (i.e. returning TRUE in a given logical function). +#' The same as `filter()' in Spark. +#' +#' @param x The RDD to be filtered. +#' @param f A unary predicate function. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' unlist(collect(filterRDD(rdd, function (x) { x < 3 }))) # c(1, 2) +#'} +#' @rdname filterRDD +#' @aliases filterRDD,RDD,function-method +setMethod("filterRDD", + signature(x = "RDD", f = "function"), + function(x, f) { + filter.func <- function(part) { + Filter(f, part) + } + lapplyPartition(x, filter.func) + }) + +#' @rdname filterRDD +#' @aliases Filter +setMethod("Filter", + signature(f = "function", x = "RDD"), + function(f, x) { + filterRDD(x, f) + }) + +#' Reduce across elements of an RDD. +#' +#' This function reduces the elements of this RDD using the +#' specified commutative and associative binary operator. +#' +#' @param x The RDD to reduce +#' @param func Commutative and associative function to apply on elements +#' of the RDD. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' reduce(rdd, "+") # 55 +#'} +#' @rdname reduce +#' @aliases reduce,RDD,ANY-method +setMethod("reduce", + signature(x = "RDD", func = "ANY"), + function(x, func) { + + reducePartition <- function(part) { + Reduce(func, part) + } + + partitionList <- collect(lapplyPartition(x, reducePartition), + flatten = FALSE) + Reduce(func, partitionList) + }) + +#' Get the maximum element of an RDD. +#' +#' @param x The RDD to get the maximum element from +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' maximum(rdd) # 10 +#'} +#' @rdname maximum +#' @aliases maximum,RDD +setMethod("maximum", + signature(x = "RDD"), + function(x) { + reduce(x, max) + }) + +#' Get the minimum element of an RDD. +#' +#' @param x The RDD to get the minimum element from +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' minimum(rdd) # 1 +#'} +#' @rdname minimum +#' @aliases minimum,RDD +setMethod("minimum", + signature(x = "RDD"), + function(x) { + reduce(x, min) + }) + +#' Add up the elements in an RDD. +#' +#' @param x The RDD to add up the elements in +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' sumRDD(rdd) # 55 +#'} +#' @rdname sumRDD +#' @aliases sumRDD,RDD +setMethod("sumRDD", + signature(x = "RDD"), + function(x) { + reduce(x, "+") + }) + +#' Applies a function to all elements in an RDD, and force evaluation. +#' +#' @param x The RDD to apply the function +#' @param func The function to be applied. +#' @return invisible NULL. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' foreach(rdd, function(x) { save(x, file=...) }) +#'} +#' @rdname foreach +#' @aliases foreach,RDD,function-method +setMethod("foreach", + signature(x = "RDD", func = "function"), + function(x, func) { + partition.func <- function(x) { + lapply(x, func) + NULL + } + invisible(collect(mapPartitions(x, partition.func))) + }) + +#' Applies a function to each partition in an RDD, and force evaluation. +#' +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' foreachPartition(rdd, function(part) { save(part, file=...); NULL }) +#'} +#' @rdname foreach +#' @aliases foreachPartition,RDD,function-method +setMethod("foreachPartition", + signature(x = "RDD", func = "function"), + function(x, func) { + invisible(collect(mapPartitions(x, func))) + }) + +#' Take elements from an RDD. +#' +#' This function takes the first NUM elements in the RDD and +#' returns them in a list. +#' +#' @param x The RDD to take elements from +#' @param num Number of elements to take +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' take(rdd, 2L) # list(1, 2) +#'} +#' @rdname take +#' @aliases take,RDD,numeric-method +setMethod("take", + signature(x = "RDD", num = "numeric"), + function(x, num) { + resList <- list() + index <- -1 + jrdd <- getJRDD(x) + numPartitions <- numPartitions(x) + serializedModeRDD <- getSerializedMode(x) + + # TODO(shivaram): Collect more than one partition based on size + # estimates similar to the scala version of `take`. + while (TRUE) { + index <- index + 1 + + if (length(resList) >= num || index >= numPartitions) + break + + # a JList of byte arrays + partitionArr <- callJMethod(jrdd, "collectPartitions", as.list(as.integer(index))) + partition <- partitionArr[[1]] + + size <- num - length(resList) + # elems is capped to have at most `size` elements + elems <- convertJListToRList(partition, + flatten = TRUE, + logicalUpperBound = size, + serializedMode = serializedModeRDD) + + resList <- append(resList, elems) + } + resList + }) + + +#' First +#' +#' Return the first element of an RDD +#' +#' @rdname first +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' first(rdd) +#' } +setMethod("first", + signature(x = "RDD"), + function(x) { + take(x, 1)[[1]] + }) + +#' Removes the duplicates from RDD. +#' +#' This function returns a new RDD containing the distinct elements in the +#' given RDD. The same as `distinct()' in Spark. +#' +#' @param x The RDD to remove duplicates from. +#' @param numPartitions Number of partitions to create. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, c(1,2,2,3,3,3)) +#' sort(unlist(collect(distinct(rdd)))) # c(1, 2, 3) +#'} +#' @rdname distinct +#' @aliases distinct,RDD-method +setMethod("distinct", + signature(x = "RDD"), + function(x, numPartitions = SparkR::numPartitions(x)) { + identical.mapped <- lapply(x, function(x) { list(x, NULL) }) + reduced <- reduceByKey(identical.mapped, + function(x, y) { x }, + numPartitions) + resRDD <- lapply(reduced, function(x) { x[[1]] }) + resRDD + }) + +#' Return an RDD that is a sampled subset of the given RDD. +#' +#' The same as `sample()' in Spark. (We rename it due to signature +#' inconsistencies with the `sample()' function in R's base package.) +#' +#' @param x The RDD to sample elements from +#' @param withReplacement Sampling with replacement or not +#' @param fraction The (rough) sample target fraction +#' @param seed Randomness seed value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' collect(sampleRDD(rdd, FALSE, 0.5, 1618L)) # ~5 distinct elements +#' collect(sampleRDD(rdd, TRUE, 0.5, 9L)) # ~5 elements possibly with duplicates +#'} +#' @rdname sampleRDD +#' @aliases sampleRDD,RDD +setMethod("sampleRDD", + signature(x = "RDD", withReplacement = "logical", + fraction = "numeric", seed = "integer"), + function(x, withReplacement, fraction, seed) { + + # The sampler: takes a partition and returns its sampled version. + samplingFunc <- function(partIndex, part) { + set.seed(seed) + res <- vector("list", length(part)) + len <- 0 + + # Discards some random values to ensure each partition has a + # different random seed. + runif(partIndex) + + for (elem in part) { + if (withReplacement) { + count <- rpois(1, fraction) + if (count > 0) { + res[(len + 1):(len + count)] <- rep(list(elem), count) + len <- len + count + } + } else { + if (runif(1) < fraction) { + len <- len + 1 + res[[len]] <- elem + } + } + } + + # TODO(zongheng): look into the performance of the current + # implementation. Look into some iterator package? Note that + # Scala avoids many calls to creating an empty list and PySpark + # similarly achieves this using `yield'. + if (len > 0) + res[1:len] + else + list() + } + + lapplyPartitionsWithIndex(x, samplingFunc) + }) + +#' Return a list of the elements that are a sampled subset of the given RDD. +#' +#' @param x The RDD to sample elements from +#' @param withReplacement Sampling with replacement or not +#' @param num Number of elements to return +#' @param seed Randomness seed value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:100) +#' # exactly 5 elements sampled, which may not be distinct +#' takeSample(rdd, TRUE, 5L, 1618L) +#' # exactly 5 distinct elements sampled +#' takeSample(rdd, FALSE, 5L, 16181618L) +#'} +#' @rdname takeSample +#' @aliases takeSample,RDD +setMethod("takeSample", signature(x = "RDD", withReplacement = "logical", + num = "integer", seed = "integer"), + function(x, withReplacement, num, seed) { + # This function is ported from RDD.scala. + fraction <- 0.0 + total <- 0 + multiplier <- 3.0 + initialCount <- count(x) + maxSelected <- 0 + MAXINT <- .Machine$integer.max + + if (num < 0) + stop(paste("Negative number of elements requested")) + + if (initialCount > MAXINT - 1) { + maxSelected <- MAXINT - 1 + } else { + maxSelected <- initialCount + } + + if (num > initialCount && !withReplacement) { + total <- maxSelected + fraction <- multiplier * (maxSelected + 1) / initialCount + } else { + total <- num + fraction <- multiplier * (num + 1) / initialCount + } + + set.seed(seed) + samples <- collect(sampleRDD(x, withReplacement, fraction, + as.integer(ceiling(runif(1, + -MAXINT, + MAXINT))))) + # If the first sample didn't turn out large enough, keep trying to + # take samples; this shouldn't happen often because we use a big + # multiplier for thei initial size + while (length(samples) < total) + samples <- collect(sampleRDD(x, withReplacement, fraction, + as.integer(ceiling(runif(1, + -MAXINT, + MAXINT))))) + + # TODO(zongheng): investigate if this call is an in-place shuffle? + sample(samples)[1:total] + }) + +#' Creates tuples of the elements in this RDD by applying a function. +#' +#' @param x The RDD. +#' @param func The function to be applied. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3)) +#' collect(keyBy(rdd, function(x) { x*x })) # list(list(1, 1), list(4, 2), list(9, 3)) +#'} +#' @rdname keyBy +#' @aliases keyBy,RDD +setMethod("keyBy", + signature(x = "RDD", func = "function"), + function(x, func) { + apply.func <- function(x) { + list(func(x), x) + } + lapply(x, apply.func) + }) + +#' Return a new RDD that has exactly numPartitions partitions. +#' Can increase or decrease the level of parallelism in this RDD. Internally, +#' this uses a shuffle to redistribute data. +#' If you are decreasing the number of partitions in this RDD, consider using +#' coalesce, which can avoid performing a shuffle. +#' +#' @param x The RDD. +#' @param numPartitions Number of partitions to create. +#' @seealso coalesce +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5, 6, 7), 4L) +#' numPartitions(rdd) # 4 +#' numPartitions(repartition(rdd, 2L)) # 2 +#'} +#' @rdname repartition +#' @aliases repartition,RDD +setMethod("repartition", + signature(x = "RDD", numPartitions = "numeric"), + function(x, numPartitions) { + coalesce(x, numPartitions, TRUE) + }) + +#' Return a new RDD that is reduced into numPartitions partitions. +#' +#' @param x The RDD. +#' @param numPartitions Number of partitions to create. +#' @seealso repartition +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5), 3L) +#' numPartitions(rdd) # 3 +#' numPartitions(coalesce(rdd, 1L)) # 1 +#'} +#' @rdname coalesce +#' @aliases coalesce,RDD +setMethod("coalesce", + signature(x = "RDD", numPartitions = "numeric"), + function(x, numPartitions, shuffle = FALSE) { + numPartitions <- numToInt(numPartitions) + if (shuffle || numPartitions > SparkR::numPartitions(x)) { + func <- function(partIndex, part) { + set.seed(partIndex) # partIndex as seed + start <- as.integer(sample(numPartitions, 1) - 1) + lapply(seq_along(part), + function(i) { + pos <- (start + i) %% numPartitions + list(pos, part[[i]]) + }) + } + shuffled <- lapplyPartitionsWithIndex(x, func) + repartitioned <- partitionBy(shuffled, numPartitions) + values(repartitioned) + } else { + jrdd <- callJMethod(getJRDD(x), "coalesce", numPartitions, shuffle) + RDD(jrdd) + } + }) + +#' Save this RDD as a SequenceFile of serialized objects. +#' +#' @param x The RDD to save +#' @param path The directory where the file is saved +#' @seealso objectFile +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' saveAsObjectFile(rdd, "/tmp/sparkR-tmp") +#'} +#' @rdname saveAsObjectFile +#' @aliases saveAsObjectFile,RDD +setMethod("saveAsObjectFile", + signature(x = "RDD", path = "character"), + function(x, path) { + # If serializedMode == "string" we need to serialize the data before saving it since + # objectFile() assumes serializedMode == "byte". + if (getSerializedMode(x) != "byte") { + x <- serializeToBytes(x) + } + # Return nothing + invisible(callJMethod(getJRDD(x), "saveAsObjectFile", path)) + }) + +#' Save this RDD as a text file, using string representations of elements. +#' +#' @param x The RDD to save +#' @param path The directory where the partitions of the text file are saved +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' saveAsTextFile(rdd, "/tmp/sparkR-tmp") +#'} +#' @rdname saveAsTextFile +#' @aliases saveAsTextFile,RDD +setMethod("saveAsTextFile", + signature(x = "RDD", path = "character"), + function(x, path) { + func <- function(str) { + toString(str) + } + stringRdd <- lapply(x, func) + # Return nothing + invisible( + callJMethod(getJRDD(stringRdd, serializedMode = "string"), "saveAsTextFile", path)) + }) + +#' Sort an RDD by the given key function. +#' +#' @param x An RDD to be sorted. +#' @param func A function used to compute the sort key for each element. +#' @param ascending A flag to indicate whether the sorting is ascending or descending. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where all elements are sorted. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(3, 2, 1)) +#' collect(sortBy(rdd, function(x) { x })) # list (1, 2, 3) +#'} +#' @rdname sortBy +#' @aliases sortBy,RDD,RDD-method +setMethod("sortBy", + signature(x = "RDD", func = "function"), + function(x, func, ascending = TRUE, numPartitions = SparkR::numPartitions(x)) { + values(sortByKey(keyBy(x, func), ascending, numPartitions)) + }) + +# Helper function to get first N elements from an RDD in the specified order. +# Param: +# x An RDD. +# num Number of elements to return. +# ascending A flag to indicate whether the sorting is ascending or descending. +# Return: +# A list of the first N elements from the RDD in the specified order. +# +takeOrderedElem <- function(x, num, ascending = TRUE) { + if (num <= 0L) { + return(list()) + } + + partitionFunc <- function(part) { + if (num < length(part)) { + # R limitation: order works only on primitive types! + ord <- order(unlist(part, recursive = FALSE), decreasing = !ascending) + part[ord[1:num]] + } else { + part + } + } + + newRdd <- mapPartitions(x, partitionFunc) + + resList <- list() + index <- -1 + jrdd <- getJRDD(newRdd) + numPartitions <- numPartitions(newRdd) + serializedModeRDD <- getSerializedMode(newRdd) + + while (TRUE) { + index <- index + 1 + + if (index >= numPartitions) { + ord <- order(unlist(resList, recursive = FALSE), decreasing = !ascending) + resList <- resList[ord[1:num]] + break + } + + # a JList of byte arrays + partitionArr <- callJMethod(jrdd, "collectPartitions", as.list(as.integer(index))) + partition <- partitionArr[[1]] + + # elems is capped to have at most `num` elements + elems <- convertJListToRList(partition, + flatten = TRUE, + logicalUpperBound = num, + serializedMode = serializedModeRDD) + + resList <- append(resList, elems) + } + resList +} + +#' Returns the first N elements from an RDD in ascending order. +#' +#' @param x An RDD. +#' @param num Number of elements to return. +#' @return The first N elements from the RDD in ascending order. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) +#' takeOrdered(rdd, 6L) # list(1, 2, 3, 4, 5, 6) +#'} +#' @rdname takeOrdered +#' @aliases takeOrdered,RDD,RDD-method +setMethod("takeOrdered", + signature(x = "RDD", num = "integer"), + function(x, num) { + takeOrderedElem(x, num) + }) + +#' Returns the top N elements from an RDD. +#' +#' @param x An RDD. +#' @param num Number of elements to return. +#' @return The top N elements from the RDD. +#' @rdname top +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) +#' top(rdd, 6L) # list(10, 9, 7, 6, 5, 4) +#'} +#' @rdname top +#' @aliases top,RDD,RDD-method +setMethod("top", + signature(x = "RDD", num = "integer"), + function(x, num) { + takeOrderedElem(x, num, FALSE) + }) + +#' Fold an RDD using a given associative function and a neutral "zero value". +#' +#' Aggregate the elements of each partition, and then the results for all the +#' partitions, using a given associative function and a neutral "zero value". +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param op An associative function for the folding operation. +#' @return The folding result. +#' @rdname fold +#' @seealso reduce +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5)) +#' fold(rdd, 0, "+") # 15 +#'} +#' @rdname fold +#' @aliases fold,RDD,RDD-method +setMethod("fold", + signature(x = "RDD", zeroValue = "ANY", op = "ANY"), + function(x, zeroValue, op) { + aggregateRDD(x, zeroValue, op, op) + }) + +#' Aggregate an RDD using the given combine functions and a neutral "zero value". +#' +#' Aggregate the elements of each partition, and then the results for all the +#' partitions, using given combine functions and a neutral "zero value". +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param seqOp A function to aggregate the RDD elements. It may return a different +#' result type from the type of the RDD elements. +#' @param combOp A function to aggregate results of seqOp. +#' @return The aggregation result. +#' @rdname aggregateRDD +#' @seealso reduce +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4)) +#' zeroValue <- list(0, 0) +#' seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } +#' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } +#' aggregateRDD(rdd, zeroValue, seqOp, combOp) # list(10, 4) +#'} +#' @rdname aggregateRDD +#' @aliases aggregateRDD,RDD,RDD-method +setMethod("aggregateRDD", + signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", combOp = "ANY"), + function(x, zeroValue, seqOp, combOp) { + partitionFunc <- function(part) { + Reduce(seqOp, part, zeroValue) + } + + partitionList <- collect(lapplyPartition(x, partitionFunc), + flatten = FALSE) + Reduce(combOp, partitionList, zeroValue) + }) + +#' Pipes elements to a forked external process. +#' +#' The same as 'pipe()' in Spark. +#' +#' @param x The RDD whose elements are piped to the forked external process. +#' @param command The command to fork an external process. +#' @param env A named list to set environment variables of the external process. +#' @return A new RDD created by piping all elements to a forked external process. +#' @rdname pipeRDD +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' collect(pipeRDD(rdd, "more") +#' Output: c("1", "2", ..., "10") +#'} +#' @rdname pipeRDD +#' @aliases pipeRDD,RDD,character-method +setMethod("pipeRDD", + signature(x = "RDD", command = "character"), + function(x, command, env = list()) { + func <- function(part) { + trim.trailing.func <- function(x) { + sub("[\r\n]*$", "", toString(x)) + } + input <- unlist(lapply(part, trim.trailing.func)) + res <- system2(command, stdout = TRUE, input = input, env = env) + lapply(res, trim.trailing.func) + } + lapplyPartition(x, func) + }) + +# TODO: Consider caching the name in the RDD's environment +#' Return an RDD's name. +#' +#' @param x The RDD whose name is returned. +#' @rdname name +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1,2,3)) +#' name(rdd) # NULL (if not set before) +#'} +#' @rdname name +#' @aliases name,RDD +setMethod("name", + signature(x = "RDD"), + function(x) { + callJMethod(getJRDD(x), "name") + }) + +#' Set an RDD's name. +#' +#' @param x The RDD whose name is to be set. +#' @param name The RDD name to be set. +#' @return a new RDD renamed. +#' @rdname setName +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1,2,3)) +#' setName(rdd, "myRDD") +#' name(rdd) # "myRDD" +#'} +#' @rdname setName +#' @aliases setName,RDD +setMethod("setName", + signature(x = "RDD", name = "character"), + function(x, name) { + callJMethod(getJRDD(x), "setName", name) + x + }) + +#' Zip an RDD with generated unique Long IDs. +#' +#' Items in the kth partition will get ids k, n+k, 2*n+k, ..., where +#' n is the number of partitions. So there may exist gaps, but this +#' method won't trigger a spark job, which is different from +#' zipWithIndex. +#' +#' @param x An RDD to be zipped. +#' @return An RDD with zipped items. +#' @seealso zipWithIndex +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) +#' collect(zipWithUniqueId(rdd)) +#' # list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) +#'} +#' @rdname zipWithUniqueId +#' @aliases zipWithUniqueId,RDD +setMethod("zipWithUniqueId", + signature(x = "RDD"), + function(x) { + n <- numPartitions(x) + + partitionFunc <- function(partIndex, part) { + mapply( + function(item, index) { + list(item, (index - 1) * n + partIndex) + }, + part, + seq_along(part), + SIMPLIFY = FALSE) + } + + lapplyPartitionsWithIndex(x, partitionFunc) + }) + +#' Zip an RDD with its element indices. +#' +#' The ordering is first based on the partition index and then the +#' ordering of items within each partition. So the first item in +#' the first partition gets index 0, and the last item in the last +#' partition receives the largest index. +#' +#' This method needs to trigger a Spark job when this RDD contains +#' more than one partition. +#' +#' @param x An RDD to be zipped. +#' @return An RDD with zipped items. +#' @seealso zipWithUniqueId +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) +#' collect(zipWithIndex(rdd)) +#' # list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) +#'} +#' @rdname zipWithIndex +#' @aliases zipWithIndex,RDD +setMethod("zipWithIndex", + signature(x = "RDD"), + function(x) { + n <- numPartitions(x) + if (n > 1) { + nums <- collect(lapplyPartition(x, + function(part) { + list(length(part)) + })) + startIndices <- Reduce("+", nums, accumulate = TRUE) + } + + partitionFunc <- function(partIndex, part) { + if (partIndex == 0) { + startIndex <- 0 + } else { + startIndex <- startIndices[[partIndex]] + } + + mapply( + function(item, index) { + list(item, index - 1 + startIndex) + }, + part, + seq_along(part), + SIMPLIFY = FALSE) + } + + lapplyPartitionsWithIndex(x, partitionFunc) + }) + +#' Coalesce all elements within each partition of an RDD into a list. +#' +#' @param x An RDD. +#' @return An RDD created by coalescing all elements within +#' each partition into a list. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, as.list(1:4), 2L) +#' collect(glom(rdd)) +#' # list(list(1, 2), list(3, 4)) +#'} +#' @rdname glom +#' @aliases glom,RDD +setMethod("glom", + signature(x = "RDD"), + function(x) { + partitionFunc <- function(part) { + list(part) + } + + lapplyPartition(x, partitionFunc) + }) + +############ Binary Functions ############# + +#' Return the union RDD of two RDDs. +#' The same as union() in Spark. +#' +#' @param x An RDD. +#' @param y An RDD. +#' @return a new RDD created by performing the simple union (witout removing +#' duplicates) of two input RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' unionRDD(rdd, rdd) # 1, 2, 3, 1, 2, 3 +#'} +#' @rdname unionRDD +#' @aliases unionRDD,RDD,RDD-method +setMethod("unionRDD", + signature(x = "RDD", y = "RDD"), + function(x, y) { + if (getSerializedMode(x) == getSerializedMode(y)) { + jrdd <- callJMethod(getJRDD(x), "union", getJRDD(y)) + union.rdd <- RDD(jrdd, getSerializedMode(x)) + } else { + # One of the RDDs is not serialized, we need to serialize it first. + if (getSerializedMode(x) != "byte") x <- serializeToBytes(x) + if (getSerializedMode(y) != "byte") y <- serializeToBytes(y) + jrdd <- callJMethod(getJRDD(x), "union", getJRDD(y)) + union.rdd <- RDD(jrdd, "byte") + } + union.rdd + }) + +#' Zip an RDD with another RDD. +#' +#' Zips this RDD with another one, returning key-value pairs with the +#' first element in each RDD second element in each RDD, etc. Assumes +#' that the two RDDs have the same number of partitions and the same +#' number of elements in each partition (e.g. one was made through +#' a map on the other). +#' +#' @param x An RDD to be zipped. +#' @param other Another RDD to be zipped. +#' @return An RDD zipped from the two RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, 0:4) +#' rdd2 <- parallelize(sc, 1000:1004) +#' collect(zipRDD(rdd1, rdd2)) +#' # list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004)) +#'} +#' @rdname zipRDD +#' @aliases zipRDD,RDD +setMethod("zipRDD", + signature(x = "RDD", other = "RDD"), + function(x, other) { + n1 <- numPartitions(x) + n2 <- numPartitions(other) + if (n1 != n2) { + stop("Can only zip RDDs which have the same number of partitions.") + } + + rdds <- appendPartitionLengths(x, other) + jrdd <- callJMethod(getJRDD(rdds[[1]]), "zip", getJRDD(rdds[[2]])) + # The jrdd's elements are of scala Tuple2 type. The serialized + # flag here is used for the elements inside the tuples. + rdd <- RDD(jrdd, getSerializedMode(rdds[[1]])) + + mergePartitions(rdd, TRUE) + }) + +#' Cartesian product of this RDD and another one. +#' +#' Return the Cartesian product of this RDD and another one, +#' that is, the RDD of all pairs of elements (a, b) where a +#' is in this and b is in other. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @return A new RDD which is the Cartesian product of these two RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:2) +#' sortByKey(cartesian(rdd, rdd)) +#' # list(list(1, 1), list(1, 2), list(2, 1), list(2, 2)) +#'} +#' @rdname cartesian +#' @aliases cartesian,RDD,RDD-method +setMethod("cartesian", + signature(x = "RDD", other = "RDD"), + function(x, other) { + rdds <- appendPartitionLengths(x, other) + jrdd <- callJMethod(getJRDD(rdds[[1]]), "cartesian", getJRDD(rdds[[2]])) + # The jrdd's elements are of scala Tuple2 type. The serialized + # flag here is used for the elements inside the tuples. + rdd <- RDD(jrdd, getSerializedMode(rdds[[1]])) + + mergePartitions(rdd, FALSE) + }) + +#' Subtract an RDD with another RDD. +#' +#' Return an RDD with the elements from this that are not in other. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @param numPartitions Number of the partitions in the result RDD. +#' @return An RDD with the elements from this that are not in other. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(1, 1, 2, 2, 3, 4)) +#' rdd2 <- parallelize(sc, list(2, 4)) +#' collect(subtract(rdd1, rdd2)) +#' # list(1, 1, 3) +#'} +#' @rdname subtract +#' @aliases subtract,RDD +setMethod("subtract", + signature(x = "RDD", other = "RDD"), + function(x, other, numPartitions = SparkR::numPartitions(x)) { + mapFunction <- function(e) { list(e, NA) } + rdd1 <- map(x, mapFunction) + rdd2 <- map(other, mapFunction) + keys(subtractByKey(rdd1, rdd2, numPartitions)) + }) + +#' Intersection of this RDD and another one. +#' +#' Return the intersection of this RDD and another one. +#' The output will not contain any duplicate elements, +#' even if the input RDDs did. Performs a hash partition +#' across the cluster. +#' Note that this method performs a shuffle internally. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @param numPartitions The number of partitions in the result RDD. +#' @return An RDD which is the intersection of these two RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5)) +#' rdd2 <- parallelize(sc, list(1, 6, 2, 3, 7, 8)) +#' collect(sortBy(intersection(rdd1, rdd2), function(x) { x })) +#' # list(1, 2, 3) +#'} +#' @rdname intersection +#' @aliases intersection,RDD +setMethod("intersection", + signature(x = "RDD", other = "RDD"), + function(x, other, numPartitions = SparkR::numPartitions(x)) { + rdd1 <- map(x, function(v) { list(v, NA) }) + rdd2 <- map(other, function(v) { list(v, NA) }) + + filterFunction <- function(elem) { + iters <- elem[[2]] + all(as.vector( + lapply(iters, function(iter) { length(iter) > 0 }), mode = "logical")) + } + + keys(filterRDD(cogroup(rdd1, rdd2, numPartitions = numPartitions), filterFunction)) + }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R new file mode 100644 index 000000000000..4f05ba524a01 --- /dev/null +++ b/R/pkg/R/SQLContext.R @@ -0,0 +1,494 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# SQLcontext.R: SQLContext-driven functions + +#' infer the SQL type +infer_type <- function(x) { + if (is.null(x)) { + stop("can not infer type from NULL") + } + + # class of POSIXlt is c("POSIXlt" "POSIXt") + type <- switch(class(x)[[1]], + integer = "integer", + character = "string", + logical = "boolean", + double = "double", + numeric = "double", + raw = "binary", + list = "array", + environment = "map", + Date = "date", + POSIXlt = "timestamp", + POSIXct = "timestamp", + stop(paste("Unsupported type for DataFrame:", class(x)))) + + if (type == "map") { + stopifnot(length(x) > 0) + key <- ls(x)[[1]] + list(type = "map", + keyType = "string", + valueType = infer_type(get(key, x)), + valueContainsNull = TRUE) + } else if (type == "array") { + stopifnot(length(x) > 0) + names <- names(x) + if (is.null(names)) { + list(type = "array", elementType = infer_type(x[[1]]), containsNull = TRUE) + } else { + # StructType + types <- lapply(x, infer_type) + fields <- lapply(1:length(x), function(i) { + structField(names[[i]], types[[i]], TRUE) + }) + do.call(structType, fields) + } + } else if (length(x) > 1) { + list(type = "array", elementType = type, containsNull = TRUE) + } else { + type + } +} + +#' Create a DataFrame from an RDD +#' +#' Converts an RDD to a DataFrame by infer the types. +#' +#' @param sqlCtx A SQLContext +#' @param data An RDD or list or data.frame +#' @param schema a list of column names or named list (StructType), optional +#' @return an DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) +#' df <- createDataFrame(sqlCtx, rdd) +#' } + +# TODO(davies): support sampling and infer type from NA +createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { + if (is.data.frame(data)) { + # get the names of columns, they will be put into RDD + schema <- names(data) + n <- nrow(data) + m <- ncol(data) + # get rid of factor type + dropFactor <- function(x) { + if (is.factor(x)) { + as.character(x) + } else { + x + } + } + data <- lapply(1:n, function(i) { + lapply(1:m, function(j) { dropFactor(data[i,j]) }) + }) + } + if (is.list(data)) { + sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sqlCtx) + rdd <- parallelize(sc, data) + } else if (inherits(data, "RDD")) { + rdd <- data + } else { + stop(paste("unexpected type:", class(data))) + } + + if (is.null(schema) || (!inherits(schema, "structType") && is.null(names(schema)))) { + row <- first(rdd) + names <- if (is.null(schema)) { + names(row) + } else { + as.list(schema) + } + if (is.null(names)) { + names <- lapply(1:length(row), function(x) { + paste("_", as.character(x), sep = "") + }) + } + + # SPAKR-SQL does not support '.' in column name, so replace it with '_' + # TODO(davies): remove this once SPARK-2775 is fixed + names <- lapply(names, function(n) { + nn <- gsub("[.]", "_", n) + if (nn != n) { + warning(paste("Use", nn, "instead of", n, " as column name")) + } + nn + }) + + types <- lapply(row, infer_type) + fields <- lapply(1:length(row), function(i) { + structField(names[[i]], types[[i]], TRUE) + }) + schema <- do.call(structType, fields) + } + + stopifnot(class(schema) == "structType") + # schemaString <- tojson(schema) + + jrdd <- getJRDD(lapply(rdd, function(x) x), "row") + srdd <- callJMethod(jrdd, "rdd") + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF", + srdd, schema$jobj, sqlCtx) + dataFrame(sdf) +} + +#' toDF +#' +#' Converts an RDD to a DataFrame by infer the types. +#' +#' @param x An RDD +#' +#' @rdname DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) +#' df <- toDF(rdd) +#' } + +setGeneric("toDF", function(x, ...) { standardGeneric("toDF") }) + +setMethod("toDF", signature(x = "RDD"), + function(x, ...) { + sqlCtx <- if (exists(".sparkRHivesc", envir = .sparkREnv)) { + get(".sparkRHivesc", envir = .sparkREnv) + } else if (exists(".sparkRSQLsc", envir = .sparkREnv)) { + get(".sparkRSQLsc", envir = .sparkREnv) + } else { + stop("no SQL context available") + } + createDataFrame(sqlCtx, x, ...) + }) + +#' Create a DataFrame from a JSON file. +#' +#' Loads a JSON file (one object per line), returning the result as a DataFrame +#' It goes through the entire dataset once to determine the schema. +#' +#' @param sqlCtx SQLContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' } + +jsonFile <- function(sqlCtx, path) { + # Allow the user to have a more flexible definiton of the text file path + path <- normalizePath(path) + # Convert a string vector of paths to a string containing comma separated paths + path <- paste(path, collapse = ",") + sdf <- callJMethod(sqlCtx, "jsonFile", path) + dataFrame(sdf) +} + + +#' JSON RDD +#' +#' Loads an RDD storing one JSON object per string as a DataFrame. +#' +#' @param sqlCtx SQLContext to use +#' @param rdd An RDD of JSON string +#' @param schema A StructType object to use as schema +#' @param samplingRatio The ratio of simpling used to infer the schema +#' @return A DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- texFile(sc, "path/to/json") +#' df <- jsonRDD(sqlCtx, rdd) +#' } + +# TODO: support schema +jsonRDD <- function(sqlCtx, rdd, schema = NULL, samplingRatio = 1.0) { + rdd <- serializeToString(rdd) + if (is.null(schema)) { + sdf <- callJMethod(sqlCtx, "jsonRDD", callJMethod(getJRDD(rdd), "rdd"), samplingRatio) + dataFrame(sdf) + } else { + stop("not implemented") + } +} + + +#' Create a DataFrame from a Parquet file. +#' +#' Loads a Parquet file, returning the result as a DataFrame. +#' +#' @param sqlCtx SQLContext to use +#' @param ... Path(s) of parquet file(s) to read. +#' @return DataFrame +#' @export + +# TODO: Implement saveasParquetFile and write examples for both +parquetFile <- function(sqlCtx, ...) { + # Allow the user to have a more flexible definiton of the text file path + paths <- lapply(list(...), normalizePath) + sdf <- callJMethod(sqlCtx, "parquetFile", paths) + dataFrame(sdf) +} + +#' SQL Query +#' +#' Executes a SQL query using Spark, returning the result as a DataFrame. +#' +#' @param sqlCtx SQLContext to use +#' @param sqlQuery A character vector containing the SQL query +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "table") +#' new_df <- sql(sqlCtx, "SELECT * FROM table") +#' } + +sql <- function(sqlCtx, sqlQuery) { + sdf <- callJMethod(sqlCtx, "sql", sqlQuery) + dataFrame(sdf) +} + + +#' Create a DataFrame from a SparkSQL Table +#' +#' Returns the specified Table as a DataFrame. The Table must have already been registered +#' in the SQLContext. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName The SparkSQL Table to convert to a DataFrame. +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "table") +#' new_df <- table(sqlCtx, "table") +#' } + +table <- function(sqlCtx, tableName) { + sdf <- callJMethod(sqlCtx, "table", tableName) + dataFrame(sdf) +} + + +#' Tables +#' +#' Returns a DataFrame containing names of tables in the given database. +#' +#' @param sqlCtx SQLContext to use +#' @param databaseName name of the database +#' @return a DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' tables(sqlCtx, "hive") +#' } + +tables <- function(sqlCtx, databaseName = NULL) { + jdf <- if (is.null(databaseName)) { + callJMethod(sqlCtx, "tables") + } else { + callJMethod(sqlCtx, "tables", databaseName) + } + dataFrame(jdf) +} + + +#' Table Names +#' +#' Returns the names of tables in the given database as an array. +#' +#' @param sqlCtx SQLContext to use +#' @param databaseName name of the database +#' @return a list of table names +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' tableNames(sqlCtx, "hive") +#' } + +tableNames <- function(sqlCtx, databaseName = NULL) { + if (is.null(databaseName)) { + callJMethod(sqlCtx, "tableNames") + } else { + callJMethod(sqlCtx, "tableNames", databaseName) + } +} + + +#' Cache Table +#' +#' Caches the specified table in-memory. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName The name of the table being cached +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "table") +#' cacheTable(sqlCtx, "table") +#' } + +cacheTable <- function(sqlCtx, tableName) { + callJMethod(sqlCtx, "cacheTable", tableName) +} + +#' Uncache Table +#' +#' Removes the specified table from the in-memory cache. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName The name of the table being uncached +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "table") +#' uncacheTable(sqlCtx, "table") +#' } + +uncacheTable <- function(sqlCtx, tableName) { + callJMethod(sqlCtx, "uncacheTable", tableName) +} + +#' Clear Cache +#' +#' Removes all cached tables from the in-memory cache. +#' +#' @param sqlCtx SQLContext to use +#' @examples +#' \dontrun{ +#' clearCache(sqlCtx) +#' } + +clearCache <- function(sqlCtx) { + callJMethod(sqlCtx, "clearCache") +} + +#' Drop Temporary Table +#' +#' Drops the temporary table with the given table name in the catalog. +#' If the table has been cached/persisted before, it's also unpersisted. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName The name of the SparkSQL table to be dropped. +#' @examples +#' \dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df <- loadDF(sqlCtx, path, "parquet") +#' registerTempTable(df, "table") +#' dropTempTable(sqlCtx, "table") +#' } + +dropTempTable <- function(sqlCtx, tableName) { + if (class(tableName) != "character") { + stop("tableName must be a string.") + } + callJMethod(sqlCtx, "dropTempTable", tableName) +} + +#' Load an DataFrame +#' +#' Returns the dataset in a data source as a DataFrame +#' +#' The data source is specified by the `source` and a set of options(...). +#' If `source` is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. +#' +#' @param sqlCtx SQLContext to use +#' @param path The path of files to load +#' @param source the name of external data source +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df <- load(sqlCtx, "path/to/file.json", source = "json") +#' } + +loadDF <- function(sqlCtx, path = NULL, source = NULL, ...) { + options <- varargsToEnv(...) + if (!is.null(path)) { + options[['path']] <- path + } + sdf <- callJMethod(sqlCtx, "load", source, options) + dataFrame(sdf) +} + +#' Create an external table +#' +#' Creates an external table based on the dataset in a data source, +#' Returns the DataFrame associated with the external table. +#' +#' The data source is specified by the `source` and a set of options(...). +#' If `source` is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName A name of the table +#' @param path The path of files to load +#' @param source the name of external data source +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df <- sparkRSQL.createExternalTable(sqlCtx, "myjson", path="path/to/json", source="json") +#' } + +createExternalTable <- function(sqlCtx, tableName, path = NULL, source = NULL, ...) { + options <- varargsToEnv(...) + if (!is.null(path)) { + options[['path']] <- path + } + sdf <- callJMethod(sqlCtx, "createExternalTable", tableName, source, options) + dataFrame(sdf) +} diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R new file mode 100644 index 000000000000..2fb6fae55f28 --- /dev/null +++ b/R/pkg/R/backend.R @@ -0,0 +1,115 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Methods to call into SparkRBackend. + + +# Returns TRUE if object is an instance of given class +isInstanceOf <- function(jobj, className) { + stopifnot(class(jobj) == "jobj") + cls <- callJStatic("java.lang.Class", "forName", className) + callJMethod(cls, "isInstance", jobj) +} + +# Call a Java method named methodName on the object +# specified by objId. objId should be a "jobj" returned +# from the SparkRBackend. +callJMethod <- function(objId, methodName, ...) { + stopifnot(class(objId) == "jobj") + if (!isValidJobj(objId)) { + stop("Invalid jobj ", objId$id, + ". If SparkR was restarted, Spark operations need to be re-executed.") + } + invokeJava(isStatic = FALSE, objId$id, methodName, ...) +} + +# Call a static method on a specified className +callJStatic <- function(className, methodName, ...) { + invokeJava(isStatic = TRUE, className, methodName, ...) +} + +# Create a new object of the specified class name +newJObject <- function(className, ...) { + invokeJava(isStatic = TRUE, className, methodName = "", ...) +} + +# Remove an object from the SparkR backend. This is done +# automatically when a jobj is garbage collected. +removeJObject <- function(objId) { + invokeJava(isStatic = TRUE, "SparkRHandler", "rm", objId) +} + +isRemoveMethod <- function(isStatic, objId, methodName) { + isStatic == TRUE && objId == "SparkRHandler" && methodName == "rm" +} + +# Invoke a Java method on the SparkR backend. Users +# should typically use one of the higher level methods like +# callJMethod, callJStatic etc. instead of using this. +# +# isStatic - TRUE if the method to be called is static +# objId - String that refers to the object on which method is invoked +# Should be a jobj id for non-static methods and the classname +# for static methods +# methodName - name of method to be invoked +invokeJava <- function(isStatic, objId, methodName, ...) { + if (!exists(".sparkRCon", .sparkREnv)) { + stop("No connection to backend found. Please re-run sparkR.init") + } + + # If this isn't a removeJObject call + if (!isRemoveMethod(isStatic, objId, methodName)) { + objsToRemove <- ls(.toRemoveJobjs) + if (length(objsToRemove) > 0) { + sapply(objsToRemove, + function(e) { + removeJObject(e) + }) + rm(list = objsToRemove, envir = .toRemoveJobjs) + } + } + + + rc <- rawConnection(raw(0), "r+") + + writeBoolean(rc, isStatic) + writeString(rc, objId) + writeString(rc, methodName) + + args <- list(...) + writeInt(rc, length(args)) + writeArgs(rc, args) + + # Construct the whole request message to send it once, + # avoiding write-write-read pattern in case of Nagle's algorithm. + # Refer to http://en.wikipedia.org/wiki/Nagle%27s_algorithm for the details. + bytesToSend <- rawConnectionValue(rc) + close(rc) + rc <- rawConnection(raw(0), "r+") + writeInt(rc, length(bytesToSend)) + writeBin(bytesToSend, rc) + requestMessage <- rawConnectionValue(rc) + close(rc) + + conn <- get(".sparkRCon", .sparkREnv) + writeBin(requestMessage, conn) + + # TODO: check the status code to output error information + returnStatus <- readInt(conn) + stopifnot(returnStatus == 0) + readObject(conn) +} diff --git a/R/pkg/R/broadcast.R b/R/pkg/R/broadcast.R new file mode 100644 index 000000000000..583fa2e7fdcf --- /dev/null +++ b/R/pkg/R/broadcast.R @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# S4 class representing Broadcast variables + +# Hidden environment that holds values for broadcast variables +# This will not be serialized / shipped by default +.broadcastNames <- new.env() +.broadcastValues <- new.env() +.broadcastIdToName <- new.env() + +#' @title S4 class that represents a Broadcast variable +#' @description Broadcast variables can be created using the broadcast +#' function from a \code{SparkContext}. +#' @rdname broadcast-class +#' @seealso broadcast +#' +#' @param id Id of the backing Spark broadcast variable +#' @export +setClass("Broadcast", slots = list(id = "character")) + +#' @rdname broadcast-class +#' @param value Value of the broadcast variable +#' @param jBroadcastRef reference to the backing Java broadcast object +#' @param objName name of broadcasted object +#' @export +Broadcast <- function(id, value, jBroadcastRef, objName) { + .broadcastValues[[id]] <- value + .broadcastNames[[as.character(objName)]] <- jBroadcastRef + .broadcastIdToName[[id]] <- as.character(objName) + new("Broadcast", id = id) +} + +#' @description +#' \code{value} can be used to get the value of a broadcast variable inside +#' a distributed function. +#' +#' @param bcast The broadcast variable to get +#' @rdname broadcast +#' @aliases value,Broadcast-method +setMethod("value", + signature(bcast = "Broadcast"), + function(bcast) { + if (exists(bcast@id, envir = .broadcastValues)) { + get(bcast@id, envir = .broadcastValues) + } else { + NULL + } + }) + +#' Internal function to set values of a broadcast variable. +#' +#' This function is used internally by Spark to set the value of a broadcast +#' variable on workers. Not intended for use outside the package. +#' +#' @rdname broadcast-internal +#' @seealso broadcast, value + +#' @param bcastId The id of broadcast variable to set +#' @param value The value to be set +#' @export +setBroadcastValue <- function(bcastId, value) { + bcastIdStr <- as.character(bcastId) + .broadcastValues[[bcastIdStr]] <- value +} + +#' Helper function to clear the list of broadcast variables we know about +#' Should be called when the SparkR JVM backend is shutdown +clearBroadcastVariables <- function() { + bcasts <- ls(.broadcastNames) + rm(list = bcasts, envir = .broadcastNames) +} diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R new file mode 100644 index 000000000000..1281c41213e3 --- /dev/null +++ b/R/pkg/R/client.R @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Client code to connect to SparkRBackend + +# Creates a SparkR client connection object +# if one doesn't already exist +connectBackend <- function(hostname, port, timeout = 6000) { + if (exists(".sparkRcon", envir = .sparkREnv)) { + if (isOpen(.sparkREnv[[".sparkRCon"]])) { + cat("SparkRBackend client connection already exists\n") + return(get(".sparkRcon", envir = .sparkREnv)) + } + } + + con <- socketConnection(host = hostname, port = port, server = FALSE, + blocking = TRUE, open = "wb", timeout = timeout) + + assign(".sparkRCon", con, envir = .sparkREnv) + con +} + +launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts) { + if (.Platform$OS.type == "unix") { + sparkSubmitBinName = "spark-submit" + } else { + sparkSubmitBinName = "spark-submit.cmd" + } + + if (sparkHome != "") { + sparkSubmitBin <- file.path(sparkHome, "bin", sparkSubmitBinName) + } else { + sparkSubmitBin <- sparkSubmitBinName + } + + if (jars != "") { + jars <- paste("--jars", jars) + } + + combinedArgs <- paste(jars, sparkSubmitOpts, args, sep = " ") + cat("Launching java with spark-submit command", sparkSubmitBin, combinedArgs, "\n") + invisible(system2(sparkSubmitBin, combinedArgs, wait = F)) +} diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R new file mode 100644 index 000000000000..95fb9ff0887b --- /dev/null +++ b/R/pkg/R/column.R @@ -0,0 +1,199 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Column Class + +#' @include generics.R jobj.R schema.R +NULL + +setOldClass("jobj") + +#' @title S4 class that represents a DataFrame column +#' @description The column class supports unary, binary operations on DataFrame columns + +#' @rdname column +#' +#' @param jc reference to JVM DataFrame column +#' @export +setClass("Column", + slots = list(jc = "jobj")) + +setMethod("initialize", "Column", function(.Object, jc) { + .Object@jc <- jc + .Object +}) + +column <- function(jc) { + new("Column", jc) +} + +col <- function(x) { + column(callJStatic("org.apache.spark.sql.functions", "col", x)) +} + +#' @rdname show +setMethod("show", "Column", + function(object) { + cat("Column", callJMethod(object@jc, "toString"), "\n") + }) + +operators <- list( + "+" = "plus", "-" = "minus", "*" = "multiply", "/" = "divide", "%%" = "mod", + "==" = "equalTo", ">" = "gt", "<" = "lt", "!=" = "notEqual", "<=" = "leq", ">=" = "geq", + # we can not override `&&` and `||`, so use `&` and `|` instead + "&" = "and", "|" = "or" #, "!" = "unary_$bang" +) +column_functions1 <- c("asc", "desc", "isNull", "isNotNull") +column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains") +functions <- c("min", "max", "sum", "avg", "mean", "count", "abs", "sqrt", + "first", "last", "lower", "upper", "sumDistinct") + +createOperator <- function(op) { + setMethod(op, + signature(e1 = "Column"), + function(e1, e2) { + jc <- if (missing(e2)) { + if (op == "-") { + callJMethod(e1@jc, "unary_$minus") + } else { + callJMethod(e1@jc, operators[[op]]) + } + } else { + if (class(e2) == "Column") { + e2 <- e2@jc + } + callJMethod(e1@jc, operators[[op]], e2) + } + column(jc) + }) +} + +createColumnFunction1 <- function(name) { + setMethod(name, + signature(x = "Column"), + function(x) { + column(callJMethod(x@jc, name)) + }) +} + +createColumnFunction2 <- function(name) { + setMethod(name, + signature(x = "Column"), + function(x, data) { + if (class(data) == "Column") { + data <- data@jc + } + jc <- callJMethod(x@jc, name, data) + column(jc) + }) +} + +createStaticFunction <- function(name) { + setMethod(name, + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc) + column(jc) + }) +} + +createMethods <- function() { + for (op in names(operators)) { + createOperator(op) + } + for (name in column_functions1) { + createColumnFunction1(name) + } + for (name in column_functions2) { + createColumnFunction2(name) + } + for (x in functions) { + createStaticFunction(x) + } +} + +createMethods() + +#' alias +#' +#' Set a new name for a column +setMethod("alias", + signature(object = "Column"), + function(object, data) { + if (is.character(data)) { + column(callJMethod(object@jc, "as", data)) + } else { + stop("data should be character") + } + }) + +#' An expression that returns a substring. +#' +#' @param start starting position +#' @param stop ending position +setMethod("substr", signature(x = "Column"), + function(x, start, stop) { + jc <- callJMethod(x@jc, "substr", as.integer(start - 1), as.integer(stop - start + 1)) + column(jc) + }) + +#' Casts the column to a different data type. +#' @examples +#' \dontrun{ +#' cast(df$age, "string") +#' cast(df$name, list(type="array", elementType="byte", containsNull = TRUE)) +#' } +setMethod("cast", + signature(x = "Column"), + function(x, dataType) { + if (is.character(dataType)) { + column(callJMethod(x@jc, "cast", dataType)) + } else if (is.list(dataType)) { + json <- tojson(dataType) + jdataType <- callJStatic("org.apache.spark.sql.types.DataType", "fromJson", json) + column(callJMethod(x@jc, "cast", jdataType)) + } else { + stop("dataType should be character or list") + } + }) + +#' Approx Count Distinct +#' +#' Returns the approximate number of distinct items in a group. +#' +setMethod("approxCountDistinct", + signature(x = "Column"), + function(x, rsd = 0.95) { + jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd) + column(jc) + }) + +#' Count Distinct +#' +#' returns the number of distinct items in a group. +#' +setMethod("countDistinct", + signature(x = "Column"), + function(x, ...) { + jcol <- lapply(list(...), function (x) { + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, + listToSeq(jcol)) + column(jc) + }) + diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R new file mode 100644 index 000000000000..b4845b694899 --- /dev/null +++ b/R/pkg/R/context.R @@ -0,0 +1,225 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# context.R: SparkContext driven functions + +getMinPartitions <- function(sc, minPartitions) { + if (is.null(minPartitions)) { + defaultParallelism <- callJMethod(sc, "defaultParallelism") + minPartitions <- min(defaultParallelism, 2) + } + as.integer(minPartitions) +} + +#' Create an RDD from a text file. +#' +#' This function reads a text file from HDFS, a local file system (available on all +#' nodes), or any Hadoop-supported file system URI, and creates an +#' RDD of strings from it. +#' +#' @param sc SparkContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @param minPartitions Minimum number of partitions to be created. If NULL, the default +#' value is chosen based on available parallelism. +#' @return RDD where each item is of type \code{character} +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' lines <- textFile(sc, "myfile.txt") +#'} +textFile <- function(sc, path, minPartitions = NULL) { + # Allow the user to have a more flexible definiton of the text file path + path <- suppressWarnings(normalizePath(path)) + #' Convert a string vector of paths to a string containing comma separated paths + path <- paste(path, collapse = ",") + + jrdd <- callJMethod(sc, "textFile", path, getMinPartitions(sc, minPartitions)) + # jrdd is of type JavaRDD[String] + RDD(jrdd, "string") +} + +#' Load an RDD saved as a SequenceFile containing serialized objects. +#' +#' The file to be loaded should be one that was previously generated by calling +#' saveAsObjectFile() of the RDD class. +#' +#' @param sc SparkContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @param minPartitions Minimum number of partitions to be created. If NULL, the default +#' value is chosen based on available parallelism. +#' @return RDD containing serialized R objects. +#' @seealso saveAsObjectFile +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- objectFile(sc, "myfile") +#'} +objectFile <- function(sc, path, minPartitions = NULL) { + # Allow the user to have a more flexible definiton of the text file path + path <- suppressWarnings(normalizePath(path)) + #' Convert a string vector of paths to a string containing comma separated paths + path <- paste(path, collapse = ",") + + jrdd <- callJMethod(sc, "objectFile", path, getMinPartitions(sc, minPartitions)) + # Assume the RDD contains serialized R objects. + RDD(jrdd, "byte") +} + +#' Create an RDD from a homogeneous list or vector. +#' +#' This function creates an RDD from a local homogeneous list in R. The elements +#' in the list are split into \code{numSlices} slices and distributed to nodes +#' in the cluster. +#' +#' @param sc SparkContext to use +#' @param coll collection to parallelize +#' @param numSlices number of partitions to create in the RDD +#' @return an RDD created from this collection +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2) +#' # The RDD should contain 10 elements +#' length(rdd) +#'} +parallelize <- function(sc, coll, numSlices = 1) { + # TODO: bound/safeguard numSlices + # TODO: unit tests for if the split works for all primitives + # TODO: support matrix, data frame, etc + if ((!is.list(coll) && !is.vector(coll)) || is.data.frame(coll)) { + if (is.data.frame(coll)) { + message(paste("context.R: A data frame is parallelized by columns.")) + } else { + if (is.matrix(coll)) { + message(paste("context.R: A matrix is parallelized by elements.")) + } else { + message(paste("context.R: parallelize() currently only supports lists and vectors.", + "Calling as.list() to coerce coll into a list.")) + } + } + coll <- as.list(coll) + } + + if (numSlices > length(coll)) + numSlices <- length(coll) + + sliceLen <- ceiling(length(coll) / numSlices) + slices <- split(coll, rep(1:(numSlices + 1), each = sliceLen)[1:length(coll)]) + + # Serialize each slice: obtain a list of raws, or a list of lists (slices) of + # 2-tuples of raws + serializedSlices <- lapply(slices, serialize, connection = NULL) + + jrdd <- callJStatic("org.apache.spark.api.r.RRDD", + "createRDDFromArray", sc, serializedSlices) + + RDD(jrdd, "byte") +} + +#' Include this specified package on all workers +#' +#' This function can be used to include a package on all workers before the +#' user's code is executed. This is useful in scenarios where other R package +#' functions are used in a function passed to functions like \code{lapply}. +#' NOTE: The package is assumed to be installed on every node in the Spark +#' cluster. +#' +#' @param sc SparkContext to use +#' @param pkg Package name +#' +#' @export +#' @examples +#'\dontrun{ +#' library(Matrix) +#' +#' sc <- sparkR.init() +#' # Include the matrix library we will be using +#' includePackage(sc, Matrix) +#' +#' generateSparse <- function(x) { +#' sparseMatrix(i=c(1, 2, 3), j=c(1, 2, 3), x=c(1, 2, 3)) +#' } +#' +#' rdd <- lapplyPartition(parallelize(sc, 1:2, 2L), generateSparse) +#' collect(rdd) +#'} +includePackage <- function(sc, pkg) { + pkg <- as.character(substitute(pkg)) + if (exists(".packages", .sparkREnv)) { + packages <- .sparkREnv$.packages + } else { + packages <- list() + } + packages <- c(packages, pkg) + .sparkREnv$.packages <- packages +} + +#' @title Broadcast a variable to all workers +#' +#' @description +#' Broadcast a read-only variable to the cluster, returning a \code{Broadcast} +#' object for reading it in distributed functions. +#' +#' @param sc Spark Context to use +#' @param object Object to be broadcast +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:2, 2L) +#' +#' # Large Matrix object that we want to broadcast +#' randomMat <- matrix(nrow=100, ncol=10, data=rnorm(1000)) +#' randomMatBr <- broadcast(sc, randomMat) +#' +#' # Use the broadcast variable inside the function +#' useBroadcast <- function(x) { +#' sum(value(randomMatBr) * x) +#' } +#' sumRDD <- lapply(rdd, useBroadcast) +#'} +broadcast <- function(sc, object) { + objName <- as.character(substitute(object)) + serializedObj <- serialize(object, connection = NULL) + + jBroadcast <- callJMethod(sc, "broadcast", serializedObj) + id <- as.character(callJMethod(jBroadcast, "id")) + + Broadcast(id, object, jBroadcast, objName) +} + +#' @title Set the checkpoint directory +#' +#' Set the directory under which RDDs are going to be checkpointed. The +#' directory must be a HDFS path if running on a cluster. +#' +#' @param sc Spark Context to use +#' @param dirName Directory path +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' setCheckpointDir(sc, "~/checkpoint") +#' rdd <- parallelize(sc, 1:2, 2L) +#' checkpoint(rdd) +#'} +setCheckpointDir <- function(sc, dirName) { + invisible(callJMethod(sc, "setCheckpointDir", suppressWarnings(normalizePath(dirName)))) +} diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R new file mode 100644 index 000000000000..257b435607ce --- /dev/null +++ b/R/pkg/R/deserialize.R @@ -0,0 +1,184 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utility functions to deserialize objects from Java. + +# Type mapping from Java to R +# +# void -> NULL +# Int -> integer +# String -> character +# Boolean -> logical +# Double -> double +# Long -> double +# Array[Byte] -> raw +# Date -> Date +# Time -> POSIXct +# +# Array[T] -> list() +# Object -> jobj + +readObject <- function(con) { + # Read type first + type <- readType(con) + readTypedObject(con, type) +} + +readTypedObject <- function(con, type) { + switch (type, + "i" = readInt(con), + "c" = readString(con), + "b" = readBoolean(con), + "d" = readDouble(con), + "r" = readRaw(con), + "D" = readDate(con), + "t" = readTime(con), + "l" = readList(con), + "n" = NULL, + "j" = getJobj(readString(con)), + stop(paste("Unsupported type for deserialization", type))) +} + +readString <- function(con) { + stringLen <- readInt(con) + string <- readBin(con, raw(), stringLen, endian = "big") + rawToChar(string) +} + +readInt <- function(con) { + readBin(con, integer(), n = 1, endian = "big") +} + +readDouble <- function(con) { + readBin(con, double(), n = 1, endian = "big") +} + +readBoolean <- function(con) { + as.logical(readInt(con)) +} + +readType <- function(con) { + rawToChar(readBin(con, "raw", n = 1L)) +} + +readDate <- function(con) { + as.Date(readString(con)) +} + +readTime <- function(con) { + t <- readDouble(con) + as.POSIXct(t, origin = "1970-01-01") +} + +# We only support lists where all elements are of same type +readList <- function(con) { + type <- readType(con) + len <- readInt(con) + if (len > 0) { + l <- vector("list", len) + for (i in 1:len) { + l[[i]] <- readTypedObject(con, type) + } + l + } else { + list() + } +} + +readRaw <- function(con) { + dataLen <- readInt(con) + data <- readBin(con, raw(), as.integer(dataLen), endian = "big") +} + +readRawLen <- function(con, dataLen) { + data <- readBin(con, raw(), as.integer(dataLen), endian = "big") +} + +readDeserialize <- function(con) { + # We have two cases that are possible - In one, the entire partition is + # encoded as a byte array, so we have only one value to read. If so just + # return firstData + dataLen <- readInt(con) + firstData <- unserialize( + readBin(con, raw(), as.integer(dataLen), endian = "big")) + + # Else, read things into a list + dataLen <- readInt(con) + if (length(dataLen) > 0 && dataLen > 0) { + data <- list(firstData) + while (length(dataLen) > 0 && dataLen > 0) { + data[[length(data) + 1L]] <- unserialize( + readBin(con, raw(), as.integer(dataLen), endian = "big")) + dataLen <- readInt(con) + } + unlist(data, recursive = FALSE) + } else { + firstData + } +} + +readDeserializeRows <- function(inputCon) { + # readDeserializeRows will deserialize a DataOutputStream composed of + # a list of lists. Since the DOS is one continuous stream and + # the number of rows varies, we put the readRow function in a while loop + # that termintates when the next row is empty. + data <- list() + while(TRUE) { + row <- readRow(inputCon) + if (length(row) == 0) { + break + } + data[[length(data) + 1L]] <- row + } + data # this is a list of named lists now +} + +readRowList <- function(obj) { + # readRowList is meant for use inside an lapply. As a result, it is + # necessary to open a standalone connection for the row and consume + # the numCols bytes inside the read function in order to correctly + # deserialize the row. + rawObj <- rawConnection(obj, "r+") + on.exit(close(rawObj)) + readRow(rawObj) +} + +readRow <- function(inputCon) { + numCols <- readInt(inputCon) + if (length(numCols) > 0 && numCols > 0) { + lapply(1:numCols, function(x) { + obj <- readObject(inputCon) + if (is.null(obj)) { + NA + } else { + obj + } + }) # each row is a list now + } else { + list() + } +} + +# Take a single column as Array[Byte] and deserialize it into an atomic vector +readCol <- function(inputCon, numRows) { + # sapply can not work with POSIXlt + do.call(c, lapply(1:numRows, function(x) { + value <- readObject(inputCon) + # Replace NULL with NA so we can coerce to vectors + if (is.null(value)) NA else value + })) +} diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R new file mode 100644 index 000000000000..34dbe84051c5 --- /dev/null +++ b/R/pkg/R/generics.R @@ -0,0 +1,573 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############ RDD Actions and Transformations ############ + +#' @rdname aggregateRDD +#' @seealso reduce +#' @export +setGeneric("aggregateRDD", function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) + +#' @rdname cache-methods +#' @export +setGeneric("cache", function(x) { standardGeneric("cache") }) + +#' @rdname coalesce +#' @seealso repartition +#' @export +setGeneric("coalesce", function(x, numPartitions, ...) { standardGeneric("coalesce") }) + +#' @rdname checkpoint-methods +#' @export +setGeneric("checkpoint", function(x) { standardGeneric("checkpoint") }) + +#' @rdname collect-methods +#' @export +setGeneric("collect", function(x, ...) { standardGeneric("collect") }) + +#' @rdname collect-methods +#' @export +setGeneric("collectAsMap", function(x) { standardGeneric("collectAsMap") }) + +#' @rdname collect-methods +#' @export +setGeneric("collectPartition", + function(x, partitionId) { + standardGeneric("collectPartition") + }) + +#' @rdname count +#' @export +setGeneric("count", function(x) { standardGeneric("count") }) + +#' @rdname countByValue +#' @export +setGeneric("countByValue", function(x) { standardGeneric("countByValue") }) + +#' @rdname distinct +#' @export +setGeneric("distinct", function(x, numPartitions = 1) { standardGeneric("distinct") }) + +#' @rdname filterRDD +#' @export +setGeneric("filterRDD", function(x, f) { standardGeneric("filterRDD") }) + +#' @rdname first +#' @export +setGeneric("first", function(x) { standardGeneric("first") }) + +#' @rdname flatMap +#' @export +setGeneric("flatMap", function(X, FUN) { standardGeneric("flatMap") }) + +#' @rdname fold +#' @seealso reduce +#' @export +setGeneric("fold", function(x, zeroValue, op) { standardGeneric("fold") }) + +#' @rdname foreach +#' @export +setGeneric("foreach", function(x, func) { standardGeneric("foreach") }) + +#' @rdname foreach +#' @export +setGeneric("foreachPartition", function(x, func) { standardGeneric("foreachPartition") }) + +# The jrdd accessor function. +setGeneric("getJRDD", function(rdd, ...) { standardGeneric("getJRDD") }) + +#' @rdname glom +#' @export +setGeneric("glom", function(x) { standardGeneric("glom") }) + +#' @rdname keyBy +#' @export +setGeneric("keyBy", function(x, func) { standardGeneric("keyBy") }) + +#' @rdname lapplyPartition +#' @export +setGeneric("lapplyPartition", function(X, FUN) { standardGeneric("lapplyPartition") }) + +#' @rdname lapplyPartitionsWithIndex +#' @export +setGeneric("lapplyPartitionsWithIndex", + function(X, FUN) { + standardGeneric("lapplyPartitionsWithIndex") + }) + +#' @rdname lapply +#' @export +setGeneric("map", function(X, FUN) { standardGeneric("map") }) + +#' @rdname lapplyPartition +#' @export +setGeneric("mapPartitions", function(X, FUN) { standardGeneric("mapPartitions") }) + +#' @rdname lapplyPartitionsWithIndex +#' @export +setGeneric("mapPartitionsWithIndex", + function(X, FUN) { standardGeneric("mapPartitionsWithIndex") }) + +#' @rdname maximum +#' @export +setGeneric("maximum", function(x) { standardGeneric("maximum") }) + +#' @rdname minimum +#' @export +setGeneric("minimum", function(x) { standardGeneric("minimum") }) + +#' @rdname sumRDD +#' @export +setGeneric("sumRDD", function(x) { standardGeneric("sumRDD") }) + +#' @rdname name +#' @export +setGeneric("name", function(x) { standardGeneric("name") }) + +#' @rdname numPartitions +#' @export +setGeneric("numPartitions", function(x) { standardGeneric("numPartitions") }) + +#' @rdname persist +#' @export +setGeneric("persist", function(x, newLevel) { standardGeneric("persist") }) + +#' @rdname pipeRDD +#' @export +setGeneric("pipeRDD", function(x, command, env = list()) { standardGeneric("pipeRDD")}) + +#' @rdname reduce +#' @export +setGeneric("reduce", function(x, func) { standardGeneric("reduce") }) + +#' @rdname repartition +#' @seealso coalesce +#' @export +setGeneric("repartition", function(x, numPartitions) { standardGeneric("repartition") }) + +#' @rdname sampleRDD +#' @export +setGeneric("sampleRDD", + function(x, withReplacement, fraction, seed) { + standardGeneric("sampleRDD") + }) + +#' @rdname saveAsObjectFile +#' @seealso objectFile +#' @export +setGeneric("saveAsObjectFile", function(x, path) { standardGeneric("saveAsObjectFile") }) + +#' @rdname saveAsTextFile +#' @export +setGeneric("saveAsTextFile", function(x, path) { standardGeneric("saveAsTextFile") }) + +#' @rdname setName +#' @export +setGeneric("setName", function(x, name) { standardGeneric("setName") }) + +#' @rdname sortBy +#' @export +setGeneric("sortBy", + function(x, func, ascending = TRUE, numPartitions = 1) { + standardGeneric("sortBy") + }) + +#' @rdname take +#' @export +setGeneric("take", function(x, num) { standardGeneric("take") }) + +#' @rdname takeOrdered +#' @export +setGeneric("takeOrdered", function(x, num) { standardGeneric("takeOrdered") }) + +#' @rdname takeSample +#' @export +setGeneric("takeSample", + function(x, withReplacement, num, seed) { + standardGeneric("takeSample") + }) + +#' @rdname top +#' @export +setGeneric("top", function(x, num) { standardGeneric("top") }) + +#' @rdname unionRDD +#' @export +setGeneric("unionRDD", function(x, y) { standardGeneric("unionRDD") }) + +#' @rdname unpersist-methods +#' @export +setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") }) + +#' @rdname zipRDD +#' @export +setGeneric("zipRDD", function(x, other) { standardGeneric("zipRDD") }) + +#' @rdname zipWithIndex +#' @seealso zipWithUniqueId +#' @export +setGeneric("zipWithIndex", function(x) { standardGeneric("zipWithIndex") }) + +#' @rdname zipWithUniqueId +#' @seealso zipWithIndex +#' @export +setGeneric("zipWithUniqueId", function(x) { standardGeneric("zipWithUniqueId") }) + + +############ Binary Functions ############# + +#' @rdname cartesian +#' @export +setGeneric("cartesian", function(x, other) { standardGeneric("cartesian") }) + +#' @rdname countByKey +#' @export +setGeneric("countByKey", function(x) { standardGeneric("countByKey") }) + +#' @rdname flatMapValues +#' @export +setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") }) + +#' @rdname intersection +#' @export +setGeneric("intersection", function(x, other, numPartitions = 1) { + standardGeneric("intersection") }) + +#' @rdname keys +#' @export +setGeneric("keys", function(x) { standardGeneric("keys") }) + +#' @rdname lookup +#' @export +setGeneric("lookup", function(x, key) { standardGeneric("lookup") }) + +#' @rdname mapValues +#' @export +setGeneric("mapValues", function(X, FUN) { standardGeneric("mapValues") }) + +#' @rdname sampleByKey +#' @export +setGeneric("sampleByKey", + function(x, withReplacement, fractions, seed) { + standardGeneric("sampleByKey") + }) + +#' @rdname values +#' @export +setGeneric("values", function(x) { standardGeneric("values") }) + + +############ Shuffle Functions ############ + +#' @rdname aggregateByKey +#' @seealso foldByKey, combineByKey +#' @export +setGeneric("aggregateByKey", + function(x, zeroValue, seqOp, combOp, numPartitions) { + standardGeneric("aggregateByKey") + }) + +#' @rdname cogroup +#' @export +setGeneric("cogroup", + function(..., numPartitions) { + standardGeneric("cogroup") + }, + signature = "...") + +#' @rdname combineByKey +#' @seealso groupByKey, reduceByKey +#' @export +setGeneric("combineByKey", + function(x, createCombiner, mergeValue, mergeCombiners, numPartitions) { + standardGeneric("combineByKey") + }) + +#' @rdname foldByKey +#' @seealso aggregateByKey, combineByKey +#' @export +setGeneric("foldByKey", + function(x, zeroValue, func, numPartitions) { + standardGeneric("foldByKey") + }) + +#' @rdname join-methods +#' @export +setGeneric("fullOuterJoin", function(x, y, numPartitions) { standardGeneric("fullOuterJoin") }) + +#' @rdname groupByKey +#' @seealso reduceByKey +#' @export +setGeneric("groupByKey", function(x, numPartitions) { standardGeneric("groupByKey") }) + +#' @rdname join-methods +#' @export +setGeneric("join", function(x, y, ...) { standardGeneric("join") }) + +#' @rdname join-methods +#' @export +setGeneric("leftOuterJoin", function(x, y, numPartitions) { standardGeneric("leftOuterJoin") }) + +#' @rdname partitionBy +#' @export +setGeneric("partitionBy", function(x, numPartitions, ...) { standardGeneric("partitionBy") }) + +#' @rdname reduceByKey +#' @seealso groupByKey +#' @export +setGeneric("reduceByKey", function(x, combineFunc, numPartitions) { standardGeneric("reduceByKey")}) + +#' @rdname reduceByKeyLocally +#' @seealso reduceByKey +#' @export +setGeneric("reduceByKeyLocally", + function(x, combineFunc) { + standardGeneric("reduceByKeyLocally") + }) + +#' @rdname join-methods +#' @export +setGeneric("rightOuterJoin", function(x, y, numPartitions) { standardGeneric("rightOuterJoin") }) + +#' @rdname sortByKey +#' @export +setGeneric("sortByKey", + function(x, ascending = TRUE, numPartitions = 1) { + standardGeneric("sortByKey") + }) + +#' @rdname subtract +#' @export +setGeneric("subtract", + function(x, other, numPartitions = 1) { + standardGeneric("subtract") + }) + +#' @rdname subtractByKey +#' @export +setGeneric("subtractByKey", + function(x, other, numPartitions = 1) { + standardGeneric("subtractByKey") + }) + + +################### Broadcast Variable Methods ################# + +#' @rdname broadcast +#' @export +setGeneric("value", function(bcast) { standardGeneric("value") }) + + + +#################### DataFrame Methods ######################## + +#' @rdname schema +#' @export +setGeneric("columns", function(x) {standardGeneric("columns") }) + +#' @rdname schema +#' @export +setGeneric("dtypes", function(x) { standardGeneric("dtypes") }) + +#' @rdname explain +#' @export +setGeneric("explain", function(x, ...) { standardGeneric("explain") }) + +#' @rdname except +#' @export +setGeneric("except", function(x, y) { standardGeneric("except") }) + +#' @rdname filter +#' @export +setGeneric("filter", function(x, condition) { standardGeneric("filter") }) + +#' @rdname DataFrame +#' @export +setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") }) + +#' @rdname insertInto +#' @export +setGeneric("insertInto", function(x, tableName, ...) { standardGeneric("insertInto") }) + +#' @rdname intersect +#' @export +setGeneric("intersect", function(x, y) { standardGeneric("intersect") }) + +#' @rdname isLocal +#' @export +setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) + +#' @rdname limit +#' @export +setGeneric("limit", function(x, num) {standardGeneric("limit") }) + +#' @rdname sortDF +#' @export +setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") }) + +#' @rdname schema +#' @export +setGeneric("printSchema", function(x) { standardGeneric("printSchema") }) + +#' @rdname registerTempTable +#' @export +setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") }) + +#' @rdname sampleDF +#' @export +setGeneric("sampleDF", + function(x, withReplacement, fraction, seed) { + standardGeneric("sampleDF") + }) + +#' @rdname saveAsParquetFile +#' @export +setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) + +#' @rdname saveAsTable +#' @export +setGeneric("saveAsTable", function(df, tableName, source, mode, ...) { + standardGeneric("saveAsTable") +}) + +#' @rdname saveAsTable +#' @export +setGeneric("saveDF", function(df, path, source, mode, ...) { standardGeneric("saveDF") }) + +#' @rdname schema +#' @export +setGeneric("schema", function(x) { standardGeneric("schema") }) + +#' @rdname select +#' @export +setGeneric("select", function(x, col, ...) { standardGeneric("select") } ) + +#' @rdname select +#' @export +setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr") }) + +#' @rdname showDF +#' @export +setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) + +#' @rdname sortDF +#' @export +setGeneric("sortDF", function(x, col, ...) { standardGeneric("sortDF") }) + +#' @rdname tojson +#' @export +setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) + +#' @rdname DataFrame +#' @export +setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) + +#' @rdname unionAll +#' @export +setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) + +#' @rdname filter +#' @export +setGeneric("where", function(x, condition) { standardGeneric("where") }) + +#' @rdname withColumn +#' @export +setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn") }) + +#' @rdname withColumnRenamed +#' @export +setGeneric("withColumnRenamed", function(x, existingCol, newCol) { + standardGeneric("withColumnRenamed") }) + + +###################### Column Methods ########################## + +#' @rdname column +#' @export +setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) + +#' @rdname column +#' @export +setGeneric("asc", function(x) { standardGeneric("asc") }) + +#' @rdname column +#' @export +setGeneric("avg", function(x, ...) { standardGeneric("avg") }) + +#' @rdname column +#' @export +setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) + +#' @rdname column +#' @export +setGeneric("contains", function(x, ...) { standardGeneric("contains") }) +#' @rdname column +#' @export +setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) + +#' @rdname column +#' @export +setGeneric("desc", function(x) { standardGeneric("desc") }) + +#' @rdname column +#' @export +setGeneric("endsWith", function(x, ...) { standardGeneric("endsWith") }) + +#' @rdname column +#' @export +setGeneric("getField", function(x, ...) { standardGeneric("getField") }) + +#' @rdname column +#' @export +setGeneric("getItem", function(x, ...) { standardGeneric("getItem") }) + +#' @rdname column +#' @export +setGeneric("isNull", function(x) { standardGeneric("isNull") }) + +#' @rdname column +#' @export +setGeneric("isNotNull", function(x) { standardGeneric("isNotNull") }) + +#' @rdname column +#' @export +setGeneric("last", function(x) { standardGeneric("last") }) + +#' @rdname column +#' @export +setGeneric("like", function(x, ...) { standardGeneric("like") }) + +#' @rdname column +#' @export +setGeneric("lower", function(x) { standardGeneric("lower") }) + +#' @rdname column +#' @export +setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) + +#' @rdname column +#' @export +setGeneric("startsWith", function(x, ...) { standardGeneric("startsWith") }) + +#' @rdname column +#' @export +setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) + +#' @rdname column +#' @export +setGeneric("upper", function(x) { standardGeneric("upper") }) + diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R new file mode 100644 index 000000000000..02237b3672d6 --- /dev/null +++ b/R/pkg/R/group.R @@ -0,0 +1,135 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# group.R - GroupedData class and methods implemented in S4 OO classes + +#' @include generics.R jobj.R schema.R column.R +NULL + +setOldClass("jobj") + +#' @title S4 class that represents a GroupedData +#' @description GroupedDatas can be created using groupBy() on a DataFrame +#' @rdname GroupedData +#' @seealso groupBy +#' +#' @param sgd A Java object reference to the backing Scala GroupedData +#' @export +setClass("GroupedData", + slots = list(sgd = "jobj")) + +setMethod("initialize", "GroupedData", function(.Object, sgd) { + .Object@sgd <- sgd + .Object +}) + +#' @rdname DataFrame +groupedData <- function(sgd) { + new("GroupedData", sgd) +} + + +#' @rdname show +setMethod("show", "GroupedData", + function(object) { + cat("GroupedData\n") + }) + +#' Count +#' +#' Count the number of rows for each group. +#' The resulting DataFrame will also contain the grouping columns. +#' +#' @param x a GroupedData +#' @return a DataFrame +#' @export +#' @examples +#' \dontrun{ +#' count(groupBy(df, "name")) +#' } +setMethod("count", + signature(x = "GroupedData"), + function(x) { + dataFrame(callJMethod(x@sgd, "count")) + }) + +#' Agg +#' +#' Aggregates on the entire DataFrame without groups. +#' The resulting DataFrame will also contain the grouping columns. +#' +#' df2 <- agg(df, = ) +#' df2 <- agg(df, newColName = aggFunction(column)) +#' +#' @param x a GroupedData +#' @return a DataFrame +#' @rdname agg +#' @examples +#' \dontrun{ +#' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)' +#' df2 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum +#' } +setGeneric("agg", function (x, ...) { standardGeneric("agg") }) + +setMethod("agg", + signature(x = "GroupedData"), + function(x, ...) { + cols = list(...) + stopifnot(length(cols) > 0) + if (is.character(cols[[1]])) { + cols <- varargsToEnv(...) + sdf <- callJMethod(x@sgd, "agg", cols) + } else if (class(cols[[1]]) == "Column") { + ns <- names(cols) + if (!is.null(ns)) { + for (n in ns) { + if (n != "") { + cols[[n]] = alias(cols[[n]], n) + } + } + } + jcols <- lapply(cols, function(c) { c@jc }) + # the GroupedData.agg(col, cols*) API does not contain grouping Column + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "aggWithGrouping", + x@sgd, listToSeq(jcols)) + } else { + stop("agg can only support Column or character") + } + dataFrame(sdf) + }) + + +# sum/mean/avg/min/max +methods <- c("sum", "mean", "avg", "min", "max") + +createMethod <- function(name) { + setMethod(name, + signature(x = "GroupedData"), + function(x, ...) { + sdf <- callJMethod(x@sgd, name, toSeq(...)) + dataFrame(sdf) + }) +} + +createMethods <- function() { + for (name in methods) { + createMethod(name) + } +} + +createMethods() + diff --git a/R/pkg/R/jobj.R b/R/pkg/R/jobj.R new file mode 100644 index 000000000000..a8a25230b636 --- /dev/null +++ b/R/pkg/R/jobj.R @@ -0,0 +1,104 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# References to objects that exist on the JVM backend +# are maintained using the jobj. + +#' @include generics.R +NULL + +# Maintain a reference count of Java object references +# This allows us to GC the java object when it is safe +.validJobjs <- new.env(parent = emptyenv()) + +# List of object ids to be removed +.toRemoveJobjs <- new.env(parent = emptyenv()) + +# Check if jobj was created with the current SparkContext +isValidJobj <- function(jobj) { + if (exists(".scStartTime", envir = .sparkREnv)) { + jobj$appId == get(".scStartTime", envir = .sparkREnv) + } else { + FALSE + } +} + +getJobj <- function(objId) { + newObj <- jobj(objId) + if (exists(objId, .validJobjs)) { + .validJobjs[[objId]] <- .validJobjs[[objId]] + 1 + } else { + .validJobjs[[objId]] <- 1 + } + newObj +} + +# Handler for a java object that exists on the backend. +jobj <- function(objId) { + if (!is.character(objId)) { + stop("object id must be a character") + } + # NOTE: We need a new env for a jobj as we can only register + # finalizers for environments or external references pointers. + obj <- structure(new.env(parent = emptyenv()), class = "jobj") + obj$id <- objId + obj$appId <- get(".scStartTime", envir = .sparkREnv) + + # Register a finalizer to remove the Java object when this reference + # is garbage collected in R + reg.finalizer(obj, cleanup.jobj) + obj +} + +#' Print a JVM object reference. +#' +#' This function prints the type and id for an object stored +#' in the SparkR JVM backend. +#' +#' @param x The JVM object reference +#' @param ... further arguments passed to or from other methods +print.jobj <- function(x, ...) { + cls <- callJMethod(x, "getClass") + name <- callJMethod(cls, "getName") + cat("Java ref type", name, "id", x$id, "\n", sep = " ") +} + +cleanup.jobj <- function(jobj) { + if (isValidJobj(jobj)) { + objId <- jobj$id + # If we don't know anything about this jobj, ignore it + if (exists(objId, envir = .validJobjs)) { + .validJobjs[[objId]] <- .validJobjs[[objId]] - 1 + + if (.validJobjs[[objId]] == 0) { + rm(list = objId, envir = .validJobjs) + # NOTE: We cannot call removeJObject here as the finalizer may be run + # in the middle of another RPC. Thus we queue up this object Id to be removed + # and then run all the removeJObject when the next RPC is called. + .toRemoveJobjs[[objId]] <- 1 + } + } + } +} + +clearJobjs <- function() { + valid <- ls(.validJobjs) + rm(list = valid, envir = .validJobjs) + + removeList <- ls(.toRemoveJobjs) + rm(list = removeList, envir = .toRemoveJobjs) +} diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R new file mode 100644 index 000000000000..9791e55791ba --- /dev/null +++ b/R/pkg/R/pairRDD.R @@ -0,0 +1,909 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Operations supported on RDDs contains pairs (i.e key, value) +#' @include generics.R jobj.R RDD.R +NULL + +############ Actions and Transformations ############ + +#' Look up elements of a key in an RDD +#' +#' @description +#' \code{lookup} returns a list of values in this RDD for key key. +#' +#' @param x The RDD to collect +#' @param key The key to look up for +#' @return a list of values in this RDD for key key +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(c(1, 1), c(2, 2), c(1, 3)) +#' rdd <- parallelize(sc, pairs) +#' lookup(rdd, 1) # list(1, 3) +#'} +#' @rdname lookup +#' @aliases lookup,RDD-method +setMethod("lookup", + signature(x = "RDD", key = "ANY"), + function(x, key) { + partitionFunc <- function(part) { + filtered <- part[unlist(lapply(part, function(i) { identical(key, i[[1]]) }))] + lapply(filtered, function(i) { i[[2]] }) + } + valsRDD <- lapplyPartition(x, partitionFunc) + collect(valsRDD) + }) + +#' Count the number of elements for each key, and return the result to the +#' master as lists of (key, count) pairs. +#' +#' Same as countByKey in Spark. +#' +#' @param x The RDD to count keys. +#' @return list of (key, count) pairs, where count is number of each key in rdd. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(c("a", 1), c("b", 1), c("a", 1))) +#' countByKey(rdd) # ("a", 2L), ("b", 1L) +#'} +#' @rdname countByKey +#' @aliases countByKey,RDD-method +setMethod("countByKey", + signature(x = "RDD"), + function(x) { + keys <- lapply(x, function(item) { item[[1]] }) + countByValue(keys) + }) + +#' Return an RDD with the keys of each tuple. +#' +#' @param x The RDD from which the keys of each tuple is returned. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) +#' collect(keys(rdd)) # list(1, 3) +#'} +#' @rdname keys +#' @aliases keys,RDD +setMethod("keys", + signature(x = "RDD"), + function(x) { + func <- function(k) { + k[[1]] + } + lapply(x, func) + }) + +#' Return an RDD with the values of each tuple. +#' +#' @param x The RDD from which the values of each tuple is returned. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) +#' collect(values(rdd)) # list(2, 4) +#'} +#' @rdname values +#' @aliases values,RDD +setMethod("values", + signature(x = "RDD"), + function(x) { + func <- function(v) { + v[[2]] + } + lapply(x, func) + }) + +#' Applies a function to all values of the elements, without modifying the keys. +#' +#' The same as `mapValues()' in Spark. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on the value of each element. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' makePairs <- lapply(rdd, function(x) { list(x, x) }) +#' collect(mapValues(makePairs, function(x) { x * 2) }) +#' Output: list(list(1,2), list(2,4), list(3,6), ...) +#'} +#' @rdname mapValues +#' @aliases mapValues,RDD,function-method +setMethod("mapValues", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + func <- function(x) { + list(x[[1]], FUN(x[[2]])) + } + lapply(X, func) + }) + +#' Pass each value in the key-value pair RDD through a flatMap function without +#' changing the keys; this also retains the original RDD's partitioning. +#' +#' The same as 'flatMapValues()' in Spark. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on the value of each element. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) +#' collect(flatMapValues(rdd, function(x) { x })) +#' Output: list(list(1,1), list(1,2), list(2,3), list(2,4)) +#'} +#' @rdname flatMapValues +#' @aliases flatMapValues,RDD,function-method +setMethod("flatMapValues", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + flatMapFunc <- function(x) { + lapply(FUN(x[[2]]), function(v) { list(x[[1]], v) }) + } + flatMap(X, flatMapFunc) + }) + +############ Shuffle Functions ############ + +#' Partition an RDD by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' For each element of this RDD, the partitioner is used to compute a hash +#' function and the RDD is partitioned using this hash value. +#' +#' @param x The RDD to partition. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param numPartitions Number of partitions to create. +#' @param ... Other optional arguments to partitionBy. +#' +#' @param partitionFunc The partition function to use. Uses a default hashCode +#' function if not provided +#' @return An RDD partitioned using the specified partitioner. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- partitionBy(rdd, 2L) +#' collectPartition(parts, 0L) # First partition should contain list(1, 2) and list(1, 4) +#'} +#' @rdname partitionBy +#' @aliases partitionBy,RDD,integer-method +setMethod("partitionBy", + signature(x = "RDD", numPartitions = "numeric"), + function(x, numPartitions, partitionFunc = hashCode) { + + #if (missing(partitionFunc)) { + # partitionFunc <- hashCode + #} + + partitionFunc <- cleanClosure(partitionFunc) + serializedHashFuncBytes <- serialize(partitionFunc, connection = NULL) + + packageNamesArr <- serialize(.sparkREnv$.packages, + connection = NULL) + broadcastArr <- lapply(ls(.broadcastNames), function(name) { + get(name, .broadcastNames) }) + jrdd <- getJRDD(x) + + # We create a PairwiseRRDD that extends RDD[(Int, Array[Byte])], + # where the key is the target partition number, the value is + # the content (key-val pairs). + pairwiseRRDD <- newJObject("org.apache.spark.api.r.PairwiseRRDD", + callJMethod(jrdd, "rdd"), + numToInt(numPartitions), + serializedHashFuncBytes, + getSerializedMode(x), + packageNamesArr, + as.character(.sparkREnv$libname), + broadcastArr, + callJMethod(jrdd, "classTag")) + + # Create a corresponding partitioner. + rPartitioner <- newJObject("org.apache.spark.HashPartitioner", + numToInt(numPartitions)) + + # Call partitionBy on the obtained PairwiseRDD. + javaPairRDD <- callJMethod(pairwiseRRDD, "asJavaPairRDD") + javaPairRDD <- callJMethod(javaPairRDD, "partitionBy", rPartitioner) + + # Call .values() on the result to get back the final result, the + # shuffled acutal content key-val pairs. + r <- callJMethod(javaPairRDD, "values") + + RDD(r, serializedMode = "byte") + }) + +#' Group values by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and group values for each key in the RDD into a single sequence. +#' +#' @param x The RDD to group. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, list(V)) +#' @seealso reduceByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- groupByKey(rdd, 2L) +#' grouped <- collect(parts) +#' grouped[[1]] # Should be a list(1, list(2, 4)) +#'} +#' @rdname groupByKey +#' @aliases groupByKey,RDD,integer-method +setMethod("groupByKey", + signature(x = "RDD", numPartitions = "numeric"), + function(x, numPartitions) { + shuffled <- partitionBy(x, numPartitions) + groupVals <- function(part) { + vals <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + appendList <- function(acc, i) { + addItemToAccumulator(acc, i) + acc + } + makeList <- function(i) { + acc <- initAccumulator() + addItemToAccumulator(acc, i) + acc + } + # Each item in the partition is list of (K, V) + lapply(part, + function(item) { + item$hash <- as.character(hashCode(item[[1]])) + updateOrCreatePair(item, keys, vals, pred, + appendList, makeList) + }) + # extract out data field + vals <- eapply(vals, + function(i) { + length(i$data) <- i$counter + i$data + }) + # Every key in the environment contains a list + # Convert that to list(K, Seq[V]) + convertEnvsToList(keys, vals) + } + lapplyPartition(shuffled, groupVals) + }) + +#' Merge values by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and merges the values for each key using an associative reduce function. +#' +#' @param x The RDD to reduce by key. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param combineFunc The associative reduce function to use. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, V') where V' is the merged +#' value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- reduceByKey(rdd, "+", 2L) +#' reduced <- collect(parts) +#' reduced[[1]] # Should be a list(1, 6) +#'} +#' @rdname reduceByKey +#' @aliases reduceByKey,RDD,integer-method +setMethod("reduceByKey", + signature(x = "RDD", combineFunc = "ANY", numPartitions = "numeric"), + function(x, combineFunc, numPartitions) { + reduceVals <- function(part) { + vals <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(hashCode(item[[1]])) + updateOrCreatePair(item, keys, vals, pred, combineFunc, identity) + }) + convertEnvsToList(keys, vals) + } + locallyReduced <- lapplyPartition(x, reduceVals) + shuffled <- partitionBy(locallyReduced, numPartitions) + lapplyPartition(shuffled, reduceVals) + }) + +#' Merge values by key locally +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and merges the values for each key using an associative reduce function, but return the +#' results immediately to the driver as an R list. +#' +#' @param x The RDD to reduce by key. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param combineFunc The associative reduce function to use. +#' @return A list of elements of type list(K, V') where V' is the merged value for each key +#' @seealso reduceByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' reduced <- reduceByKeyLocally(rdd, "+") +#' reduced # list(list(1, 6), list(1.1, 3)) +#'} +#' @rdname reduceByKeyLocally +#' @aliases reduceByKeyLocally,RDD,integer-method +setMethod("reduceByKeyLocally", + signature(x = "RDD", combineFunc = "ANY"), + function(x, combineFunc) { + reducePart <- function(part) { + vals <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(hashCode(item[[1]])) + updateOrCreatePair(item, keys, vals, pred, combineFunc, identity) + }) + list(list(keys, vals)) # return hash to avoid re-compute in merge + } + mergeParts <- function(accum, x) { + pred <- function(item) { + exists(item$hash, accum[[1]]) + } + lapply(ls(x[[1]]), + function(name) { + item <- list(x[[1]][[name]], x[[2]][[name]]) + item$hash <- name + updateOrCreatePair(item, accum[[1]], accum[[2]], pred, combineFunc, identity) + }) + accum + } + reduced <- mapPartitions(x, reducePart) + merged <- reduce(reduced, mergeParts) + convertEnvsToList(merged[[1]], merged[[2]]) + }) + +#' Combine values by key +#' +#' Generic function to combine the elements for each key using a custom set of +#' aggregation functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], +#' for a "combined type" C. Note that V and C can be different -- for example, one +#' might group an RDD of type (Int, Int) into an RDD of type (Int, Seq[Int]). + +#' Users provide three functions: +#' \itemize{ +#' \item createCombiner, which turns a V into a C (e.g., creates a one-element list) +#' \item mergeValue, to merge a V into a C (e.g., adds it to the end of a list) - +#' \item mergeCombiners, to combine two C's into a single one (e.g., concatentates +#' two lists). +#' } +#' +#' @param x The RDD to combine. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param createCombiner Create a combiner (C) given a value (V) +#' @param mergeValue Merge the given value (V) with an existing combiner (C) +#' @param mergeCombiners Merge two combiners and return a new combiner +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, C) where C is the combined type +#' +#' @seealso groupByKey, reduceByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- combineByKey(rdd, function(x) { x }, "+", "+", 2L) +#' combined <- collect(parts) +#' combined[[1]] # Should be a list(1, 6) +#'} +#' @rdname combineByKey +#' @aliases combineByKey,RDD,ANY,ANY,ANY,integer-method +setMethod("combineByKey", + signature(x = "RDD", createCombiner = "ANY", mergeValue = "ANY", + mergeCombiners = "ANY", numPartitions = "numeric"), + function(x, createCombiner, mergeValue, mergeCombiners, numPartitions) { + combineLocally <- function(part) { + combiners <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(hashCode(item[[1]])) + updateOrCreatePair(item, keys, combiners, pred, mergeValue, createCombiner) + }) + convertEnvsToList(keys, combiners) + } + locallyCombined <- lapplyPartition(x, combineLocally) + shuffled <- partitionBy(locallyCombined, numPartitions) + mergeAfterShuffle <- function(part) { + combiners <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(hashCode(item[[1]])) + updateOrCreatePair(item, keys, combiners, pred, mergeCombiners, identity) + }) + convertEnvsToList(keys, combiners) + } + lapplyPartition(shuffled, mergeAfterShuffle) + }) + +#' Aggregate a pair RDD by each key. +#' +#' Aggregate the values of each key in an RDD, using given combine functions +#' and a neutral "zero value". This function can return a different result type, +#' U, than the type of the values in this RDD, V. Thus, we need one operation +#' for merging a V into a U and one operation for merging two U's, The former +#' operation is used for merging values within a partition, and the latter is +#' used for merging values between partitions. To avoid memory allocation, both +#' of these functions are allowed to modify and return their first argument +#' instead of creating a new U. +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param seqOp A function to aggregate the values of each key. It may return +#' a different result type from the type of the values. +#' @param combOp A function to aggregate results of seqOp. +#' @return An RDD containing the aggregation result. +#' @seealso foldByKey, combineByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) +#' zeroValue <- list(0, 0) +#' seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } +#' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } +#' aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) +#' # list(list(1, list(3, 2)), list(2, list(7, 2))) +#'} +#' @rdname aggregateByKey +#' @aliases aggregateByKey,RDD,ANY,ANY,ANY,integer-method +setMethod("aggregateByKey", + signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", + combOp = "ANY", numPartitions = "numeric"), + function(x, zeroValue, seqOp, combOp, numPartitions) { + createCombiner <- function(v) { + do.call(seqOp, list(zeroValue, v)) + } + + combineByKey(x, createCombiner, seqOp, combOp, numPartitions) + }) + +#' Fold a pair RDD by each key. +#' +#' Aggregate the values of each key in an RDD, using an associative function "func" +#' and a neutral "zero value" which may be added to the result an arbitrary +#' number of times, and must not change the result (e.g., 0 for addition, or +#' 1 for multiplication.). +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param func An associative function for folding values of each key. +#' @return An RDD containing the aggregation result. +#' @seealso aggregateByKey, combineByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) +#' foldByKey(rdd, 0, "+", 2L) # list(list(1, 3), list(2, 7)) +#'} +#' @rdname foldByKey +#' @aliases foldByKey,RDD,ANY,ANY,integer-method +setMethod("foldByKey", + signature(x = "RDD", zeroValue = "ANY", + func = "ANY", numPartitions = "numeric"), + function(x, zeroValue, func, numPartitions) { + aggregateByKey(x, zeroValue, func, func, numPartitions) + }) + +############ Binary Functions ############# + +#' Join two RDDs +#' +#' @description +#' \code{join} This function joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return a new RDD containing all pairs of elements with matching keys in +#' two input RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' join(rdd1, rdd2, 2L) # list(list(1, list(1, 2)), list(1, list(1, 3)) +#'} +#' @rdname join-methods +#' @aliases join,RDD,RDD-method +setMethod("join", + signature(x = "RDD", y = "RDD"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(FALSE, FALSE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), + doJoin) + }) + +#' Left outer join two RDDs +#' +#' @description +#' \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, v) in x, the resulting RDD will either contain +#' all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL)) +#' if no elements in rdd2 have key k. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' leftOuterJoin(rdd1, rdd2, 2L) +#' # list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) +#'} +#' @rdname join-methods +#' @aliases leftOuterJoin,RDD,RDD-method +setMethod("leftOuterJoin", + signature(x = "RDD", y = "RDD", numPartitions = "numeric"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(FALSE, TRUE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) + }) + +#' Right outer join two RDDs +#' +#' @description +#' \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, w) in y, the resulting RDD will either contain +#' all pairs (k, (v, w)) for (k, v) in x, or the pair (k, (NULL, w)) +#' if no elements in x have key k. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rightOuterJoin(rdd1, rdd2, 2L) +#' # list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) +#'} +#' @rdname join-methods +#' @aliases rightOuterJoin,RDD,RDD-method +setMethod("rightOuterJoin", + signature(x = "RDD", y = "RDD", numPartitions = "numeric"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(TRUE, FALSE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) + }) + +#' Full outer join two RDDs +#' +#' @description +#' \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, v) in x and (k, w) in y, the resulting RDD +#' will contain all pairs (k, (v, w)) for both (k, v) in x and +#' (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements +#' in x/y have key k. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) +#' rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' fullOuterJoin(rdd1, rdd2, 2L) # list(list(1, list(2, 1)), +#' # list(1, list(3, 1)), +#' # list(2, list(NULL, 4))) +#' # list(3, list(3, NULL)), +#'} +#' @rdname join-methods +#' @aliases fullOuterJoin,RDD,RDD-method +setMethod("fullOuterJoin", + signature(x = "RDD", y = "RDD", numPartitions = "numeric"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(TRUE, TRUE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) + }) + +#' For each key k in several RDDs, return a resulting RDD that +#' whose values are a list of values for the key in all RDDs. +#' +#' @param ... Several RDDs. +#' @param numPartitions Number of partitions to create. +#' @return a new RDD containing all pairs of elements with values in a list +#' in all RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' cogroup(rdd1, rdd2, numPartitions = 2L) +#' # list(list(1, list(1, list(2, 3))), list(2, list(list(4), list())) +#'} +#' @rdname cogroup +#' @aliases cogroup,RDD-method +setMethod("cogroup", + "RDD", + function(..., numPartitions) { + rdds <- list(...) + rddsLen <- length(rdds) + for (i in 1:rddsLen) { + rdds[[i]] <- lapply(rdds[[i]], + function(x) { list(x[[1]], list(i, x[[2]])) }) + } + union.rdd <- Reduce(unionRDD, rdds) + group.func <- function(vlist) { + res <- list() + length(res) <- rddsLen + for (x in vlist) { + i <- x[[1]] + acc <- res[[i]] + # Create an accumulator. + if (is.null(acc)) { + acc <- initAccumulator() + } + addItemToAccumulator(acc, x[[2]]) + res[[i]] <- acc + } + lapply(res, function(acc) { + if (is.null(acc)) { + list() + } else { + acc$data + } + }) + } + cogroup.rdd <- mapValues(groupByKey(union.rdd, numPartitions), + group.func) + }) + +#' Sort a (k, v) pair RDD by k. +#' +#' @param x A (k, v) pair RDD to be sorted. +#' @param ascending A flag to indicate whether the sorting is ascending or descending. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where all (k, v) pair elements are sorted. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(3, 1), list(2, 2), list(1, 3))) +#' collect(sortByKey(rdd)) # list (list(1, 3), list(2, 2), list(3, 1)) +#'} +#' @rdname sortByKey +#' @aliases sortByKey,RDD,RDD-method +setMethod("sortByKey", + signature(x = "RDD"), + function(x, ascending = TRUE, numPartitions = SparkR::numPartitions(x)) { + rangeBounds <- list() + + if (numPartitions > 1) { + rddSize <- count(x) + # constant from Spark's RangePartitioner + maxSampleSize <- numPartitions * 20 + fraction <- min(maxSampleSize / max(rddSize, 1), 1.0) + + samples <- collect(keys(sampleRDD(x, FALSE, fraction, 1L))) + + # Note: the built-in R sort() function only works on atomic vectors + samples <- sort(unlist(samples, recursive = FALSE), decreasing = !ascending) + + if (length(samples) > 0) { + rangeBounds <- lapply(seq_len(numPartitions - 1), + function(i) { + j <- ceiling(length(samples) * i / numPartitions) + samples[j] + }) + } + } + + rangePartitionFunc <- function(key) { + partition <- 0 + + # TODO: Use binary search instead of linear search, similar with Spark + while (partition < length(rangeBounds) && key > rangeBounds[[partition + 1]]) { + partition <- partition + 1 + } + + if (ascending) { + partition + } else { + numPartitions - partition - 1 + } + } + + partitionFunc <- function(part) { + sortKeyValueList(part, decreasing = !ascending) + } + + newRDD <- partitionBy(x, numPartitions, rangePartitionFunc) + lapplyPartition(newRDD, partitionFunc) + }) + +#' Subtract a pair RDD with another pair RDD. +#' +#' Return an RDD with the pairs from x whose keys are not in other. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @param numPartitions Number of the partitions in the result RDD. +#' @return An RDD with the pairs from x whose keys are not in other. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4), +#' list("b", 5), list("a", 2))) +#' rdd2 <- parallelize(sc, list(list("a", 3), list("c", 1))) +#' collect(subtractByKey(rdd1, rdd2)) +#' # list(list("b", 4), list("b", 5)) +#'} +#' @rdname subtractByKey +#' @aliases subtractByKey,RDD +setMethod("subtractByKey", + signature(x = "RDD", other = "RDD"), + function(x, other, numPartitions = SparkR::numPartitions(x)) { + filterFunction <- function(elem) { + iters <- elem[[2]] + (length(iters[[1]]) > 0) && (length(iters[[2]]) == 0) + } + + flatMapValues(filterRDD(cogroup(x, + other, + numPartitions = numPartitions), + filterFunction), + function (v) { v[[1]] }) + }) + +#' Return a subset of this RDD sampled by key. +#' +#' @description +#' \code{sampleByKey} Create a sample of this RDD using variable sampling rates +#' for different keys as specified by fractions, a key to sampling rate map. +#' +#' @param x The RDD to sample elements by key, where each element is +#' list(K, V) or c(K, V). +#' @param withReplacement Sampling with replacement or not +#' @param fraction The (rough) sample target fraction +#' @param seed Randomness seed value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3000) +#' pairs <- lapply(rdd, function(x) { if (x %% 3 == 0) list("a", x) +#' else { if (x %% 3 == 1) list("b", x) else list("c", x) }}) +#' fractions <- list(a = 0.2, b = 0.1, c = 0.3) +#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) +#' 100 < length(lookup(sample, "a")) && 300 > length(lookup(sample, "a")) # TRUE +#' 50 < length(lookup(sample, "b")) && 150 > length(lookup(sample, "b")) # TRUE +#' 200 < length(lookup(sample, "c")) && 400 > length(lookup(sample, "c")) # TRUE +#' lookup(sample, "a")[which.min(lookup(sample, "a"))] >= 0 # TRUE +#' lookup(sample, "a")[which.max(lookup(sample, "a"))] <= 2000 # TRUE +#' lookup(sample, "b")[which.min(lookup(sample, "b"))] >= 0 # TRUE +#' lookup(sample, "b")[which.max(lookup(sample, "b"))] <= 2000 # TRUE +#' lookup(sample, "c")[which.min(lookup(sample, "c"))] >= 0 # TRUE +#' lookup(sample, "c")[which.max(lookup(sample, "c"))] <= 2000 # TRUE +#' fractions <- list(a = 0.2, b = 0.1, c = 0.3, d = 0.4) +#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # Key "d" will be ignored +#' fractions <- list(a = 0.2, b = 0.1) +#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # KeyError: "c" +#'} +#' @rdname sampleByKey +#' @aliases sampleByKey,RDD-method +setMethod("sampleByKey", + signature(x = "RDD", withReplacement = "logical", + fractions = "vector", seed = "integer"), + function(x, withReplacement, fractions, seed) { + + for (elem in fractions) { + if (elem < 0.0) { + stop(paste("Negative fraction value ", fractions[which(fractions == elem)])) + } + } + + # The sampler: takes a partition and returns its sampled version. + samplingFunc <- function(partIndex, part) { + set.seed(bitwXor(seed, partIndex)) + res <- vector("list", length(part)) + len <- 0 + + # mixing because the initial seeds are close to each other + runif(10) + + for (elem in part) { + if (elem[[1]] %in% names(fractions)) { + frac <- as.numeric(fractions[which(elem[[1]] == names(fractions))]) + if (withReplacement) { + count <- rpois(1, frac) + if (count > 0) { + res[(len + 1):(len + count)] <- rep(list(elem), count) + len <- len + count + } + } else { + if (runif(1) < frac) { + len <- len + 1 + res[[len]] <- elem + } + } + } else { + stop("KeyError: \"", elem[[1]], "\"") + } + } + + # TODO(zongheng): look into the performance of the current + # implementation. Look into some iterator package? Note that + # Scala avoids many calls to creating an empty list and PySpark + # similarly achieves this using `yield'. (duplicated from sampleRDD) + if (len > 0) { + res[1:len] + } else { + list() + } + } + + lapplyPartitionsWithIndex(x, samplingFunc) + }) diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R new file mode 100644 index 000000000000..e442119086b1 --- /dev/null +++ b/R/pkg/R/schema.R @@ -0,0 +1,162 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# A set of S3 classes and methods that support the SparkSQL `StructType` and `StructField +# datatypes. These are used to create and interact with DataFrame schemas. + +#' structType +#' +#' Create a structType object that contains the metadata for a DataFrame. Intended for +#' use with createDataFrame and toDF. +#' +#' @param x a structField object (created with the field() function) +#' @param ... additional structField objects +#' @return a structType object +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) +#' schema <- structType(structField("a", "integer"), structField("b", "string")) +#' df <- createDataFrame(sqlCtx, rdd, schema) +#' } +structType <- function(x, ...) { + UseMethod("structType", x) +} + +structType.jobj <- function(x) { + obj <- structure(list(), class = "structType") + obj$jobj <- x + obj$fields <- function() { lapply(callJMethod(obj$jobj, "fields"), structField) } + obj +} + +structType.structField <- function(x, ...) { + fields <- list(x, ...) + if (!all(sapply(fields, inherits, "structField"))) { + stop("All arguments must be structField objects.") + } + sfObjList <- lapply(fields, function(field) { + field$jobj + }) + stObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "createStructType", + listToSeq(sfObjList)) + structType(stObj) +} + +#' Print a Spark StructType. +#' +#' This function prints the contents of a StructType returned from the +#' SparkR JVM backend. +#' +#' @param x A StructType object +#' @param ... further arguments passed to or from other methods +print.structType <- function(x, ...) { + cat("StructType\n", + sapply(x$fields(), function(field) { paste("|-", "name = \"", field$name(), + "\", type = \"", field$dataType.toString(), + "\", nullable = ", field$nullable(), "\n", + sep = "") }) + , sep = "") +} + +#' structField +#' +#' Create a structField object that contains the metadata for a single field in a schema. +#' +#' @param x The name of the field +#' @param type The data type of the field +#' @param nullable A logical vector indicating whether or not the field is nullable +#' @return a structField object +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) +#' field1 <- structField("a", "integer", TRUE) +#' field2 <- structField("b", "string", TRUE) +#' schema <- structType(field1, field2) +#' df <- createDataFrame(sqlCtx, rdd, schema) +#' } + +structField <- function(x, ...) { + UseMethod("structField", x) +} + +structField.jobj <- function(x) { + obj <- structure(list(), class = "structField") + obj$jobj <- x + obj$name <- function() { callJMethod(x, "name") } + obj$dataType <- function() { callJMethod(x, "dataType") } + obj$dataType.toString <- function() { callJMethod(obj$dataType(), "toString") } + obj$dataType.simpleString <- function() { callJMethod(obj$dataType(), "simpleString") } + obj$nullable <- function() { callJMethod(x, "nullable") } + obj +} + +structField.character <- function(x, type, nullable = TRUE) { + if (class(x) != "character") { + stop("Field name must be a string.") + } + if (class(type) != "character") { + stop("Field type must be a string.") + } + if (class(nullable) != "logical") { + stop("nullable must be either TRUE or FALSE") + } + options <- c("byte", + "integer", + "double", + "numeric", + "character", + "string", + "binary", + "raw", + "logical", + "boolean", + "timestamp", + "date") + dataType <- if (type %in% options) { + type + } else { + stop(paste("Unsupported type for Dataframe:", type)) + } + sfObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "createStructField", + x, + dataType, + nullable) + structField(sfObj) +} + +#' Print a Spark StructField. +#' +#' This function prints the contents of a StructField returned from the +#' SparkR JVM backend. +#' +#' @param x A StructField object +#' @param ... further arguments passed to or from other methods +print.structField <- function(x, ...) { + cat("StructField(name = \"", x$name(), + "\", type = \"", x$dataType.toString(), + "\", nullable = ", x$nullable(), + ")", + sep = "") +} diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R new file mode 100644 index 000000000000..c53d0a961016 --- /dev/null +++ b/R/pkg/R/serialize.R @@ -0,0 +1,192 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utility functions to serialize R objects so they can be read in Java. + +# Type mapping from R to Java +# +# NULL -> Void +# integer -> Int +# character -> String +# logical -> Boolean +# double, numeric -> Double +# raw -> Array[Byte] +# Date -> Date +# POSIXct,POSIXlt -> Time +# +# list[T] -> Array[T], where T is one of above mentioned types +# environment -> Map[String, T], where T is a native type +# jobj -> Object, where jobj is an object created in the backend + +writeObject <- function(con, object, writeType = TRUE) { + # NOTE: In R vectors have same type as objects. So we don't support + # passing in vectors as arrays and instead require arrays to be passed + # as lists. + type <- class(object)[[1]] # class of POSIXlt is c("POSIXlt", "POSIXt") + if (writeType) { + writeType(con, type) + } + switch(type, + NULL = writeVoid(con), + integer = writeInt(con, object), + character = writeString(con, object), + logical = writeBoolean(con, object), + double = writeDouble(con, object), + numeric = writeDouble(con, object), + raw = writeRaw(con, object), + list = writeList(con, object), + jobj = writeJobj(con, object), + environment = writeEnv(con, object), + Date = writeDate(con, object), + POSIXlt = writeTime(con, object), + POSIXct = writeTime(con, object), + stop(paste("Unsupported type for serialization", type))) +} + +writeVoid <- function(con) { + # no value for NULL +} + +writeJobj <- function(con, value) { + if (!isValidJobj(value)) { + stop("invalid jobj ", value$id) + } + writeString(con, value$id) +} + +writeString <- function(con, value) { + utfVal <- enc2utf8(value) + writeInt(con, as.integer(nchar(utfVal, type = "bytes") + 1)) + writeBin(utfVal, con, endian = "big") +} + +writeInt <- function(con, value) { + writeBin(as.integer(value), con, endian = "big") +} + +writeDouble <- function(con, value) { + writeBin(value, con, endian = "big") +} + +writeBoolean <- function(con, value) { + # TRUE becomes 1, FALSE becomes 0 + writeInt(con, as.integer(value)) +} + +writeRawSerialize <- function(outputCon, batch) { + outputSer <- serialize(batch, ascii = FALSE, connection = NULL) + writeRaw(outputCon, outputSer) +} + +writeRowSerialize <- function(outputCon, rows) { + invisible(lapply(rows, function(r) { + bytes <- serializeRow(r) + writeRaw(outputCon, bytes) + })) +} + +serializeRow <- function(row) { + rawObj <- rawConnection(raw(0), "wb") + on.exit(close(rawObj)) + writeRow(rawObj, row) + rawConnectionValue(rawObj) +} + +writeRow <- function(con, row) { + numCols <- length(row) + writeInt(con, numCols) + for (i in 1:numCols) { + writeObject(con, row[[i]]) + } +} + +writeRaw <- function(con, batch) { + writeInt(con, length(batch)) + writeBin(batch, con, endian = "big") +} + +writeType <- function(con, class) { + type <- switch(class, + NULL = "n", + integer = "i", + character = "c", + logical = "b", + double = "d", + numeric = "d", + raw = "r", + list = "l", + jobj = "j", + environment = "e", + Date = "D", + POSIXlt = 't', + POSIXct = 't', + stop(paste("Unsupported type for serialization", class))) + writeBin(charToRaw(type), con) +} + +# Used to pass arrays where all the elements are of the same type +writeList <- function(con, arr) { + # All elements should be of same type + elemType <- unique(sapply(arr, function(elem) { class(elem) })) + stopifnot(length(elemType) <= 1) + + # TODO: Empty lists are given type "character" right now. + # This may not work if the Java side expects array of any other type. + if (length(elemType) == 0) { + elemType <- class("somestring") + } + + writeType(con, elemType) + writeInt(con, length(arr)) + + if (length(arr) > 0) { + for (a in arr) { + writeObject(con, a, FALSE) + } + } +} + +# Used to pass in hash maps required on Java side. +writeEnv <- function(con, env) { + len <- length(env) + + writeInt(con, len) + if (len > 0) { + writeList(con, as.list(ls(env))) + vals <- lapply(ls(env), function(x) { env[[x]] }) + writeList(con, as.list(vals)) + } +} + +writeDate <- function(con, date) { + writeString(con, as.character(date)) +} + +writeTime <- function(con, time) { + writeDouble(con, as.double(time)) +} + +# Used to serialize in a list of objects where each +# object can be of a different type. Serialization format is +# for each object +writeArgs <- function(con, args) { + if (length(args) > 0) { + for (a in args) { + writeObject(con, a) + } + } +} diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R new file mode 100644 index 000000000000..bc82df01f0ff --- /dev/null +++ b/R/pkg/R/sparkR.R @@ -0,0 +1,266 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.sparkREnv <- new.env() + +sparkR.onLoad <- function(libname, pkgname) { + .sparkREnv$libname <- libname +} + +# Utility function that returns TRUE if we have an active connection to the +# backend and FALSE otherwise +connExists <- function(env) { + tryCatch({ + exists(".sparkRCon", envir = env) && isOpen(env[[".sparkRCon"]]) + }, error = function(err) { + return(FALSE) + }) +} + +#' Stop the Spark context. +#' +#' Also terminates the backend this R session is connected to +sparkR.stop <- function() { + env <- .sparkREnv + if (exists(".sparkRCon", envir = env)) { + # cat("Stopping SparkR\n") + if (exists(".sparkRjsc", envir = env)) { + sc <- get(".sparkRjsc", envir = env) + callJMethod(sc, "stop") + rm(".sparkRjsc", envir = env) + } + + if (exists(".backendLaunched", envir = env)) { + callJStatic("SparkRHandler", "stopBackend") + } + + # Also close the connection and remove it from our env + conn <- get(".sparkRCon", envir = env) + close(conn) + + rm(".sparkRCon", envir = env) + rm(".scStartTime", envir = env) + } + + if (exists(".monitorConn", envir = env)) { + conn <- get(".monitorConn", envir = env) + close(conn) + rm(".monitorConn", envir = env) + } + + # Clear all broadcast variables we have + # as the jobj will not be valid if we restart the JVM + clearBroadcastVariables() + + # Clear jobj maps + clearJobjs() +} + +#' Initialize a new Spark Context. +#' +#' This function initializes a new SparkContext. +#' +#' @param master The Spark master URL. +#' @param appName Application name to register with cluster manager +#' @param sparkHome Spark Home directory +#' @param sparkEnvir Named list of environment variables to set on worker nodes. +#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors. +#' @param sparkJars Character string vector of jar files to pass to the worker nodes. +#' @param sparkRLibDir The path where R is installed on the worker nodes. +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init("local[2]", "SparkR", "/home/spark") +#' sc <- sparkR.init("local[2]", "SparkR", "/home/spark", +#' list(spark.executor.memory="1g")) +#' sc <- sparkR.init("yarn-client", "SparkR", "/home/spark", +#' list(spark.executor.memory="1g"), +#' list(LD_LIBRARY_PATH="/directory of JVM libraries (libjvm.so) on workers/"), +#' c("jarfile1.jar","jarfile2.jar")) +#'} + +sparkR.init <- function( + master = "", + appName = "SparkR", + sparkHome = Sys.getenv("SPARK_HOME"), + sparkEnvir = list(), + sparkExecutorEnv = list(), + sparkJars = "", + sparkRLibDir = "") { + + if (exists(".sparkRjsc", envir = .sparkREnv)) { + cat("Re-using existing Spark Context. Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n") + return(get(".sparkRjsc", envir = .sparkREnv)) + } + + sparkMem <- Sys.getenv("SPARK_MEM", "512m") + jars <- suppressWarnings(normalizePath(as.character(sparkJars))) + + # Classpath separator is ";" on Windows + # URI needs four /// as from http://stackoverflow.com/a/18522792 + if (.Platform$OS.type == "unix") { + collapseChar <- ":" + uriSep <- "//" + } else { + collapseChar <- ";" + uriSep <- "////" + } + + existingPort <- Sys.getenv("EXISTING_SPARKR_BACKEND_PORT", "") + if (existingPort != "") { + backendPort <- existingPort + } else { + path <- tempfile(pattern = "backend_port") + launchBackend( + args = path, + sparkHome = sparkHome, + jars = jars, + sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell")) + # wait atmost 100 seconds for JVM to launch + wait <- 0.1 + for (i in 1:25) { + Sys.sleep(wait) + if (file.exists(path)) { + break + } + wait <- wait * 1.25 + } + if (!file.exists(path)) { + stop("JVM is not ready after 10 seconds") + } + f <- file(path, open='rb') + backendPort <- readInt(f) + monitorPort <- readInt(f) + close(f) + file.remove(path) + if (length(backendPort) == 0 || backendPort == 0 || + length(monitorPort) == 0 || monitorPort == 0) { + stop("JVM failed to launch") + } + assign(".monitorConn", socketConnection(port = monitorPort), envir = .sparkREnv) + assign(".backendLaunched", 1, envir = .sparkREnv) + } + + .sparkREnv$backendPort <- backendPort + tryCatch({ + connectBackend("localhost", backendPort) + }, error = function(err) { + stop("Failed to connect JVM\n") + }) + + if (nchar(sparkHome) != 0) { + sparkHome <- normalizePath(sparkHome) + } + + if (nchar(sparkRLibDir) != 0) { + .sparkREnv$libname <- sparkRLibDir + } + + sparkEnvirMap <- new.env() + for (varname in names(sparkEnvir)) { + sparkEnvirMap[[varname]] <- sparkEnvir[[varname]] + } + + sparkExecutorEnvMap <- new.env() + if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) { + sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) + } + for (varname in names(sparkExecutorEnv)) { + sparkExecutorEnvMap[[varname]] <- sparkExecutorEnv[[varname]] + } + + nonEmptyJars <- Filter(function(x) { x != "" }, jars) + localJarPaths <- sapply(nonEmptyJars, function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) + + # Set the start time to identify jobjs + # Seconds resolution is good enough for this purpose, so use ints + assign(".scStartTime", as.integer(Sys.time()), envir = .sparkREnv) + + assign( + ".sparkRjsc", + callJStatic( + "org.apache.spark.api.r.RRDD", + "createSparkContext", + master, + appName, + as.character(sparkHome), + as.list(localJarPaths), + sparkEnvirMap, + sparkExecutorEnvMap), + envir = .sparkREnv + ) + + sc <- get(".sparkRjsc", envir = .sparkREnv) + + # Register a finalizer to sleep 1 seconds on R exit to make RStudio happy + reg.finalizer(.sparkREnv, function(x) { Sys.sleep(1) }, onexit = TRUE) + + sc +} + +#' Initialize a new SQLContext. +#' +#' This function creates a SparkContext from an existing JavaSparkContext and +#' then uses it to initialize a new SQLContext +#' +#' @param jsc The existing JavaSparkContext created with SparkR.init() +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#'} + +sparkRSQL.init <- function(jsc) { + if (exists(".sparkRSQLsc", envir = .sparkREnv)) { + return(get(".sparkRSQLsc", envir = .sparkREnv)) + } + + sqlCtx <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "createSQLContext", + jsc) + assign(".sparkRSQLsc", sqlCtx, envir = .sparkREnv) + sqlCtx +} + +#' Initialize a new HiveContext. +#' +#' This function creates a HiveContext from an existing JavaSparkContext +#' +#' @param jsc The existing JavaSparkContext created with SparkR.init() +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRHive.init(sc) +#'} + +sparkRHive.init <- function(jsc) { + if (exists(".sparkRHivesc", envir = .sparkREnv)) { + return(get(".sparkRHivesc", envir = .sparkREnv)) + } + + ssc <- callJMethod(jsc, "sc") + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.HiveContext", ssc) + }, error = function(err) { + stop("Spark SQL is not built with Hive support") + }) + + assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv) + hiveCtx +} diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R new file mode 100644 index 000000000000..0e7b7bd5a5b3 --- /dev/null +++ b/R/pkg/R/utils.R @@ -0,0 +1,547 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utilities and Helpers + +# Given a JList, returns an R list containing the same elements, the number +# of which is optionally upper bounded by `logicalUpperBound` (by default, +# return all elements). Takes care of deserializations and type conversions. +convertJListToRList <- function(jList, flatten, logicalUpperBound = NULL, + serializedMode = "byte") { + arrSize <- callJMethod(jList, "size") + + # Datasets with serializedMode == "string" (such as an RDD directly generated by textFile()): + # each partition is not dense-packed into one Array[Byte], and `arrSize` + # here corresponds to number of logical elements. Thus we can prune here. + if (serializedMode == "string" && !is.null(logicalUpperBound)) { + arrSize <- min(arrSize, logicalUpperBound) + } + + results <- if (arrSize > 0) { + lapply(0:(arrSize - 1), + function(index) { + obj <- callJMethod(jList, "get", as.integer(index)) + + # Assume it is either an R object or a Java obj ref. + if (inherits(obj, "jobj")) { + if (isInstanceOf(obj, "scala.Tuple2")) { + # JavaPairRDD[Array[Byte], Array[Byte]]. + + keyBytes = callJMethod(obj, "_1") + valBytes = callJMethod(obj, "_2") + res <- list(unserialize(keyBytes), + unserialize(valBytes)) + } else { + stop(paste("utils.R: convertJListToRList only supports", + "RDD[Array[Byte]] and", + "JavaPairRDD[Array[Byte], Array[Byte]] for now")) + } + } else { + if (inherits(obj, "raw")) { + if (serializedMode == "byte") { + # RDD[Array[Byte]]. `obj` is a whole partition. + res <- unserialize(obj) + # For serialized datasets, `obj` (and `rRaw`) here corresponds to + # one whole partition dense-packed together. We deserialize the + # whole partition first, then cap the number of elements to be returned. + } else if (serializedMode == "row") { + res <- readRowList(obj) + # For DataFrames that have been converted to RRDDs, we call readRowList + # which will read in each row of the RRDD as a list and deserialize + # each element. + flatten <<- FALSE + # Use global assignment to change the flatten flag. This means + # we don't have to worry about the default argument in other functions + # e.g. collect + } + # TODO: is it possible to distinguish element boundary so that we can + # unserialize only what we need? + if (!is.null(logicalUpperBound)) { + res <- head(res, n = logicalUpperBound) + } + } else { + # obj is of a primitive Java type, is simplified to R's + # corresponding type. + res <- list(obj) + } + } + res + }) + } else { + list() + } + + if (flatten) { + as.list(unlist(results, recursive = FALSE)) + } else { + as.list(results) + } +} + +# Returns TRUE if `name` refers to an RDD in the given environment `env` +isRDD <- function(name, env) { + obj <- get(name, envir = env) + inherits(obj, "RDD") +} + +#' Compute the hashCode of an object +#' +#' Java-style function to compute the hashCode for the given object. Returns +#' an integer value. +#' +#' @details +#' This only works for integer, numeric and character types right now. +#' +#' @param key the object to be hashed +#' @return the hash code as an integer +#' @export +#' @examples +#' hashCode(1L) # 1 +#' hashCode(1.0) # 1072693248 +#' hashCode("1") # 49 +hashCode <- function(key) { + if (class(key) == "integer") { + as.integer(key[[1]]) + } else if (class(key) == "numeric") { + # Convert the double to long and then calculate the hash code + rawVec <- writeBin(key[[1]], con = raw()) + intBits <- packBits(rawToBits(rawVec), "integer") + as.integer(bitwXor(intBits[2], intBits[1])) + } else if (class(key) == "character") { + .Call("stringHashCode", key) + } else { + warning(paste("Could not hash object, returning 0", sep = "")) + as.integer(0) + } +} + +# Create a new RDD with serializedMode == "byte". +# Return itself if already in "byte" format. +serializeToBytes <- function(rdd) { + if (!inherits(rdd, "RDD")) { + stop("Argument 'rdd' is not an RDD type.") + } + if (getSerializedMode(rdd) != "byte") { + ser.rdd <- lapply(rdd, function(x) { x }) + return(ser.rdd) + } else { + return(rdd) + } +} + +# Create a new RDD with serializedMode == "string". +# Return itself if already in "string" format. +serializeToString <- function(rdd) { + if (!inherits(rdd, "RDD")) { + stop("Argument 'rdd' is not an RDD type.") + } + if (getSerializedMode(rdd) != "string") { + ser.rdd <- lapply(rdd, function(x) { toString(x) }) + # force it to create jrdd using "string" + getJRDD(ser.rdd, serializedMode = "string") + return(ser.rdd) + } else { + return(rdd) + } +} + +# Fast append to list by using an accumulator. +# http://stackoverflow.com/questions/17046336/here-we-go-again-append-an-element-to-a-list-in-r +# +# The accumulator should has three fields size, counter and data. +# This function amortizes the allocation cost by doubling +# the size of the list every time it fills up. +addItemToAccumulator <- function(acc, item) { + if(acc$counter == acc$size) { + acc$size <- acc$size * 2 + length(acc$data) <- acc$size + } + acc$counter <- acc$counter + 1 + acc$data[[acc$counter]] <- item +} + +initAccumulator <- function() { + acc <- new.env() + acc$counter <- 0 + acc$data <- list(NULL) + acc$size <- 1 + acc +} + +# Utility function to sort a list of key value pairs +# Used in unit tests +sortKeyValueList <- function(kv_list, decreasing = FALSE) { + keys <- sapply(kv_list, function(x) x[[1]]) + kv_list[order(keys, decreasing = decreasing)] +} + +# Utility function to generate compact R lists from grouped rdd +# Used in Join-family functions +# param: +# tagged_list R list generated via groupByKey with tags(1L, 2L, ...) +# cnull Boolean list where each element determines whether the corresponding list should +# be converted to list(NULL) +genCompactLists <- function(tagged_list, cnull) { + len <- length(tagged_list) + lists <- list(vector("list", len), vector("list", len)) + index <- list(1, 1) + + for (x in tagged_list) { + tag <- x[[1]] + idx <- index[[tag]] + lists[[tag]][[idx]] <- x[[2]] + index[[tag]] <- idx + 1 + } + + len <- lapply(index, function(x) x - 1) + for (i in (1:2)) { + if (cnull[[i]] && len[[i]] == 0) { + lists[[i]] <- list(NULL) + } else { + length(lists[[i]]) <- len[[i]] + } + } + + lists +} + +# Utility function to merge compact R lists +# Used in Join-family functions +# param: +# left/right Two compact lists ready for Cartesian product +mergeCompactLists <- function(left, right) { + result <- list() + length(result) <- length(left) * length(right) + index <- 1 + for (i in left) { + for (j in right) { + result[[index]] <- list(i, j) + index <- index + 1 + } + } + result +} + +# Utility function to wrapper above two operations +# Used in Join-family functions +# param (same as genCompactLists): +# tagged_list R list generated via groupByKey with tags(1L, 2L, ...) +# cnull Boolean list where each element determines whether the corresponding list should +# be converted to list(NULL) +joinTaggedList <- function(tagged_list, cnull) { + lists <- genCompactLists(tagged_list, cnull) + mergeCompactLists(lists[[1]], lists[[2]]) +} + +# Utility function to reduce a key-value list with predicate +# Used in *ByKey functions +# param +# pair key-value pair +# keys/vals env of key/value with hashes +# updateOrCreatePred predicate function +# updateFn update or merge function for existing pair, similar with `mergeVal` @combineByKey +# createFn create function for new pair, similar with `createCombiner` @combinebykey +updateOrCreatePair <- function(pair, keys, vals, updateOrCreatePred, updateFn, createFn) { + # assume hashVal bind to `$hash`, key/val with index 1/2 + hashVal <- pair$hash + key <- pair[[1]] + val <- pair[[2]] + if (updateOrCreatePred(pair)) { + assign(hashVal, do.call(updateFn, list(get(hashVal, envir = vals), val)), envir = vals) + } else { + assign(hashVal, do.call(createFn, list(val)), envir = vals) + assign(hashVal, key, envir = keys) + } +} + +# Utility function to convert key&values envs into key-val list +convertEnvsToList <- function(keys, vals) { + lapply(ls(keys), + function(name) { + list(keys[[name]], vals[[name]]) + }) +} + +# Utility function to capture the varargs into environment object +varargsToEnv <- function(...) { + pairs <- as.list(substitute(list(...)))[-1L] + env <- new.env() + for (name in names(pairs)) { + env[[name]] <- pairs[[name]] + } + env +} + +getStorageLevel <- function(newLevel = c("DISK_ONLY", + "DISK_ONLY_2", + "MEMORY_AND_DISK", + "MEMORY_AND_DISK_2", + "MEMORY_AND_DISK_SER", + "MEMORY_AND_DISK_SER_2", + "MEMORY_ONLY", + "MEMORY_ONLY_2", + "MEMORY_ONLY_SER", + "MEMORY_ONLY_SER_2", + "OFF_HEAP")) { + match.arg(newLevel) + storageLevel <- switch(newLevel, + "DISK_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY"), + "DISK_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY_2"), + "MEMORY_AND_DISK" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK"), + "MEMORY_AND_DISK_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_2"), + "MEMORY_AND_DISK_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER"), + "MEMORY_AND_DISK_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER_2"), + "MEMORY_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY"), + "MEMORY_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_2"), + "MEMORY_ONLY_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER"), + "MEMORY_ONLY_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER_2"), + "OFF_HEAP" = callJStatic("org.apache.spark.storage.StorageLevel", "OFF_HEAP")) +} + +# Utility function for functions where an argument needs to be integer but we want to allow +# the user to type (for example) `5` instead of `5L` to avoid a confusing error message. +numToInt <- function(num) { + if (as.integer(num) != num) { + warning(paste("Coercing", as.list(sys.call())[[2]], "to integer.")) + } + as.integer(num) +} + +# create a Seq in JVM +toSeq <- function(...) { + callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", list(...)) +} + +# create a Seq in JVM from a list +listToSeq <- function(l) { + callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", l) +} + +# Utility function to recursively traverse the Abstract Syntax Tree (AST) of a +# user defined function (UDF), and to examine variables in the UDF to decide +# if their values should be included in the new function environment. +# param +# node The current AST node in the traversal. +# oldEnv The original function environment. +# defVars An Accumulator of variables names defined in the function's calling environment, +# including function argument and local variable names. +# checkedFunc An environment of function objects examined during cleanClosure. It can +# be considered as a "name"-to-"list of functions" mapping. +# newEnv A new function environment to store necessary function dependencies, an output argument. +processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { + nodeLen <- length(node) + + if (nodeLen > 1 && typeof(node) == "language") { + # Recursive case: current AST node is an internal node, check for its children. + if (length(node[[1]]) > 1) { + for (i in 1:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else { # if node[[1]] is length of 1, check for some R special functions. + nodeChar <- as.character(node[[1]]) + if (nodeChar == "{" || nodeChar == "(") { # Skip start symbol. + for (i in 2:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else if (nodeChar == "<-" || nodeChar == "=" || + nodeChar == "<<-") { # Assignment Ops. + defVar <- node[[2]] + if (length(defVar) == 1 && typeof(defVar) == "symbol") { + # Add the defined variable name into defVars. + addItemToAccumulator(defVars, as.character(defVar)) + } else { + processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv) + } + for (i in 3:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else if (nodeChar == "function") { # Function definition. + # Add parameter names. + newArgs <- names(node[[2]]) + lapply(newArgs, function(arg) { addItemToAccumulator(defVars, arg) }) + for (i in 3:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else if (nodeChar == "$") { # Skip the field. + processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv) + } else if (nodeChar == "::" || nodeChar == ":::") { + processClosure(node[[3]], oldEnv, defVars, checkedFuncs, newEnv) + } else { + for (i in 1:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } + } + } else if (nodeLen == 1 && + (typeof(node) == "symbol" || typeof(node) == "language")) { + # Base case: current AST node is a leaf node and a symbol or a function call. + nodeChar <- as.character(node) + if (!nodeChar %in% defVars$data) { # Not a function parameter or local variable. + func.env <- oldEnv + topEnv <- parent.env(.GlobalEnv) + # Search in function environment, and function's enclosing environments + # up to global environment. There is no need to look into package environments + # above the global or namespace environment that is not SparkR below the global, + # as they are assumed to be loaded on workers. + while (!identical(func.env, topEnv)) { + # Namespaces other than "SparkR" will not be searched. + if (!isNamespace(func.env) || + (getNamespaceName(func.env) == "SparkR" && + !(nodeChar %in% getNamespaceExports("SparkR")))) { # Only include SparkR internals. + # Set parameter 'inherits' to FALSE since we do not need to search in + # attached package environments. + if (tryCatch(exists(nodeChar, envir = func.env, inherits = FALSE), + error = function(e) { FALSE })) { + obj <- get(nodeChar, envir = func.env, inherits = FALSE) + if (is.function(obj)) { # If the node is a function call. + funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F, + ifnotfound = list(list(NULL)))[[1]] + found <- sapply(funcList, function(func) { + ifelse(identical(func, obj), TRUE, FALSE) + }) + if (sum(found) > 0) { # If function has been examined, ignore. + break + } + # Function has not been examined, record it and recursively clean its closure. + assign(nodeChar, + if (is.null(funcList[[1]])) { + list(obj) + } else { + append(funcList, obj) + }, + envir = checkedFuncs) + obj <- cleanClosure(obj, checkedFuncs) + } + assign(nodeChar, obj, envir = newEnv) + break + } + } + + # Continue to search in enclosure. + func.env <- parent.env(func.env) + } + } + } +} + +# Utility function to get user defined function (UDF) dependencies (closure). +# More specifically, this function captures the values of free variables defined +# outside a UDF, and stores them in the function's environment. +# param +# func A function whose closure needs to be captured. +# checkedFunc An environment of function objects examined during cleanClosure. It can be +# considered as a "name"-to-"list of functions" mapping. +# return value +# a new version of func that has an correct environment (closure). +cleanClosure <- function(func, checkedFuncs = new.env()) { + if (is.function(func)) { + newEnv <- new.env(parent = .GlobalEnv) + func.body <- body(func) + oldEnv <- environment(func) + # defVars is an Accumulator of variables names defined in the function's calling + # environment. First, function's arguments are added to defVars. + defVars <- initAccumulator() + argNames <- names(as.list(args(func))) + for (i in 1:(length(argNames) - 1)) { # Remove the ending NULL in pairlist. + addItemToAccumulator(defVars, argNames[i]) + } + # Recursively examine variables in the function body. + processClosure(func.body, oldEnv, defVars, checkedFuncs, newEnv) + environment(func) <- newEnv + } + func +} + +# Append partition lengths to each partition in two input RDDs if needed. +# param +# x An RDD. +# Other An RDD. +# return value +# A list of two result RDDs. +appendPartitionLengths <- function(x, other) { + if (getSerializedMode(x) != getSerializedMode(other) || + getSerializedMode(x) == "byte") { + # Append the number of elements in each partition to that partition so that we can later + # know the boundary of elements from x and other. + # + # Note that this appending also serves the purpose of reserialization, because even if + # any RDD is serialized, we need to reserialize it to make sure its partitions are encoded + # as a single byte array. For example, partitions of an RDD generated from partitionBy() + # may be encoded as multiple byte arrays. + appendLength <- function(part) { + len <- length(part) + part[[len + 1]] <- len + 1 + part + } + x <- lapplyPartition(x, appendLength) + other <- lapplyPartition(other, appendLength) + } + list (x, other) +} + +# Perform zip or cartesian between elements from two RDDs in each partition +# param +# rdd An RDD. +# zip A boolean flag indicating this call is for zip operation or not. +# return value +# A result RDD. +mergePartitions <- function(rdd, zip) { + serializerMode <- getSerializedMode(rdd) + partitionFunc <- function(partIndex, part) { + len <- length(part) + if (len > 0) { + if (serializerMode == "byte") { + lengthOfValues <- part[[len]] + lengthOfKeys <- part[[len - lengthOfValues]] + stopifnot(len == lengthOfKeys + lengthOfValues) + + # For zip operation, check if corresponding partitions of both RDDs have the same number of elements. + if (zip && lengthOfKeys != lengthOfValues) { + stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.") + } + + if (lengthOfKeys > 1) { + keys <- part[1 : (lengthOfKeys - 1)] + } else { + keys <- list() + } + if (lengthOfValues > 1) { + values <- part[(lengthOfKeys + 1) : (len - 1)] + } else { + values <- list() + } + + if (!zip) { + return(mergeCompactLists(keys, values)) + } + } else { + keys <- part[c(TRUE, FALSE)] + values <- part[c(FALSE, TRUE)] + } + mapply( + function(k, v) { list(k, v) }, + keys, + values, + SIMPLIFY = FALSE, + USE.NAMES = FALSE) + } else { + part + } + } + + PipelinedRDD(rdd, partitionFunc) +} diff --git a/R/pkg/R/zzz.R b/R/pkg/R/zzz.R new file mode 100644 index 000000000000..80d796d46794 --- /dev/null +++ b/R/pkg/R/zzz.R @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.onLoad <- function(libname, pkgname) { + sparkR.onLoad(libname, pkgname) +} + diff --git a/R/pkg/inst/profile/general.R b/R/pkg/inst/profile/general.R new file mode 100644 index 000000000000..8fe711b62208 --- /dev/null +++ b/R/pkg/inst/profile/general.R @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.First <- function() { + home <- Sys.getenv("SPARK_HOME") + .libPaths(c(file.path(home, "R", "lib"), .libPaths())) + Sys.setenv(NOAWT=1) +} diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R new file mode 100644 index 000000000000..7a7f2031152a --- /dev/null +++ b/R/pkg/inst/profile/shell.R @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.First <- function() { + home <- Sys.getenv("SPARK_HOME") + .libPaths(c(file.path(home, "R", "lib"), .libPaths())) + Sys.setenv(NOAWT=1) + + library(utils) + library(SparkR) + sc <- sparkR.init(Sys.getenv("MASTER", unset = "")) + assign("sc", sc, envir=.GlobalEnv) + sqlCtx <- sparkRSQL.init(sc) + assign("sqlCtx", sqlCtx, envir=.GlobalEnv) + cat("\n Welcome to SparkR!") + cat("\n Spark context is available as sc, SQL context is available as sqlCtx\n") +} diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/test_binaryFile.R new file mode 100644 index 000000000000..ca4218f3819f --- /dev/null +++ b/R/pkg/inst/tests/test_binaryFile.R @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("functions on binary files") + +# JavaSparkContext handle +sc <- sparkR.init() + +mockFile = c("Spark is pretty.", "Spark is awesome.") + +test_that("saveAsObjectFile()/objectFile() following textFile() works", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1, 1) + saveAsObjectFile(rdd, fileName2) + rdd <- objectFile(sc, fileName2) + expect_equal(collect(rdd), as.list(mockFile)) + + unlink(fileName1) + unlink(fileName2, recursive = TRUE) +}) + +test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + + l <- list(1, 2, 3) + rdd <- parallelize(sc, l, 1) + saveAsObjectFile(rdd, fileName) + rdd <- objectFile(sc, fileName) + expect_equal(collect(rdd), l) + + unlink(fileName, recursive = TRUE) +}) + +test_that("saveAsObjectFile()/objectFile() following RDD transformations works", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1) + + words <- flatMap(rdd, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + counts <- reduceByKey(wordCount, "+", 2L) + + saveAsObjectFile(counts, fileName2) + counts <- objectFile(sc, fileName2) + + output <- collect(counts) + expected <- list(list("awesome.", 1), list("Spark", 2), list("pretty.", 1), + list("is", 2)) + expect_equal(sortKeyValueList(output), sortKeyValueList(expected)) + + unlink(fileName1) + unlink(fileName2, recursive = TRUE) +}) + +test_that("saveAsObjectFile()/objectFile() works with multiple paths", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + + rdd1 <- parallelize(sc, "Spark is pretty.") + saveAsObjectFile(rdd1, fileName1) + rdd2 <- parallelize(sc, "Spark is awesome.") + saveAsObjectFile(rdd2, fileName2) + + rdd <- objectFile(sc, c(fileName1, fileName2)) + expect_true(count(rdd) == 2) + + unlink(fileName1, recursive = TRUE) + unlink(fileName2, recursive = TRUE) +}) + diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R new file mode 100644 index 000000000000..c15553ba2851 --- /dev/null +++ b/R/pkg/inst/tests/test_binary_function.R @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("binary functions") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Data +nums <- 1:10 +rdd <- parallelize(sc, nums, 2L) + +# File content +mockFile <- c("Spark is pretty.", "Spark is awesome.") + +test_that("union on two RDDs", { + actual <- collect(unionRDD(rdd, rdd)) + expect_equal(actual, as.list(rep(nums, 2))) + + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + text.rdd <- textFile(sc, fileName) + union.rdd <- unionRDD(rdd, text.rdd) + actual <- collect(union.rdd) + expect_equal(actual, c(as.list(nums), mockFile)) + expect_true(getSerializedMode(union.rdd) == "byte") + + rdd<- map(text.rdd, function(x) {x}) + union.rdd <- unionRDD(rdd, text.rdd) + actual <- collect(union.rdd) + expect_equal(actual, as.list(c(mockFile, mockFile))) + expect_true(getSerializedMode(union.rdd) == "byte") + + unlink(fileName) +}) + +test_that("cogroup on two RDDs", { + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) + rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) + cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) + actual <- collect(cogroup.rdd) + expect_equal(actual, + list(list(1, list(list(1), list(2, 3))), list(2, list(list(4), list())))) + + rdd1 <- parallelize(sc, list(list("a", 1), list("a", 4))) + rdd2 <- parallelize(sc, list(list("b", 2), list("a", 3))) + cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) + actual <- collect(cogroup.rdd) + + expected <- list(list("b", list(list(), list(2))), list("a", list(list(1, 4), list(3)))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) +}) diff --git a/R/pkg/inst/tests/test_broadcast.R b/R/pkg/inst/tests/test_broadcast.R new file mode 100644 index 000000000000..fee91a427d6d --- /dev/null +++ b/R/pkg/inst/tests/test_broadcast.R @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("broadcast variables") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Partitioned data +nums <- 1:2 +rrdd <- parallelize(sc, nums, 2L) + +test_that("using broadcast variable", { + randomMat <- matrix(nrow=10, ncol=10, data=rnorm(100)) + randomMatBr <- broadcast(sc, randomMat) + + useBroadcast <- function(x) { + sum(value(randomMatBr) * x) + } + actual <- collect(lapply(rrdd, useBroadcast)) + expected <- list(sum(randomMat) * 1, sum(randomMat) * 2) + expect_equal(actual, expected) +}) + +test_that("without using broadcast variable", { + randomMat <- matrix(nrow=10, ncol=10, data=rnorm(100)) + + useBroadcast <- function(x) { + sum(randomMat * x) + } + actual <- collect(lapply(rrdd, useBroadcast)) + expected <- list(sum(randomMat) * 1, sum(randomMat) * 2) + expect_equal(actual, expected) +}) diff --git a/R/pkg/inst/tests/test_context.R b/R/pkg/inst/tests/test_context.R new file mode 100644 index 000000000000..e4aab37436a7 --- /dev/null +++ b/R/pkg/inst/tests/test_context.R @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("test functions in sparkR.R") + +test_that("repeatedly starting and stopping SparkR", { + for (i in 1:4) { + sc <- sparkR.init() + rdd <- parallelize(sc, 1:20, 2L) + expect_equal(count(rdd), 20) + sparkR.stop() + } +}) + +test_that("rdd GC across sparkR.stop", { + sparkR.stop() + sc <- sparkR.init() # sc should get id 0 + rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 + rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 + sparkR.stop() + + sc <- sparkR.init() # sc should get id 0 again + + # GC rdd1 before creating rdd3 and rdd2 after + rm(rdd1) + gc() + + rdd3 <- parallelize(sc, 1:20, 2L) # rdd3 should get id 1 now + rdd4 <- parallelize(sc, 1:10, 2L) # rdd4 should get id 2 now + + rm(rdd2) + gc() + + count(rdd3) + count(rdd4) +}) diff --git a/R/pkg/inst/tests/test_includePackage.R b/R/pkg/inst/tests/test_includePackage.R new file mode 100644 index 000000000000..8152b448d087 --- /dev/null +++ b/R/pkg/inst/tests/test_includePackage.R @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("include R packages") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Partitioned data +nums <- 1:2 +rdd <- parallelize(sc, nums, 2L) + +test_that("include inside function", { + # Only run the test if plyr is installed. + if ("plyr" %in% rownames(installed.packages())) { + suppressPackageStartupMessages(library(plyr)) + generateData <- function(x) { + suppressPackageStartupMessages(library(plyr)) + attach(airquality) + result <- transform(Ozone, logOzone = log(Ozone)) + result + } + + data <- lapplyPartition(rdd, generateData) + actual <- collect(data) + } +}) + +test_that("use include package", { + # Only run the test if plyr is installed. + if ("plyr" %in% rownames(installed.packages())) { + suppressPackageStartupMessages(library(plyr)) + generateData <- function(x) { + attach(airquality) + result <- transform(Ozone, logOzone = log(Ozone)) + result + } + + includePackage(sc, plyr) + data <- lapplyPartition(rdd, generateData) + actual <- collect(data) + } +}) diff --git a/R/pkg/inst/tests/test_parallelize_collect.R b/R/pkg/inst/tests/test_parallelize_collect.R new file mode 100644 index 000000000000..fff028657db3 --- /dev/null +++ b/R/pkg/inst/tests/test_parallelize_collect.R @@ -0,0 +1,109 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("parallelize() and collect()") + +# Mock data +numVector <- c(-10:97) +numList <- list(sqrt(1), sqrt(2), sqrt(3), 4 ** 10) +strVector <- c("Dexter Morgan: I suppose I should be upset, even feel", + "violated, but I'm not. No, in fact, I think this is a friendly", + "message, like \"Hey, wanna play?\" and yes, I want to play. ", + "I really, really do.") +strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ", + "other times it helps me control the chaos.", + "Dexter Morgan: Harry and Dorris Morgan did a wonderful job ", + "raising me. But they're both dead now. I didn't kill them. Honest.") + +numPairs <- list(list(1, 1), list(1, 2), list(2, 2), list(2, 3)) +strPairs <- list(list(strList, strList), list(strList, strList)) + +# JavaSparkContext handle +jsc <- sparkR.init() + +# Tests + +test_that("parallelize() on simple vectors and lists returns an RDD", { + numVectorRDD <- parallelize(jsc, numVector, 1) + numVectorRDD2 <- parallelize(jsc, numVector, 10) + numListRDD <- parallelize(jsc, numList, 1) + numListRDD2 <- parallelize(jsc, numList, 4) + strVectorRDD <- parallelize(jsc, strVector, 2) + strVectorRDD2 <- parallelize(jsc, strVector, 3) + strListRDD <- parallelize(jsc, strList, 4) + strListRDD2 <- parallelize(jsc, strList, 1) + + rdds <- c(numVectorRDD, + numVectorRDD2, + numListRDD, + numListRDD2, + strVectorRDD, + strVectorRDD2, + strListRDD, + strListRDD2) + + for (rdd in rdds) { + expect_true(inherits(rdd, "RDD")) + expect_true(.hasSlot(rdd, "jrdd") + && inherits(rdd@jrdd, "jobj") + && isInstanceOf(rdd@jrdd, "org.apache.spark.api.java.JavaRDD")) + } +}) + +test_that("collect(), following a parallelize(), gives back the original collections", { + numVectorRDD <- parallelize(jsc, numVector, 10) + expect_equal(collect(numVectorRDD), as.list(numVector)) + + numListRDD <- parallelize(jsc, numList, 1) + numListRDD2 <- parallelize(jsc, numList, 4) + expect_equal(collect(numListRDD), as.list(numList)) + expect_equal(collect(numListRDD2), as.list(numList)) + + strVectorRDD <- parallelize(jsc, strVector, 2) + strVectorRDD2 <- parallelize(jsc, strVector, 3) + expect_equal(collect(strVectorRDD), as.list(strVector)) + expect_equal(collect(strVectorRDD2), as.list(strVector)) + + strListRDD <- parallelize(jsc, strList, 4) + strListRDD2 <- parallelize(jsc, strList, 1) + expect_equal(collect(strListRDD), as.list(strList)) + expect_equal(collect(strListRDD2), as.list(strList)) +}) + +test_that("regression: collect() following a parallelize() does not drop elements", { + # 10 %/% 6 = 1, ceiling(10 / 6) = 2 + collLen <- 10 + numPart <- 6 + expected <- runif(collLen) + actual <- collect(parallelize(jsc, expected, numPart)) + expect_equal(actual, as.list(expected)) +}) + +test_that("parallelize() and collect() work for lists of pairs (pairwise data)", { + # use the pairwise logical to indicate pairwise data + numPairsRDDD1 <- parallelize(jsc, numPairs, 1) + numPairsRDDD2 <- parallelize(jsc, numPairs, 2) + numPairsRDDD3 <- parallelize(jsc, numPairs, 3) + expect_equal(collect(numPairsRDDD1), numPairs) + expect_equal(collect(numPairsRDDD2), numPairs) + expect_equal(collect(numPairsRDDD3), numPairs) + # can also leave out the parameter name, if the params are supplied in order + strPairsRDDD1 <- parallelize(jsc, strPairs, 1) + strPairsRDDD2 <- parallelize(jsc, strPairs, 2) + expect_equal(collect(strPairsRDDD1), strPairs) + expect_equal(collect(strPairsRDDD2), strPairs) +}) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R new file mode 100644 index 000000000000..d55af93e3e50 --- /dev/null +++ b/R/pkg/inst/tests/test_rdd.R @@ -0,0 +1,784 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("basic RDD functions") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Data +nums <- 1:10 +rdd <- parallelize(sc, nums, 2L) + +intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) +intRdd <- parallelize(sc, intPairs, 2L) + +test_that("get number of partitions in RDD", { + expect_equal(numPartitions(rdd), 2) + expect_equal(numPartitions(intRdd), 2) +}) + +test_that("first on RDD", { + expect_true(first(rdd) == 1) + newrdd <- lapply(rdd, function(x) x + 1) + expect_true(first(newrdd) == 2) +}) + +test_that("count and length on RDD", { + expect_equal(count(rdd), 10) + expect_equal(length(rdd), 10) +}) + +test_that("count by values and keys", { + mods <- lapply(rdd, function(x) { x %% 3 }) + actual <- countByValue(mods) + expected <- list(list(0, 3L), list(1, 4L), list(2, 3L)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + actual <- countByKey(intRdd) + expected <- list(list(2L, 2L), list(1L, 2L)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("lapply on RDD", { + multiples <- lapply(rdd, function(x) { 2 * x }) + actual <- collect(multiples) + expect_equal(actual, as.list(nums * 2)) +}) + +test_that("lapplyPartition on RDD", { + sums <- lapplyPartition(rdd, function(part) { sum(unlist(part)) }) + actual <- collect(sums) + expect_equal(actual, list(15, 40)) +}) + +test_that("mapPartitions on RDD", { + sums <- mapPartitions(rdd, function(part) { sum(unlist(part)) }) + actual <- collect(sums) + expect_equal(actual, list(15, 40)) +}) + +test_that("flatMap() on RDDs", { + flat <- flatMap(intRdd, function(x) { list(x, x) }) + actual <- collect(flat) + expect_equal(actual, rep(intPairs, each=2)) +}) + +test_that("filterRDD on RDD", { + filtered.rdd <- filterRDD(rdd, function(x) { x %% 2 == 0 }) + actual <- collect(filtered.rdd) + expect_equal(actual, list(2, 4, 6, 8, 10)) + + filtered.rdd <- Filter(function(x) { x[[2]] < 0 }, intRdd) + actual <- collect(filtered.rdd) + expect_equal(actual, list(list(1L, -1))) + + # Filter out all elements. + filtered.rdd <- filterRDD(rdd, function(x) { x > 10 }) + actual <- collect(filtered.rdd) + expect_equal(actual, list()) +}) + +test_that("lookup on RDD", { + vals <- lookup(intRdd, 1L) + expect_equal(vals, list(-1, 200)) + + vals <- lookup(intRdd, 3L) + expect_equal(vals, list()) +}) + +test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { + rdd2 <- rdd + for (i in 1:12) + rdd2 <- lapplyPartitionsWithIndex( + rdd2, function(partIndex, part) { + part <- as.list(unlist(part) * partIndex + i) + }) + rdd2 <- lapply(rdd2, function(x) x + x) + actual <- collect(rdd2) + expected <- list(24, 24, 24, 24, 24, + 168, 170, 172, 174, 176) + expect_equal(actual, expected) +}) + +test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkpoint()", { + # RDD + rdd2 <- rdd + # PipelinedRDD + rdd2 <- lapplyPartitionsWithIndex( + rdd2, + function(partIndex, part) { + part <- as.list(unlist(part) * partIndex) + }) + + cache(rdd2) + expect_true(rdd2@env$isCached) + rdd2 <- lapply(rdd2, function(x) x) + expect_false(rdd2@env$isCached) + + unpersist(rdd2) + expect_false(rdd2@env$isCached) + + persist(rdd2, "MEMORY_AND_DISK") + expect_true(rdd2@env$isCached) + rdd2 <- lapply(rdd2, function(x) x) + expect_false(rdd2@env$isCached) + + unpersist(rdd2) + expect_false(rdd2@env$isCached) + + tempDir <- tempfile(pattern = "checkpoint") + setCheckpointDir(sc, tempDir) + checkpoint(rdd2) + expect_true(rdd2@env$isCheckpointed) + + rdd2 <- lapply(rdd2, function(x) x) + expect_false(rdd2@env$isCached) + expect_false(rdd2@env$isCheckpointed) + + # make sure the data is collectable + collect(rdd2) + + unlink(tempDir) +}) + +test_that("reduce on RDD", { + sum <- reduce(rdd, "+") + expect_equal(sum, 55) + + # Also test with an inline function + sumInline <- reduce(rdd, function(x, y) { x + y }) + expect_equal(sumInline, 55) +}) + +test_that("lapply with dependency", { + fa <- 5 + multiples <- lapply(rdd, function(x) { fa * x }) + actual <- collect(multiples) + + expect_equal(actual, as.list(nums * 5)) +}) + +test_that("lapplyPartitionsWithIndex on RDDs", { + func <- function(partIndex, part) { list(partIndex, Reduce("+", part)) } + actual <- collect(lapplyPartitionsWithIndex(rdd, func), flatten = FALSE) + expect_equal(actual, list(list(0, 15), list(1, 40))) + + pairsRDD <- parallelize(sc, list(list(1, 2), list(3, 4), list(4, 8)), 1L) + partitionByParity <- function(key) { if (key %% 2 == 1) 0 else 1 } + mkTup <- function(partIndex, part) { list(partIndex, part) } + actual <- collect(lapplyPartitionsWithIndex( + partitionBy(pairsRDD, 2L, partitionByParity), + mkTup), + FALSE) + expect_equal(actual, list(list(0, list(list(1, 2), list(3, 4))), + list(1, list(list(4, 8))))) +}) + +test_that("sampleRDD() on RDDs", { + expect_equal(unlist(collect(sampleRDD(rdd, FALSE, 1.0, 2014L))), nums) +}) + +test_that("takeSample() on RDDs", { + # ported from RDDSuite.scala, modified seeds + data <- parallelize(sc, 1:100, 2L) + for (seed in 4:5) { + s <- takeSample(data, FALSE, 20L, seed) + expect_equal(length(s), 20L) + expect_equal(length(unique(s)), 20L) + for (elem in s) { + expect_true(elem >= 1 && elem <= 100) + } + } + for (seed in 4:5) { + s <- takeSample(data, FALSE, 200L, seed) + expect_equal(length(s), 100L) + expect_equal(length(unique(s)), 100L) + for (elem in s) { + expect_true(elem >= 1 && elem <= 100) + } + } + for (seed in 4:5) { + s <- takeSample(data, TRUE, 20L, seed) + expect_equal(length(s), 20L) + for (elem in s) { + expect_true(elem >= 1 && elem <= 100) + } + } + for (seed in 4:5) { + s <- takeSample(data, TRUE, 100L, seed) + expect_equal(length(s), 100L) + # Chance of getting all distinct elements is astronomically low, so test we + # got < 100 + expect_true(length(unique(s)) < 100L) + } + for (seed in 4:5) { + s <- takeSample(data, TRUE, 200L, seed) + expect_equal(length(s), 200L) + # Chance of getting all distinct elements is still quite low, so test we + # got < 100 + expect_true(length(unique(s)) < 100L) + } +}) + +test_that("mapValues() on pairwise RDDs", { + multiples <- mapValues(intRdd, function(x) { x * 2 }) + actual <- collect(multiples) + expected <- lapply(intPairs, function(x) { + list(x[[1]], x[[2]] * 2) + }) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("flatMapValues() on pairwise RDDs", { + l <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) + actual <- collect(flatMapValues(l, function(x) { x })) + expect_equal(actual, list(list(1,1), list(1,2), list(2,3), list(2,4))) + + # Generate x to x+1 for every value + actual <- collect(flatMapValues(intRdd, function(x) { x:(x + 1) })) + expect_equal(actual, + list(list(1L, -1), list(1L, 0), list(2L, 100), list(2L, 101), + list(2L, 1), list(2L, 2), list(1L, 200), list(1L, 201))) +}) + +test_that("reduceByKeyLocally() on PairwiseRDDs", { + pairs <- parallelize(sc, list(list(1, 2), list(1.1, 3), list(1, 4)), 2L) + actual <- reduceByKeyLocally(pairs, "+") + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(1, 6), list(1.1, 3)))) + + pairs <- parallelize(sc, list(list("abc", 1.2), list(1.1, 0), list("abc", 1.3), + list("bb", 5)), 4L) + actual <- reduceByKeyLocally(pairs, "+") + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("abc", 2.5), list(1.1, 0), list("bb", 5)))) +}) + +test_that("distinct() on RDDs", { + nums.rep2 <- rep(1:10, 2) + rdd.rep2 <- parallelize(sc, nums.rep2, 2L) + uniques <- distinct(rdd.rep2) + actual <- sort(unlist(collect(uniques))) + expect_equal(actual, nums) +}) + +test_that("maximum() on RDDs", { + max <- maximum(rdd) + expect_equal(max, 10) +}) + +test_that("minimum() on RDDs", { + min <- minimum(rdd) + expect_equal(min, 1) +}) + +test_that("sumRDD() on RDDs", { + sum <- sumRDD(rdd) + expect_equal(sum, 55) +}) + +test_that("keyBy on RDDs", { + func <- function(x) { x*x } + keys <- keyBy(rdd, func) + actual <- collect(keys) + expect_equal(actual, lapply(nums, function(x) { list(func(x), x) })) +}) + +test_that("repartition/coalesce on RDDs", { + rdd <- parallelize(sc, 1:20, 4L) # each partition contains 5 elements + + # repartition + r1 <- repartition(rdd, 2) + expect_equal(numPartitions(r1), 2L) + count <- length(collectPartition(r1, 0L)) + expect_true(count >= 8 && count <= 12) + + r2 <- repartition(rdd, 6) + expect_equal(numPartitions(r2), 6L) + count <- length(collectPartition(r2, 0L)) + expect_true(count >=0 && count <= 4) + + # coalesce + r3 <- coalesce(rdd, 1) + expect_equal(numPartitions(r3), 1L) + count <- length(collectPartition(r3, 0L)) + expect_equal(count, 20) +}) + +test_that("sortBy() on RDDs", { + sortedRdd <- sortBy(rdd, function(x) { x * x }, ascending = FALSE) + actual <- collect(sortedRdd) + expect_equal(actual, as.list(sort(nums, decreasing = TRUE))) + + rdd2 <- parallelize(sc, sort(nums, decreasing = TRUE), 2L) + sortedRdd2 <- sortBy(rdd2, function(x) { x * x }) + actual <- collect(sortedRdd2) + expect_equal(actual, as.list(nums)) +}) + +test_that("takeOrdered() on RDDs", { + l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) + rdd <- parallelize(sc, l) + actual <- takeOrdered(rdd, 6L) + expect_equal(actual, as.list(sort(unlist(l)))[1:6]) + + l <- list("e", "d", "c", "d", "a") + rdd <- parallelize(sc, l) + actual <- takeOrdered(rdd, 3L) + expect_equal(actual, as.list(sort(unlist(l)))[1:3]) +}) + +test_that("top() on RDDs", { + l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) + rdd <- parallelize(sc, l) + actual <- top(rdd, 6L) + expect_equal(actual, as.list(sort(unlist(l), decreasing = TRUE))[1:6]) + + l <- list("e", "d", "c", "d", "a") + rdd <- parallelize(sc, l) + actual <- top(rdd, 3L) + expect_equal(actual, as.list(sort(unlist(l), decreasing = TRUE))[1:3]) +}) + +test_that("fold() on RDDs", { + actual <- fold(rdd, 0, "+") + expect_equal(actual, Reduce("+", nums, 0)) + + rdd <- parallelize(sc, list()) + actual <- fold(rdd, 0, "+") + expect_equal(actual, 0) +}) + +test_that("aggregateRDD() on RDDs", { + rdd <- parallelize(sc, list(1, 2, 3, 4)) + zeroValue <- list(0, 0) + seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } + combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } + actual <- aggregateRDD(rdd, zeroValue, seqOp, combOp) + expect_equal(actual, list(10, 4)) + + rdd <- parallelize(sc, list()) + actual <- aggregateRDD(rdd, zeroValue, seqOp, combOp) + expect_equal(actual, list(0, 0)) +}) + +test_that("zipWithUniqueId() on RDDs", { + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) + actual <- collect(zipWithUniqueId(rdd)) + expected <- list(list("a", 0), list("b", 3), list("c", 1), + list("d", 4), list("e", 2)) + expect_equal(actual, expected) + + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L) + actual <- collect(zipWithUniqueId(rdd)) + expected <- list(list("a", 0), list("b", 1), list("c", 2), + list("d", 3), list("e", 4)) + expect_equal(actual, expected) +}) + +test_that("zipWithIndex() on RDDs", { + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) + actual <- collect(zipWithIndex(rdd)) + expected <- list(list("a", 0), list("b", 1), list("c", 2), + list("d", 3), list("e", 4)) + expect_equal(actual, expected) + + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L) + actual <- collect(zipWithIndex(rdd)) + expected <- list(list("a", 0), list("b", 1), list("c", 2), + list("d", 3), list("e", 4)) + expect_equal(actual, expected) +}) + +test_that("glom() on RDD", { + rdd <- parallelize(sc, as.list(1:4), 2L) + actual <- collect(glom(rdd)) + expect_equal(actual, list(list(1, 2), list(3, 4))) +}) + +test_that("keys() on RDDs", { + keys <- keys(intRdd) + actual <- collect(keys) + expect_equal(actual, lapply(intPairs, function(x) { x[[1]] })) +}) + +test_that("values() on RDDs", { + values <- values(intRdd) + actual <- collect(values) + expect_equal(actual, lapply(intPairs, function(x) { x[[2]] })) +}) + +test_that("pipeRDD() on RDDs", { + actual <- collect(pipeRDD(rdd, "more")) + expected <- as.list(as.character(1:10)) + expect_equal(actual, expected) + + trailed.rdd <- parallelize(sc, c("1", "", "2\n", "3\n\r\n")) + actual <- collect(pipeRDD(trailed.rdd, "sort")) + expected <- list("", "1", "2", "3") + expect_equal(actual, expected) + + rev.nums <- 9:0 + rev.rdd <- parallelize(sc, rev.nums, 2L) + actual <- collect(pipeRDD(rev.rdd, "sort")) + expected <- as.list(as.character(c(5:9, 0:4))) + expect_equal(actual, expected) +}) + +test_that("zipRDD() on RDDs", { + rdd1 <- parallelize(sc, 0:4, 2) + rdd2 <- parallelize(sc, 1000:1004, 2) + actual <- collect(zipRDD(rdd1, rdd2)) + expect_equal(actual, + list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004))) + + mockFile = c("Spark is pretty.", "Spark is awesome.") + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName, 1) + actual <- collect(zipRDD(rdd, rdd)) + expected <- lapply(mockFile, function(x) { list(x ,x) }) + expect_equal(actual, expected) + + rdd1 <- parallelize(sc, 0:1, 1) + actual <- collect(zipRDD(rdd1, rdd)) + expected <- lapply(0:1, function(x) { list(x, mockFile[x + 1]) }) + expect_equal(actual, expected) + + rdd1 <- map(rdd, function(x) { x }) + actual <- collect(zipRDD(rdd, rdd1)) + expected <- lapply(mockFile, function(x) { list(x, x) }) + expect_equal(actual, expected) + + unlink(fileName) +}) + +test_that("cartesian() on RDDs", { + rdd <- parallelize(sc, 1:3) + actual <- collect(cartesian(rdd, rdd)) + expect_equal(sortKeyValueList(actual), + list( + list(1, 1), list(1, 2), list(1, 3), + list(2, 1), list(2, 2), list(2, 3), + list(3, 1), list(3, 2), list(3, 3))) + + # test case where one RDD is empty + emptyRdd <- parallelize(sc, list()) + actual <- collect(cartesian(rdd, emptyRdd)) + expect_equal(actual, list()) + + mockFile = c("Spark is pretty.", "Spark is awesome.") + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + actual <- collect(cartesian(rdd, rdd)) + expected <- list( + list("Spark is awesome.", "Spark is pretty."), + list("Spark is awesome.", "Spark is awesome."), + list("Spark is pretty.", "Spark is pretty."), + list("Spark is pretty.", "Spark is awesome.")) + expect_equal(sortKeyValueList(actual), expected) + + rdd1 <- parallelize(sc, 0:1) + actual <- collect(cartesian(rdd1, rdd)) + expect_equal(sortKeyValueList(actual), + list( + list(0, "Spark is pretty."), + list(0, "Spark is awesome."), + list(1, "Spark is pretty."), + list(1, "Spark is awesome."))) + + rdd1 <- map(rdd, function(x) { x }) + actual <- collect(cartesian(rdd, rdd1)) + expect_equal(sortKeyValueList(actual), expected) + + unlink(fileName) +}) + +test_that("subtract() on RDDs", { + l <- list(1, 1, 2, 2, 3, 4) + rdd1 <- parallelize(sc, l) + + # subtract by itself + actual <- collect(subtract(rdd1, rdd1)) + expect_equal(actual, list()) + + # subtract by an empty RDD + rdd2 <- parallelize(sc, list()) + actual <- collect(subtract(rdd1, rdd2)) + expect_equal(as.list(sort(as.vector(actual, mode="integer"))), + l) + + rdd2 <- parallelize(sc, list(2, 4)) + actual <- collect(subtract(rdd1, rdd2)) + expect_equal(as.list(sort(as.vector(actual, mode="integer"))), + list(1, 1, 3)) + + l <- list("a", "a", "b", "b", "c", "d") + rdd1 <- parallelize(sc, l) + rdd2 <- parallelize(sc, list("b", "d")) + actual <- collect(subtract(rdd1, rdd2)) + expect_equal(as.list(sort(as.vector(actual, mode="character"))), + list("a", "a", "c")) +}) + +test_that("subtractByKey() on pairwise RDDs", { + l <- list(list("a", 1), list("b", 4), + list("b", 5), list("a", 2)) + rdd1 <- parallelize(sc, l) + + # subtractByKey by itself + actual <- collect(subtractByKey(rdd1, rdd1)) + expect_equal(actual, list()) + + # subtractByKey by an empty RDD + rdd2 <- parallelize(sc, list()) + actual <- collect(subtractByKey(rdd1, rdd2)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(l)) + + rdd2 <- parallelize(sc, list(list("a", 3), list("c", 1))) + actual <- collect(subtractByKey(rdd1, rdd2)) + expect_equal(actual, + list(list("b", 4), list("b", 5))) + + l <- list(list(1, 1), list(2, 4), + list(2, 5), list(1, 2)) + rdd1 <- parallelize(sc, l) + rdd2 <- parallelize(sc, list(list(1, 3), list(3, 1))) + actual <- collect(subtractByKey(rdd1, rdd2)) + expect_equal(actual, + list(list(2, 4), list(2, 5))) +}) + +test_that("intersection() on RDDs", { + # intersection with self + actual <- collect(intersection(rdd, rdd)) + expect_equal(sort(as.integer(actual)), nums) + + # intersection with an empty RDD + emptyRdd <- parallelize(sc, list()) + actual <- collect(intersection(rdd, emptyRdd)) + expect_equal(actual, list()) + + rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5)) + rdd2 <- parallelize(sc, list(1, 6, 2, 3, 7, 8)) + actual <- collect(intersection(rdd1, rdd2)) + expect_equal(sort(as.integer(actual)), 1:3) +}) + +test_that("join() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,1), list(2,4))) + rdd2 <- parallelize(sc, list(list(1,2), list(1,3))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(1, list(1, 2)), list(1, list(1, 3))))) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",4))) + rdd2 <- parallelize(sc, list(list("a",2), list("a",3))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("a", list(1, 2)), list("a", list(1, 3))))) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(actual, list()) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(actual, list()) +}) + +test_that("leftOuterJoin() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,1), list(2,4))) + rdd2 <- parallelize(sc, list(list(1,2), list(1,3))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",4))) + rdd2 <- parallelize(sc, list(list("a",2), list("a",3))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(4, NULL)), list("a", list(1, 2)), list("a", list(1, 3))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(1, NULL)), list(2, list(2, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(2, NULL)), list("a", list(1, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) +}) + +test_that("rightOuterJoin() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,2), list(1,3))) + rdd2 <- parallelize(sc, list(list(1,1), list(2,4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",2), list("a",3))) + rdd2 <- parallelize(sc, list(list("a",1), list("b",4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(3, list(NULL, 3)), list(4, list(NULL, 4))))) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("d", list(NULL, 4)), list("c", list(NULL, 3))))) +}) + +test_that("fullOuterJoin() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,2), list(1,3), list(3,3))) + rdd2 <- parallelize(sc, list(list(1,1), list(2,4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4)), list(3, list(3, NULL))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",2), list("a",3), list("c", 1))) + rdd2 <- parallelize(sc, list(list("a",1), list("b",4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1)), list("c", list(1, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), list(3, list(NULL, 3)), list(4, list(NULL, 4))))) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), list("d", list(NULL, 4)), list("c", list(NULL, 3))))) +}) + +test_that("sortByKey() on pairwise RDDs", { + numPairsRdd <- map(rdd, function(x) { list (x, x) }) + sortedRdd <- sortByKey(numPairsRdd, ascending = FALSE) + actual <- collect(sortedRdd) + numPairs <- lapply(nums, function(x) { list (x, x) }) + expect_equal(actual, sortKeyValueList(numPairs, decreasing = TRUE)) + + rdd2 <- parallelize(sc, sort(nums, decreasing = TRUE), 2L) + numPairsRdd2 <- map(rdd2, function(x) { list (x, x) }) + sortedRdd2 <- sortByKey(numPairsRdd2) + actual <- collect(sortedRdd2) + expect_equal(actual, numPairs) + + # sort by string keys + l <- list(list("a", 1), list("b", 2), list("1", 3), list("d", 4), list("2", 5)) + rdd3 <- parallelize(sc, l, 2L) + sortedRdd3 <- sortByKey(rdd3) + actual <- collect(sortedRdd3) + expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) + + # test on the boundary cases + + # boundary case 1: the RDD to be sorted has only 1 partition + rdd4 <- parallelize(sc, l, 1L) + sortedRdd4 <- sortByKey(rdd4) + actual <- collect(sortedRdd4) + expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) + + # boundary case 2: the sorted RDD has only 1 partition + rdd5 <- parallelize(sc, l, 2L) + sortedRdd5 <- sortByKey(rdd5, numPartitions = 1L) + actual <- collect(sortedRdd5) + expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) + + # boundary case 3: the RDD to be sorted has only 1 element + l2 <- list(list("a", 1)) + rdd6 <- parallelize(sc, l2, 2L) + sortedRdd6 <- sortByKey(rdd6) + actual <- collect(sortedRdd6) + expect_equal(actual, l2) + + # boundary case 4: the RDD to be sorted has 0 element + l3 <- list() + rdd7 <- parallelize(sc, l3, 2L) + sortedRdd7 <- sortByKey(rdd7) + actual <- collect(sortedRdd7) + expect_equal(actual, l3) +}) + +test_that("collectAsMap() on a pairwise RDD", { + rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(`1` = 2, `3` = 4)) + + rdd <- parallelize(sc, list(list("a", 1), list("b", 2))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(a = 1, b = 2)) + + rdd <- parallelize(sc, list(list(1.1, 2.2), list(1.2, 2.4))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(`1.1` = 2.2, `1.2` = 2.4)) + + rdd <- parallelize(sc, list(list(1, "a"), list(2, "b"))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(`1` = "a", `2` = "b")) +}) + +test_that("sampleByKey() on pairwise RDDs", { + rdd <- parallelize(sc, 1:2000) + pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list("a", x) else list("b", x) }) + fractions <- list(a = 0.2, b = 0.1) + sample <- sampleByKey(pairsRDD, FALSE, fractions, 1618L) + expect_equal(100 < length(lookup(sample, "a")) && 300 > length(lookup(sample, "a")), TRUE) + expect_equal(50 < length(lookup(sample, "b")) && 150 > length(lookup(sample, "b")), TRUE) + expect_equal(lookup(sample, "a")[which.min(lookup(sample, "a"))] >= 0, TRUE) + expect_equal(lookup(sample, "a")[which.max(lookup(sample, "a"))] <= 2000, TRUE) + expect_equal(lookup(sample, "b")[which.min(lookup(sample, "b"))] >= 0, TRUE) + expect_equal(lookup(sample, "b")[which.max(lookup(sample, "b"))] <= 2000, TRUE) + + rdd <- parallelize(sc, 1:2000) + pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list(2, x) else list(3, x) }) + fractions <- list(`2` = 0.2, `3` = 0.1) + sample <- sampleByKey(pairsRDD, TRUE, fractions, 1618L) + expect_equal(100 < length(lookup(sample, 2)) && 300 > length(lookup(sample, 2)), TRUE) + expect_equal(50 < length(lookup(sample, 3)) && 150 > length(lookup(sample, 3)), TRUE) + expect_equal(lookup(sample, 2)[which.min(lookup(sample, 2))] >= 0, TRUE) + expect_equal(lookup(sample, 2)[which.max(lookup(sample, 2))] <= 2000, TRUE) + expect_equal(lookup(sample, 3)[which.min(lookup(sample, 3))] >= 0, TRUE) + expect_equal(lookup(sample, 3)[which.max(lookup(sample, 3))] <= 2000, TRUE) +}) diff --git a/R/pkg/inst/tests/test_shuffle.R b/R/pkg/inst/tests/test_shuffle.R new file mode 100644 index 000000000000..d7dedda553c5 --- /dev/null +++ b/R/pkg/inst/tests/test_shuffle.R @@ -0,0 +1,221 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("partitionBy, groupByKey, reduceByKey etc.") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Data +intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) +intRdd <- parallelize(sc, intPairs, 2L) + +doublePairs <- list(list(1.5, -1), list(2.5, 100), list(2.5, 1), list(1.5, 200)) +doubleRdd <- parallelize(sc, doublePairs, 2L) + +numPairs <- list(list(1L, 100), list(2L, 200), list(4L, -1), list(3L, 1), + list(3L, 0)) +numPairsRdd <- parallelize(sc, numPairs, length(numPairs)) + +strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge and ", + "Dexter Morgan: Harry and Dorris Morgan did a wonderful job ") +strListRDD <- parallelize(sc, strList, 4) + +test_that("groupByKey for integers", { + grouped <- groupByKey(intRdd, 2L) + + actual <- collect(grouped) + + expected <- list(list(2L, list(100, 1)), list(1L, list(-1, 200))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("groupByKey for doubles", { + grouped <- groupByKey(doubleRdd, 2L) + + actual <- collect(grouped) + + expected <- list(list(1.5, list(-1, 200)), list(2.5, list(100, 1))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("reduceByKey for ints", { + reduced <- reduceByKey(intRdd, "+", 2L) + + actual <- collect(reduced) + + expected <- list(list(2L, 101), list(1L, 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("reduceByKey for doubles", { + reduced <- reduceByKey(doubleRdd, "+", 2L) + actual <- collect(reduced) + + expected <- list(list(1.5, 199), list(2.5, 101)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("combineByKey for ints", { + reduced <- combineByKey(intRdd, function(x) { x }, "+", "+", 2L) + + actual <- collect(reduced) + + expected <- list(list(2L, 101), list(1L, 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("combineByKey for doubles", { + reduced <- combineByKey(doubleRdd, function(x) { x }, "+", "+", 2L) + actual <- collect(reduced) + + expected <- list(list(1.5, 199), list(2.5, 101)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("combineByKey for characters", { + stringKeyRDD <- parallelize(sc, + list(list("max", 1L), list("min", 2L), + list("other", 3L), list("max", 4L)), 2L) + reduced <- combineByKey(stringKeyRDD, + function(x) { x }, "+", "+", 2L) + actual <- collect(reduced) + + expected <- list(list("max", 5L), list("min", 2L), list("other", 3L)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("aggregateByKey", { + # test aggregateByKey for int keys + rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) + + zeroValue <- list(0, 0) + seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } + combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } + aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) + + actual <- collect(aggregatedRDD) + + expected <- list(list(1, list(3, 2)), list(2, list(7, 2))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test aggregateByKey for string keys + rdd <- parallelize(sc, list(list("a", 1), list("a", 2), list("b", 3), list("b", 4))) + + zeroValue <- list(0, 0) + seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } + combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } + aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) + + actual <- collect(aggregatedRDD) + + expected <- list(list("a", list(3, 2)), list("b", list(7, 2))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("foldByKey", { + # test foldByKey for int keys + folded <- foldByKey(intRdd, 0, "+", 2L) + + actual <- collect(folded) + + expected <- list(list(2L, 101), list(1L, 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test foldByKey for double keys + folded <- foldByKey(doubleRdd, 0, "+", 2L) + + actual <- collect(folded) + + expected <- list(list(1.5, 199), list(2.5, 101)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test foldByKey for string keys + stringKeyPairs <- list(list("a", -1), list("b", 100), list("b", 1), list("a", 200)) + + stringKeyRDD <- parallelize(sc, stringKeyPairs) + folded <- foldByKey(stringKeyRDD, 0, "+", 2L) + + actual <- collect(folded) + + expected <- list(list("b", 101), list("a", 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test foldByKey for empty pair RDD + rdd <- parallelize(sc, list()) + folded <- foldByKey(rdd, 0, "+", 2L) + actual <- collect(folded) + expected <- list() + expect_equal(actual, expected) + + # test foldByKey for RDD with only 1 pair + rdd <- parallelize(sc, list(list(1, 1))) + folded <- foldByKey(rdd, 0, "+", 2L) + actual <- collect(folded) + expected <- list(list(1, 1)) + expect_equal(actual, expected) +}) + +test_that("partitionBy() partitions data correctly", { + # Partition by magnitude + partitionByMagnitude <- function(key) { if (key >= 3) 1 else 0 } + + resultRDD <- partitionBy(numPairsRdd, 2L, partitionByMagnitude) + + expected_first <- list(list(1, 100), list(2, 200)) # key < 3 + expected_second <- list(list(4, -1), list(3, 1), list(3, 0)) # key >= 3 + actual_first <- collectPartition(resultRDD, 0L) + actual_second <- collectPartition(resultRDD, 1L) + + expect_equal(sortKeyValueList(actual_first), sortKeyValueList(expected_first)) + expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second)) +}) + +test_that("partitionBy works with dependencies", { + kOne <- 1 + partitionByParity <- function(key) { if (key %% 2 == kOne) 7 else 4 } + + # Partition by parity + resultRDD <- partitionBy(numPairsRdd, numPartitions = 2L, partitionByParity) + + # keys even; 100 %% 2 == 0 + expected_first <- list(list(2, 200), list(4, -1)) + # keys odd; 3 %% 2 == 1 + expected_second <- list(list(1, 100), list(3, 1), list(3, 0)) + actual_first <- collectPartition(resultRDD, 0L) + actual_second <- collectPartition(resultRDD, 1L) + + expect_equal(sortKeyValueList(actual_first), sortKeyValueList(expected_first)) + expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second)) +}) + +test_that("test partitionBy with string keys", { + words <- flatMap(strListRDD, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + resultRDD <- partitionBy(wordCount, 2L) + expected_first <- list(list("Dexter", 1), list("Dexter", 1)) + expected_second <- list(list("and", 1), list("and", 1)) + + actual_first <- Filter(function(item) { item[[1]] == "Dexter" }, + collectPartition(resultRDD, 0L)) + actual_second <- Filter(function(item) { item[[1]] == "and" }, + collectPartition(resultRDD, 1L)) + + expect_equal(sortKeyValueList(actual_first), sortKeyValueList(expected_first)) + expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second)) +}) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R new file mode 100644 index 000000000000..af7a6c582047 --- /dev/null +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -0,0 +1,709 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("SparkSQL functions") + +# Tests for SparkSQL functions in SparkR + +sc <- sparkR.init() + +sqlCtx <- sparkRSQL.init(sc) + +mockLines <- c("{\"name\":\"Michael\"}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Justin\", \"age\":19}") +jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") +parquetPath <- tempfile(pattern="sparkr-test", fileext=".parquet") +writeLines(mockLines, jsonPath) + +test_that("infer types", { + expect_equal(infer_type(1L), "integer") + expect_equal(infer_type(1.0), "double") + expect_equal(infer_type("abc"), "string") + expect_equal(infer_type(TRUE), "boolean") + expect_equal(infer_type(as.Date("2015-03-11")), "date") + expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp") + expect_equal(infer_type(c(1L, 2L)), + list(type = 'array', elementType = "integer", containsNull = TRUE)) + expect_equal(infer_type(list(1L, 2L)), + list(type = 'array', elementType = "integer", containsNull = TRUE)) + expect_equal(infer_type(list(a = 1L, b = "2")), + structType(structField(x = "a", type = "integer", nullable = TRUE), + structField(x = "b", type = "string", nullable = TRUE))) + e <- new.env() + assign("a", 1L, envir = e) + expect_equal(infer_type(e), + list(type = "map", keyType = "string", valueType = "integer", + valueContainsNull = TRUE)) +}) + +test_that("structType and structField", { + testField <- structField("a", "string") + expect_true(inherits(testField, "structField")) + expect_true(testField$name() == "a") + expect_true(testField$nullable()) + + testSchema <- structType(testField, structField("b", "integer")) + expect_true(inherits(testSchema, "structType")) + expect_true(inherits(testSchema$fields()[[2]], "structField")) + expect_true(testSchema$fields()[[1]]$dataType.toString() == "StringType") +}) + +test_that("create DataFrame from RDD", { + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) + df <- createDataFrame(sqlCtx, rdd, list("a", "b")) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + df <- createDataFrame(sqlCtx, rdd) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("_1", "_2")) + + schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), + structField(x = "b", type = "string", nullable = TRUE)) + df <- createDataFrame(sqlCtx, rdd, schema) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) + df <- createDataFrame(sqlCtx, rdd) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) +}) + +test_that("toDF", { + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) + df <- toDF(rdd, list("a", "b")) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + df <- toDF(rdd) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("_1", "_2")) + + schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), + structField(x = "b", type = "string", nullable = TRUE)) + df <- toDF(rdd, schema) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) + df <- toDF(rdd) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) +}) + +test_that("create DataFrame from list or data.frame", { + l <- list(list(1, 2), list(3, 4)) + df <- createDataFrame(sqlCtx, l, c("a", "b")) + expect_equal(columns(df), c("a", "b")) + + l <- list(list(a=1, b=2), list(a=3, b=4)) + df <- createDataFrame(sqlCtx, l) + expect_equal(columns(df), c("a", "b")) + + a <- 1:3 + b <- c("a", "b", "c") + ldf <- data.frame(a, b) + df <- createDataFrame(sqlCtx, ldf) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + expect_equal(count(df), 3) + ldf2 <- collect(df) + expect_equal(ldf$a, ldf2$a) +}) + +test_that("create DataFrame with different data types", { + l <- list(a = 1L, b = 2, c = TRUE, d = "ss", e = as.Date("2012-12-13"), + f = as.POSIXct("2015-03-15 12:13:14.056")) + df <- createDataFrame(sqlCtx, list(l)) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "double"), c("c", "boolean"), + c("d", "string"), c("e", "date"), c("f", "timestamp"))) + expect_equal(count(df), 1) + expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) +}) + +# TODO: enable this test after fix serialization for nested object +#test_that("create DataFrame with nested array and struct", { +# e <- new.env() +# assign("n", 3L, envir = e) +# l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L)) +# df <- createDataFrame(sqlCtx, list(l), c("a", "b", "c", "d")) +# expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"), +# c("c", "map"), c("d", "struct"))) +# expect_equal(count(df), 1) +# ldf <- collect(df) +# expect_equal(ldf[1,], l[[1]]) +#}) + +test_that("jsonFile() on a local file returns a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 3) +}) + +test_that("jsonRDD() on a RDD with json string", { + rdd <- parallelize(sc, mockLines) + expect_true(count(rdd) == 3) + df <- jsonRDD(sqlCtx, rdd) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 3) + + rdd2 <- flatMap(rdd, function(x) c(x, x)) + df <- jsonRDD(sqlCtx, rdd2) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 6) +}) + +test_that("test cache, uncache and clearCache", { + df <- jsonFile(sqlCtx, jsonPath) + registerTempTable(df, "table1") + cacheTable(sqlCtx, "table1") + uncacheTable(sqlCtx, "table1") + clearCache(sqlCtx) + dropTempTable(sqlCtx, "table1") +}) + +test_that("test tableNames and tables", { + df <- jsonFile(sqlCtx, jsonPath) + registerTempTable(df, "table1") + expect_true(length(tableNames(sqlCtx)) == 1) + df <- tables(sqlCtx) + expect_true(count(df) == 1) + dropTempTable(sqlCtx, "table1") +}) + +test_that("registerTempTable() results in a queryable table and sql() results in a new DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + registerTempTable(df, "table1") + newdf <- sql(sqlCtx, "SELECT * FROM table1 where name = 'Michael'") + expect_true(inherits(newdf, "DataFrame")) + expect_true(count(newdf) == 1) + dropTempTable(sqlCtx, "table1") +}) + +test_that("insertInto() on a registered table", { + df <- loadDF(sqlCtx, jsonPath, "json") + saveDF(df, parquetPath, "parquet", "overwrite") + dfParquet <- loadDF(sqlCtx, parquetPath, "parquet") + + lines <- c("{\"name\":\"Bob\", \"age\":24}", + "{\"name\":\"James\", \"age\":35}") + jsonPath2 <- tempfile(pattern="jsonPath2", fileext=".tmp") + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + writeLines(lines, jsonPath2) + df2 <- loadDF(sqlCtx, jsonPath2, "json") + saveDF(df2, parquetPath2, "parquet", "overwrite") + dfParquet2 <- loadDF(sqlCtx, parquetPath2, "parquet") + + registerTempTable(dfParquet, "table1") + insertInto(dfParquet2, "table1") + expect_true(count(sql(sqlCtx, "select * from table1")) == 5) + expect_true(first(sql(sqlCtx, "select * from table1 order by age"))$name == "Michael") + dropTempTable(sqlCtx, "table1") + + registerTempTable(dfParquet, "table1") + insertInto(dfParquet2, "table1", overwrite = TRUE) + expect_true(count(sql(sqlCtx, "select * from table1")) == 2) + expect_true(first(sql(sqlCtx, "select * from table1 order by age"))$name == "Bob") + dropTempTable(sqlCtx, "table1") +}) + +test_that("table() returns a new DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + registerTempTable(df, "table1") + tabledf <- table(sqlCtx, "table1") + expect_true(inherits(tabledf, "DataFrame")) + expect_true(count(tabledf) == 3) + dropTempTable(sqlCtx, "table1") +}) + +test_that("toRDD() returns an RRDD", { + df <- jsonFile(sqlCtx, jsonPath) + testRDD <- toRDD(df) + expect_true(inherits(testRDD, "RDD")) + expect_true(count(testRDD) == 3) +}) + +test_that("union on two RDDs created from DataFrames returns an RRDD", { + df <- jsonFile(sqlCtx, jsonPath) + RDD1 <- toRDD(df) + RDD2 <- toRDD(df) + unioned <- unionRDD(RDD1, RDD2) + expect_true(inherits(unioned, "RDD")) + expect_true(SparkR:::getSerializedMode(unioned) == "byte") + expect_true(collect(unioned)[[2]]$name == "Andy") +}) + +test_that("union on mixed serialization types correctly returns a byte RRDD", { + # Byte RDD + nums <- 1:10 + rdd <- parallelize(sc, nums, 2L) + + # String RDD + textLines <- c("Michael", + "Andy, 30", + "Justin, 19") + textPath <- tempfile(pattern="sparkr-textLines", fileext=".tmp") + writeLines(textLines, textPath) + textRDD <- textFile(sc, textPath) + + df <- jsonFile(sqlCtx, jsonPath) + dfRDD <- toRDD(df) + + unionByte <- unionRDD(rdd, dfRDD) + expect_true(inherits(unionByte, "RDD")) + expect_true(SparkR:::getSerializedMode(unionByte) == "byte") + expect_true(collect(unionByte)[[1]] == 1) + expect_true(collect(unionByte)[[12]]$name == "Andy") + + unionString <- unionRDD(textRDD, dfRDD) + expect_true(inherits(unionString, "RDD")) + expect_true(SparkR:::getSerializedMode(unionString) == "byte") + expect_true(collect(unionString)[[1]] == "Michael") + expect_true(collect(unionString)[[5]]$name == "Andy") +}) + +test_that("objectFile() works with row serialization", { + objectPath <- tempfile(pattern="spark-test", fileext=".tmp") + df <- jsonFile(sqlCtx, jsonPath) + dfRDD <- toRDD(df) + saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) + objectIn <- objectFile(sc, objectPath) + + expect_true(inherits(objectIn, "RDD")) + expect_equal(SparkR:::getSerializedMode(objectIn), "byte") + expect_equal(collect(objectIn)[[2]]$age, 30) +}) + +test_that("lapply() on a DataFrame returns an RDD with the correct columns", { + df <- jsonFile(sqlCtx, jsonPath) + testRDD <- lapply(df, function(row) { + row$newCol <- row$age + 5 + row + }) + expect_true(inherits(testRDD, "RDD")) + collected <- collect(testRDD) + expect_true(collected[[1]]$name == "Michael") + expect_true(collected[[2]]$newCol == "35") +}) + +test_that("collect() returns a data.frame", { + df <- jsonFile(sqlCtx, jsonPath) + rdf <- collect(df) + expect_true(is.data.frame(rdf)) + expect_true(names(rdf)[1] == "age") + expect_true(nrow(rdf) == 3) + expect_true(ncol(rdf) == 2) +}) + +test_that("limit() returns DataFrame with the correct number of rows", { + df <- jsonFile(sqlCtx, jsonPath) + dfLimited <- limit(df, 2) + expect_true(inherits(dfLimited, "DataFrame")) + expect_true(count(dfLimited) == 2) +}) + +test_that("collect() and take() on a DataFrame return the same number of rows and columns", { + df <- jsonFile(sqlCtx, jsonPath) + expect_true(nrow(collect(df)) == nrow(take(df, 10))) + expect_true(ncol(collect(df)) == ncol(take(df, 10))) +}) + +test_that("multiple pipeline transformations starting with a DataFrame result in an RDD with the correct values", { + df <- jsonFile(sqlCtx, jsonPath) + first <- lapply(df, function(row) { + row$age <- row$age + 5 + row + }) + second <- lapply(first, function(row) { + row$testCol <- if (row$age == 35 && !is.na(row$age)) TRUE else FALSE + row + }) + expect_true(inherits(second, "RDD")) + expect_true(count(second) == 3) + expect_true(collect(second)[[2]]$age == 35) + expect_true(collect(second)[[2]]$testCol) + expect_false(collect(second)[[3]]$testCol) +}) + +test_that("cache(), persist(), and unpersist() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + expect_false(df@env$isCached) + cache(df) + expect_true(df@env$isCached) + + unpersist(df) + expect_false(df@env$isCached) + + persist(df, "MEMORY_AND_DISK") + expect_true(df@env$isCached) + + unpersist(df) + expect_false(df@env$isCached) + + # make sure the data is collectable + expect_true(is.data.frame(collect(df))) +}) + +test_that("schema(), dtypes(), columns(), names() return the correct values/format", { + df <- jsonFile(sqlCtx, jsonPath) + testSchema <- schema(df) + expect_true(length(testSchema$fields()) == 2) + expect_true(testSchema$fields()[[1]]$dataType.toString() == "LongType") + expect_true(testSchema$fields()[[2]]$dataType.simpleString() == "string") + expect_true(testSchema$fields()[[1]]$name() == "age") + + testTypes <- dtypes(df) + expect_true(length(testTypes[[1]]) == 2) + expect_true(testTypes[[1]][1] == "age") + + testCols <- columns(df) + expect_true(length(testCols) == 2) + expect_true(testCols[2] == "name") + + testNames <- names(df) + expect_true(length(testNames) == 2) + expect_true(testNames[2] == "name") +}) + +test_that("head() and first() return the correct data", { + df <- jsonFile(sqlCtx, jsonPath) + testHead <- head(df) + expect_true(nrow(testHead) == 3) + expect_true(ncol(testHead) == 2) + + testHead2 <- head(df, 2) + expect_true(nrow(testHead2) == 2) + expect_true(ncol(testHead2) == 2) + + testFirst <- first(df) + expect_true(nrow(testFirst) == 1) +}) + +test_that("distinct() on DataFrames", { + lines <- c("{\"name\":\"Michael\"}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Justin\", \"age\":19}", + "{\"name\":\"Justin\", \"age\":19}") + jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPathWithDup) + + df <- jsonFile(sqlCtx, jsonPathWithDup) + uniques <- distinct(df) + expect_true(inherits(uniques, "DataFrame")) + expect_true(count(uniques) == 3) +}) + +test_that("sampleDF on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + sampled <- sampleDF(df, FALSE, 1.0) + expect_equal(nrow(collect(sampled)), count(df)) + expect_true(inherits(sampled, "DataFrame")) + sampled2 <- sampleDF(df, FALSE, 0.1) + expect_true(count(sampled2) < 3) +}) + +test_that("select operators", { + df <- select(jsonFile(sqlCtx, jsonPath), "name", "age") + expect_true(inherits(df$name, "Column")) + expect_true(inherits(df[[2]], "Column")) + expect_true(inherits(df[["age"]], "Column")) + + expect_true(inherits(df[,1], "DataFrame")) + expect_equal(columns(df[,1]), c("name")) + expect_equal(columns(df[,"age"]), c("age")) + df2 <- df[,c("age", "name")] + expect_true(inherits(df2, "DataFrame")) + expect_equal(columns(df2), c("age", "name")) + + df$age2 <- df$age + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == df$age)), 2) + df$age2 <- df$age * 2 + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == df$age * 2)), 2) + + df$age2 <- NULL + expect_equal(columns(df), c("name", "age")) + df$age3 <- NULL + expect_equal(columns(df), c("name", "age")) +}) + +test_that("select with column", { + df <- jsonFile(sqlCtx, jsonPath) + df1 <- select(df, "name") + expect_true(columns(df1) == c("name")) + expect_true(count(df1) == 3) + + df2 <- select(df, df$age) + expect_true(columns(df2) == c("age")) + expect_true(count(df2) == 3) +}) + +test_that("selectExpr() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + selected <- selectExpr(df, "age * 2") + expect_true(names(selected) == "(age * 2)") + expect_equal(collect(selected), collect(select(df, df$age * 2L))) + + selected2 <- selectExpr(df, "name as newName", "abs(age) as age") + expect_equal(names(selected2), c("newName", "age")) + expect_true(count(selected2) == 3) +}) + +test_that("column calculation", { + df <- jsonFile(sqlCtx, jsonPath) + d <- collect(select(df, alias(df$age + 1, "age2"))) + expect_true(names(d) == c("age2")) + df2 <- select(df, lower(df$name), abs(df$age)) + expect_true(inherits(df2, "DataFrame")) + expect_true(count(df2) == 3) +}) + +test_that("load() from json file", { + df <- loadDF(sqlCtx, jsonPath, "json") + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 3) +}) + +test_that("save() as parquet file", { + df <- loadDF(sqlCtx, jsonPath, "json") + saveDF(df, parquetPath, "parquet", mode="overwrite") + df2 <- loadDF(sqlCtx, parquetPath, "parquet") + expect_true(inherits(df2, "DataFrame")) + expect_true(count(df2) == 3) +}) + +test_that("test HiveContext", { + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + }, error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + df <- createExternalTable(hiveCtx, "json", jsonPath, "json") + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 3) + df2 <- sql(hiveCtx, "select * from json") + expect_true(inherits(df2, "DataFrame")) + expect_true(count(df2) == 3) + + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + saveAsTable(df, "json", "json", "append", path = jsonPath2) + df3 <- sql(hiveCtx, "select * from json") + expect_true(inherits(df3, "DataFrame")) + expect_true(count(df3) == 6) +}) + +test_that("column operators", { + c <- SparkR:::col("a") + c2 <- (- c + 1 - 2) * 3 / 4.0 + c3 <- (c + c2 - c2) * c2 %% c2 + c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3) +}) + +test_that("column functions", { + c <- SparkR:::col("a") + c2 <- min(c) + max(c) + sum(c) + avg(c) + count(c) + abs(c) + sqrt(c) + c3 <- lower(c) + upper(c) + first(c) + last(c) + c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string") +}) + +test_that("string operators", { + df <- jsonFile(sqlCtx, jsonPath) + expect_equal(count(where(df, like(df$name, "A%"))), 1) + expect_equal(count(where(df, startsWith(df$name, "A"))), 1) + expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") + expect_equal(collect(select(df, cast(df$age, "string")))[[2, 1]], "30") +}) + +test_that("group by", { + df <- jsonFile(sqlCtx, jsonPath) + df1 <- agg(df, name = "max", age = "sum") + expect_true(1 == count(df1)) + df1 <- agg(df, age2 = max(df$age)) + expect_true(1 == count(df1)) + expect_equal(columns(df1), c("age2")) + + gd <- groupBy(df, "name") + expect_true(inherits(gd, "GroupedData")) + df2 <- count(gd) + expect_true(inherits(df2, "DataFrame")) + expect_true(3 == count(df2)) + + df3 <- agg(gd, age = "sum") + expect_true(inherits(df3, "DataFrame")) + expect_true(3 == count(df3)) + + df3 <- agg(gd, age = sum(df$age)) + expect_true(inherits(df3, "DataFrame")) + expect_true(3 == count(df3)) + expect_equal(columns(df3), c("name", "age")) + + df4 <- sum(gd, "age") + expect_true(inherits(df4, "DataFrame")) + expect_true(3 == count(df4)) + expect_true(3 == count(mean(gd, "age"))) + expect_true(3 == count(max(gd, "age"))) +}) + +test_that("sortDF() and orderBy() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + sorted <- sortDF(df, df$age) + expect_true(collect(sorted)[1,2] == "Michael") + + sorted2 <- sortDF(df, "name") + expect_true(collect(sorted2)[2,"age"] == 19) + + sorted3 <- orderBy(df, asc(df$age)) + expect_true(is.na(first(sorted3)$age)) + expect_true(collect(sorted3)[2, "age"] == 19) + + sorted4 <- orderBy(df, desc(df$name)) + expect_true(first(sorted4)$name == "Michael") + expect_true(collect(sorted4)[3,"name"] == "Andy") +}) + +test_that("filter() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + filtered <- filter(df, "age > 20") + expect_true(count(filtered) == 1) + expect_true(collect(filtered)$name == "Andy") + filtered2 <- where(df, df$name != "Michael") + expect_true(count(filtered2) == 2) + expect_true(collect(filtered2)$age[2] == 19) +}) + +test_that("join() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + + mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", + "{\"name\":\"Andy\", \"test\": \"no\"}", + "{\"name\":\"Justin\", \"test\": \"yes\"}", + "{\"name\":\"Bob\", \"test\": \"yes\"}") + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(mockLines2, jsonPath2) + df2 <- jsonFile(sqlCtx, jsonPath2) + + joined <- join(df, df2) + expect_equal(names(joined), c("age", "name", "name", "test")) + expect_true(count(joined) == 12) + + joined2 <- join(df, df2, df$name == df2$name) + expect_equal(names(joined2), c("age", "name", "name", "test")) + expect_true(count(joined2) == 3) + + joined3 <- join(df, df2, df$name == df2$name, "right_outer") + expect_equal(names(joined3), c("age", "name", "name", "test")) + expect_true(count(joined3) == 4) + expect_true(is.na(collect(orderBy(joined3, joined3$age))$age[2])) + + joined4 <- select(join(df, df2, df$name == df2$name, "outer"), + alias(df$age + 5, "newAge"), df$name, df2$test) + expect_equal(names(joined4), c("newAge", "name", "test")) + expect_true(count(joined4) == 4) + expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24) +}) + +test_that("toJSON() returns an RDD of the correct values", { + df <- jsonFile(sqlCtx, jsonPath) + testRDD <- toJSON(df) + expect_true(inherits(testRDD, "RDD")) + expect_true(SparkR:::getSerializedMode(testRDD) == "string") + expect_equal(collect(testRDD)[[1]], mockLines[1]) +}) + +test_that("showDF()", { + df <- jsonFile(sqlCtx, jsonPath) + expect_output(showDF(df), "age name \nnull Michael\n30 Andy \n19 Justin ") +}) + +test_that("isLocal()", { + df <- jsonFile(sqlCtx, jsonPath) + expect_false(isLocal(df)) +}) + +test_that("unionAll(), except(), and intersect() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + + lines <- c("{\"name\":\"Bob\", \"age\":24}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"James\", \"age\":35}") + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPath2) + df2 <- loadDF(sqlCtx, jsonPath2, "json") + + unioned <- sortDF(unionAll(df, df2), df$age) + expect_true(inherits(unioned, "DataFrame")) + expect_true(count(unioned) == 6) + expect_true(first(unioned)$name == "Michael") + + excepted <- sortDF(except(df, df2), desc(df$age)) + expect_true(inherits(unioned, "DataFrame")) + expect_true(count(excepted) == 2) + expect_true(first(excepted)$name == "Justin") + + intersected <- sortDF(intersect(df, df2), df$age) + expect_true(inherits(unioned, "DataFrame")) + expect_true(count(intersected) == 1) + expect_true(first(intersected)$name == "Andy") +}) + +test_that("withColumn() and withColumnRenamed()", { + df <- jsonFile(sqlCtx, jsonPath) + newDF <- withColumn(df, "newAge", df$age + 2) + expect_true(length(columns(newDF)) == 3) + expect_true(columns(newDF)[3] == "newAge") + expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32) + + newDF2 <- withColumnRenamed(df, "age", "newerAge") + expect_true(length(columns(newDF2)) == 2) + expect_true(columns(newDF2)[1] == "newerAge") +}) + +test_that("saveDF() on DataFrame and works with parquetFile", { + df <- jsonFile(sqlCtx, jsonPath) + saveDF(df, parquetPath, "parquet", mode="overwrite") + parquetDF <- parquetFile(sqlCtx, parquetPath) + expect_true(inherits(parquetDF, "DataFrame")) + expect_equal(count(df), count(parquetDF)) +}) + +test_that("parquetFile works with multiple input paths", { + df <- jsonFile(sqlCtx, jsonPath) + saveDF(df, parquetPath, "parquet", mode="overwrite") + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + saveDF(df, parquetPath2, "parquet", mode="overwrite") + parquetDF <- parquetFile(sqlCtx, parquetPath, parquetPath2) + expect_true(inherits(parquetDF, "DataFrame")) + expect_true(count(parquetDF) == count(df)*2) +}) + +unlink(parquetPath) +unlink(jsonPath) diff --git a/R/pkg/inst/tests/test_take.R b/R/pkg/inst/tests/test_take.R new file mode 100644 index 000000000000..7f4c7c315d78 --- /dev/null +++ b/R/pkg/inst/tests/test_take.R @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("tests RDD function take()") + +# Mock data +numVector <- c(-10:97) +numList <- list(sqrt(1), sqrt(2), sqrt(3), 4 ** 10) +strVector <- c("Dexter Morgan: I suppose I should be upset, even feel", + "violated, but I'm not. No, in fact, I think this is a friendly", + "message, like \"Hey, wanna play?\" and yes, I want to play. ", + "I really, really do.") +strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ", + "other times it helps me control the chaos.", + "Dexter Morgan: Harry and Dorris Morgan did a wonderful job ", + "raising me. But they're both dead now. I didn't kill them. Honest.") + +# JavaSparkContext handle +jsc <- sparkR.init() + +test_that("take() gives back the original elements in correct count and order", { + numVectorRDD <- parallelize(jsc, numVector, 10) + # case: number of elements to take is less than the size of the first partition + expect_equal(take(numVectorRDD, 1), as.list(head(numVector, n = 1))) + # case: number of elements to take is the same as the size of the first partition + expect_equal(take(numVectorRDD, 11), as.list(head(numVector, n = 11))) + # case: number of elements to take is greater than all elements + expect_equal(take(numVectorRDD, length(numVector)), as.list(numVector)) + expect_equal(take(numVectorRDD, length(numVector) + 1), as.list(numVector)) + + numListRDD <- parallelize(jsc, numList, 1) + numListRDD2 <- parallelize(jsc, numList, 4) + expect_equal(take(numListRDD, 3), take(numListRDD2, 3)) + expect_equal(take(numListRDD, 5), take(numListRDD2, 5)) + expect_equal(take(numListRDD, 1), as.list(head(numList, n = 1))) + expect_equal(take(numListRDD2, 999), numList) + + strVectorRDD <- parallelize(jsc, strVector, 2) + strVectorRDD2 <- parallelize(jsc, strVector, 3) + expect_equal(take(strVectorRDD, 4), as.list(strVector)) + expect_equal(take(strVectorRDD2, 2), as.list(head(strVector, n = 2))) + + strListRDD <- parallelize(jsc, strList, 4) + strListRDD2 <- parallelize(jsc, strList, 1) + expect_equal(take(strListRDD, 3), as.list(head(strList, n = 3))) + expect_equal(take(strListRDD2, 1), as.list(head(strList, n = 1))) + + expect_true(length(take(strListRDD, 0)) == 0) + expect_true(length(take(strVectorRDD, 0)) == 0) + expect_true(length(take(numListRDD, 0)) == 0) + expect_true(length(take(numVectorRDD, 0)) == 0) +}) + diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/test_textFile.R new file mode 100644 index 000000000000..6b87b4b3e0b0 --- /dev/null +++ b/R/pkg/inst/tests/test_textFile.R @@ -0,0 +1,162 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("the textFile() function") + +# JavaSparkContext handle +sc <- sparkR.init() + +mockFile = c("Spark is pretty.", "Spark is awesome.") + +test_that("textFile() on a local file returns an RDD", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + expect_true(inherits(rdd, "RDD")) + expect_true(count(rdd) > 0) + expect_true(count(rdd) == 2) + + unlink(fileName) +}) + +test_that("textFile() followed by a collect() returns the same content", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + expect_equal(collect(rdd), as.list(mockFile)) + + unlink(fileName) +}) + +test_that("textFile() word count works as expected", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + + words <- flatMap(rdd, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + counts <- reduceByKey(wordCount, "+", 2L) + output <- collect(counts) + expected <- list(list("pretty.", 1), list("is", 2), list("awesome.", 1), + list("Spark", 2)) + expect_equal(sortKeyValueList(output), sortKeyValueList(expected)) + + unlink(fileName) +}) + +test_that("several transformations on RDD created by textFile()", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) # RDD + for (i in 1:10) { + # PipelinedRDD initially created from RDD + rdd <- lapply(rdd, function(x) paste(x, x)) + } + collect(rdd) + + unlink(fileName) +}) + +test_that("textFile() followed by a saveAsTextFile() returns the same content", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1, 1L) + saveAsTextFile(rdd, fileName2) + rdd <- textFile(sc, fileName2) + expect_equal(collect(rdd), as.list(mockFile)) + + unlink(fileName1) + unlink(fileName2) +}) + +test_that("saveAsTextFile() on a parallelized list works as expected", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + l <- list(1, 2, 3) + rdd <- parallelize(sc, l, 1L) + saveAsTextFile(rdd, fileName) + rdd <- textFile(sc, fileName) + expect_equal(collect(rdd), lapply(l, function(x) {toString(x)})) + + unlink(fileName) +}) + +test_that("textFile() and saveAsTextFile() word count works as expected", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1) + + words <- flatMap(rdd, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + counts <- reduceByKey(wordCount, "+", 2L) + + saveAsTextFile(counts, fileName2) + rdd <- textFile(sc, fileName2) + + output <- collect(rdd) + expected <- list(list("awesome.", 1), list("Spark", 2), + list("pretty.", 1), list("is", 2)) + expectedStr <- lapply(expected, function(x) { toString(x) }) + expect_equal(sortKeyValueList(output), sortKeyValueList(expectedStr)) + + unlink(fileName1) + unlink(fileName2) +}) + +test_that("textFile() on multiple paths", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines("Spark is pretty.", fileName1) + writeLines("Spark is awesome.", fileName2) + + rdd <- textFile(sc, c(fileName1, fileName2)) + expect_true(count(rdd) == 2) + + unlink(fileName1) + unlink(fileName2) +}) + +test_that("Pipelined operations on RDDs created using textFile", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + + lengths <- lapply(rdd, function(x) { length(x) }) + expect_equal(collect(lengths), list(1, 1)) + + lengthsPipelined <- lapply(lengths, function(x) { x + 10 }) + expect_equal(collect(lengthsPipelined), list(11, 11)) + + lengths30 <- lapply(lengthsPipelined, function(x) { x + 20 }) + expect_equal(collect(lengths30), list(31, 31)) + + lengths20 <- lapply(lengths, function(x) { x + 20 }) + expect_equal(collect(lengths20), list(21, 21)) + + unlink(fileName) +}) + diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/test_utils.R new file mode 100644 index 000000000000..9c5bb427932b --- /dev/null +++ b/R/pkg/inst/tests/test_utils.R @@ -0,0 +1,137 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("functions in utils.R") + +# JavaSparkContext handle +sc <- sparkR.init() + +test_that("convertJListToRList() gives back (deserializes) the original JLists + of strings and integers", { + # It's hard to manually create a Java List using rJava, since it does not + # support generics well. Instead, we rely on collect() returning a + # JList. + nums <- as.list(1:10) + rdd <- parallelize(sc, nums, 1L) + jList <- callJMethod(rdd@jrdd, "collect") + rList <- convertJListToRList(jList, flatten = TRUE) + expect_equal(rList, nums) + + strs <- as.list("hello", "spark") + rdd <- parallelize(sc, strs, 2L) + jList <- callJMethod(rdd@jrdd, "collect") + rList <- convertJListToRList(jList, flatten = TRUE) + expect_equal(rList, strs) +}) + +test_that("serializeToBytes on RDD", { + # File content + mockFile <- c("Spark is pretty.", "Spark is awesome.") + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + text.rdd <- textFile(sc, fileName) + expect_true(getSerializedMode(text.rdd) == "string") + ser.rdd <- serializeToBytes(text.rdd) + expect_equal(collect(ser.rdd), as.list(mockFile)) + expect_true(getSerializedMode(ser.rdd) == "byte") + + unlink(fileName) +}) + +test_that("cleanClosure on R functions", { + y <- c(1, 2, 3) + g <- function(x) { x + 1 } + f <- function(x) { g(x) + y } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 2) # y, g + actual <- get("y", envir = env, inherits = FALSE) + expect_equal(actual, y) + actual <- get("g", envir = env, inherits = FALSE) + expect_equal(actual, g) + + # Test for nested enclosures and package variables. + env2 <- new.env() + funcEnv <- new.env(parent = env2) + f <- function(x) { log(g(x) + y) } + environment(f) <- funcEnv # enclosing relationship: f -> funcEnv -> env2 -> .GlobalEnv + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 2) # "min" should not be included + actual <- get("y", envir = env, inherits = FALSE) + expect_equal(actual, y) + actual <- get("g", envir = env, inherits = FALSE) + expect_equal(actual, g) + + base <- c(1, 2, 3) + l <- list(field = matrix(1)) + field <- matrix(2) + defUse <- 3 + g <- function(x) { x + y } + f <- function(x) { + defUse <- base::as.integer(x) + 1 # Test for access operators `::`. + lapply(x, g) + 1 # Test for capturing function call "g"'s closure as a argument of lapply. + l$field[1,1] <- 3 # Test for access operators `$`. + res <- defUse + l$field[1,] # Test for def-use chain of "defUse", and "" symbol. + f(res) # Test for recursive calls. + } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 3) # Only "g", "l" and "f". No "base", "field" or "defUse". + expect_true("g" %in% ls(env)) + expect_true("l" %in% ls(env)) + expect_true("f" %in% ls(env)) + expect_equal(get("l", envir = env, inherits = FALSE), l) + # "y" should be in the environemnt of g. + newG <- get("g", envir = env, inherits = FALSE) + env <- environment(newG) + expect_equal(length(ls(env)), 1) + actual <- get("y", envir = env, inherits = FALSE) + expect_equal(actual, y) + + # Test for function (and variable) definitions. + f <- function(x) { + g <- function(y) { y * 2 } + g(x) + } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 0) # "y" and "g" should not be included. + + # Test for overriding variables in base namespace (Issue: SparkR-196). + nums <- as.list(1:10) + rdd <- parallelize(sc, nums, 2L) + t = 4 # Override base::t in .GlobalEnv. + f <- function(x) { x > t } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(ls(env), "t") + expect_equal(get("t", envir = env, inherits = FALSE), t) + actual <- collect(lapply(rdd, f)) + expected <- as.list(c(rep(FALSE, 4), rep(TRUE, 6))) + expect_equal(actual, expected) + + # Test for broadcast variables. + a <- matrix(nrow=10, ncol=10, data=rnorm(100)) + aBroadcast <- broadcast(sc, a) + normMultiply <- function(x) { norm(aBroadcast$value) * x } + newnormMultiply <- SparkR:::cleanClosure(normMultiply) + env <- environment(newnormMultiply) + expect_equal(ls(env), "aBroadcast") + expect_equal(get("aBroadcast", envir = env, inherits = FALSE), aBroadcast) +}) diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R new file mode 100644 index 000000000000..3584b418a71a --- /dev/null +++ b/R/pkg/inst/worker/daemon.R @@ -0,0 +1,52 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Worker daemon + +rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +script <- paste(rLibDir, "SparkR/worker/worker.R", sep = "/") + +# preload SparkR package, speedup worker +.libPaths(c(rLibDir, .libPaths())) +suppressPackageStartupMessages(library(SparkR)) + +port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) +inputCon <- socketConnection(port = port, open = "rb", blocking = TRUE, timeout = 3600) + +while (TRUE) { + ready <- socketSelect(list(inputCon)) + if (ready) { + port <- SparkR:::readInt(inputCon) + # There is a small chance that it could be interrupted by signal, retry one time + if (length(port) == 0) { + port <- SparkR:::readInt(inputCon) + if (length(port) == 0) { + cat("quitting daemon\n") + quit(save = "no") + } + } + p <- parallel:::mcfork() + if (inherits(p, "masterProcess")) { + close(inputCon) + Sys.setenv(SPARKR_WORKER_PORT = port) + source(script) + # Set SIGUSR1 so that child can exit + tools::pskill(Sys.getpid(), tools::SIGUSR1) + parallel:::mcexit(0L) + } + } +} diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R new file mode 100644 index 000000000000..014bf7bd7b3f --- /dev/null +++ b/R/pkg/inst/worker/worker.R @@ -0,0 +1,177 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Worker class + +# Get current system time +currentTimeSecs <- function() { + as.numeric(Sys.time()) +} + +# Get elapsed time +elapsedSecs <- function() { + proc.time()[3] +} + +# Constants +specialLengths <- list(END_OF_STERAM = 0L, TIMING_DATA = -1L) + +# Timing R process boot +bootTime <- currentTimeSecs() +bootElap <- elapsedSecs() + +rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +# Set libPaths to include SparkR package as loadNamespace needs this +# TODO: Figure out if we can avoid this by not loading any objects that require +# SparkR namespace +.libPaths(c(rLibDir, .libPaths())) +suppressPackageStartupMessages(library(SparkR)) + +port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) +inputCon <- socketConnection(port = port, blocking = TRUE, open = "rb") +outputCon <- socketConnection(port = port, blocking = TRUE, open = "wb") + +# read the index of the current partition inside the RDD +partition <- SparkR:::readInt(inputCon) + +deserializer <- SparkR:::readString(inputCon) +serializer <- SparkR:::readString(inputCon) + +# Include packages as required +packageNames <- unserialize(SparkR:::readRaw(inputCon)) +for (pkg in packageNames) { + suppressPackageStartupMessages(library(as.character(pkg), character.only=TRUE)) +} + +# read function dependencies +funcLen <- SparkR:::readInt(inputCon) +computeFunc <- unserialize(SparkR:::readRawLen(inputCon, funcLen)) +env <- environment(computeFunc) +parent.env(env) <- .GlobalEnv # Attach under global environment. + +# Timing init envs for computing +initElap <- elapsedSecs() + +# Read and set broadcast variables +numBroadcastVars <- SparkR:::readInt(inputCon) +if (numBroadcastVars > 0) { + for (bcast in seq(1:numBroadcastVars)) { + bcastId <- SparkR:::readInt(inputCon) + value <- unserialize(SparkR:::readRaw(inputCon)) + setBroadcastValue(bcastId, value) + } +} + +# Timing broadcast +broadcastElap <- elapsedSecs() + +# If -1: read as normal RDD; if >= 0, treat as pairwise RDD and treat the int +# as number of partitions to create. +numPartitions <- SparkR:::readInt(inputCon) + +isEmpty <- SparkR:::readInt(inputCon) + +if (isEmpty != 0) { + + if (numPartitions == -1) { + if (deserializer == "byte") { + # Now read as many characters as described in funcLen + data <- SparkR:::readDeserialize(inputCon) + } else if (deserializer == "string") { + data <- as.list(readLines(inputCon)) + } else if (deserializer == "row") { + data <- SparkR:::readDeserializeRows(inputCon) + } + # Timing reading input data for execution + inputElap <- elapsedSecs() + + output <- computeFunc(partition, data) + # Timing computing + computeElap <- elapsedSecs() + + if (serializer == "byte") { + SparkR:::writeRawSerialize(outputCon, output) + } else if (serializer == "row") { + SparkR:::writeRowSerialize(outputCon, output) + } else { + # write lines one-by-one with flag + lapply(output, function(line) SparkR:::writeString(outputCon, line)) + } + # Timing output + outputElap <- elapsedSecs() + } else { + if (deserializer == "byte") { + # Now read as many characters as described in funcLen + data <- SparkR:::readDeserialize(inputCon) + } else if (deserializer == "string") { + data <- readLines(inputCon) + } else if (deserializer == "row") { + data <- SparkR:::readDeserializeRows(inputCon) + } + # Timing reading input data for execution + inputElap <- elapsedSecs() + + res <- new.env() + + # Step 1: hash the data to an environment + hashTupleToEnvir <- function(tuple) { + # NOTE: execFunction is the hash function here + hashVal <- computeFunc(tuple[[1]]) + bucket <- as.character(hashVal %% numPartitions) + acc <- res[[bucket]] + # Create a new accumulator + if (is.null(acc)) { + acc <- SparkR:::initAccumulator() + } + SparkR:::addItemToAccumulator(acc, tuple) + res[[bucket]] <- acc + } + invisible(lapply(data, hashTupleToEnvir)) + # Timing computing + computeElap <- elapsedSecs() + + # Step 2: write out all of the environment as key-value pairs. + for (name in ls(res)) { + SparkR:::writeInt(outputCon, 2L) + SparkR:::writeInt(outputCon, as.integer(name)) + # Truncate the accumulator list to the number of elements we have + length(res[[name]]$data) <- res[[name]]$counter + SparkR:::writeRawSerialize(outputCon, res[[name]]$data) + } + # Timing output + outputElap <- elapsedSecs() + } +} else { + inputElap <- broadcastElap + computeElap <- broadcastElap + outputElap <- broadcastElap +} + +# Report timing +SparkR:::writeInt(outputCon, specialLengths$TIMING_DATA) +SparkR:::writeDouble(outputCon, bootTime) +SparkR:::writeDouble(outputCon, initElap - bootElap) # init +SparkR:::writeDouble(outputCon, broadcastElap - initElap) # broadcast +SparkR:::writeDouble(outputCon, inputElap - broadcastElap) # input +SparkR:::writeDouble(outputCon, computeElap - inputElap) # compute +SparkR:::writeDouble(outputCon, outputElap - computeElap) # output + +# End of output +SparkR:::writeInt(outputCon, specialLengths$END_OF_STERAM) + +close(outputCon) +close(inputCon) diff --git a/R/pkg/src/Makefile b/R/pkg/src/Makefile new file mode 100644 index 000000000000..a55a56fe80e1 --- /dev/null +++ b/R/pkg/src/Makefile @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +all: sharelib + +sharelib: string_hash_code.c + R CMD SHLIB -o SparkR.so string_hash_code.c + +clean: + rm -f *.o + rm -f *.so + +.PHONY: all clean diff --git a/R/pkg/src/Makefile.win b/R/pkg/src/Makefile.win new file mode 100644 index 000000000000..aa486d822837 --- /dev/null +++ b/R/pkg/src/Makefile.win @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +all: sharelib + +sharelib: string_hash_code.c + R CMD SHLIB -o SparkR.dll string_hash_code.c + +clean: + rm -f *.o + rm -f *.dll + +.PHONY: all clean diff --git a/R/pkg/src/string_hash_code.c b/R/pkg/src/string_hash_code.c new file mode 100644 index 000000000000..e3274b9a0c54 --- /dev/null +++ b/R/pkg/src/string_hash_code.c @@ -0,0 +1,49 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +/* + * A C function for R extension which implements the Java String hash algorithm. + * Refer to http://en.wikipedia.org/wiki/Java_hashCode%28%29#The_java.lang.String_hash_function + * + */ + +#include +#include + +/* for compatibility with R before 3.1 */ +#ifndef IS_SCALAR +#define IS_SCALAR(x, type) (TYPEOF(x) == (type) && XLENGTH(x) == 1) +#endif + +SEXP stringHashCode(SEXP string) { + const char* str; + R_xlen_t len, i; + int hashCode = 0; + + if (!IS_SCALAR(string, STRSXP)) { + error("invalid input"); + } + + str = CHAR(asChar(string)); + len = XLENGTH(asChar(string)); + + for (i = 0; i < len; i++) { + hashCode = (hashCode << 5) - hashCode + *str++; + } + + return ScalarInteger(hashCode); +} diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R new file mode 100644 index 000000000000..4f8a1ed2d83e --- /dev/null +++ b/R/pkg/tests/run-all.R @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) +library(SparkR) + +test_package("SparkR") diff --git a/R/run-tests.sh b/R/run-tests.sh new file mode 100755 index 000000000000..e82ad0ba2cd0 --- /dev/null +++ b/R/run-tests.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +FWDIR="$(cd `dirname $0`; pwd)" + +FAILED=0 +LOGFILE=$FWDIR/unit-tests.out +rm -f $LOGFILE + +SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +FAILED=$((PIPESTATUS[0]||$FAILED)) + +if [[ $FAILED != 0 ]]; then + cat $LOGFILE + echo -en "\033[31m" # Red + echo "Had test failures; see logs." + echo -en "\033[0m" # No color + exit -1 +else + echo -en "\033[32m" # Green + echo "Tests passed." + echo -en "\033[0m" # No color +fi diff --git a/README.md b/README.md index 16628bd40677..c3afc4db9c63 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ To build Spark and its example programs, run: (You do not need to do this if you downloaded a pre-built package.) More detailed documentation is available from the project site, at -["Building Spark with Maven"](http://spark.apache.org/docs/latest/building-spark.html). +["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html). ## Interactive Scala Shell @@ -85,7 +85,7 @@ storage systems. Because the protocols have changed in different versions of Hadoop, you must build Spark against the same version that your cluster runs. Please refer to the build documentation at -["Specifying the Hadoop Version"](http://spark.apache.org/docs/latest/building-with-maven.html#specifying-the-hadoop-version) +["Specifying the Hadoop Version"](http://spark.apache.org/docs/latest/building-spark.html#specifying-the-hadoop-version) for detailed guidance on building for a particular distribution of Hadoop, including building for particular Hive and Hive Thriftserver distributions. See also ["Third Party Hadoop Distributions"](http://spark.apache.org/docs/latest/hadoop-third-party-distributions.html) diff --git a/assembly/pom.xml b/assembly/pom.xml index b2a9d0780ee2..20593e710ded 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT ../pom.xml @@ -36,19 +36,9 @@ scala-${scala.binary.version} spark-assembly-${project.version}-hadoop${hadoop.version}.jar ${project.build.directory}/${spark.jar.dir}/${spark.jar.basename} - spark - /usr/share/spark - root - 744 - - - com.google.guava - guava - compile - org.apache.spark spark-core_${scala.binary.version} @@ -133,20 +123,6 @@ shade - - - com.google - org.spark-project.guava - - com.google.common.** - - - com/google/common/base/Absent* - com/google/common/base/Optional* - com/google/common/base/Present* - - - @@ -237,123 +213,6 @@ - - deb - - - - org.codehaus.mojo - buildnumber-maven-plugin - 1.2 - - - validate - - create - - - 8 - - - - - - org.vafer - jdeb - 0.11 - - - package - - jdeb - - - ${project.build.directory}/${deb.pkg.name}_${project.version}-${buildNumber}_all.deb - false - gzip - - - ${spark.jar} - file - - perm - ${deb.user} - ${deb.user} - ${deb.install.path}/jars - - - - ${basedir}/src/deb/RELEASE - file - - perm - ${deb.user} - ${deb.user} - ${deb.install.path} - - - - ${basedir}/../conf - directory - - perm - ${deb.user} - ${deb.user} - ${deb.install.path}/conf - 744 - - - - ${basedir}/../bin - directory - - perm - ${deb.user} - ${deb.user} - ${deb.install.path}/bin - ${deb.bin.filemode} - - - - ${basedir}/../sbin - directory - - perm - ${deb.user} - ${deb.user} - ${deb.install.path}/sbin - 744 - - - - ${basedir}/../python - directory - - perm - ${deb.user} - ${deb.user} - ${deb.install.path}/python - 744 - - - - - - - - - - - - kinesis-asl - - - org.apache.httpcomponents - httpclient - ${commons.httpclient.version} - - - diff --git a/assembly/src/deb/RELEASE b/assembly/src/deb/RELEASE deleted file mode 100644 index aad50ee73aa4..000000000000 --- a/assembly/src/deb/RELEASE +++ /dev/null @@ -1,2 +0,0 @@ -compute-classpath.sh uses the existence of this file to decide whether to put the assembly jar on the -classpath or instead to use classfiles in the source tree. \ No newline at end of file diff --git a/assembly/src/deb/control/control b/assembly/src/deb/control/control deleted file mode 100644 index a6b4471d485f..000000000000 --- a/assembly/src/deb/control/control +++ /dev/null @@ -1,8 +0,0 @@ -Package: [[deb.pkg.name]] -Version: [[version]]-[[buildNumber]] -Section: misc -Priority: extra -Architecture: all -Maintainer: Matei Zaharia -Description: [[name]] -Distribution: development diff --git a/bagel/pom.xml b/bagel/pom.xml index 510e92640eff..1f3dec91314f 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT ../pom.xml diff --git a/bagel/src/test/resources/log4j.properties b/bagel/src/test/resources/log4j.properties index 853ef0ed2986..edbecdae9209 100644 --- a/bagel/src/test/resources/log4j.properties +++ b/bagel/src/test/resources/log4j.properties @@ -24,4 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd deleted file mode 100644 index 088f993954d9..000000000000 --- a/bin/compute-classpath.cmd +++ /dev/null @@ -1,124 +0,0 @@ -@echo off - -rem -rem Licensed to the Apache Software Foundation (ASF) under one or more -rem contributor license agreements. See the NOTICE file distributed with -rem this work for additional information regarding copyright ownership. -rem The ASF licenses this file to You under the Apache License, Version 2.0 -rem (the "License"); you may not use this file except in compliance with -rem the License. You may obtain a copy of the License at -rem -rem http://www.apache.org/licenses/LICENSE-2.0 -rem -rem Unless required by applicable law or agreed to in writing, software -rem distributed under the License is distributed on an "AS IS" BASIS, -rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -rem See the License for the specific language governing permissions and -rem limitations under the License. -rem - -rem This script computes Spark's classpath and prints it to stdout; it's used by both the "run" -rem script and the ExecutorRunner in standalone cluster mode. - -rem If we're called from spark-class2.cmd, it already set enabledelayedexpansion and setting -rem it here would stop us from affecting its copy of the CLASSPATH variable; otherwise we -rem need to set it here because we use !datanucleus_jars! below. -if "%DONT_PRINT_CLASSPATH%"=="1" goto skip_delayed_expansion -setlocal enabledelayedexpansion -:skip_delayed_expansion - -set SCALA_VERSION=2.10 - -rem Figure out where the Spark framework is installed -set FWDIR=%~dp0..\ - -rem Load environment variables from conf\spark-env.cmd, if it exists -if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd" - -rem Build up classpath -set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH% - -if not "x%SPARK_CONF_DIR%"=="x" ( - set CLASSPATH=%CLASSPATH%;%SPARK_CONF_DIR% -) else ( - set CLASSPATH=%CLASSPATH%;%FWDIR%conf -) - -if exist "%FWDIR%RELEASE" ( - for %%d in ("%FWDIR%lib\spark-assembly*.jar") do ( - set ASSEMBLY_JAR=%%d - ) -) else ( - for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*.jar") do ( - set ASSEMBLY_JAR=%%d - ) -) - -set CLASSPATH=%CLASSPATH%;%ASSEMBLY_JAR% - -rem When Hive support is needed, Datanucleus jars must be included on the classpath. -rem Datanucleus jars do not work if only included in the uber jar as plugin.xml metadata is lost. -rem Both sbt and maven will populate "lib_managed/jars/" with the datanucleus jars when Spark is -rem built with Hive, so look for them there. -if exist "%FWDIR%RELEASE" ( - set datanucleus_dir=%FWDIR%lib -) else ( - set datanucleus_dir=%FWDIR%lib_managed\jars -) -set "datanucleus_jars=" -for %%d in ("%datanucleus_dir%\datanucleus-*.jar") do ( - set datanucleus_jars=!datanucleus_jars!;%%d -) -set CLASSPATH=%CLASSPATH%;%datanucleus_jars% - -set SPARK_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%tools\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\classes - -set SPARK_TEST_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\test-classes - -if "x%SPARK_TESTING%"=="x1" ( - rem Add test clases to path - note, add SPARK_CLASSES and SPARK_TEST_CLASSES before CLASSPATH - rem so that local compilation takes precedence over assembled jar - set CLASSPATH=%SPARK_CLASSES%;%SPARK_TEST_CLASSES%;%CLASSPATH% -) - -rem Add hadoop conf dir - else FileSystem.*, etc fail -rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts -rem the configurtion files. -if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir - set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR% -:no_hadoop_conf_dir - -if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir - set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR% -:no_yarn_conf_dir - -rem To allow for distributions to append needed libraries to the classpath (e.g. when -rem using the "hadoop-provided" profile to build Spark), check SPARK_DIST_CLASSPATH and -rem append it to tbe final classpath. -if not "x%$SPARK_DIST_CLASSPATH%"=="x" ( - set CLASSPATH=%CLASSPATH%;%SPARK_DIST_CLASSPATH% -) - -rem A bit of a hack to allow calling this script within run2.cmd without seeing output -if "%DONT_PRINT_CLASSPATH%"=="1" goto exit - -echo %CLASSPATH% - -:exit diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh deleted file mode 100755 index 9e8d0b785194..000000000000 --- a/bin/compute-classpath.sh +++ /dev/null @@ -1,159 +0,0 @@ -#!/usr/bin/env bash - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# This script computes Spark's classpath and prints it to stdout; it's used by both the "run" -# script and the ExecutorRunner in standalone cluster mode. - -# Figure out where Spark is installed -FWDIR="$(cd "`dirname "$0"`"/..; pwd)" - -. "$FWDIR"/bin/load-spark-env.sh - -if [ -n "$SPARK_CLASSPATH" ]; then - CLASSPATH="$SPARK_CLASSPATH:$SPARK_SUBMIT_CLASSPATH" -else - CLASSPATH="$SPARK_SUBMIT_CLASSPATH" -fi - -# Build up classpath -if [ -n "$SPARK_CONF_DIR" ]; then - CLASSPATH="$CLASSPATH:$SPARK_CONF_DIR" -else - CLASSPATH="$CLASSPATH:$FWDIR/conf" -fi - -ASSEMBLY_DIR="$FWDIR/assembly/target/scala-$SPARK_SCALA_VERSION" - -if [ -n "$JAVA_HOME" ]; then - JAR_CMD="$JAVA_HOME/bin/jar" -else - JAR_CMD="jar" -fi - -# A developer option to prepend more recently compiled Spark classes -if [ -n "$SPARK_PREPEND_CLASSES" ]; then - echo "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark"\ - "classes ahead of assembly." >&2 - CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/core/target/jars/*" - CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/tools/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SPARK_SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SPARK_SCALA_VERSION/classes" -fi - -# Use spark-assembly jar from either RELEASE or assembly directory -if [ -f "$FWDIR/RELEASE" ]; then - assembly_folder="$FWDIR"/lib -else - assembly_folder="$ASSEMBLY_DIR" -fi - -num_jars=0 - -for f in ${assembly_folder}/spark-assembly*hadoop*.jar; do - if [[ ! -e "$f" ]]; then - echo "Failed to find Spark assembly in $assembly_folder" 1>&2 - echo "You need to build Spark before running this program." 1>&2 - exit 1 - fi - ASSEMBLY_JAR="$f" - num_jars=$((num_jars+1)) -done - -if [ "$num_jars" -gt "1" ]; then - echo "Found multiple Spark assembly jars in $assembly_folder:" 1>&2 - ls ${assembly_folder}/spark-assembly*hadoop*.jar 1>&2 - echo "Please remove all but one jar." 1>&2 - exit 1 -fi - -# Verify that versions of java used to build the jars and run Spark are compatible -jar_error_check=$("$JAR_CMD" -tf "$ASSEMBLY_JAR" nonexistent/class/path 2>&1) -if [[ "$jar_error_check" =~ "invalid CEN header" ]]; then - echo "Loading Spark jar with '$JAR_CMD' failed. " 1>&2 - echo "This is likely because Spark was compiled with Java 7 and run " 1>&2 - echo "with Java 6. (see SPARK-1703). Please use Java 7 to run Spark " 1>&2 - echo "or build Spark with Java 6." 1>&2 - exit 1 -fi - -CLASSPATH="$CLASSPATH:$ASSEMBLY_JAR" - -# When Hive support is needed, Datanucleus jars must be included on the classpath. -# Datanucleus jars do not work if only included in the uber jar as plugin.xml metadata is lost. -# Both sbt and maven will populate "lib_managed/jars/" with the datanucleus jars when Spark is -# built with Hive, so first check if the datanucleus jars exist, and then ensure the current Spark -# assembly is built for Hive, before actually populating the CLASSPATH with the jars. -# Note that this check order is faster (by up to half a second) in the case where Hive is not used. -if [ -f "$FWDIR/RELEASE" ]; then - datanucleus_dir="$FWDIR"/lib -else - datanucleus_dir="$FWDIR"/lib_managed/jars -fi - -datanucleus_jars="$(find "$datanucleus_dir" 2>/dev/null | grep "datanucleus-.*\\.jar$")" -datanucleus_jars="$(echo "$datanucleus_jars" | tr "\n" : | sed s/:$//g)" - -if [ -n "$datanucleus_jars" ]; then - hive_files=$("$JAR_CMD" -tf "$ASSEMBLY_JAR" org/apache/hadoop/hive/ql/exec 2>/dev/null) - if [ -n "$hive_files" ]; then - echo "Spark assembly has been built with Hive, including Datanucleus jars on classpath" 1>&2 - CLASSPATH="$CLASSPATH:$datanucleus_jars" - fi -fi - -# Add test classes if we're running from SBT or Maven with SPARK_TESTING set to 1 -if [[ $SPARK_TESTING == 1 ]]; then - CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SPARK_SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SPARK_SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SPARK_SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SPARK_SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SPARK_SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SPARK_SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SPARK_SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SPARK_SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/test-classes" -fi - -# Add hadoop conf dir if given -- otherwise FileSystem.*, etc fail ! -# Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts -# the configurtion files. -if [ -n "$HADOOP_CONF_DIR" ]; then - CLASSPATH="$CLASSPATH:$HADOOP_CONF_DIR" -fi -if [ -n "$YARN_CONF_DIR" ]; then - CLASSPATH="$CLASSPATH:$YARN_CONF_DIR" -fi - -# To allow for distributions to append needed libraries to the classpath (e.g. when -# using the "hadoop-provided" profile to build Spark), check SPARK_DIST_CLASSPATH and -# append it to tbe final classpath. -if [ -n "$SPARK_DIST_CLASSPATH" ]; then - CLASSPATH="$CLASSPATH:$SPARK_DIST_CLASSPATH" -fi - -echo "$CLASSPATH" diff --git a/bin/load-spark-env.cmd b/bin/load-spark-env.cmd new file mode 100644 index 000000000000..36d932c453b6 --- /dev/null +++ b/bin/load-spark-env.cmd @@ -0,0 +1,59 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem This script loads spark-env.cmd if it exists, and ensures it is only loaded once. +rem spark-env.cmd is loaded from SPARK_CONF_DIR if set, or within the current directory's +rem conf/ subdirectory. + +if [%SPARK_ENV_LOADED%] == [] ( + set SPARK_ENV_LOADED=1 + + if not [%SPARK_CONF_DIR%] == [] ( + set user_conf_dir=%SPARK_CONF_DIR% + ) else ( + set user_conf_dir=%~dp0..\..\conf + ) + + call :LoadSparkEnv +) + +rem Setting SPARK_SCALA_VERSION if not already set. + +set ASSEMBLY_DIR2=%SPARK_HOME%/assembly/target/scala-2.11 +set ASSEMBLY_DIR1=%SPARK_HOME%/assembly/target/scala-2.10 + +if [%SPARK_SCALA_VERSION%] == [] ( + + if exist %ASSEMBLY_DIR2% if exist %ASSEMBLY_DIR1% ( + echo "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." + echo "Either clean one of them or, set SPARK_SCALA_VERSION=2.11 in spark-env.cmd." + exit 1 + ) + if exist %ASSEMBLY_DIR2% ( + set SPARK_SCALA_VERSION=2.11 + ) else ( + set SPARK_SCALA_VERSION=2.10 + ) +) +exit /b 0 + +:LoadSparkEnv +if exist "%user_conf_dir%\spark-env.cmd" ( + call "%user_conf_dir%\spark-env.cmd" +) diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index 356b3d49b2ff..95779e9ddbb1 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -20,6 +20,7 @@ # This script loads spark-env.sh if it exists, and ensures it is only loaded once. # spark-env.sh is loaded from SPARK_CONF_DIR if set, or within the current directory's # conf/ subdirectory. +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" if [ -z "$SPARK_ENV_LOADED" ]; then export SPARK_ENV_LOADED=1 @@ -43,7 +44,7 @@ if [ -z "$SPARK_SCALA_VERSION" ]; then ASSEMBLY_DIR2="$FWDIR/assembly/target/scala-2.11" ASSEMBLY_DIR1="$FWDIR/assembly/target/scala-2.10" - + if [[ -d "$ASSEMBLY_DIR2" && -d "$ASSEMBLY_DIR1" ]]; then echo -e "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." 1>&2 echo -e 'Either clean one of them or, export SPARK_SCALA_VERSION=2.11 in spark-env.sh.' 1>&2 @@ -54,5 +55,5 @@ if [ -z "$SPARK_SCALA_VERSION" ]; then export SPARK_SCALA_VERSION="2.11" else export SPARK_SCALA_VERSION="2.10" - fi + fi fi diff --git a/bin/pyspark b/bin/pyspark index 0b4f695dd06d..8acad6113797 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -18,36 +18,24 @@ # # Figure out where Spark is installed -FWDIR="$(cd "`dirname "$0"`"/..; pwd)" +export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" -# Export this as SPARK_HOME -export SPARK_HOME="$FWDIR" - -source "$FWDIR/bin/utils.sh" - -source "$FWDIR"/bin/load-spark-env.sh +source "$SPARK_HOME"/bin/load-spark-env.sh function usage() { + if [ -n "$1" ]; then + echo $1 + fi echo "Usage: ./bin/pyspark [options]" 1>&2 - "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - exit 0 + "$SPARK_HOME"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + exit $2 } +export -f usage if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then usage fi -# Exit if the user hasn't compiled Spark -if [ ! -f "$FWDIR/RELEASE" ]; then - # Exit if the user hasn't compiled Spark - ls "$FWDIR"/assembly/target/scala-$SPARK_SCALA_VERSION/spark-assembly*hadoop*.jar >& /dev/null - if [[ $? != 0 ]]; then - echo "Failed to find Spark assembly in $FWDIR/assembly/target" 1>&2 - echo "You need to build Spark before running this program" 1>&2 - exit 1 - fi -fi - # In Spark <= 1.1, setting IPYTHON=1 would cause the driver to be launched using the `ipython` # executable, while the worker would still be launched using PYSPARK_PYTHON. # @@ -95,26 +83,13 @@ export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH" # Load the PySpark shell.py script when ./pyspark is used interactively: export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" -export PYTHONSTARTUP="$FWDIR/python/pyspark/shell.py" - -# Build up arguments list manually to preserve quotes and backslashes. -# We export Spark submit arguments as an environment variable because shell.py must run as a -# PYTHONSTARTUP script, which does not take in arguments. This is required for IPython notebooks. -SUBMIT_USAGE_FUNCTION=usage -gatherSparkSubmitOpts "$@" -PYSPARK_SUBMIT_ARGS="" -whitespace="[[:space:]]" -for i in "${SUBMISSION_OPTS[@]}"; do - if [[ $i =~ \" ]]; then i=$(echo $i | sed 's/\"/\\\"/g'); fi - if [[ $i =~ $whitespace ]]; then i=\"$i\"; fi - PYSPARK_SUBMIT_ARGS="$PYSPARK_SUBMIT_ARGS $i" -done -export PYSPARK_SUBMIT_ARGS +export PYTHONSTARTUP="$SPARK_HOME/python/pyspark/shell.py" # For pyspark tests if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR + export PYTHONHASHSEED=0 if [[ -n "$PYSPARK_DOC_TEST" ]]; then exec "$PYSPARK_DRIVER_PYTHON" -m doctest $1 else @@ -123,14 +98,6 @@ if [[ -n "$SPARK_TESTING" ]]; then exit fi -# If a python file is provided, directly run spark-submit. -if [[ "$1" =~ \.py$ ]]; then - echo -e "\nWARNING: Running python applications through ./bin/pyspark is deprecated as of Spark 1.0." 1>&2 - echo -e "Use ./bin/spark-submit \n" 1>&2 - primary="$1" - shift - gatherSparkSubmitOpts "$@" - exec "$FWDIR"/bin/spark-submit "${SUBMISSION_OPTS[@]}" "$primary" "${APPLICATION_OPTS[@]}" -else - exec "$PYSPARK_DRIVER_PYTHON" $PYSPARK_DRIVER_PYTHON_OPTS -fi +export PYSPARK_DRIVER_PYTHON +export PYSPARK_DRIVER_PYTHON_OPTS +exec "$SPARK_HOME"/bin/spark-submit pyspark-shell-main "$@" diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index a542ec80b49d..09b4149c2a43 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -17,59 +17,21 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -set SCALA_VERSION=2.10 - rem Figure out where the Spark framework is installed -set FWDIR=%~dp0..\ - -rem Export this as SPARK_HOME -set SPARK_HOME=%FWDIR% - -rem Test whether the user has built Spark -if exist "%FWDIR%RELEASE" goto skip_build_test -set FOUND_JAR=0 -for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*.jar") do ( - set FOUND_JAR=1 -) -if [%FOUND_JAR%] == [0] ( - echo Failed to find Spark assembly JAR. - echo You need to build Spark before running this program. - goto exit -) -:skip_build_test +set SPARK_HOME=%~dp0.. -rem Load environment variables from conf\spark-env.cmd, if it exists -if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd" +call %SPARK_HOME%\bin\load-spark-env.cmd rem Figure out which Python to use. -if [%PYSPARK_PYTHON%] == [] set PYSPARK_PYTHON=python +if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( + set PYSPARK_DRIVER_PYTHON=python + if not [%PYSPARK_PYTHON%] == [] set PYSPARK_DRIVER_PYTHON=%PYSPARK_PYTHON% +) -set PYTHONPATH=%FWDIR%python;%PYTHONPATH% -set PYTHONPATH=%FWDIR%python\lib\py4j-0.8.2.1-src.zip;%PYTHONPATH% +set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH% +set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.8.2.1-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% -set PYTHONSTARTUP=%FWDIR%python\pyspark\shell.py -set PYSPARK_SUBMIT_ARGS=%* - -echo Running %PYSPARK_PYTHON% with PYTHONPATH=%PYTHONPATH% - -rem Check whether the argument is a file -for /f %%i in ('echo %1^| findstr /R "\.py"') do ( - set PYTHON_FILE=%%i -) - -if [%PYTHON_FILE%] == [] ( - if [%IPYTHON%] == [1] ( - ipython %IPYTHON_OPTS% - ) else ( - %PYSPARK_PYTHON% - ) -) else ( - echo. - echo WARNING: Running python applications through ./bin/pyspark.cmd is deprecated as of Spark 1.0. - echo Use ./bin/spark-submit ^ - echo. - "%FWDIR%\bin\spark-submit.cmd" %PYSPARK_SUBMIT_ARGS% -) +set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py -:exit +call %SPARK_HOME%\bin\spark-submit2.cmd pyspark-shell-main %* diff --git a/bin/run-example b/bin/run-example index c567acf9a6b5..798e2caeb88c 100755 --- a/bin/run-example +++ b/bin/run-example @@ -42,7 +42,7 @@ fi JAR_COUNT=0 -for f in ${JAR_PATH}/spark-examples-*hadoop*.jar; do +for f in "${JAR_PATH}"/spark-examples-*hadoop*.jar; do if [[ ! -e "$f" ]]; then echo "Failed to find Spark examples assembly in $FWDIR/lib or $FWDIR/examples/target" 1>&2 echo "You need to build Spark before running this program" 1>&2 @@ -54,7 +54,7 @@ done if [ "$JAR_COUNT" -gt "1" ]; then echo "Found multiple Spark examples assembly jars in ${JAR_PATH}" 1>&2 - ls ${JAR_PATH}/spark-examples-*hadoop*.jar 1>&2 + ls "${JAR_PATH}"/spark-examples-*hadoop*.jar 1>&2 echo "Please remove all but one jar." 1>&2 exit 1 fi @@ -67,7 +67,7 @@ if [[ ! $EXAMPLE_CLASS == org.apache.spark.examples* ]]; then EXAMPLE_CLASS="org.apache.spark.examples.$EXAMPLE_CLASS" fi -"$FWDIR"/bin/spark-submit \ +exec "$FWDIR"/bin/spark-submit \ --master $EXAMPLE_MASTER \ --class $EXAMPLE_CLASS \ "$SPARK_EXAMPLES_JAR" \ diff --git a/bin/run-example2.cmd b/bin/run-example2.cmd index b49d0dcb4ff2..c3e0221fb62e 100644 --- a/bin/run-example2.cmd +++ b/bin/run-example2.cmd @@ -25,8 +25,7 @@ set FWDIR=%~dp0..\ rem Export this as SPARK_HOME set SPARK_HOME=%FWDIR% -rem Load environment variables from conf\spark-env.cmd, if it exists -if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd" +call %SPARK_HOME%\bin\load-spark-env.cmd rem Test that an argument was given if not "x%1"=="x" goto arg_given diff --git a/bin/spark-class b/bin/spark-class index 1b945461fabc..c49d97ce5cf2 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -16,88 +16,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -# NOTE: Any changes to this file must be reflected in SparkSubmitDriverBootstrapper.scala! - -cygwin=false -case "`uname`" in - CYGWIN*) cygwin=true;; -esac +set -e # Figure out where Spark is installed -FWDIR="$(cd "`dirname "$0"`"/..; pwd)" - -# Export this as SPARK_HOME -export SPARK_HOME="$FWDIR" +export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" -. "$FWDIR"/bin/load-spark-env.sh +. "$SPARK_HOME"/bin/load-spark-env.sh if [ -z "$1" ]; then echo "Usage: spark-class []" 1>&2 exit 1 fi -if [ -n "$SPARK_MEM" ]; then - echo -e "Warning: SPARK_MEM is deprecated, please use a more specific config option" 1>&2 - echo -e "(e.g., spark.executor.memory or spark.driver.memory)." 1>&2 -fi - -# Use SPARK_MEM or 512m as the default memory, to be overridden by specific options -DEFAULT_MEM=${SPARK_MEM:-512m} - -SPARK_DAEMON_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS -Dspark.akka.logLifecycleEvents=true" - -# Add java opts and memory settings for master, worker, history server, executors, and repl. -case "$1" in - # Master, Worker, and HistoryServer use SPARK_DAEMON_JAVA_OPTS (and specific opts) + SPARK_DAEMON_MEMORY. - 'org.apache.spark.deploy.master.Master') - OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS $SPARK_MASTER_OPTS" - OUR_JAVA_MEM=${SPARK_DAEMON_MEMORY:-$DEFAULT_MEM} - ;; - 'org.apache.spark.deploy.worker.Worker') - OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS $SPARK_WORKER_OPTS" - OUR_JAVA_MEM=${SPARK_DAEMON_MEMORY:-$DEFAULT_MEM} - ;; - 'org.apache.spark.deploy.history.HistoryServer') - OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS $SPARK_HISTORY_OPTS" - OUR_JAVA_MEM=${SPARK_DAEMON_MEMORY:-$DEFAULT_MEM} - ;; - - # Executors use SPARK_JAVA_OPTS + SPARK_EXECUTOR_MEMORY. - 'org.apache.spark.executor.CoarseGrainedExecutorBackend') - OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_EXECUTOR_OPTS" - OUR_JAVA_MEM=${SPARK_EXECUTOR_MEMORY:-$DEFAULT_MEM} - ;; - 'org.apache.spark.executor.MesosExecutorBackend') - OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_EXECUTOR_OPTS" - OUR_JAVA_MEM=${SPARK_EXECUTOR_MEMORY:-$DEFAULT_MEM} - export PYTHONPATH="$FWDIR/python:$PYTHONPATH" - export PYTHONPATH="$FWDIR/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH" - ;; - - # Spark submit uses SPARK_JAVA_OPTS + SPARK_SUBMIT_OPTS + - # SPARK_DRIVER_MEMORY + SPARK_SUBMIT_DRIVER_MEMORY. - 'org.apache.spark.deploy.SparkSubmit') - OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_SUBMIT_OPTS" - OUR_JAVA_MEM=${SPARK_DRIVER_MEMORY:-$DEFAULT_MEM} - if [ -n "$SPARK_SUBMIT_LIBRARY_PATH" ]; then - if [[ $OSTYPE == darwin* ]]; then - export DYLD_LIBRARY_PATH="$SPARK_SUBMIT_LIBRARY_PATH:$DYLD_LIBRARY_PATH" - else - export LD_LIBRARY_PATH="$SPARK_SUBMIT_LIBRARY_PATH:$LD_LIBRARY_PATH" - fi - fi - if [ -n "$SPARK_SUBMIT_DRIVER_MEMORY" ]; then - OUR_JAVA_MEM="$SPARK_SUBMIT_DRIVER_MEMORY" - fi - ;; - - *) - OUR_JAVA_OPTS="$SPARK_JAVA_OPTS" - OUR_JAVA_MEM=${SPARK_DRIVER_MEMORY:-$DEFAULT_MEM} - ;; -esac - # Find the java binary if [ -n "${JAVA_HOME}" ]; then RUNNER="${JAVA_HOME}/bin/java" @@ -109,83 +39,68 @@ else exit 1 fi fi -JAVA_VERSION=$("$RUNNER" -version 2>&1 | grep 'version' | sed 's/.* version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q') -# Set JAVA_OPTS to be able to load native libraries and to set heap size -if [ "$JAVA_VERSION" -ge 18 ]; then - JAVA_OPTS="$OUR_JAVA_OPTS" +# Find assembly jar +SPARK_ASSEMBLY_JAR= +if [ -f "$SPARK_HOME/RELEASE" ]; then + ASSEMBLY_DIR="$SPARK_HOME/lib" else - JAVA_OPTS="-XX:MaxPermSize=128m $OUR_JAVA_OPTS" + ASSEMBLY_DIR="$SPARK_HOME/assembly/target/scala-$SPARK_SCALA_VERSION" fi -JAVA_OPTS="$JAVA_OPTS -Xms$OUR_JAVA_MEM -Xmx$OUR_JAVA_MEM" -# Load extra JAVA_OPTS from conf/java-opts, if it exists -if [ -e "$FWDIR/conf/java-opts" ] ; then - JAVA_OPTS="$JAVA_OPTS `cat "$FWDIR"/conf/java-opts`" -fi - -# Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in CommandUtils.scala! - -TOOLS_DIR="$FWDIR"/tools -SPARK_TOOLS_JAR="" -if [ -e "$TOOLS_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-tools*[0-9Tg].jar ]; then - # Use the JAR from the SBT build - export SPARK_TOOLS_JAR="`ls "$TOOLS_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-tools*[0-9Tg].jar`" +num_jars="$(ls -1 "$ASSEMBLY_DIR" | grep "^spark-assembly.*hadoop.*\.jar$" | wc -l)" +if [ "$num_jars" -eq "0" -a -z "$SPARK_ASSEMBLY_JAR" ]; then + echo "Failed to find Spark assembly in $ASSEMBLY_DIR." 1>&2 + echo "You need to build Spark before running this program." 1>&2 + exit 1 fi -if [ -e "$TOOLS_DIR"/target/spark-tools*[0-9Tg].jar ]; then - # Use the JAR from the Maven build - # TODO: this also needs to become an assembly! - export SPARK_TOOLS_JAR="`ls "$TOOLS_DIR"/target/spark-tools*[0-9Tg].jar`" +ASSEMBLY_JARS="$(ls -1 "$ASSEMBLY_DIR" | grep "^spark-assembly.*hadoop.*\.jar$" || true)" +if [ "$num_jars" -gt "1" ]; then + echo "Found multiple Spark assembly jars in $ASSEMBLY_DIR:" 1>&2 + echo "$ASSEMBLY_JARS" 1>&2 + echo "Please remove all but one jar." 1>&2 + exit 1 fi -# Compute classpath using external script -classpath_output=$("$FWDIR"/bin/compute-classpath.sh) -if [[ "$?" != "0" ]]; then - echo "$classpath_output" - exit 1 +SPARK_ASSEMBLY_JAR="${ASSEMBLY_DIR}/${ASSEMBLY_JARS}" + +# Verify that versions of java used to build the jars and run Spark are compatible +if [ -n "$JAVA_HOME" ]; then + JAR_CMD="$JAVA_HOME/bin/jar" else - CLASSPATH="$classpath_output" + JAR_CMD="jar" fi -if [[ "$1" =~ org.apache.spark.tools.* ]]; then - if test -z "$SPARK_TOOLS_JAR"; then - echo "Failed to find Spark Tools Jar in $FWDIR/tools/target/scala-$SPARK_SCALA_VERSION/" 1>&2 - echo "You need to run \"build/sbt tools/package\" before running $1." 1>&2 +if [ $(command -v "$JAR_CMD") ] ; then + jar_error_check=$("$JAR_CMD" -tf "$SPARK_ASSEMBLY_JAR" nonexistent/class/path 2>&1) + if [[ "$jar_error_check" =~ "invalid CEN header" ]]; then + echo "Loading Spark jar with '$JAR_CMD' failed. " 1>&2 + echo "This is likely because Spark was compiled with Java 7 and run " 1>&2 + echo "with Java 6. (see SPARK-1703). Please use Java 7 to run Spark " 1>&2 + echo "or build Spark with Java 6." 1>&2 exit 1 fi - CLASSPATH="$CLASSPATH:$SPARK_TOOLS_JAR" fi -if $cygwin; then - CLASSPATH="`cygpath -wp "$CLASSPATH"`" - if [ "$1" == "org.apache.spark.tools.JavaAPICompletenessChecker" ]; then - export SPARK_TOOLS_JAR="`cygpath -w "$SPARK_TOOLS_JAR"`" - fi +LAUNCH_CLASSPATH="$SPARK_ASSEMBLY_JAR" + +# Add the launcher build dir to the classpath if requested. +if [ -n "$SPARK_PREPEND_CLASSES" ]; then + LAUNCH_CLASSPATH="$SPARK_HOME/launcher/target/scala-$SPARK_SCALA_VERSION/classes:$LAUNCH_CLASSPATH" fi -export CLASSPATH -# In Spark submit client mode, the driver is launched in the same JVM as Spark submit itself. -# Here we must parse the properties file for relevant "spark.driver.*" configs before launching -# the driver JVM itself. Instead of handling this complexity in Bash, we launch a separate JVM -# to prepare the launch environment of this driver JVM. +export _SPARK_ASSEMBLY="$SPARK_ASSEMBLY_JAR" + +# The launcher library will print arguments separated by a NULL character, to allow arguments with +# characters that would be otherwise interpreted by the shell. Read that in a while loop, populating +# an array that will be used to exec the final command. +CMD=() +while IFS= read -d '' -r ARG; do + CMD+=("$ARG") +done < <("$RUNNER" -cp "$LAUNCH_CLASSPATH" org.apache.spark.launcher.Main "$@") -if [ -n "$SPARK_SUBMIT_BOOTSTRAP_DRIVER" ]; then - # This is used only if the properties file actually contains these special configs - # Export the environment variables needed by SparkSubmitDriverBootstrapper - export RUNNER - export CLASSPATH - export JAVA_OPTS - export OUR_JAVA_MEM - export SPARK_CLASS=1 - shift # Ignore main class (org.apache.spark.deploy.SparkSubmit) and use our own - exec "$RUNNER" org.apache.spark.deploy.SparkSubmitDriverBootstrapper "$@" +if [ "${CMD[0]}" = "usage" ]; then + "${CMD[@]}" else - # Note: The format of this command is closely echoed in SparkSubmitDriverBootstrapper.scala - if [ -n "$SPARK_PRINT_LAUNCH_COMMAND" ]; then - echo -n "Spark Command: " 1>&2 - echo "$RUNNER" -cp "$CLASSPATH" $JAVA_OPTS "$@" 1>&2 - echo -e "========================================\n" 1>&2 - fi - exec "$RUNNER" -cp "$CLASSPATH" $JAVA_OPTS "$@" + exec "${CMD[@]}" fi - diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index da46543647ef..3d068dd3a273 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -17,135 +17,51 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -rem Any changes to this file must be reflected in SparkSubmitDriverBootstrapper.scala! - -setlocal enabledelayedexpansion - -set SCALA_VERSION=2.10 - rem Figure out where the Spark framework is installed -set FWDIR=%~dp0..\ - -rem Export this as SPARK_HOME -set SPARK_HOME=%FWDIR% +set SPARK_HOME=%~dp0.. -rem Load environment variables from conf\spark-env.cmd, if it exists -if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd" +call %SPARK_HOME%\bin\load-spark-env.cmd rem Test that an argument was given -if not "x%1"=="x" goto arg_given +if "x%1"=="x" ( echo Usage: spark-class ^ [^] - goto exit -:arg_given - -if not "x%SPARK_MEM%"=="x" ( - echo Warning: SPARK_MEM is deprecated, please use a more specific config option - echo e.g., spark.executor.memory or spark.driver.memory. + exit /b 1 ) -rem Use SPARK_MEM or 512m as the default memory, to be overridden by specific options -set OUR_JAVA_MEM=%SPARK_MEM% -if "x%OUR_JAVA_MEM%"=="x" set OUR_JAVA_MEM=512m - -set SPARK_DAEMON_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% -Dspark.akka.logLifecycleEvents=true - -rem Add java opts and memory settings for master, worker, history server, executors, and repl. -rem Master, Worker and HistoryServer use SPARK_DAEMON_JAVA_OPTS (and specific opts) + SPARK_DAEMON_MEMORY. -if "%1"=="org.apache.spark.deploy.master.Master" ( - set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% %SPARK_MASTER_OPTS% - if not "x%SPARK_DAEMON_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DAEMON_MEMORY% -) else if "%1"=="org.apache.spark.deploy.worker.Worker" ( - set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% %SPARK_WORKER_OPTS% - if not "x%SPARK_DAEMON_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DAEMON_MEMORY% -) else if "%1"=="org.apache.spark.deploy.history.HistoryServer" ( - set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% %SPARK_HISTORY_OPTS% - if not "x%SPARK_DAEMON_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DAEMON_MEMORY% - -rem Executors use SPARK_JAVA_OPTS + SPARK_EXECUTOR_MEMORY. -) else if "%1"=="org.apache.spark.executor.CoarseGrainedExecutorBackend" ( - set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% %SPARK_EXECUTOR_OPTS% - if not "x%SPARK_EXECUTOR_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_EXECUTOR_MEMORY% -) else if "%1"=="org.apache.spark.executor.MesosExecutorBackend" ( - set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% %SPARK_EXECUTOR_OPTS% - if not "x%SPARK_EXECUTOR_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_EXECUTOR_MEMORY% +rem Find assembly jar +set SPARK_ASSEMBLY_JAR=0 -rem Spark submit uses SPARK_JAVA_OPTS + SPARK_SUBMIT_OPTS + -rem SPARK_DRIVER_MEMORY + SPARK_SUBMIT_DRIVER_MEMORY. -rem The repl also uses SPARK_REPL_OPTS. -) else if "%1"=="org.apache.spark.deploy.SparkSubmit" ( - set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% %SPARK_SUBMIT_OPTS% %SPARK_REPL_OPTS% - if not "x%SPARK_SUBMIT_LIBRARY_PATH%"=="x" ( - set OUR_JAVA_OPTS=!OUR_JAVA_OPTS! -Djava.library.path=%SPARK_SUBMIT_LIBRARY_PATH% - ) else if not "x%SPARK_LIBRARY_PATH%"=="x" ( - set OUR_JAVA_OPTS=!OUR_JAVA_OPTS! -Djava.library.path=%SPARK_LIBRARY_PATH% - ) - if not "x%SPARK_DRIVER_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DRIVER_MEMORY% - if not "x%SPARK_SUBMIT_DRIVER_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_SUBMIT_DRIVER_MEMORY% +if exist "%SPARK_HOME%\RELEASE" ( + set ASSEMBLY_DIR=%SPARK_HOME%\lib ) else ( - set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% - if not "x%SPARK_DRIVER_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DRIVER_MEMORY% + set ASSEMBLY_DIR=%SPARK_HOME%\assembly\target\scala-%SPARK_SCALA_VERSION% ) -rem Set JAVA_OPTS to be able to load native libraries and to set heap size -for /f "tokens=3" %%i in ('java -version 2^>^&1 ^| find "version"') do set jversion=%%i -for /f "tokens=1 delims=_" %%i in ("%jversion:~1,-1%") do set jversion=%%i -if "%jversion%" geq "1.8.0" ( - set JAVA_OPTS=%OUR_JAVA_OPTS% -Xms%OUR_JAVA_MEM% -Xmx%OUR_JAVA_MEM% -) else ( - set JAVA_OPTS=-XX:MaxPermSize=128m %OUR_JAVA_OPTS% -Xms%OUR_JAVA_MEM% -Xmx%OUR_JAVA_MEM% -) -rem Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in CommandUtils.scala! - -rem Test whether the user has built Spark -if exist "%FWDIR%RELEASE" goto skip_build_test -set FOUND_JAR=0 -for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*.jar") do ( - set FOUND_JAR=1 +for %%d in (%ASSEMBLY_DIR%\spark-assembly*hadoop*.jar) do ( + set SPARK_ASSEMBLY_JAR=%%d ) -if "%FOUND_JAR%"=="0" ( +if "%SPARK_ASSEMBLY_JAR%"=="0" ( echo Failed to find Spark assembly JAR. echo You need to build Spark before running this program. - goto exit + exit /b 1 ) -:skip_build_test -set TOOLS_DIR=%FWDIR%tools -set SPARK_TOOLS_JAR= -for %%d in ("%TOOLS_DIR%\target\scala-%SCALA_VERSION%\spark-tools*assembly*.jar") do ( - set SPARK_TOOLS_JAR=%%d +set LAUNCH_CLASSPATH=%SPARK_ASSEMBLY_JAR% + +rem Add the launcher build dir to the classpath if requested. +if not "x%SPARK_PREPEND_CLASSES%"=="x" ( + set LAUNCH_CLASSPATH=%SPARK_HOME%\launcher\target\scala-%SPARK_SCALA_VERSION%\classes;%LAUNCH_CLASSPATH% ) -rem Compute classpath using external script -set DONT_PRINT_CLASSPATH=1 -call "%FWDIR%bin\compute-classpath.cmd" -set DONT_PRINT_CLASSPATH=0 -set CLASSPATH=%CLASSPATH%;%SPARK_TOOLS_JAR% +set _SPARK_ASSEMBLY=%SPARK_ASSEMBLY_JAR% rem Figure out where java is. set RUNNER=java if not "x%JAVA_HOME%"=="x" set RUNNER=%JAVA_HOME%\bin\java -rem In Spark submit client mode, the driver is launched in the same JVM as Spark submit itself. -rem Here we must parse the properties file for relevant "spark.driver.*" configs before launching -rem the driver JVM itself. Instead of handling this complexity here, we launch a separate JVM -rem to prepare the launch environment of this driver JVM. - -rem In this case, leave out the main class (org.apache.spark.deploy.SparkSubmit) and use our own. -rem Leaving out the first argument is surprisingly difficult to do in Windows. Note that this must -rem be done here because the Windows "shift" command does not work in a conditional block. -set BOOTSTRAP_ARGS= -shift -:start_parse -if "%~1" == "" goto end_parse -set BOOTSTRAP_ARGS=%BOOTSTRAP_ARGS% %~1 -shift -goto start_parse -:end_parse - -if not [%SPARK_SUBMIT_BOOTSTRAP_DRIVER%] == [] ( - set SPARK_CLASS=1 - "%RUNNER%" org.apache.spark.deploy.SparkSubmitDriverBootstrapper %BOOTSTRAP_ARGS% -) else ( - "%RUNNER%" -cp "%CLASSPATH%" %JAVA_OPTS% %* +rem The launcher library prints the command to be executed in a single line suitable for being +rem executed by the batch interpreter. So read all the output of the launcher into a variable. +for /f "tokens=*" %%i in ('cmd /C ""%RUNNER%" -cp %LAUNCH_CLASSPATH% org.apache.spark.launcher.Main %*"') do ( + set SPARK_CMD=%%i ) -:exit +%SPARK_CMD% diff --git a/bin/spark-shell b/bin/spark-shell index cca5aa067612..b3761b5e1375 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -28,25 +28,24 @@ esac # Enter posix mode for bash set -o posix -## Global script variables -FWDIR="$(cd "`dirname "$0"`"/..; pwd)" +export FWDIR="$(cd "`dirname "$0"`"/..; pwd)" -function usage() { +usage() { + if [ -n "$1" ]; then + echo "$1" + fi echo "Usage: ./bin/spark-shell [options]" "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - exit 0 + exit "$2" } +export -f usage if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - usage + usage "" 0 fi -source "$FWDIR"/bin/utils.sh -SUBMIT_USAGE_FUNCTION=usage -gatherSparkSubmitOpts "$@" - # SPARK-4161: scala does not assume use of the java classpath, -# so we need to add the "-Dscala.usejavacp=true" flag mnually. We +# so we need to add the "-Dscala.usejavacp=true" flag manually. We # do this specifically for the Spark shell because the scala REPL # has its own class loader, and any additional classpath specified # through spark.driver.extraClassPath is not automatically propagated. @@ -61,11 +60,11 @@ function main() { # (see https://github.com/sbt/sbt/issues/562). stty -icanon min 1 -echo > /dev/null 2>&1 export SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Djline.terminal=unix" - "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "${SUBMISSION_OPTS[@]}" spark-shell "${APPLICATION_OPTS[@]}" + "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "$@" stty icanon echo > /dev/null 2>&1 else export SPARK_SUBMIT_OPTS - "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "${SUBMISSION_OPTS[@]}" spark-shell "${APPLICATION_OPTS[@]}" + "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "$@" fi } diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd old mode 100755 new mode 100644 diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd index 1d1a40da315e..02f51fe59a91 100644 --- a/bin/spark-shell2.cmd +++ b/bin/spark-shell2.cmd @@ -25,17 +25,28 @@ if %ERRORLEVEL% equ 0 ( exit /b 0 ) -call %SPARK_HOME%\bin\windows-utils.cmd %* -if %ERRORLEVEL% equ 1 ( +rem SPARK-4161: scala does not assume use of the java classpath, +rem so we need to add the "-Dscala.usejavacp=true" flag manually. We +rem do this specifically for the Spark shell because the scala REPL +rem has its own class loader, and any additional classpath specified +rem through spark.driver.extraClassPath is not automatically propagated. +if "x%SPARK_SUBMIT_OPTS%"=="x" ( + set SPARK_SUBMIT_OPTS=-Dscala.usejavacp=true + goto run_shell +) +set SPARK_SUBMIT_OPTS="%SPARK_SUBMIT_OPTS% -Dscala.usejavacp=true" + +:run_shell +call %SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main %* +set SPARK_ERROR_LEVEL=%ERRORLEVEL% +if not "x%SPARK_LAUNCHER_USAGE_ERROR%"=="x" ( call :usage exit /b 1 ) - -cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd --class org.apache.spark.repl.Main %SUBMISSION_OPTS% spark-shell %APPLICATION_OPTS% - -exit /b 0 +exit /b %SPARK_ERROR_LEVEL% :usage +echo %SPARK_LAUNCHER_USAGE_ERROR% echo "Usage: .\bin\spark-shell.cmd [options]" >&2 -%SPARK_HOME%\bin\spark-submit --help 2>&1 | findstr /V "Usage" 1>&2 -exit /b 0 +call %SPARK_HOME%\bin\spark-submit2.cmd --help 2>&1 | findstr /V "Usage" 1>&2 +goto :eof diff --git a/bin/spark-sql b/bin/spark-sql index 3b6cc420fea8..ca1729f4cfcb 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -25,12 +25,15 @@ set -o posix # NOTE: This exact class name is matched downstream by SparkSubmit. # Any changes need to be reflected there. -CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" +export CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" # Figure out where Spark is installed -FWDIR="$(cd "`dirname "$0"`"/..; pwd)" +export FWDIR="$(cd "`dirname "$0"`"/..; pwd)" function usage { + if [ -n "$1" ]; then + echo "$1" + fi echo "Usage: ./bin/spark-sql [options] [cli option]" pattern="usage" pattern+="\|Spark assembly has been built with Hive" @@ -42,16 +45,13 @@ function usage { "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 echo echo "CLI options:" - "$FWDIR"/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 + "$FWDIR"/bin/spark-class "$CLASS" --help 2>&1 | grep -v "$pattern" 1>&2 + exit "$2" } +export -f usage if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - usage - exit 0 + usage "" 0 fi -source "$FWDIR"/bin/utils.sh -SUBMIT_USAGE_FUNCTION=usage -gatherSparkSubmitOpts "$@" - -exec "$FWDIR"/bin/spark-submit --class $CLASS "${SUBMISSION_OPTS[@]}" spark-internal "${APPLICATION_OPTS[@]}" +exec "$FWDIR"/bin/spark-submit --class "$CLASS" "$@" diff --git a/bin/spark-submit b/bin/spark-submit index 3e5cbdbb2439..0e0afe71a0f0 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -17,58 +17,21 @@ # limitations under the License. # -# NOTE: Any changes in this file must be reflected in SparkSubmitDriverBootstrapper.scala! - -export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" -ORIG_ARGS=("$@") - -# Set COLUMNS for progress bar -export COLUMNS=`tput cols` - -while (($#)); do - if [ "$1" = "--deploy-mode" ]; then - SPARK_SUBMIT_DEPLOY_MODE=$2 - elif [ "$1" = "--properties-file" ]; then - SPARK_SUBMIT_PROPERTIES_FILE=$2 - elif [ "$1" = "--driver-memory" ]; then - export SPARK_SUBMIT_DRIVER_MEMORY=$2 - elif [ "$1" = "--driver-library-path" ]; then - export SPARK_SUBMIT_LIBRARY_PATH=$2 - elif [ "$1" = "--driver-class-path" ]; then - export SPARK_SUBMIT_CLASSPATH=$2 - elif [ "$1" = "--driver-java-options" ]; then - export SPARK_SUBMIT_OPTS=$2 - elif [ "$1" = "--master" ]; then - export MASTER=$2 - fi - shift -done - -if [ -z "$SPARK_CONF_DIR" ]; then - export SPARK_CONF_DIR="$SPARK_HOME/conf" -fi -DEFAULT_PROPERTIES_FILE="$SPARK_CONF_DIR/spark-defaults.conf" -if [ "$MASTER" == "yarn-cluster" ]; then - SPARK_SUBMIT_DEPLOY_MODE=cluster +SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + +# disable randomized hash for string in Python 3.3+ +export PYTHONHASHSEED=0 + +# Only define a usage function if an upstream script hasn't done so. +if ! type -t usage >/dev/null 2>&1; then + usage() { + if [ -n "$1" ]; then + echo "$1" + fi + "$SPARK_HOME"/bin/spark-class org.apache.spark.deploy.SparkSubmit --help + exit "$2" + } + export -f usage fi -export SPARK_SUBMIT_DEPLOY_MODE=${SPARK_SUBMIT_DEPLOY_MODE:-"client"} -export SPARK_SUBMIT_PROPERTIES_FILE=${SPARK_SUBMIT_PROPERTIES_FILE:-"$DEFAULT_PROPERTIES_FILE"} - -# For client mode, the driver will be launched in the same JVM that launches -# SparkSubmit, so we may need to read the properties file for any extra class -# paths, library paths, java options and memory early on. Otherwise, it will -# be too late by the time the driver JVM has started. - -if [[ "$SPARK_SUBMIT_DEPLOY_MODE" == "client" && -f "$SPARK_SUBMIT_PROPERTIES_FILE" ]]; then - # Parse the properties file only if the special configs exist - contains_special_configs=$( - grep -e "spark.driver.extra*\|spark.driver.memory" "$SPARK_SUBMIT_PROPERTIES_FILE" | \ - grep -v "^[[:space:]]*#" - ) - if [ -n "$contains_special_configs" ]; then - export SPARK_SUBMIT_BOOTSTRAP_DRIVER=1 - fi -fi - -exec "$SPARK_HOME"/bin/spark-class org.apache.spark.deploy.SparkSubmit "${ORIG_ARGS[@]}" +exec "$SPARK_HOME"/bin/spark-class org.apache.spark.deploy.SparkSubmit "$@" diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd index 12244a9cb04f..d3fc4a5cc3f6 100644 --- a/bin/spark-submit2.cmd +++ b/bin/spark-submit2.cmd @@ -17,62 +17,22 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -rem NOTE: Any changes in this file must be reflected in SparkSubmitDriverBootstrapper.scala! - -set SPARK_HOME=%~dp0.. -set ORIG_ARGS=%* - -rem Reset the values of all variables used -set SPARK_SUBMIT_DEPLOY_MODE=client - -if not defined %SPARK_CONF_DIR% ( - set SPARK_CONF_DIR=%SPARK_HOME%\conf -) -set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_CONF_DIR%\spark-defaults.conf -set SPARK_SUBMIT_DRIVER_MEMORY= -set SPARK_SUBMIT_LIBRARY_PATH= -set SPARK_SUBMIT_CLASSPATH= -set SPARK_SUBMIT_OPTS= -set SPARK_SUBMIT_BOOTSTRAP_DRIVER= - -:loop -if [%1] == [] goto continue - if [%1] == [--deploy-mode] ( - set SPARK_SUBMIT_DEPLOY_MODE=%2 - ) else if [%1] == [--properties-file] ( - set SPARK_SUBMIT_PROPERTIES_FILE=%2 - ) else if [%1] == [--driver-memory] ( - set SPARK_SUBMIT_DRIVER_MEMORY=%2 - ) else if [%1] == [--driver-library-path] ( - set SPARK_SUBMIT_LIBRARY_PATH=%2 - ) else if [%1] == [--driver-class-path] ( - set SPARK_SUBMIT_CLASSPATH=%2 - ) else if [%1] == [--driver-java-options] ( - set SPARK_SUBMIT_OPTS=%2 - ) else if [%1] == [--master] ( - set MASTER=%2 - ) - shift -goto loop -:continue - -if [%MASTER%] == [yarn-cluster] ( - set SPARK_SUBMIT_DEPLOY_MODE=cluster -) - -rem For client mode, the driver will be launched in the same JVM that launches -rem SparkSubmit, so we may need to read the properties file for any extra class -rem paths, library paths, java options and memory early on. Otherwise, it will -rem be too late by the time the driver JVM has started. - -if [%SPARK_SUBMIT_DEPLOY_MODE%] == [client] ( - if exist %SPARK_SUBMIT_PROPERTIES_FILE% ( - rem Parse the properties file only if the special configs exist - for /f %%i in ('findstr /r /c:"^[\t ]*spark.driver.memory" /c:"^[\t ]*spark.driver.extra" ^ - %SPARK_SUBMIT_PROPERTIES_FILE%') do ( - set SPARK_SUBMIT_BOOTSTRAP_DRIVER=1 - ) - ) +rem This is the entry point for running Spark submit. To avoid polluting the +rem environment, it just launches a new cmd to do the real work. + +rem disable randomized hash for string in Python 3.3+ +set PYTHONHASHSEED=0 + +set CLASS=org.apache.spark.deploy.SparkSubmit +call %~dp0spark-class2.cmd %CLASS% %* +set SPARK_ERROR_LEVEL=%ERRORLEVEL% +if not "x%SPARK_LAUNCHER_USAGE_ERROR%"=="x" ( + call :usage + exit /b 1 ) +exit /b %SPARK_ERROR_LEVEL% -cmd /V /E /C %SPARK_HOME%\bin\spark-class.cmd org.apache.spark.deploy.SparkSubmit %ORIG_ARGS% +:usage +echo %SPARK_LAUNCHER_USAGE_ERROR% +call %SPARK_HOME%\bin\spark-class2.cmd %CLASS% --help +goto :eof diff --git a/bin/sparkR b/bin/sparkR new file mode 100755 index 000000000000..8c918e2b09ae --- /dev/null +++ b/bin/sparkR @@ -0,0 +1,39 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Figure out where Spark is installed +export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + +source "$SPARK_HOME"/bin/load-spark-env.sh + +function usage() { + if [ -n "$1" ]; then + echo $1 + fi + echo "Usage: ./bin/sparkR [options]" 1>&2 + "$SPARK_HOME"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + exit $2 +} +export -f usage + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + usage +fi + +exec "$SPARK_HOME"/bin/spark-submit sparkr-shell-main "$@" diff --git a/bin/sparkR.cmd b/bin/sparkR.cmd new file mode 100644 index 000000000000..d7b60183ca8e --- /dev/null +++ b/bin/sparkR.cmd @@ -0,0 +1,23 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem This is the entry point for running SparkR. To avoid polluting the +rem environment, it just launches a new cmd to do the real work. + +cmd /V /E /C %~dp0sparkR2.cmd %* diff --git a/bin/sparkR2.cmd b/bin/sparkR2.cmd new file mode 100644 index 000000000000..e47f22c7300b --- /dev/null +++ b/bin/sparkR2.cmd @@ -0,0 +1,26 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem Figure out where the Spark framework is installed +set SPARK_HOME=%~dp0.. + +call %SPARK_HOME%\bin\load-spark-env.cmd + + +call %SPARK_HOME%\bin\spark-submit2.cmd sparkr-shell-main %* diff --git a/bin/utils.sh b/bin/utils.sh deleted file mode 100755 index 22ea2b9a6d58..000000000000 --- a/bin/utils.sh +++ /dev/null @@ -1,59 +0,0 @@ -#!/usr/bin/env bash - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Gather all spark-submit options into SUBMISSION_OPTS -function gatherSparkSubmitOpts() { - - if [ -z "$SUBMIT_USAGE_FUNCTION" ]; then - echo "Function for printing usage of $0 is not set." 1>&2 - echo "Please set usage function to shell variable 'SUBMIT_USAGE_FUNCTION' in $0" 1>&2 - exit 1 - fi - - # NOTE: If you add or remove spark-sumbmit options, - # modify NOT ONLY this script but also SparkSubmitArgument.scala - SUBMISSION_OPTS=() - APPLICATION_OPTS=() - while (($#)); do - case "$1" in - --master | --deploy-mode | --class | --name | --jars | --py-files | --files | \ - --conf | --properties-file | --driver-memory | --driver-java-options | \ - --driver-library-path | --driver-class-path | --executor-memory | --driver-cores | \ - --total-executor-cores | --executor-cores | --queue | --num-executors | --archives) - if [[ $# -lt 2 ]]; then - "$SUBMIT_USAGE_FUNCTION" - exit 1; - fi - SUBMISSION_OPTS+=("$1"); shift - SUBMISSION_OPTS+=("$1"); shift - ;; - - --verbose | -v | --supervise) - SUBMISSION_OPTS+=("$1"); shift - ;; - - *) - APPLICATION_OPTS+=("$1"); shift - ;; - esac - done - - export SUBMISSION_OPTS - export APPLICATION_OPTS -} diff --git a/bin/windows-utils.cmd b/bin/windows-utils.cmd deleted file mode 100644 index 1082a952dac9..000000000000 --- a/bin/windows-utils.cmd +++ /dev/null @@ -1,59 +0,0 @@ -rem -rem Licensed to the Apache Software Foundation (ASF) under one or more -rem contributor license agreements. See the NOTICE file distributed with -rem this work for additional information regarding copyright ownership. -rem The ASF licenses this file to You under the Apache License, Version 2.0 -rem (the "License"); you may not use this file except in compliance with -rem the License. You may obtain a copy of the License at -rem -rem http://www.apache.org/licenses/LICENSE-2.0 -rem -rem Unless required by applicable law or agreed to in writing, software -rem distributed under the License is distributed on an "AS IS" BASIS, -rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -rem See the License for the specific language governing permissions and -rem limitations under the License. -rem - -rem Gather all spark-submit options into SUBMISSION_OPTS - -set SUBMISSION_OPTS= -set APPLICATION_OPTS= - -rem NOTE: If you add or remove spark-sumbmit options, -rem modify NOT ONLY this script but also SparkSubmitArgument.scala - -:OptsLoop -if "x%1"=="x" ( - goto :OptsLoopEnd -) - -SET opts="\<--master\> \<--deploy-mode\> \<--class\> \<--name\> \<--jars\> \<--py-files\> \<--files\>" -SET opts="%opts:~1,-1% \<--conf\> \<--properties-file\> \<--driver-memory\> \<--driver-java-options\>" -SET opts="%opts:~1,-1% \<--driver-library-path\> \<--driver-class-path\> \<--executor-memory\>" -SET opts="%opts:~1,-1% \<--driver-cores\> \<--total-executor-cores\> \<--executor-cores\> \<--queue\>" -SET opts="%opts:~1,-1% \<--num-executors\> \<--archives\>" - -echo %1 | findstr %opts% >nul -if %ERRORLEVEL% equ 0 ( - if "x%2"=="x" ( - echo "%1" requires an argument. >&2 - exit /b 1 - ) - set SUBMISSION_OPTS=%SUBMISSION_OPTS% %1 %2 - shift - shift - goto :OptsLoop -) -echo %1 | findstr "\<--verbose\> \<-v\> \<--supervise\>" >nul -if %ERRORLEVEL% equ 0 ( - set SUBMISSION_OPTS=%SUBMISSION_OPTS% %1 - shift - goto :OptsLoop -) -set APPLICATION_OPTS=%APPLICATION_OPTS% %1 -shift -goto :OptsLoop - -:OptsLoopEnd -exit /b 0 diff --git a/build/mvn b/build/mvn index 43471f83e904..3561110a4c01 100755 --- a/build/mvn +++ b/build/mvn @@ -21,6 +21,8 @@ _DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" # Preserve the calling directory _CALLING_DIR="$(pwd)" +# Options used during compilation +_COMPILE_JVM_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" # Installs any application tarball given a URL, the expected tarball name, # and, optionally, a checkable binary path to determine if the binary has @@ -34,25 +36,25 @@ install_app() { local binary="${_DIR}/$3" # setup `curl` and `wget` silent options if we're running on Jenkins - local curl_opts="" + local curl_opts="-L" local wget_opts="" if [ -n "$AMPLAB_JENKINS" ]; then - curl_opts="-s" - wget_opts="--quiet" + curl_opts="-s ${curl_opts}" + wget_opts="--quiet ${wget_opts}" else - curl_opts="--progress-bar" - wget_opts="--progress=bar:force" + curl_opts="--progress-bar ${curl_opts}" + wget_opts="--progress=bar:force ${wget_opts}" fi if [ -z "$3" -o ! -f "$binary" ]; then # check if we already have the tarball # check if we have curl installed # download application - [ ! -f "${local_tarball}" ] && [ -n "`which curl 2>/dev/null`" ] && \ + [ ! -f "${local_tarball}" ] && [ $(command -v curl) ] && \ echo "exec: curl ${curl_opts} ${remote_tarball}" && \ curl ${curl_opts} "${remote_tarball}" > "${local_tarball}" # if the file still doesn't exist, lets try `wget` and cross our fingers - [ ! -f "${local_tarball}" ] && [ -n "`which wget 2>/dev/null`" ] && \ + [ ! -f "${local_tarball}" ] && [ $(command -v wget) ] && \ echo "exec: wget ${wget_opts} ${remote_tarball}" && \ wget ${wget_opts} -O "${local_tarball}" "${remote_tarball}" # if both were unsuccessful, exit @@ -68,10 +70,10 @@ install_app() { # Install maven under the build/ folder install_mvn() { install_app \ - "http://apache.claz.org/maven/maven-3/3.2.3/binaries" \ - "apache-maven-3.2.3-bin.tar.gz" \ - "apache-maven-3.2.3/bin/mvn" - MVN_BIN="${_DIR}/apache-maven-3.2.3/bin/mvn" + "http://archive.apache.org/dist/maven/maven-3/3.2.5/binaries" \ + "apache-maven-3.2.5-bin.tar.gz" \ + "apache-maven-3.2.5/bin/mvn" + MVN_BIN="${_DIR}/apache-maven-3.2.5/bin/mvn" } # Install zinc under the build/ folder @@ -136,6 +138,7 @@ cd "${_CALLING_DIR}" # Now that zinc is ensured to be installed, check its status and, if its # not running or just installed, start it if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`${ZINC_BIN} -status`" ]; then + export ZINC_OPTS=${ZINC_OPTS:-"$_COMPILE_JVM_OPTS"} ${ZINC_BIN} -shutdown ${ZINC_BIN} -start -port ${ZINC_PORT} \ -scala-compiler "${SCALA_COMPILER}" \ @@ -143,7 +146,7 @@ if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`${ZINC_BIN} -status`" ]; then fi # Set any `mvn` options if not already present -export MAVEN_OPTS=${MAVEN_OPTS:-"-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m"} +export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"} # Last, call the `mvn` command as usual ${MVN_BIN} "$@" diff --git a/build/sbt b/build/sbt index 28ebb64f7197..cc3203d79bcc 100755 --- a/build/sbt +++ b/build/sbt @@ -125,4 +125,32 @@ loadConfigFile() { [[ -f "$etc_sbt_opts_file" ]] && set -- $(loadConfigFile "$etc_sbt_opts_file") "$@" [[ -f "$sbt_opts_file" ]] && set -- $(loadConfigFile "$sbt_opts_file") "$@" +exit_status=127 +saved_stty="" + +restoreSttySettings() { + stty $saved_stty + saved_stty="" +} + +onExit() { + if [[ "$saved_stty" != "" ]]; then + restoreSttySettings + fi + exit $exit_status +} + +saveSttySettings() { + saved_stty=$(stty -g 2>/dev/null) + if [[ ! $? ]]; then + saved_stty="" + fi +} + +saveSttySettings +trap onExit INT + run "$@" + +exit_status=$? +onExit diff --git a/build/sbt-launch-lib.bash b/build/sbt-launch-lib.bash index f5df439effb0..504be48b358f 100755 --- a/build/sbt-launch-lib.bash +++ b/build/sbt-launch-lib.bash @@ -50,9 +50,9 @@ acquire_sbt_jar () { # Download printf "Attempting to fetch sbt\n" JAR_DL="${JAR}.part" - if hash curl 2>/dev/null; then + if [ $(command -v curl) ]; then (curl --silent ${URL1} > "${JAR_DL}" || curl --silent ${URL2} > "${JAR_DL}") && mv "${JAR_DL}" "${JAR}" - elif hash wget 2>/dev/null; then + elif [ $(command -v wget) ]; then (wget --quiet ${URL1} -O "${JAR_DL}" || wget --quiet ${URL2} -O "${JAR_DL}") && mv "${JAR_DL}" "${JAR}" else printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n" @@ -81,7 +81,7 @@ execRunner () { echo "" } - exec "$@" + "$@" } addJava () { diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index 89eec7d4b7f6..3a2a88219818 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -6,7 +6,7 @@ log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n # Settings to quiet third party logs that are too verbose -log4j.logger.org.eclipse.jetty=WARN -log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index 96b6844f0aab..2e0cb5db170a 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -87,6 +87,7 @@ # period 10 Poll period # unit seconds Units of poll period # prefix EMPTY STRING Prefix to prepend to metric name +# protocol tcp Protocol ("tcp" or "udp") to use ## Examples # Enable JmxSink for all instances by class name @@ -121,6 +122,15 @@ #worker.sink.csv.unit=minutes +# Enable Slf4jSink for all instances by class name +#*.sink.slf4j.class=org.apache.spark.metrics.sink.Slf4jSink + +# Polling period for Slf4JSink +#*.sink.sl4j.period=1 + +#*.sink.sl4j.unit=minutes + + # Enable jvm source for instance master, worker, driver and executor #master.source.jvm.class=org.apache.spark.metrics.source.JvmSource diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 0886b0276fb9..67f81d33361e 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -15,7 +15,7 @@ # - SPARK_PUBLIC_DNS, to set the public DNS name of the driver program # - SPARK_CLASSPATH, default classpath entries to append # - SPARK_LOCAL_DIRS, storage directories to use on this node for shuffle and RDD data -# - MESOS_NATIVE_LIBRARY, to point to your libmesos.so if you use Mesos +# - MESOS_NATIVE_JAVA_LIBRARY, to point to your libmesos.so if you use Mesos # Options read in YARN client mode # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files diff --git a/core/pom.xml b/core/pom.xml index d9a49c9e08af..5e89d548cd47 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT ../pom.xml @@ -34,6 +34,10 @@ Spark Project Core http://spark.apache.org/ + + com.google.guava + guava + com.twitter chill_${scala.binary.version} @@ -70,8 +74,17 @@ javax.servlet servlet-api + + org.codehaus.jackson + jackson-mapper-asl + + + org.apache.spark + spark-launcher_${scala.binary.version} + ${project.version} + org.apache.spark spark-network-common_${scala.binary.version} @@ -90,32 +103,52 @@ org.apache.curator curator-recipes + + org.eclipse.jetty jetty-plus + compile org.eclipse.jetty jetty-security + compile org.eclipse.jetty jetty-util + compile org.eclipse.jetty jetty-server + compile - - com.google.guava - guava + org.eclipse.jetty + jetty-http + compile + + + org.eclipse.jetty + jetty-continuation compile + + org.eclipse.jetty + jetty-servlet + compile + + + + org.eclipse.jetty.orbit + javax.servlet + ${orbit.version} + + org.apache.commons commons-lang3 @@ -204,30 +237,49 @@ stream - com.codahale.metrics + io.dropwizard.metrics metrics-core - com.codahale.metrics + io.dropwizard.metrics metrics-jvm - com.codahale.metrics + io.dropwizard.metrics metrics-json - com.codahale.metrics + io.dropwizard.metrics metrics-graphite + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.module + jackson-module-scala_2.10 + org.apache.derby derby test + + org.apache.ivy + ivy + ${ivy.version} + + + oro + + oro + ${oro.version} + org.tachyonproject tachyon-client - 0.5.0 + 0.6.4 org.apache.hadoop @@ -276,6 +328,12 @@ selenium-java test + + + xml-apis + xml-apis + test + org.mockito mockito-all @@ -286,16 +344,6 @@ scalacheck_${scala.binary.version} test - - org.easymock - easymockclassextension - test - - - asm - asm - test - junit junit @@ -350,59 +398,28 @@ true - - org.apache.maven.plugins - maven-shade-plugin - - - package - - shade - - - false - - - com.google.guava:guava - - - - - - com.google.guava:guava - - com/google/common/base/Absent* - com/google/common/base/Optional* - com/google/common/base/Present* - - - - - - - - org.apache.maven.plugins maven-dependency-plugin + copy-dependencies package copy-dependencies - + ${project.build.directory} false false true true - guava + + guava,jetty-io,jetty-servlet,jetty-continuation,jetty-http,jetty-plus,jetty-util,jetty-server,jetty-security + true @@ -429,4 +446,55 @@ + + + Windows + + + Windows + + + + \ + .bat + + + + unix + + + unix + + + + / + .sh + + + + sparkr + + + + org.codehaus.mojo + exec-maven-plugin + 1.3.2 + + + sparkr-pkg + compile + + exec + + + + + ..${path.separator}R${path.separator}install-dev${script.extension} + + + + + + + diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java new file mode 100644 index 000000000000..fbc566695905 --- /dev/null +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import org.apache.spark.scheduler.*; + +/** + * Class that allows users to receive all SparkListener events. + * Users should override the onEvent method. + * + * This is a concrete Java class in order to ensure that we don't forget to update it when adding + * new methods to SparkListener: forgetting to add a method will result in a compilation error (if + * this was a concrete Scala class, default implementations of new event handlers would be inherited + * from the SparkListener trait). + */ +public class SparkFirehoseListener implements SparkListener { + + public void onEvent(SparkListenerEvent event) { } + + @Override + public final void onStageCompleted(SparkListenerStageCompleted stageCompleted) { + onEvent(stageCompleted); + } + + @Override + public final void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) { + onEvent(stageSubmitted); + } + + @Override + public final void onTaskStart(SparkListenerTaskStart taskStart) { + onEvent(taskStart); + } + + @Override + public final void onTaskGettingResult(SparkListenerTaskGettingResult taskGettingResult) { + onEvent(taskGettingResult); + } + + @Override + public final void onTaskEnd(SparkListenerTaskEnd taskEnd) { + onEvent(taskEnd); + } + + @Override + public final void onJobStart(SparkListenerJobStart jobStart) { + onEvent(jobStart); + } + + @Override + public final void onJobEnd(SparkListenerJobEnd jobEnd) { + onEvent(jobEnd); + } + + @Override + public final void onEnvironmentUpdate(SparkListenerEnvironmentUpdate environmentUpdate) { + onEvent(environmentUpdate); + } + + @Override + public final void onBlockManagerAdded(SparkListenerBlockManagerAdded blockManagerAdded) { + onEvent(blockManagerAdded); + } + + @Override + public final void onBlockManagerRemoved(SparkListenerBlockManagerRemoved blockManagerRemoved) { + onEvent(blockManagerRemoved); + } + + @Override + public final void onUnpersistRDD(SparkListenerUnpersistRDD unpersistRDD) { + onEvent(unpersistRDD); + } + + @Override + public final void onApplicationStart(SparkListenerApplicationStart applicationStart) { + onEvent(applicationStart); + } + + @Override + public final void onApplicationEnd(SparkListenerApplicationEnd applicationEnd) { + onEvent(applicationEnd); + } + + @Override + public final void onExecutorMetricsUpdate( + SparkListenerExecutorMetricsUpdate executorMetricsUpdate) { + onEvent(executorMetricsUpdate); + } + + @Override + public final void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { + onEvent(executorAdded); + } + + @Override + public final void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { + onEvent(executorRemoved); + } +} diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java deleted file mode 100644 index 095f9fb94fdf..000000000000 --- a/core/src/main/java/org/apache/spark/TaskContext.java +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark; - -import java.io.Serializable; - -import scala.Function0; -import scala.Function1; -import scala.Unit; - -import org.apache.spark.annotation.DeveloperApi; -import org.apache.spark.executor.TaskMetrics; -import org.apache.spark.util.TaskCompletionListener; - -/** - * Contextual information about a task which can be read or mutated during - * execution. To access the TaskContext for a running task use - * TaskContext.get(). - */ -public abstract class TaskContext implements Serializable { - /** - * Return the currently active TaskContext. This can be called inside of - * user functions to access contextual information about running tasks. - */ - public static TaskContext get() { - return taskContext.get(); - } - - private static ThreadLocal taskContext = - new ThreadLocal(); - - static void setTaskContext(TaskContext tc) { - taskContext.set(tc); - } - - static void unset() { - taskContext.remove(); - } - - /** - * Whether the task has completed. - */ - public abstract boolean isCompleted(); - - /** - * Whether the task has been killed. - */ - public abstract boolean isInterrupted(); - - /** @deprecated use {@link #isRunningLocally()} */ - @Deprecated - public abstract boolean runningLocally(); - - public abstract boolean isRunningLocally(); - - /** - * Add a (Java friendly) listener to be executed on task completion. - * This will be called in all situation - success, failure, or cancellation. - * An example use is for HadoopRDD to register a callback to close the input stream. - */ - public abstract TaskContext addTaskCompletionListener(TaskCompletionListener listener); - - /** - * Add a listener in the form of a Scala closure to be executed on task completion. - * This will be called in all situations - success, failure, or cancellation. - * An example use is for HadoopRDD to register a callback to close the input stream. - */ - public abstract TaskContext addTaskCompletionListener(final Function1 f); - - /** - * Add a callback function to be executed on task completion. An example use - * is for HadoopRDD to register a callback to close the input stream. - * Will be called in any situation - success, failure, or cancellation. - * - * @deprecated use {@link #addTaskCompletionListener(scala.Function1)} - * - * @param f Callback function. - */ - @Deprecated - public abstract void addOnCompleteCallback(final Function0 f); - - /** - * The ID of the stage that this task belong to. - */ - public abstract int stageId(); - - /** - * The ID of the RDD partition that is computed by this task. - */ - public abstract int partitionId(); - - /** - * How many times this task has been attempted. The first task attempt will be assigned - * attemptNumber = 0, and subsequent attempts will have increasing attempt numbers. - */ - public abstract int attemptNumber(); - - /** @deprecated use {@link #taskAttemptId()}; it was renamed to avoid ambiguity. */ - @Deprecated - public abstract long attemptId(); - - /** - * An ID that is unique to this task attempt (within the same SparkContext, no two task attempts - * will share the same attempt ID). This is roughly equivalent to Hadoop's TaskAttemptID. - */ - public abstract long taskAttemptId(); - - /** ::DeveloperApi:: */ - @DeveloperApi - public abstract TaskMetrics taskMetrics(); -} diff --git a/core/src/test/scala/org/apache/spark/util/FakeClock.scala b/core/src/main/java/org/apache/spark/api/java/function/Function0.java similarity index 78% rename from core/src/test/scala/org/apache/spark/util/FakeClock.scala rename to core/src/main/java/org/apache/spark/api/java/function/Function0.java index 0a45917b08dd..38e410c5debe 100644 --- a/core/src/test/scala/org/apache/spark/util/FakeClock.scala +++ b/core/src/main/java/org/apache/spark/api/java/function/Function0.java @@ -15,12 +15,13 @@ * limitations under the License. */ -package org.apache.spark.util +package org.apache.spark.api.java.function; -class FakeClock extends Clock { - private var time = 0L +import java.io.Serializable; - def advance(millis: Long): Unit = time += millis - - def getTime(): Long = time +/** + * A zero-argument function that returns an R. + */ +public interface Function0 extends Serializable { + public R call() throws Exception; } diff --git a/core/src/main/java/org/apache/spark/util/collection/TimSort.java b/core/src/main/java/org/apache/spark/util/collection/TimSort.java index 409e1a41c5d4..a90cc0e761f6 100644 --- a/core/src/main/java/org/apache/spark/util/collection/TimSort.java +++ b/core/src/main/java/org/apache/spark/util/collection/TimSort.java @@ -425,15 +425,14 @@ private void pushRun(int runBase, int runLen) { private void mergeCollapse() { while (stackSize > 1) { int n = stackSize - 2; - if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) { + if ( (n >= 1 && runLen[n-1] <= runLen[n] + runLen[n+1]) + || (n >= 2 && runLen[n-2] <= runLen[n] + runLen[n-1])) { if (runLen[n - 1] < runLen[n + 1]) n--; - mergeAt(n); - } else if (runLen[n] <= runLen[n + 1]) { - mergeAt(n); - } else { + } else if (runLen[n] > runLen[n + 1]) { break; // Invariant is established } + mergeAt(n); } } diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index c99a61f63ea2..3a2a88219818 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -6,8 +6,7 @@ log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n # Settings to quiet third party logs that are too verbose -log4j.logger.org.eclipse.jetty=WARN -log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO -log4j.logger.org.apache.hadoop.yarn.util.RackResolver=WARN diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js index 14ba37d7c9bd..013db8df9b36 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js +++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js @@ -30,7 +30,7 @@ $(function() { stripeSummaryTable(); - $("input:checkbox").click(function() { + $('input[type="checkbox"]').click(function() { var column = "table ." + $(this).attr("name"); $(column).toggle(); stripeSummaryTable(); @@ -39,15 +39,15 @@ $(function() { $("#select-all-metrics").click(function() { if (this.checked) { // Toggle all un-checked options. - $('input:checkbox:not(:checked)').trigger('click'); + $('input[type="checkbox"]:not(:checked)').trigger('click'); } else { // Toggle all checked options. - $('input:checkbox:checked').trigger('click'); + $('input[type="checkbox"]:checked').trigger('click'); } }); // Trigger a click on the checkbox if a user clicks the label next to it. $("span.additional-metric-title").click(function() { - $(this).parent().find('input:checkbox').trigger('click'); + $(this).parent().find('input[type="checkbox"]').trigger('click'); }); }); diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index a1f7133f897e..4910744d1d79 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -85,17 +85,13 @@ table.sortable td { filter: progid:dximagetransform.microsoft.gradient(startColorstr='#FFA4EDFF', endColorstr='#FF94DDFF', GradientType=0); } -span.kill-link { +a.kill-link { margin-right: 2px; margin-left: 20px; color: gray; float: right; } -span.kill-link a { - color: gray; -} - span.expand-details { font-size: 10pt; cursor: pointer; @@ -103,6 +99,12 @@ span.expand-details { float: right; } +span.rest-uri { + font-size: 10pt; + font-style: italic; + color: gray; +} + pre { font-size: 0.8em; } @@ -190,6 +192,7 @@ span.additional-metric-title { /* Hide all additional metrics by default. This is done here rather than using JavaScript to * avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */ -.scheduler_delay, .deserialization_time, .serialization_time, .getting_result_time { +.scheduler_delay, .deserialization_time, .fetch_wait_time, .shuffle_read_remote, +.serialization_time, .getting_result_time { display: none; } diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 5f31bfba3f8d..330df1d59a9b 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -18,11 +18,10 @@ package org.apache.spark import java.io.{ObjectInputStream, Serializable} -import java.util.concurrent.atomic.AtomicLong -import java.lang.ThreadLocal import scala.collection.generic.Growable import scala.collection.mutable.Map +import scala.ref.WeakReference import scala.reflect.ClassTag import org.apache.spark.serializer.JavaSerializer @@ -108,7 +107,7 @@ class Accumulable[R, T] ( * The typical use of this method is to directly mutate the local value, eg., to add * an element to a Set. */ - def localValue = value_ + def localValue: R = value_ /** * Set the accumulator's value; only allowed on master. @@ -136,7 +135,7 @@ class Accumulable[R, T] ( Accumulators.register(this, false) } - override def toString = if (value_ == null) "null" else value_.toString + override def toString: String = if (value_ == null) "null" else value_.toString } /** @@ -256,22 +255,22 @@ object AccumulatorParam { implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 - def zero(initialValue: Double) = 0.0 + def zero(initialValue: Double): Double = 0.0 } implicit object IntAccumulatorParam extends AccumulatorParam[Int] { def addInPlace(t1: Int, t2: Int): Int = t1 + t2 - def zero(initialValue: Int) = 0 + def zero(initialValue: Int): Int = 0 } implicit object LongAccumulatorParam extends AccumulatorParam[Long] { - def addInPlace(t1: Long, t2: Long) = t1 + t2 - def zero(initialValue: Long) = 0L + def addInPlace(t1: Long, t2: Long): Long = t1 + t2 + def zero(initialValue: Long): Long = 0L } implicit object FloatAccumulatorParam extends AccumulatorParam[Float] { - def addInPlace(t1: Float, t2: Float) = t1 + t2 - def zero(initialValue: Float) = 0f + def addInPlace(t1: Float, t2: Float): Float = t1 + t2 + def zero(initialValue: Float): Float = 0f } // TODO: Add AccumulatorParams for other types, e.g. lists and strings @@ -279,13 +278,24 @@ object AccumulatorParam { // TODO: The multi-thread support in accumulators is kind of lame; check // if there's a more intuitive way of doing it right -private[spark] object Accumulators { - // TODO: Use soft references? => need to make readObject work properly then - val originals = Map[Long, Accumulable[_, _]]() - val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() { +private[spark] object Accumulators extends Logging { + /** + * This global map holds the original accumulator objects that are created on the driver. + * It keeps weak references to these objects so that accumulators can be garbage-collected + * once the RDDs and user-code that reference them are cleaned up. + */ + val originals = Map[Long, WeakReference[Accumulable[_, _]]]() + + /** + * This thread-local map holds per-task copies of accumulators; it is used to collect the set + * of accumulator updates to send back to the driver when tasks complete. After tasks complete, + * this map is cleared by `Accumulators.clear()` (see Executor.scala). + */ + private val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() { override protected def initialValue() = Map[Long, Accumulable[_, _]]() } - var lastId: Long = 0 + + private var lastId: Long = 0 def newId(): Long = synchronized { lastId += 1 @@ -294,7 +304,7 @@ private[spark] object Accumulators { def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized { if (original) { - originals(a.id) = a + originals(a.id) = new WeakReference[Accumulable[_, _]](a) } else { localAccums.get()(a.id) = a } @@ -303,7 +313,13 @@ private[spark] object Accumulators { // Clear the local (non-original) accumulators for the current thread def clear() { synchronized { - localAccums.get.clear + localAccums.get.clear() + } + } + + def remove(accId: Long) { + synchronized { + originals.remove(accId) } } @@ -320,11 +336,20 @@ private[spark] object Accumulators { def add(values: Map[Long, Any]): Unit = synchronized { for ((id, value) <- values) { if (originals.contains(id)) { - originals(id).asInstanceOf[Accumulable[Any, Any]] ++= value + // Since we are now storing weak references, we must check whether the underlying data + // is valid. + originals(id).get match { + case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]] ++= value + case None => + throw new IllegalAccessError("Attempted to access garbage collected Accumulator.") + } + } else { + logWarning(s"Ignoring accumulator update for unknown accumulator id $id") } } } - def stringifyPartialValue(partialValue: Any) = "%s".format(partialValue) - def stringifyValue(value: Any) = "%s".format(value) + def stringifyPartialValue(partialValue: Any): String = "%s".format(partialValue) + + def stringifyValue(value: Any): String = "%s".format(value) } diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index a0c0372b7f0e..4d20c7369376 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -44,13 +44,17 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { blockManager.get(key) match { case Some(blockResult) => // Partition is already materialized, so just return its values - val inputMetrics = blockResult.inputMetrics val existingMetrics = context.taskMetrics - .getInputMetricsForReadMethod(inputMetrics.readMethod) - existingMetrics.addBytesRead(inputMetrics.bytesRead) - - new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]]) - + .getInputMetricsForReadMethod(blockResult.readMethod) + existingMetrics.incBytesRead(blockResult.bytes) + + val iter = blockResult.data.asInstanceOf[Iterator[T]] + new InterruptibleIterator[T](context, iter) { + override def next(): T = { + existingMetrics.incRecordsRead(1) + delegate.next() + } + } case None => // Acquire a lock for loading this partition // If another thread already holds the lock, wait for it to finish return its results diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index ede1e23f4fcc..37198d887b07 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -22,7 +22,7 @@ import java.lang.ref.{ReferenceQueue, WeakReference} import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{RDDCheckpointData, RDD} import org.apache.spark.util.Utils /** @@ -32,6 +32,8 @@ private sealed trait CleanupTask private case class CleanRDD(rddId: Int) extends CleanupTask private case class CleanShuffle(shuffleId: Int) extends CleanupTask private case class CleanBroadcast(broadcastId: Long) extends CleanupTask +private case class CleanAccum(accId: Long) extends CleanupTask +private case class CleanCheckpoint(rddId: Int) extends CleanupTask /** * A WeakReference associated with a CleanupTask. @@ -93,68 +95,95 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { @volatile private var stopped = false /** Attach a listener object to get information of when objects are cleaned. */ - def attachListener(listener: CleanerListener) { + def attachListener(listener: CleanerListener): Unit = { listeners += listener } /** Start the cleaner. */ - def start() { + def start(): Unit = { cleaningThread.setDaemon(true) cleaningThread.setName("Spark Context Cleaner") cleaningThread.start() } - /** Stop the cleaner. */ - def stop() { + /** + * Stop the cleaning thread and wait until the thread has finished running its current task. + */ + def stop(): Unit = { stopped = true + // Interrupt the cleaning thread, but wait until the current task has finished before + // doing so. This guards against the race condition where a cleaning thread may + // potentially clean similarly named variables created by a different SparkContext, + // resulting in otherwise inexplicable block-not-found exceptions (SPARK-6132). + synchronized { + cleaningThread.interrupt() + } + cleaningThread.join() } /** Register a RDD for cleanup when it is garbage collected. */ - def registerRDDForCleanup(rdd: RDD[_]) { + def registerRDDForCleanup(rdd: RDD[_]): Unit = { registerForCleanup(rdd, CleanRDD(rdd.id)) } + def registerAccumulatorForCleanup(a: Accumulable[_, _]): Unit = { + registerForCleanup(a, CleanAccum(a.id)) + } + /** Register a ShuffleDependency for cleanup when it is garbage collected. */ - def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]) { + def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]): Unit = { registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId)) } /** Register a Broadcast for cleanup when it is garbage collected. */ - def registerBroadcastForCleanup[T](broadcast: Broadcast[T]) { + def registerBroadcastForCleanup[T](broadcast: Broadcast[T]): Unit = { registerForCleanup(broadcast, CleanBroadcast(broadcast.id)) } + /** Register a RDDCheckpointData for cleanup when it is garbage collected. */ + def registerRDDCheckpointDataForCleanup[T](rdd: RDD[_], parentId: Int): Unit = { + registerForCleanup(rdd, CleanCheckpoint(parentId)) + } + /** Register an object for cleanup. */ - private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) { + private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = { referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue) } /** Keep cleaning RDD, shuffle, and broadcast state. */ - private def keepCleaning(): Unit = Utils.logUncaughtExceptions { + private def keepCleaning(): Unit = Utils.tryOrStopSparkContext(sc) { while (!stopped) { try { val reference = Option(referenceQueue.remove(ContextCleaner.REF_QUEUE_POLL_TIMEOUT)) .map(_.asInstanceOf[CleanupTaskWeakReference]) - reference.map(_.task).foreach { task => - logDebug("Got cleaning task " + task) - referenceBuffer -= reference.get - task match { - case CleanRDD(rddId) => - doCleanupRDD(rddId, blocking = blockOnCleanupTasks) - case CleanShuffle(shuffleId) => - doCleanupShuffle(shuffleId, blocking = blockOnShuffleCleanupTasks) - case CleanBroadcast(broadcastId) => - doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks) + // Synchronize here to avoid being interrupted on stop() + synchronized { + reference.map(_.task).foreach { task => + logDebug("Got cleaning task " + task) + referenceBuffer -= reference.get + task match { + case CleanRDD(rddId) => + doCleanupRDD(rddId, blocking = blockOnCleanupTasks) + case CleanShuffle(shuffleId) => + doCleanupShuffle(shuffleId, blocking = blockOnShuffleCleanupTasks) + case CleanBroadcast(broadcastId) => + doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks) + case CleanAccum(accId) => + doCleanupAccum(accId, blocking = blockOnCleanupTasks) + case CleanCheckpoint(rddId) => + doCleanCheckpoint(rddId) + } } } } catch { + case ie: InterruptedException if stopped => // ignore case e: Exception => logError("Error in cleaning thread", e) } } } /** Perform RDD cleanup. */ - def doCleanupRDD(rddId: Int, blocking: Boolean) { + def doCleanupRDD(rddId: Int, blocking: Boolean): Unit = { try { logDebug("Cleaning RDD " + rddId) sc.unpersistRDD(rddId, blocking) @@ -166,7 +195,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Perform shuffle cleanup, asynchronously. */ - def doCleanupShuffle(shuffleId: Int, blocking: Boolean) { + def doCleanupShuffle(shuffleId: Int, blocking: Boolean): Unit = { try { logDebug("Cleaning shuffle " + shuffleId) mapOutputTrackerMaster.unregisterShuffle(shuffleId) @@ -179,17 +208,42 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Perform broadcast cleanup. */ - def doCleanupBroadcast(broadcastId: Long, blocking: Boolean) { + def doCleanupBroadcast(broadcastId: Long, blocking: Boolean): Unit = { try { - logDebug("Cleaning broadcast " + broadcastId) + logDebug(s"Cleaning broadcast $broadcastId") broadcastManager.unbroadcast(broadcastId, true, blocking) listeners.foreach(_.broadcastCleaned(broadcastId)) - logInfo("Cleaned broadcast " + broadcastId) + logDebug(s"Cleaned broadcast $broadcastId") } catch { case e: Exception => logError("Error cleaning broadcast " + broadcastId, e) } } + /** Perform accumulator cleanup. */ + def doCleanupAccum(accId: Long, blocking: Boolean): Unit = { + try { + logDebug("Cleaning accumulator " + accId) + Accumulators.remove(accId) + listeners.foreach(_.accumCleaned(accId)) + logInfo("Cleaned accumulator " + accId) + } catch { + case e: Exception => logError("Error cleaning accumulator " + accId, e) + } + } + + /** Perform checkpoint cleanup. */ + def doCleanCheckpoint(rddId: Int): Unit = { + try { + logDebug("Cleaning rdd checkpoint data " + rddId) + RDDCheckpointData.clearRDDCheckpointData(sc, rddId) + listeners.foreach(_.checkpointCleaned(rddId)) + logInfo("Cleaned rdd checkpoint data " + rddId) + } + catch { + case e: Exception => logError("Error cleaning rdd checkpoint data " + rddId, e) + } + } + private def blockManagerMaster = sc.env.blockManager.master private def broadcastManager = sc.env.broadcastManager private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] @@ -206,4 +260,6 @@ private[spark] trait CleanerListener { def rddCleaned(rddId: Int) def shuffleCleaned(shuffleId: Int) def broadcastCleaned(broadcastId: Long) + def accumCleaned(accId: Long) + def checkpointCleaned(rddId: Long) } diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 9a7cd4523e5a..fc8cdde9348e 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -74,7 +74,7 @@ class ShuffleDependency[K, V, C]( val mapSideCombine: Boolean = false) extends Dependency[Product2[K, V]] { - override def rdd = _rdd.asInstanceOf[RDD[Product2[K, V]]] + override def rdd: RDD[Product2[K, V]] = _rdd.asInstanceOf[RDD[Product2[K, V]]] val shuffleId: Int = _rdd.context.newShuffleId() @@ -91,7 +91,7 @@ class ShuffleDependency[K, V, C]( */ @DeveloperApi class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) { - override def getParents(partitionId: Int) = List(partitionId) + override def getParents(partitionId: Int): List[Int] = List(partitionId) } @@ -107,7 +107,7 @@ class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) { class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int) extends NarrowDependency[T](rdd) { - override def getParents(partitionId: Int) = { + override def getParents(partitionId: Int): List[Int] = { if (partitionId >= outStart && partitionId < outStart + length) { List(partitionId - outStart + inStart) } else { diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala index a46a81eabd96..443830f8d03b 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala @@ -19,24 +19,32 @@ package org.apache.spark /** * A client that communicates with the cluster manager to request or kill executors. + * This is currently supported only in YARN mode. */ private[spark] trait ExecutorAllocationClient { + /** + * Express a preference to the cluster manager for a given total number of executors. + * This can result in canceling pending requests or filing additional requests. + * @return whether the request is acknowledged by the cluster manager. + */ + private[spark] def requestTotalExecutors(numExecutors: Int): Boolean + /** * Request an additional number of executors from the cluster manager. - * Return whether the request is acknowledged by the cluster manager. + * @return whether the request is acknowledged by the cluster manager. */ def requestExecutors(numAdditionalExecutors: Int): Boolean /** * Request that the cluster manager kill the specified executors. - * Return whether the request is acknowledged by the cluster manager. + * @return whether the request is acknowledged by the cluster manager. */ def killExecutors(executorIds: Seq[String]): Boolean /** * Request that the cluster manager kill the specified executor. - * Return whether the request is acknowledged by the cluster manager. + * @return whether the request is acknowledged by the cluster manager. */ def killExecutor(executorId: String): Boolean = killExecutors(Seq(executorId)) } diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index b28da192c1c0..b986fa87dc2f 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -17,9 +17,12 @@ package org.apache.spark +import java.util.concurrent.TimeUnit + import scala.collection.mutable import org.apache.spark.scheduler._ +import org.apache.spark.util.{ThreadUtils, Clock, SystemClock, Utils} /** * An agent that dynamically allocates and removes executors based on the workload. @@ -49,6 +52,7 @@ import org.apache.spark.scheduler._ * spark.dynamicAllocation.enabled - Whether this feature is enabled * spark.dynamicAllocation.minExecutors - Lower bound on the number of executors * spark.dynamicAllocation.maxExecutors - Upper bound on the number of executors + * spark.dynamicAllocation.initialExecutors - Number of executors to start with * * spark.dynamicAllocation.schedulerBacklogTimeout (M) - * If there are backlogged tasks for this duration, add new executors @@ -70,21 +74,22 @@ private[spark] class ExecutorAllocationManager( import ExecutorAllocationManager._ - // Lower and upper bounds on the number of executors. These are required. - private val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", -1) - private val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", -1) + // Lower and upper bounds on the number of executors. + private val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", 0) + private val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", + Integer.MAX_VALUE) - // How long there must be backlogged tasks for before an addition is triggered - private val schedulerBacklogTimeout = conf.getLong( - "spark.dynamicAllocation.schedulerBacklogTimeout", 60) + // How long there must be backlogged tasks for before an addition is triggered (seconds) + private val schedulerBacklogTimeoutS = conf.getTimeAsSeconds( + "spark.dynamicAllocation.schedulerBacklogTimeout", "5s") - // Same as above, but used only after `schedulerBacklogTimeout` is exceeded - private val sustainedSchedulerBacklogTimeout = conf.getLong( - "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", schedulerBacklogTimeout) + // Same as above, but used only after `schedulerBacklogTimeoutS` is exceeded + private val sustainedSchedulerBacklogTimeoutS = conf.getTimeAsSeconds( + "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", s"${schedulerBacklogTimeoutS}s") - // How long an executor must be idle for before it is removed - private val executorIdleTimeout = conf.getLong( - "spark.dynamicAllocation.executorIdleTimeout", 600) + // How long an executor must be idle for before it is removed (seconds) + private val executorIdleTimeoutS = conf.getTimeAsSeconds( + "spark.dynamicAllocation.executorIdleTimeout", "600s") // During testing, the methods to actually kill and add executors are mocked out private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) @@ -121,34 +126,38 @@ private[spark] class ExecutorAllocationManager( private val intervalMillis: Long = 100 // Clock used to schedule when executors should be added and removed - private var clock: Clock = new RealClock + private var clock: Clock = new SystemClock() // Listener for Spark events that impact the allocation policy private val listener = new ExecutorAllocationListener + // Executor that handles the scheduling task. + private val executor = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("spark-dynamic-executor-allocation") + /** * Verify that the settings specified through the config are valid. * If not, throw an appropriate exception. */ private def validateSettings(): Unit = { if (minNumExecutors < 0 || maxNumExecutors < 0) { - throw new SparkException("spark.dynamicAllocation.{min/max}Executors must be set!") + throw new SparkException("spark.dynamicAllocation.{min/max}Executors must be positive!") } - if (minNumExecutors == 0 || maxNumExecutors == 0) { - throw new SparkException("spark.dynamicAllocation.{min/max}Executors cannot be 0!") + if (maxNumExecutors == 0) { + throw new SparkException("spark.dynamicAllocation.maxExecutors cannot be 0!") } if (minNumExecutors > maxNumExecutors) { throw new SparkException(s"spark.dynamicAllocation.minExecutors ($minNumExecutors) must " + s"be less than or equal to spark.dynamicAllocation.maxExecutors ($maxNumExecutors)!") } - if (schedulerBacklogTimeout <= 0) { + if (schedulerBacklogTimeoutS <= 0) { throw new SparkException("spark.dynamicAllocation.schedulerBacklogTimeout must be > 0!") } - if (sustainedSchedulerBacklogTimeout <= 0) { + if (sustainedSchedulerBacklogTimeoutS <= 0) { throw new SparkException( "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout must be > 0!") } - if (executorIdleTimeout <= 0) { + if (executorIdleTimeoutS <= 0) { throw new SparkException("spark.dynamicAllocation.executorIdleTimeout must be > 0!") } // Require external shuffle service for dynamic allocation @@ -170,47 +179,55 @@ private[spark] class ExecutorAllocationManager( } /** - * Register for scheduler callbacks to decide when to add and remove executors. + * Register for scheduler callbacks to decide when to add and remove executors, and start + * the scheduling task. */ def start(): Unit = { listenerBus.addListener(listener) - startPolling() + + val scheduleTask = new Runnable() { + override def run(): Unit = Utils.logUncaughtExceptions(schedule()) + } + executor.scheduleAtFixedRate(scheduleTask, 0, intervalMillis, TimeUnit.MILLISECONDS) } /** - * Start the main polling thread that keeps track of when to add and remove executors. + * Stop the allocation manager. */ - private def startPolling(): Unit = { - val t = new Thread { - override def run(): Unit = { - while (true) { - try { - schedule() - } catch { - case e: Exception => logError("Exception in dynamic executor allocation thread!", e) - } - Thread.sleep(intervalMillis) - } - } - } - t.setName("spark-dynamic-executor-allocation") - t.setDaemon(true) - t.start() + def stop(): Unit = { + executor.shutdown() + executor.awaitTermination(10, TimeUnit.SECONDS) + } + + /** + * The number of executors we would have if the cluster manager were to fulfill all our existing + * requests. + */ + private def targetNumExecutors(): Int = + numExecutorsPending + executorIds.size - executorsPendingToRemove.size + + /** + * The maximum number of executors we would need under the current load to satisfy all running + * and pending tasks, rounded up. + */ + private def maxNumExecutorsNeeded(): Int = { + val numRunningOrPendingTasks = listener.totalPendingTasks + listener.totalRunningTasks + (numRunningOrPendingTasks + tasksPerExecutor - 1) / tasksPerExecutor } /** - * If the add time has expired, request new executors and refresh the add time. - * If the remove time for an existing executor has expired, kill the executor. + * This is called at a fixed interval to regulate the number of pending executor requests + * and number of executors running. + * + * First, adjust our requested executors based on the add time and our current needs. + * Then, if the remove time for an existing executor has expired, kill the executor. + * * This is factored out into its own method for testing. */ private def schedule(): Unit = synchronized { val now = clock.getTimeMillis - if (addTime != NOT_SET && now >= addTime) { - addExecutors() - logDebug(s"Starting timer to add more executors (to " + - s"expire in $sustainedSchedulerBacklogTimeout seconds)") - addTime += sustainedSchedulerBacklogTimeout * 1000 - } + + addOrCancelExecutorRequests(now) removeTimes.retain { case (executorId, expireTime) => val expired = now >= expireTime @@ -221,59 +238,89 @@ private[spark] class ExecutorAllocationManager( } } + /** + * Check to see whether our existing allocation and the requests we've made previously exceed our + * current needs. If so, let the cluster manager know so that it can cancel pending requests that + * are unneeded. + * + * If not, and the add time has expired, see if we can request new executors and refresh the add + * time. + * + * @return the delta in the target number of executors. + */ + private def addOrCancelExecutorRequests(now: Long): Int = synchronized { + val currentTarget = targetNumExecutors + val maxNeeded = maxNumExecutorsNeeded + + if (maxNeeded < currentTarget) { + // The target number exceeds the number we actually need, so stop adding new + // executors and inform the cluster manager to cancel the extra pending requests. + val newTotalExecutors = math.max(maxNeeded, minNumExecutors) + client.requestTotalExecutors(newTotalExecutors) + numExecutorsToAdd = 1 + updateNumExecutorsPending(newTotalExecutors) + } else if (addTime != NOT_SET && now >= addTime) { + val delta = addExecutors(maxNeeded) + logDebug(s"Starting timer to add more executors (to " + + s"expire in $sustainedSchedulerBacklogTimeoutS seconds)") + addTime += sustainedSchedulerBacklogTimeoutS * 1000 + delta + } else { + 0 + } + } + /** * Request a number of executors from the cluster manager. * If the cap on the number of executors is reached, give up and reset the * number of executors to add next round instead of continuing to double it. - * Return the number actually requested. + * + * @param maxNumExecutorsNeeded the maximum number of executors all currently running or pending + * tasks could fill + * @return the number of additional executors actually requested. */ - private def addExecutors(): Int = synchronized { - // Do not request more executors if we have already reached the upper bound - val numExistingExecutors = executorIds.size + numExecutorsPending - if (numExistingExecutors >= maxNumExecutors) { + private def addExecutors(maxNumExecutorsNeeded: Int): Int = { + // Do not request more executors if it would put our target over the upper bound + val currentTarget = targetNumExecutors + if (currentTarget >= maxNumExecutors) { logDebug(s"Not adding executors because there are already ${executorIds.size} " + s"registered and $numExecutorsPending pending executor(s) (limit $maxNumExecutors)") numExecutorsToAdd = 1 return 0 } - // The number of executors needed to satisfy all pending tasks is the number of tasks pending - // divided by the number of tasks each executor can fit, rounded up. - val maxNumExecutorsPending = - (listener.totalPendingTasks() + tasksPerExecutor - 1) / tasksPerExecutor - if (numExecutorsPending >= maxNumExecutorsPending) { - logDebug(s"Not adding executors because there are already $numExecutorsPending " + - s"pending and pending tasks could only fill $maxNumExecutorsPending") - numExecutorsToAdd = 1 - return 0 - } - - // It's never useful to request more executors than could satisfy all the pending tasks, so - // cap request at that amount. - // Also cap request with respect to the configured upper bound. - val maxNumExecutorsToAdd = math.min( - maxNumExecutorsPending - numExecutorsPending, - maxNumExecutors - numExistingExecutors) - assert(maxNumExecutorsToAdd > 0) - - val actualNumExecutorsToAdd = math.min(numExecutorsToAdd, maxNumExecutorsToAdd) - - val newTotalExecutors = numExistingExecutors + actualNumExecutorsToAdd - val addRequestAcknowledged = testing || client.requestExecutors(actualNumExecutorsToAdd) + val actualMaxNumExecutors = math.min(maxNumExecutors, maxNumExecutorsNeeded) + val newTotalExecutors = math.min(currentTarget + numExecutorsToAdd, actualMaxNumExecutors) + val addRequestAcknowledged = testing || client.requestTotalExecutors(newTotalExecutors) if (addRequestAcknowledged) { - logInfo(s"Requesting $actualNumExecutorsToAdd new executor(s) because " + - s"tasks are backlogged (new desired total will be $newTotalExecutors)") - numExecutorsToAdd = - if (actualNumExecutorsToAdd == numExecutorsToAdd) numExecutorsToAdd * 2 else 1 - numExecutorsPending += actualNumExecutorsToAdd - actualNumExecutorsToAdd + val delta = updateNumExecutorsPending(newTotalExecutors) + logInfo(s"Requesting $delta new executor(s) because tasks are backlogged" + + s" (new desired total will be $newTotalExecutors)") + numExecutorsToAdd = if (delta == numExecutorsToAdd) { + numExecutorsToAdd * 2 + } else { + 1 + } + delta } else { - logWarning(s"Unable to reach the cluster manager " + - s"to request $actualNumExecutorsToAdd executors!") + logWarning( + s"Unable to reach the cluster manager to request $newTotalExecutors total executors!") 0 } } + /** + * Given the new target number of executors, update the number of pending executor requests, + * and return the delta from the old number of pending requests. + */ + private def updateNumExecutorsPending(newTotalExecutors: Int): Int = { + val newNumExecutorsPending = + newTotalExecutors - executorIds.size + executorsPendingToRemove.size + val delta = newNumExecutorsPending - numExecutorsPending + numExecutorsPending = newNumExecutorsPending + delta + } + /** * Request the cluster manager to remove the given executor. * Return whether the request is received. @@ -304,7 +351,7 @@ private[spark] class ExecutorAllocationManager( val removeRequestAcknowledged = testing || client.killExecutor(executorId) if (removeRequestAcknowledged) { logInfo(s"Removing executor $executorId because it has been idle for " + - s"$executorIdleTimeout seconds (new desired total will be ${numExistingExecutors - 1})") + s"$executorIdleTimeoutS seconds (new desired total will be ${numExistingExecutors - 1})") executorsPendingToRemove.add(executorId) true } else { @@ -360,8 +407,8 @@ private[spark] class ExecutorAllocationManager( private def onSchedulerBacklogged(): Unit = synchronized { if (addTime == NOT_SET) { logDebug(s"Starting timer to add executors because pending tasks " + - s"are building up (to expire in $schedulerBacklogTimeout seconds)") - addTime = clock.getTimeMillis + schedulerBacklogTimeout * 1000 + s"are building up (to expire in $schedulerBacklogTimeoutS seconds)") + addTime = clock.getTimeMillis + schedulerBacklogTimeoutS * 1000 } } @@ -384,8 +431,8 @@ private[spark] class ExecutorAllocationManager( if (executorIds.contains(executorId)) { if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) { logDebug(s"Starting idle timer for $executorId because there are no more tasks " + - s"scheduled to run on the executor (to expire in $executorIdleTimeout seconds)") - removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeout * 1000 + s"scheduled to run on the executor (to expire in $executorIdleTimeoutS seconds)") + removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeoutS * 1000 } } else { logWarning(s"Attempted to mark unknown executor $executorId idle") @@ -413,6 +460,8 @@ private[spark] class ExecutorAllocationManager( private val stageIdToNumTasks = new mutable.HashMap[Int, Int] private val stageIdToTaskIndices = new mutable.HashMap[Int, mutable.HashSet[Int]] private val executorIdToTaskIds = new mutable.HashMap[String, mutable.HashSet[Long]] + // Number of tasks currently running on the cluster. Should be 0 when no stages are active. + private var numRunningTasks: Int = _ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { val stageId = stageSubmitted.stageInfo.stageId @@ -433,6 +482,10 @@ private[spark] class ExecutorAllocationManager( // This is needed in case the stage is aborted for any reason if (stageIdToNumTasks.isEmpty) { allocationManager.onSchedulerQueueEmpty() + if (numRunningTasks != 0) { + logWarning("No stages are running, but numRunningTasks != 0") + numRunningTasks = 0 + } } } } @@ -444,6 +497,7 @@ private[spark] class ExecutorAllocationManager( val executorId = taskStart.taskInfo.executorId allocationManager.synchronized { + numRunningTasks += 1 // This guards against the race condition in which the `SparkListenerTaskStart` // event is posted before the `SparkListenerBlockManagerAdded` event, which is // possible because these events are posted in different threads. (see SPARK-4951) @@ -473,7 +527,8 @@ private[spark] class ExecutorAllocationManager( val executorId = taskEnd.taskInfo.executorId val taskId = taskEnd.taskInfo.taskId allocationManager.synchronized { - // If the executor is no longer running scheduled any tasks, mark it as idle + numRunningTasks -= 1 + // If the executor is no longer running any scheduled tasks, mark it as idle if (executorIdToTaskIds.contains(executorId)) { executorIdToTaskIds(executorId) -= taskId if (executorIdToTaskIds(executorId).isEmpty) { @@ -484,8 +539,8 @@ private[spark] class ExecutorAllocationManager( } } - override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = { - val executorId = blockManagerAdded.blockManagerId.executorId + override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = { + val executorId = executorAdded.executorId if (executorId != SparkContext.DRIVER_IDENTIFIER) { // This guards against the race condition in which the `SparkListenerTaskStart` // event is posted before the `SparkListenerBlockManagerAdded` event, which is @@ -496,9 +551,8 @@ private[spark] class ExecutorAllocationManager( } } - override def onBlockManagerRemoved( - blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = { - allocationManager.onExecutorRemoved(blockManagerRemoved.blockManagerId.executorId) + override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = { + allocationManager.onExecutorRemoved(executorRemoved.executorId) } /** @@ -513,6 +567,11 @@ private[spark] class ExecutorAllocationManager( }.sum } + /** + * The number of tasks currently running across all stages. + */ + def totalRunningTasks(): Int = numRunningTasks + /** * Return true if an executor is not currently running a task, and false otherwise. * @@ -528,28 +587,3 @@ private[spark] class ExecutorAllocationManager( private object ExecutorAllocationManager { val NOT_SET = Long.MaxValue } - -/** - * An abstract clock for measuring elapsed time. - */ -private trait Clock { - def getTimeMillis: Long -} - -/** - * A clock backed by a monotonically increasing time source. - * The time returned by this clock does not correspond to any notion of wall-clock time. - */ -private class RealClock extends Clock { - override def getTimeMillis: Long = System.nanoTime / (1000 * 1000) -} - -/** - * A clock that allows the caller to customize the time. - * This is used mainly for testing. - */ -private class TestClock(startTimeMillis: Long) extends Clock { - private var time: Long = startTimeMillis - override def getTimeMillis: Long = time - def tick(ms: Long): Unit = { time += ms } -} diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index e97a7375a267..91f9ef8ce718 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -168,7 +168,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: } } - def jobIds = Seq(jobWaiter.jobId) + def jobIds: Seq[Int] = Seq(jobWaiter.jobId) } @@ -276,7 +276,7 @@ class ComplexFutureAction[T] extends FutureAction[T] { override def value: Option[Try[T]] = p.future.value - def jobIds = jobs + def jobIds: Seq[Int] = jobs } diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 83ae57b7f151..68d05d5b0253 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -17,33 +17,131 @@ package org.apache.spark -import akka.actor.Actor +import java.util.concurrent.{ScheduledFuture, TimeUnit} + +import scala.collection.mutable + import org.apache.spark.executor.TaskMetrics +import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext} import org.apache.spark.storage.BlockManagerId -import org.apache.spark.scheduler.TaskScheduler -import org.apache.spark.util.ActorLogReceive +import org.apache.spark.scheduler.{SlaveLost, TaskScheduler} +import org.apache.spark.util.{ThreadUtils, Utils} /** * A heartbeat from executors to the driver. This is a shared message used by several internal - * components to convey liveness or execution information for in-progress tasks. + * components to convey liveness or execution information for in-progress tasks. It will also + * expire the hosts that have not heartbeated for more than spark.network.timeout. */ private[spark] case class Heartbeat( executorId: String, taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics blockManagerId: BlockManagerId) +/** + * An event that SparkContext uses to notify HeartbeatReceiver that SparkContext.taskScheduler is + * created. + */ +private[spark] case object TaskSchedulerIsSet + +private[spark] case object ExpireDeadHosts + private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) /** * Lives in the driver to receive heartbeats from executors.. */ -private[spark] class HeartbeatReceiver(scheduler: TaskScheduler) - extends Actor with ActorLogReceive with Logging { - - override def receiveWithLogging = { - case Heartbeat(executorId, taskMetrics, blockManagerId) => - val response = HeartbeatResponse( - !scheduler.executorHeartbeatReceived(executorId, taskMetrics, blockManagerId)) - sender ! response +private[spark] class HeartbeatReceiver(sc: SparkContext) + extends ThreadSafeRpcEndpoint with Logging { + + override val rpcEnv: RpcEnv = sc.env.rpcEnv + + private[spark] var scheduler: TaskScheduler = null + + // executor ID -> timestamp of when the last heartbeat from this executor was received + private val executorLastSeen = new mutable.HashMap[String, Long] + + // "spark.network.timeout" uses "seconds", while `spark.storage.blockManagerSlaveTimeoutMs` uses + // "milliseconds" + private val slaveTimeoutMs = + sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", "120s") + private val executorTimeoutMs = + sc.conf.getTimeAsSeconds("spark.network.timeout", s"${slaveTimeoutMs}ms") * 1000 + + // "spark.network.timeoutInterval" uses "seconds", while + // "spark.storage.blockManagerTimeoutIntervalMs" uses "milliseconds" + private val timeoutIntervalMs = + sc.conf.getTimeAsMs("spark.storage.blockManagerTimeoutIntervalMs", "60s") + private val checkTimeoutIntervalMs = + sc.conf.getTimeAsSeconds("spark.network.timeoutInterval", s"${timeoutIntervalMs}ms") * 1000 + + private var timeoutCheckingTask: ScheduledFuture[_] = null + + private val timeoutCheckingThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("heartbeat-timeout-checking-thread") + + private val killExecutorThread = ThreadUtils.newDaemonSingleThreadExecutor("kill-executor-thread") + + override def onStart(): Unit = { + timeoutCheckingTask = timeoutCheckingThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + Option(self).foreach(_.send(ExpireDeadHosts)) + } + }, 0, checkTimeoutIntervalMs, TimeUnit.MILLISECONDS) + } + + override def receive: PartialFunction[Any, Unit] = { + case ExpireDeadHosts => + expireDeadHosts() + case TaskSchedulerIsSet => + scheduler = sc.taskScheduler + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) => + if (scheduler != null) { + val unknownExecutor = !scheduler.executorHeartbeatReceived( + executorId, taskMetrics, blockManagerId) + val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) + executorLastSeen(executorId) = System.currentTimeMillis() + context.reply(response) + } else { + // Because Executor will sleep several seconds before sending the first "Heartbeat", this + // case rarely happens. However, if it really happens, log it and ask the executor to + // register itself again. + logWarning(s"Dropping $heartbeat because TaskScheduler is not ready yet") + context.reply(HeartbeatResponse(reregisterBlockManager = true)) + } } + + private def expireDeadHosts(): Unit = { + logTrace("Checking for hosts with no recent heartbeats in HeartbeatReceiver.") + val now = System.currentTimeMillis() + for ((executorId, lastSeenMs) <- executorLastSeen) { + if (now - lastSeenMs > executorTimeoutMs) { + logWarning(s"Removing executor $executorId with no recent heartbeats: " + + s"${now - lastSeenMs} ms exceeds timeout $executorTimeoutMs ms") + scheduler.executorLost(executorId, SlaveLost("Executor heartbeat " + + s"timed out after ${now - lastSeenMs} ms")) + if (sc.supportDynamicAllocation) { + // Asynchronously kill the executor to avoid blocking the current thread + killExecutorThread.submit(new Runnable { + override def run(): Unit = sc.killExecutor(executorId) + }) + } + executorLastSeen.remove(executorId) + } + } + } + + override def onStop(): Unit = { + if (timeoutCheckingTask != null) { + timeoutCheckingTask.cancel(true) + } + timeoutCheckingThread.shutdownNow() + killExecutorThread.shutdownNow() + } +} + +object HeartbeatReceiver { + val ENDPOINT_NAME = "HeartbeatReceiver" } diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index 677c5e0f89d7..7e706bcc42f0 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -36,7 +36,7 @@ private[spark] class HttpFileServer( var serverUri : String = null def initialize() { - baseDir = Utils.createTempDir() + baseDir = Utils.createTempDir(Utils.getLocalDir(conf), "httpd") fileDir = new File(baseDir, "files") jarDir = new File(baseDir, "jars") fileDir.mkdir() @@ -50,6 +50,15 @@ private[spark] class HttpFileServer( def stop() { httpServer.stop() + + // If we only stop sc, but the driver process still run as a services then we need to delete + // the tmp dir, if not, it will create too many tmp dirs + try { + Utils.deleteRecursively(baseDir) + } catch { + case e: Exception => + logWarning(s"Exception while deleting Spark temp dir: ${baseDir.getAbsolutePath}", e) + } } def addFile(file: File) : String = { diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index fa22787ce7ea..8de3a6c04df3 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -19,6 +19,7 @@ package org.apache.spark import java.io.File +import org.eclipse.jetty.server.ssl.SslSocketConnector import org.eclipse.jetty.util.security.{Constraint, Password} import org.eclipse.jetty.security.authentication.DigestAuthenticator import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService} @@ -72,7 +73,10 @@ private[spark] class HttpServer( */ private def doStart(startPort: Int): (Server, Int) = { val server = new Server() - val connector = new SocketConnector + + val connector = securityManager.fileServerSSLOptions.createJettySslContextFactory() + .map(new SslSocketConnector(_)).getOrElse(new SocketConnector) + connector.setMaxIdleTime(60 * 1000) connector.setSoLingerTime(-1) connector.setPort(startPort) @@ -149,13 +153,14 @@ private[spark] class HttpServer( } /** - * Get the URI of this HTTP server (http://host:port) + * Get the URI of this HTTP server (http://host:port or https://host:port) */ def uri: String = { if (server == null) { throw new ServerStateException("Server is not started") } else { - "http://" + Utils.localIpAddress + ":" + port + val scheme = if (securityManager.fileServerSSLOptions.enabled) "https" else "http" + s"$scheme://${Utils.localHostNameForURI()}:$port" } } } diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index d4f2624061e3..419d093d5564 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -118,15 +118,17 @@ trait Logging { // org.slf4j.impl.Log4jLoggerFactory, from the log4j 2.0 binding, currently // org.apache.logging.slf4j.Log4jLoggerFactory val usingLog4j12 = "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass) - val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements - if (!log4j12Initialized && usingLog4j12) { - val defaultLogProps = "org/apache/spark/log4j-defaults.properties" - Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { - case Some(url) => - PropertyConfigurator.configure(url) - System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") - case None => - System.err.println(s"Spark was unable to load $defaultLogProps") + if (usingLog4j12) { + val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements + if (!log4j12Initialized) { + val defaultLogProps = "org/apache/spark/log4j-defaults.properties" + Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { + case Some(url) => + PropertyConfigurator.configure(url) + System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") + case None => + System.err.println(s"Spark was unable to load $defaultLogProps") + } } } Logging.initialized = true diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 6e4edc7c80d7..d65c94e41066 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -21,13 +21,11 @@ import java.io._ import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.{HashSet, HashMap, Map} -import scala.concurrent.Await +import scala.collection.mutable.{HashSet, Map} import scala.collection.JavaConversions._ +import scala.reflect.ClassTag -import akka.actor._ -import akka.pattern.ask - +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, RpcEndpoint} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.BlockManagerId @@ -38,14 +36,15 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage -/** Actor class for MapOutputTrackerMaster */ -private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf) - extends Actor with ActorLogReceive with Logging { +/** RpcEndpoint class for MapOutputTrackerMaster */ +private[spark] class MapOutputTrackerMasterEndpoint( + override val rpcEnv: RpcEnv, tracker: MapOutputTrackerMaster, conf: SparkConf) + extends RpcEndpoint with Logging { val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) - override def receiveWithLogging = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case GetMapOutputStatuses(shuffleId: Int) => - val hostPort = sender.path.address.hostPort + val hostPort = context.sender.address.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId) val serializedSize = mapOutputStatuses.size @@ -53,19 +52,19 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster val msg = s"Map output statuses were $serializedSize bytes which " + s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)." - /* For SPARK-1244 we'll opt for just logging an error and then throwing an exception. - * Note that on exception the actor will just restart. A bigger refactoring (SPARK-1239) - * will ultimately remove this entire code path. */ + /* For SPARK-1244 we'll opt for just logging an error and then sending it to the sender. + * A bigger refactoring (SPARK-1239) will ultimately remove this entire code path. */ val exception = new SparkException(msg) logError(msg, exception) - throw exception + context.sendFailure(exception) + } else { + context.reply(mapOutputStatuses) } - sender ! mapOutputStatuses case StopMapOutputTracker => - logInfo("MapOutputTrackerActor stopped!") - sender ! true - context.stop(self) + logInfo("MapOutputTrackerMasterEndpoint stopped!") + context.reply(true) + stop() } } @@ -75,12 +74,9 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster * (driver and executor) use different HashMap to store its metadata. */ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging { - private val timeout = AkkaUtils.askTimeout(conf) - private val retryAttempts = AkkaUtils.numRetries(conf) - private val retryIntervalMs = AkkaUtils.retryWaitMs(conf) - /** Set to the MapOutputTrackerActor living on the driver. */ - var trackerActor: ActorRef = _ + /** Set to the MapOutputTrackerMasterEndpoint living on the driver. */ + var trackerEndpoint: RpcEndpointRef = _ /** * This HashMap has different behavior for the driver and the executors. @@ -105,12 +101,12 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging private val fetching = new HashSet[Int] /** - * Send a message to the trackerActor and get its result within a default timeout, or + * Send a message to the trackerEndpoint and get its result within a default timeout, or * throw a SparkException if this fails. */ - protected def askTracker(message: Any): Any = { + protected def askTracker[T: ClassTag](message: Any): T = { try { - AkkaUtils.askWithReply(message, trackerActor, retryAttempts, retryIntervalMs, timeout) + trackerEndpoint.askWithReply[T](message) } catch { case e: Exception => logError("Error communicating with MapOutputTracker", e) @@ -118,9 +114,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } - /** Send a one-way message to the trackerActor, to which we expect it to reply with true. */ + /** Send a one-way message to the trackerEndpoint, to which we expect it to reply with true. */ protected def sendTracker(message: Any) { - val response = askTracker(message) + val response = askTracker[Boolean](message) if (response != true) { throw new SparkException( "Error reply received from MapOutputTracker. Expecting true, got " + response.toString) @@ -157,11 +153,10 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging if (fetchedStatuses == null) { // We won the race to fetch the output locs; do so - logInfo("Doing the fetch; tracker actor = " + trackerActor) + logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) // This try-finally prevents hangs due to timeouts: try { - val fetchedBytes = - askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]] + val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) logInfo("Got the output locations") mapStatuses.put(shuffleId, fetchedStatuses) @@ -328,7 +323,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) override def stop() { sendTracker(StopMapOutputTracker) mapStatuses.clear() - trackerActor = null + trackerEndpoint = null metadataCleaner.cancel() cachedSerializedStatuses.clear() } @@ -350,17 +345,22 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr private[spark] object MapOutputTracker extends Logging { + val ENDPOINT_NAME = "MapOutputTracker" + // Serialize an array of map output locations into an efficient byte format so that we can send // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will // generally be pretty compressible because many map outputs will be on the same hostname. def serializeMapStatuses(statuses: Array[MapStatus]): Array[Byte] = { val out = new ByteArrayOutputStream val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) - // Since statuses can be modified in parallel, sync on it - statuses.synchronized { - objOut.writeObject(statuses) + Utils.tryWithSafeFinally { + // Since statuses can be modified in parallel, sync on it + statuses.synchronized { + objOut.writeObject(statuses) + } + } { + objOut.close() } - objOut.close() out.toByteArray } diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index e53a78ead2c0..b8d244408bc5 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -76,7 +76,7 @@ object Partitioner { * produce an unexpected or incorrect result. */ class HashPartitioner(partitions: Int) extends Partitioner { - def numPartitions = partitions + def numPartitions: Int = partitions def getPartition(key: Any): Int = key match { case null => 0 @@ -154,7 +154,7 @@ class RangePartitioner[K : Ordering : ClassTag, V]( } } - def numPartitions = rangeBounds.length + 1 + def numPartitions: Int = rangeBounds.length + 1 private var binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K] diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala new file mode 100644 index 000000000000..2cdc167f85af --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.File + +import com.typesafe.config.{Config, ConfigFactory, ConfigValueFactory} +import org.eclipse.jetty.util.ssl.SslContextFactory + +/** + * SSLOptions class is a common container for SSL configuration options. It offers methods to + * generate specific objects to configure SSL for different communication protocols. + * + * SSLOptions is intended to provide the maximum common set of SSL settings, which are supported + * by the protocol, which it can generate the configuration for. Since Akka doesn't support client + * authentication with SSL, SSLOptions cannot support it either. + * + * @param enabled enables or disables SSL; if it is set to false, the rest of the + * settings are disregarded + * @param keyStore a path to the key-store file + * @param keyStorePassword a password to access the key-store file + * @param keyPassword a password to access the private key in the key-store + * @param trustStore a path to the trust-store file + * @param trustStorePassword a password to access the trust-store file + * @param protocol SSL protocol (remember that SSLv3 was compromised) supported by Java + * @param enabledAlgorithms a set of encryption algorithms to use + */ +private[spark] case class SSLOptions( + enabled: Boolean = false, + keyStore: Option[File] = None, + keyStorePassword: Option[String] = None, + keyPassword: Option[String] = None, + trustStore: Option[File] = None, + trustStorePassword: Option[String] = None, + protocol: Option[String] = None, + enabledAlgorithms: Set[String] = Set.empty) { + + /** + * Creates a Jetty SSL context factory according to the SSL settings represented by this object. + */ + def createJettySslContextFactory(): Option[SslContextFactory] = { + if (enabled) { + val sslContextFactory = new SslContextFactory() + + keyStore.foreach(file => sslContextFactory.setKeyStorePath(file.getAbsolutePath)) + trustStore.foreach(file => sslContextFactory.setTrustStore(file.getAbsolutePath)) + keyStorePassword.foreach(sslContextFactory.setKeyStorePassword) + trustStorePassword.foreach(sslContextFactory.setTrustStorePassword) + keyPassword.foreach(sslContextFactory.setKeyManagerPassword) + protocol.foreach(sslContextFactory.setProtocol) + sslContextFactory.setIncludeCipherSuites(enabledAlgorithms.toSeq: _*) + + Some(sslContextFactory) + } else { + None + } + } + + /** + * Creates an Akka configuration object which contains all the SSL settings represented by this + * object. It can be used then to compose the ultimate Akka configuration. + */ + def createAkkaConfig: Option[Config] = { + import scala.collection.JavaConversions._ + if (enabled) { + Some(ConfigFactory.empty() + .withValue("akka.remote.netty.tcp.security.key-store", + ConfigValueFactory.fromAnyRef(keyStore.map(_.getAbsolutePath).getOrElse(""))) + .withValue("akka.remote.netty.tcp.security.key-store-password", + ConfigValueFactory.fromAnyRef(keyStorePassword.getOrElse(""))) + .withValue("akka.remote.netty.tcp.security.trust-store", + ConfigValueFactory.fromAnyRef(trustStore.map(_.getAbsolutePath).getOrElse(""))) + .withValue("akka.remote.netty.tcp.security.trust-store-password", + ConfigValueFactory.fromAnyRef(trustStorePassword.getOrElse(""))) + .withValue("akka.remote.netty.tcp.security.key-password", + ConfigValueFactory.fromAnyRef(keyPassword.getOrElse(""))) + .withValue("akka.remote.netty.tcp.security.random-number-generator", + ConfigValueFactory.fromAnyRef("")) + .withValue("akka.remote.netty.tcp.security.protocol", + ConfigValueFactory.fromAnyRef(protocol.getOrElse(""))) + .withValue("akka.remote.netty.tcp.security.enabled-algorithms", + ConfigValueFactory.fromIterable(enabledAlgorithms.toSeq)) + .withValue("akka.remote.netty.tcp.enable-ssl", + ConfigValueFactory.fromAnyRef(true))) + } else { + None + } + } + + /** Returns a string representation of this SSLOptions with all the passwords masked. */ + override def toString: String = s"SSLOptions{enabled=$enabled, " + + s"keyStore=$keyStore, keyStorePassword=${keyStorePassword.map(_ => "xxx")}, " + + s"trustStore=$trustStore, trustStorePassword=${trustStorePassword.map(_ => "xxx")}, " + + s"protocol=$protocol, enabledAlgorithms=$enabledAlgorithms}" + +} + +private[spark] object SSLOptions extends Logging { + + /** Resolves SSLOptions settings from a given Spark configuration object at a given namespace. + * + * The following settings are allowed: + * $ - `[ns].enabled` - `true` or `false`, to enable or disable SSL respectively + * $ - `[ns].keyStore` - a path to the key-store file; can be relative to the current directory + * $ - `[ns].keyStorePassword` - a password to the key-store file + * $ - `[ns].keyPassword` - a password to the private key + * $ - `[ns].trustStore` - a path to the trust-store file; can be relative to the current + * directory + * $ - `[ns].trustStorePassword` - a password to the trust-store file + * $ - `[ns].protocol` - a protocol name supported by a particular Java version + * $ - `[ns].enabledAlgorithms` - a comma separated list of ciphers + * + * For a list of protocols and ciphers supported by particular Java versions, you may go to + * [[https://blogs.oracle.com/java-platform-group/entry/diagnosing_tls_ssl_and_https Oracle + * blog page]]. + * + * You can optionally specify the default configuration. If you do, for each setting which is + * missing in SparkConf, the corresponding setting is used from the default configuration. + * + * @param conf Spark configuration object where the settings are collected from + * @param ns the namespace name + * @param defaults the default configuration + * @return [[org.apache.spark.SSLOptions]] object + */ + def parse(conf: SparkConf, ns: String, defaults: Option[SSLOptions] = None): SSLOptions = { + val enabled = conf.getBoolean(s"$ns.enabled", defaultValue = defaults.exists(_.enabled)) + + val keyStore = conf.getOption(s"$ns.keyStore").map(new File(_)) + .orElse(defaults.flatMap(_.keyStore)) + + val keyStorePassword = conf.getOption(s"$ns.keyStorePassword") + .orElse(defaults.flatMap(_.keyStorePassword)) + + val keyPassword = conf.getOption(s"$ns.keyPassword") + .orElse(defaults.flatMap(_.keyPassword)) + + val trustStore = conf.getOption(s"$ns.trustStore").map(new File(_)) + .orElse(defaults.flatMap(_.trustStore)) + + val trustStorePassword = conf.getOption(s"$ns.trustStorePassword") + .orElse(defaults.flatMap(_.trustStorePassword)) + + val protocol = conf.getOption(s"$ns.protocol") + .orElse(defaults.flatMap(_.protocol)) + + val enabledAlgorithms = conf.getOption(s"$ns.enabledAlgorithms") + .map(_.split(",").map(_.trim).filter(_.nonEmpty).toSet) + .orElse(defaults.map(_.enabledAlgorithms)) + .getOrElse(Set.empty) + + new SSLOptions( + enabled, + keyStore, + keyStorePassword, + keyPassword, + trustStore, + trustStorePassword, + protocol, + enabledAlgorithms) + } + +} + diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index ec82d09cd079..3653f724ba19 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -18,11 +18,16 @@ package org.apache.spark import java.net.{Authenticator, PasswordAuthentication} +import java.security.KeyStore +import java.security.cert.X509Certificate +import javax.net.ssl._ +import com.google.common.io.Files import org.apache.hadoop.io.Text import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.network.sasl.SecretKeyHolder +import org.apache.spark.util.Utils /** * Spark class responsible for security. @@ -55,7 +60,7 @@ import org.apache.spark.network.sasl.SecretKeyHolder * Spark also has a set of admin acls (`spark.admin.acls`) which is a set of users/administrators * who always have permission to view or modify the Spark application. * - * Spark does not currently support encryption after authentication. + * Starting from version 1.3, Spark has partial support for encrypted connections with SSL. * * At this point spark has multiple communication protocols that need to be secured and * different underlying mechanisms are used depending on the protocol: @@ -67,8 +72,9 @@ import org.apache.spark.network.sasl.SecretKeyHolder * to connect to the server. There is no control of the underlying * authentication mechanism so its not clear if the password is passed in * plaintext or uses DIGEST-MD5 or some other mechanism. - * Akka also has an option to turn on SSL, this option is not currently supported - * but we could add a configuration option in the future. + * + * Akka also has an option to turn on SSL, this option is currently supported (see + * the details below). * * - HTTP for broadcast and file server (via HttpServer) -> Spark currently uses Jetty * for the HttpServer. Jetty supports multiple authentication mechanisms - @@ -77,8 +83,9 @@ import org.apache.spark.network.sasl.SecretKeyHolder * to authenticate using DIGEST-MD5 via a single user and the shared secret. * Since we are using DIGEST-MD5, the shared secret is not passed on the wire * in plaintext. - * We currently do not support SSL (https), but Jetty can be configured to use it - * so we could add a configuration option for this in the future. + * + * We currently support SSL (https) for this communication protocol (see the details + * below). * * The Spark HttpServer installs the HashLoginServer and configures it to DIGEST-MD5. * Any clients must specify the user and password. There is a default @@ -142,9 +149,40 @@ import org.apache.spark.network.sasl.SecretKeyHolder * authentication. Spark will then use that user to compare against the view acls to do * authorization. If not filter is in place the user is generally null and no authorization * can take place. + * + * Connection encryption (SSL) configuration is organized hierarchically. The user can configure + * the default SSL settings which will be used for all the supported communication protocols unless + * they are overwritten by protocol specific settings. This way the user can easily provide the + * common settings for all the protocols without disabling the ability to configure each one + * individually. + * + * All the SSL settings like `spark.ssl.xxx` where `xxx` is a particular configuration property, + * denote the global configuration for all the supported protocols. In order to override the global + * configuration for the particular protocol, the properties must be overwritten in the + * protocol-specific namespace. Use `spark.ssl.yyy.xxx` settings to overwrite the global + * configuration for particular protocol denoted by `yyy`. Currently `yyy` can be either `akka` for + * Akka based connections or `fs` for broadcast and file server. + * + * Refer to [[org.apache.spark.SSLOptions]] documentation for the list of + * options that can be specified. + * + * SecurityManager initializes SSLOptions objects for different protocols separately. SSLOptions + * object parses Spark configuration at a given namespace and builds the common representation + * of SSL settings. SSLOptions is then used to provide protocol-specific configuration like + * TypeSafe configuration for Akka or SSLContextFactory for Jetty. + * + * SSL must be configured on each node and configured for each component involved in + * communication using the particular protocol. In YARN clusters, the key-store can be prepared on + * the client side then distributed and used by the executors as the part of the application + * (YARN allows the user to deploy files before the application is started). + * In standalone deployment, the user needs to provide key-stores and configuration + * options for master and workers. In this mode, the user may allow the executors to use the SSL + * settings inherited from the worker which spawned that executor. It can be accomplished by + * setting `spark.ssl.useNodeLocalConf` to `true`. */ -private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with SecretKeyHolder { +private[spark] class SecurityManager(sparkConf: SparkConf) + extends Logging with SecretKeyHolder { // key used to store the spark secret in the Hadoop UGI private val sparkSecretLookupKey = "sparkCookie" @@ -166,7 +204,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with // always add the current user and SPARK_USER to the viewAcls private val defaultAclUsers = Set[String](System.getProperty("user.name", ""), - Option(System.getenv("SPARK_USER")).getOrElse("")).filter(!_.isEmpty) + Utils.getCurrentUserName()) setViewAcls(defaultAclUsers, sparkConf.get("spark.ui.view.acls", "")) setModifyAcls(defaultAclUsers, sparkConf.get("spark.modify.acls", "")) @@ -196,6 +234,57 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with ) } + // the default SSL configuration - it will be used by all communication layers unless overwritten + private val defaultSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl", defaults = None) + + // SSL configuration for different communication layers - they can override the default + // configuration at a specified namespace. The namespace *must* start with spark.ssl. + val fileServerSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl.fs", Some(defaultSSLOptions)) + val akkaSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl.akka", Some(defaultSSLOptions)) + + logDebug(s"SSLConfiguration for file server: $fileServerSSLOptions") + logDebug(s"SSLConfiguration for Akka: $akkaSSLOptions") + + val (sslSocketFactory, hostnameVerifier) = if (fileServerSSLOptions.enabled) { + val trustStoreManagers = + for (trustStore <- fileServerSSLOptions.trustStore) yield { + val input = Files.asByteSource(fileServerSSLOptions.trustStore.get).openStream() + + try { + val ks = KeyStore.getInstance(KeyStore.getDefaultType) + ks.load(input, fileServerSSLOptions.trustStorePassword.get.toCharArray) + + val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm) + tmf.init(ks) + tmf.getTrustManagers + } finally { + input.close() + } + } + + lazy val credulousTrustStoreManagers = Array({ + logWarning("Using 'accept-all' trust manager for SSL connections.") + new X509TrustManager { + override def getAcceptedIssuers: Array[X509Certificate] = null + + override def checkClientTrusted(x509Certificates: Array[X509Certificate], s: String) {} + + override def checkServerTrusted(x509Certificates: Array[X509Certificate], s: String) {} + }: TrustManager + }) + + val sslContext = SSLContext.getInstance(fileServerSSLOptions.protocol.getOrElse("Default")) + sslContext.init(null, trustStoreManagers.getOrElse(credulousTrustStoreManagers), null) + + val hostVerifier = new HostnameVerifier { + override def verify(s: String, sslSession: SSLSession): Boolean = true + } + + (Some(sslContext.getSocketFactory), Some(hostVerifier)) + } else { + (None, None) + } + /** * Split a comma separated String, filter out any empty items, and return a Set of strings */ diff --git a/core/src/main/scala/org/apache/spark/SerializableWritable.scala b/core/src/main/scala/org/apache/spark/SerializableWritable.scala index 55cb25946c2a..cb2cae185256 100644 --- a/core/src/main/scala/org/apache/spark/SerializableWritable.scala +++ b/core/src/main/scala/org/apache/spark/SerializableWritable.scala @@ -28,8 +28,10 @@ import org.apache.spark.util.Utils @DeveloperApi class SerializableWritable[T <: Writable](@transient var t: T) extends Serializable { - def value = t - override def toString = t.toString + + def value: T = t + + override def toString: String = t.toString private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { out.defaultWriteObject() diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index f9d4aa4240e9..c1996e08756a 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -17,10 +17,14 @@ package org.apache.spark +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicBoolean + import scala.collection.JavaConverters._ -import scala.collection.concurrent.TrieMap -import scala.collection.mutable.{HashMap, LinkedHashSet} +import scala.collection.mutable.LinkedHashSet + import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.util.Utils /** * Configuration for a Spark application. Used to set various Spark parameters as key-value pairs. @@ -47,12 +51,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Create a SparkConf that loads defaults from system properties and the classpath */ def this() = this(true) - private[spark] val settings = new TrieMap[String, String]() + private val settings = new ConcurrentHashMap[String, String]() if (loadDefaults) { // Load any spark.* system properties - for ((k, v) <- System.getProperties.asScala if k.startsWith("spark.")) { - settings(k) = v + for ((key, value) <- Utils.getSystemProperties if key.startsWith("spark.")) { + set(key, value) } } @@ -64,7 +68,8 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { if (value == null) { throw new NullPointerException("null value for " + key) } - settings(key) = value + logDeprecationWarning(key) + settings.put(key, value) this } @@ -129,15 +134,15 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } /** Set multiple parameters together */ - def setAll(settings: Traversable[(String, String)]) = { - this.settings ++= settings + def setAll(settings: Traversable[(String, String)]): SparkConf = { + settings.foreach { case (k, v) => set(k, v) } this } /** Set a parameter if it isn't already configured */ def setIfMissing(key: String, value: String): SparkConf = { - if (!settings.contains(key)) { - settings(key) = value + if (settings.putIfAbsent(key, value) == null) { + logDeprecationWarning(key) } this } @@ -164,21 +169,58 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Get a parameter; throws a NoSuchElementException if it's not set */ def get(key: String): String = { - settings.getOrElse(key, throw new NoSuchElementException(key)) + getOption(key).getOrElse(throw new NoSuchElementException(key)) } /** Get a parameter, falling back to a default if not set */ def get(key: String, defaultValue: String): String = { - settings.getOrElse(key, defaultValue) + getOption(key).getOrElse(defaultValue) + } + + /** + * Get a time parameter as seconds; throws a NoSuchElementException if it's not set. If no + * suffix is provided then seconds are assumed. + * @throws NoSuchElementException + */ + def getTimeAsSeconds(key: String): Long = { + Utils.timeStringAsSeconds(get(key)) } + /** + * Get a time parameter as seconds, falling back to a default if not set. If no + * suffix is provided then seconds are assumed. + */ + def getTimeAsSeconds(key: String, defaultValue: String): Long = { + Utils.timeStringAsSeconds(get(key, defaultValue)) + } + + /** + * Get a time parameter as milliseconds; throws a NoSuchElementException if it's not set. If no + * suffix is provided then milliseconds are assumed. + * @throws NoSuchElementException + */ + def getTimeAsMs(key: String): Long = { + Utils.timeStringAsMs(get(key)) + } + + /** + * Get a time parameter as milliseconds, falling back to a default if not set. If no + * suffix is provided then milliseconds are assumed. + */ + def getTimeAsMs(key: String, defaultValue: String): Long = { + Utils.timeStringAsMs(get(key, defaultValue)) + } + + /** Get a parameter as an Option */ def getOption(key: String): Option[String] = { - settings.get(key) + Option(settings.get(key)).orElse(getDeprecatedConfig(key, this)) } /** Get all parameters as a list of pairs */ - def getAll: Array[(String, String)] = settings.toArray + def getAll: Array[(String, String)] = { + settings.entrySet().asScala.map(x => (x.getKey, x.getValue)).toArray + } /** Get a parameter as an integer, falling back to a default if not set */ def getInt(key: String, defaultValue: Int): Int = { @@ -225,11 +267,11 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getAppId: String = get("spark.app.id") /** Does the configuration contain a given parameter? */ - def contains(key: String): Boolean = settings.contains(key) + def contains(key: String): Boolean = settings.containsKey(key) /** Copy this object */ override def clone: SparkConf = { - new SparkConf(false).setAll(settings) + new SparkConf(false).setAll(getAll) } /** @@ -241,7 +283,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Checks for illegal or deprecated config settings. Throws an exception for the former. Not * idempotent - may mutate this conf object to convert deprecated settings to supported ones. */ private[spark] def validateSettings() { - if (settings.contains("spark.local.dir")) { + if (contains("spark.local.dir")) { val msg = "In Spark 1.0 and later spark.local.dir will be overridden by the value set by " + "the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone and LOCAL_DIRS in YARN)." logWarning(msg) @@ -266,7 +308,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } // Validate spark.executor.extraJavaOptions - settings.get(executorOptsKey).map { javaOpts => + getOption(executorOptsKey).map { javaOpts => if (javaOpts.contains("-Dspark")) { val msg = s"$executorOptsKey is not allowed to set Spark options (was '$javaOpts'). " + "Set them directly on a SparkConf or in a properties file when using ./bin/spark-submit." @@ -282,7 +324,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { // Validate memory fractions val memoryKeys = Seq( "spark.storage.memoryFraction", - "spark.shuffle.memoryFraction", + "spark.shuffle.memoryFraction", "spark.shuffle.safetyFraction", "spark.storage.unrollFraction", "spark.storage.safetyFraction") @@ -346,11 +388,72 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { * configuration out for debugging. */ def toDebugString: String = { - settings.toArray.sorted.map{case (k, v) => k + "=" + v}.mkString("\n") + getAll.sorted.map{case (k, v) => k + "=" + v}.mkString("\n") } + } -private[spark] object SparkConf { +private[spark] object SparkConf extends Logging { + + /** + * Maps deprecated config keys to information about the deprecation. + * + * The extra information is logged as a warning when the config is present in the user's + * configuration. + */ + private val deprecatedConfigs: Map[String, DeprecatedConfig] = { + val configs = Seq( + DeprecatedConfig("spark.cache.class", "0.8", + "The spark.cache.class property is no longer being used! Specify storage levels using " + + "the RDD.persist() method instead."), + DeprecatedConfig("spark.yarn.user.classpath.first", "1.3", + "Please use spark.{driver,executor}.userClassPathFirst instead.")) + Map(configs.map { cfg => (cfg.key -> cfg) }:_*) + } + + /** + * Maps a current config key to alternate keys that were used in previous version of Spark. + * + * The alternates are used in the order defined in this map. If deprecated configs are + * present in the user's configuration, a warning is logged. + */ + private val configsWithAlternatives = Map[String, Seq[AlternateConfig]]( + "spark.executor.userClassPathFirst" -> Seq( + AlternateConfig("spark.files.userClassPathFirst", "1.3")), + "spark.history.fs.update.interval" -> Seq( + AlternateConfig("spark.history.fs.update.interval.seconds", "1.4"), + AlternateConfig("spark.history.fs.updateInterval", "1.3"), + AlternateConfig("spark.history.updateInterval", "1.3")), + "spark.history.fs.cleaner.interval" -> Seq( + AlternateConfig("spark.history.fs.cleaner.interval.seconds", "1.4")), + "spark.history.fs.cleaner.maxAge" -> Seq( + AlternateConfig("spark.history.fs.cleaner.maxAge.seconds", "1.4")), + "spark.yarn.am.waitTime" -> Seq( + AlternateConfig("spark.yarn.applicationMaster.waitTries", "1.3", + // Translate old value to a duration, with 10s wait time per try. + translation = s => s"${s.toLong * 10}s")), + "spark.rpc.numRetries" -> Seq( + AlternateConfig("spark.akka.num.retries", "1.4")), + "spark.rpc.retry.wait" -> Seq( + AlternateConfig("spark.akka.retry.wait", "1.4")), + "spark.rpc.askTimeout" -> Seq( + AlternateConfig("spark.akka.askTimeout", "1.4")), + "spark.rpc.lookupTimeout" -> Seq( + AlternateConfig("spark.akka.lookupTimeout", "1.4")) + ) + + /** + * A view of `configsWithAlternatives` that makes it more efficient to look up deprecated + * config keys. + * + * Maps the deprecated config name to a 2-tuple (new config name, alternate config info). + */ + private val allAlternatives: Map[String, (String, AlternateConfig)] = { + configsWithAlternatives.keys.flatMap { key => + configsWithAlternatives(key).map { cfg => (cfg.key -> (key -> cfg)) } + }.toMap + } + /** * Return whether the given config is an akka config (e.g. akka.actor.provider). * Note that this does not include spark-specific akka configs (e.g. spark.akka.timeout). @@ -367,6 +470,7 @@ private[spark] object SparkConf { isAkkaConf(name) || name.startsWith("spark.akka") || name.startsWith("spark.auth") || + name.startsWith("spark.ssl") || isSparkPortConf(name) } @@ -376,4 +480,59 @@ private[spark] object SparkConf { def isSparkPortConf(name: String): Boolean = { (name.startsWith("spark.") && name.endsWith(".port")) || name.startsWith("spark.port.") } + + /** + * Looks for available deprecated keys for the given config option, and return the first + * value available. + */ + def getDeprecatedConfig(key: String, conf: SparkConf): Option[String] = { + configsWithAlternatives.get(key).flatMap { alts => + alts.collectFirst { case alt if conf.contains(alt.key) => + val value = conf.get(alt.key) + if (alt.translation != null) alt.translation(value) else value + } + } + } + + /** + * Logs a warning message if the given config key is deprecated. + */ + def logDeprecationWarning(key: String): Unit = { + deprecatedConfigs.get(key).foreach { cfg => + logWarning( + s"The configuration key '$key' has been deprecated as of Spark ${cfg.version} and " + + s"may be removed in the future. ${cfg.deprecationMessage}") + } + + allAlternatives.get(key).foreach { case (newKey, cfg) => + logWarning( + s"The configuration key '$key' has been deprecated as of Spark ${cfg.version} and " + + s"and may be removed in the future. Please use the new key '$newKey' instead.") + } + } + + /** + * Holds information about keys that have been deprecated and do not have a replacement. + * + * @param key The deprecated key. + * @param version Version of Spark where key was deprecated. + * @param deprecationMessage Message to include in the deprecation warning. + */ + private case class DeprecatedConfig( + key: String, + version: String, + deprecationMessage: String) + + /** + * Information about an alternate configuration key that has been deprecated. + * + * @param key The deprecated config key. + * @param version The Spark version in which the key was deprecated. + * @param translation A translation function for converting old config values into new ones. + */ + private case class AlternateConfig( + key: String, + version: String, + translation: String => String = null) + } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 6a354ed4d148..86269eac52db 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -20,33 +20,44 @@ package org.apache.spark import scala.language.implicitConversions import java.io._ +import java.lang.reflect.Constructor import java.net.URI import java.util.{Arrays, Properties, UUID} -import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.{AtomicReference, AtomicBoolean, AtomicInteger} import java.util.UUID.randomUUID + import scala.collection.{Map, Set} import scala.collection.JavaConversions._ import scala.collection.generic.Growable import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} +import scala.util.control.NonFatal + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} -import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat, TextInputFormat} +import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, + FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} +import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat, + TextInputFormat} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob} import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} + import org.apache.mesos.MesosNativeLibrary -import akka.actor.Props import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} -import org.apache.spark.executor.TriggerThreadDump -import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} +import org.apache.spark.executor.{ExecutorEndpoint, TriggerThreadDump} +import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, + FixedLengthBinaryInputFormat} +import org.apache.spark.io.CompressionCodec +import org.apache.spark.metrics.MetricsSystem import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkDeploySchedulerBackend, SimrSchedulerBackend} +import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, + SparkDeploySchedulerBackend, SimrSchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage._ @@ -85,6 +96,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val startTime = System.currentTimeMillis() + private val stopped: AtomicBoolean = new AtomicBoolean(false) + + private def assertNotStopped(): Unit = { + if (stopped.get()) { + throw new IllegalStateException("Cannot call methods on a stopped SparkContext") + } + } + /** * Create a SparkContext that loads settings from system properties (for instance, when * launching with ./bin/spark-submit). @@ -174,9 +193,43 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // log out Spark Version in Spark driver log logInfo(s"Running Spark version $SPARK_VERSION") - - private[spark] val conf = config.clone() - conf.validateSettings() + + /* ------------------------------------------------------------------------------------- * + | Private variables. These variables keep the internal state of the context, and are | + | not accessible by the outside world. They're mutable since we want to initialize all | + | of them to some neutral value ahead of time, so that calling "stop()" while the | + | constructor is still running is safe. | + * ------------------------------------------------------------------------------------- */ + + private var _conf: SparkConf = _ + private var _eventLogDir: Option[URI] = None + private var _eventLogCodec: Option[String] = None + private var _env: SparkEnv = _ + private var _metadataCleaner: MetadataCleaner = _ + private var _jobProgressListener: JobProgressListener = _ + private var _statusTracker: SparkStatusTracker = _ + private var _progressBar: Option[ConsoleProgressBar] = None + private var _ui: Option[SparkUI] = None + private var _hadoopConfiguration: Configuration = _ + private var _executorMemory: Int = _ + private var _schedulerBackend: SchedulerBackend = _ + private var _taskScheduler: TaskScheduler = _ + private var _heartbeatReceiver: RpcEndpointRef = _ + @volatile private var _dagScheduler: DAGScheduler = _ + private var _applicationId: String = _ + private var _eventLogger: Option[EventLoggingListener] = None + private var _executorAllocationManager: Option[ExecutorAllocationManager] = None + private var _cleaner: Option[ContextCleaner] = None + private var _listenerBusStarted: Boolean = false + private var _jars: Seq[String] = _ + private var _files: Seq[String] = _ + + /* ------------------------------------------------------------------------------------- * + | Accessors and public fields. These provide access to the internal state of the | + | context. | + * ------------------------------------------------------------------------------------- */ + + private[spark] def conf: SparkConf = _conf /** * Return a copy of this SparkContext's configuration. The configuration ''cannot'' be @@ -184,56 +237,33 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ def getConf: SparkConf = conf.clone() - if (!conf.contains("spark.master")) { - throw new SparkException("A master URL must be set in your configuration") - } - if (!conf.contains("spark.app.name")) { - throw new SparkException("An application name must be set in your configuration") - } - - if (conf.getBoolean("spark.logConf", false)) { - logInfo("Spark configuration:\n" + conf.toDebugString) - } - - // Set Spark driver host and port system properties - conf.setIfMissing("spark.driver.host", Utils.localHostName()) - conf.setIfMissing("spark.driver.port", "0") - - val jars: Seq[String] = - conf.getOption("spark.jars").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten - - val files: Seq[String] = - conf.getOption("spark.files").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten + def jars: Seq[String] = _jars + def files: Seq[String] = _files + def master: String = _conf.get("spark.master") + def appName: String = _conf.get("spark.app.name") - val master = conf.get("spark.master") - val appName = conf.get("spark.app.name") - - private[spark] val isEventLogEnabled = conf.getBoolean("spark.eventLog.enabled", false) - private[spark] val eventLogDir: Option[String] = { - if (isEventLogEnabled) { - Some(conf.get("spark.eventLog.dir", EventLoggingListener.DEFAULT_LOG_DIR).stripSuffix("/")) - } else { - None - } - } + private[spark] def isEventLogEnabled: Boolean = _conf.getBoolean("spark.eventLog.enabled", false) + private[spark] def eventLogDir: Option[URI] = _eventLogDir + private[spark] def eventLogCodec: Option[String] = _eventLogCodec // Generate the random name for a temp folder in Tachyon // Add a timestamp as the suffix here to make it more safe val tachyonFolderName = "spark-" + randomUUID.toString() - conf.set("spark.tachyonStore.folderName", tachyonFolderName) - - val isLocal = (master == "local" || master.startsWith("local[")) - if (master == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") + def isLocal: Boolean = (master == "local" || master.startsWith("local[")) // An asynchronous listener bus for Spark events private[spark] val listenerBus = new LiveListenerBus - conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER) + // This function allows components created by SparkEnv to be mocked in unit tests: + private[spark] def createSparkEnv( + conf: SparkConf, + isLocal: Boolean, + listenerBus: LiveListenerBus): SparkEnv = { + SparkEnv.createDriverEnv(conf, isLocal, listenerBus) + } - // Create the Spark execution environment (cache, map output tracker, etc) - private[spark] val env = SparkEnv.createDriverEnv(conf, isLocal, listenerBus) - SparkEnv.set(env) + private[spark] def env: SparkEnv = _env // Used to store a URL for each static file/jar together with the file's local timestamp private[spark] val addedFiles = HashMap[String, Long]() @@ -241,164 +271,263 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Keeps track of all persisted RDDs private[spark] val persistentRdds = new TimeStampedWeakValueHashMap[Int, RDD[_]] - private[spark] val metadataCleaner = - new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, conf) + private[spark] def metadataCleaner: MetadataCleaner = _metadataCleaner + private[spark] def jobProgressListener: JobProgressListener = _jobProgressListener + def statusTracker: SparkStatusTracker = _statusTracker - private[spark] val jobProgressListener = new JobProgressListener(conf) - listenerBus.addListener(jobProgressListener) + private[spark] def progressBar: Option[ConsoleProgressBar] = _progressBar - val statusTracker = new SparkStatusTracker(this) + private[spark] def ui: Option[SparkUI] = _ui - private[spark] val progressBar: Option[ConsoleProgressBar] = - if (conf.getBoolean("spark.ui.showConsoleProgress", true) && !log.isInfoEnabled) { - Some(new ConsoleProgressBar(this)) - } else { - None - } + /** + * A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. + * + * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you + * plan to set some global configurations for all Hadoop RDDs. + */ + def hadoopConfiguration: Configuration = _hadoopConfiguration - // Initialize the Spark UI - private[spark] val ui: Option[SparkUI] = - if (conf.getBoolean("spark.ui.enabled", true)) { - Some(SparkUI.createLiveUI(this, conf, listenerBus, jobProgressListener, - env.securityManager,appName)) - } else { - // For tests, do not enable the UI - None - } + private[spark] def executorMemory: Int = _executorMemory + + // Environment variables to pass to our executors. + private[spark] val executorEnvs = HashMap[String, String]() - // Bind the UI before starting the task scheduler to communicate - // the bound port to the cluster manager properly - ui.foreach(_.bind()) + // Set SPARK_USER for user who is running SparkContext. + val sparkUser = Utils.getCurrentUserName() + + private[spark] def schedulerBackend: SchedulerBackend = _schedulerBackend + private[spark] def schedulerBackend_=(sb: SchedulerBackend): Unit = { + _schedulerBackend = sb + } - /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ - val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf) + private[spark] def taskScheduler: TaskScheduler = _taskScheduler + private[spark] def taskScheduler_=(ts: TaskScheduler): Unit = { + _taskScheduler = ts + } - // Add each JAR given through the constructor - if (jars != null) { - jars.foreach(addJar) + private[spark] def dagScheduler: DAGScheduler = _dagScheduler + private[spark] def dagScheduler_=(ds: DAGScheduler): Unit = { + _dagScheduler = ds } - if (files != null) { - files.foreach(addFile) + def applicationId: String = _applicationId + + def metricsSystem: MetricsSystem = if (_env != null) _env.metricsSystem else null + + private[spark] def eventLogger: Option[EventLoggingListener] = _eventLogger + + private[spark] def executorAllocationManager: Option[ExecutorAllocationManager] = + _executorAllocationManager + + private[spark] def cleaner: Option[ContextCleaner] = _cleaner + + private[spark] var checkpointDir: Option[String] = None + + // Thread Local variable that can be used by users to pass information down the stack + private val localProperties = new InheritableThreadLocal[Properties] { + override protected def childValue(parent: Properties): Properties = new Properties(parent) + override protected def initialValue(): Properties = new Properties() } + /* ------------------------------------------------------------------------------------- * + | Initialization. This code initializes the context in a manner that is exception-safe. | + | All internal fields holding state are initialized here, and any error prompts the | + | stop() method to be called. | + * ------------------------------------------------------------------------------------- */ + private def warnSparkMem(value: String): String = { logWarning("Using SPARK_MEM to set amount of memory to use per executor process is " + "deprecated, please use spark.executor.memory instead.") value } - private[spark] val executorMemory = conf.getOption("spark.executor.memory") - .orElse(Option(System.getenv("SPARK_EXECUTOR_MEMORY"))) - .orElse(Option(System.getenv("SPARK_MEM")).map(warnSparkMem)) - .map(Utils.memoryStringToMb) - .getOrElse(512) + try { + _conf = config.clone() + _conf.validateSettings() - // Environment variables to pass to our executors. - private[spark] val executorEnvs = HashMap[String, String]() + if (!_conf.contains("spark.master")) { + throw new SparkException("A master URL must be set in your configuration") + } + if (!_conf.contains("spark.app.name")) { + throw new SparkException("An application name must be set in your configuration") + } - // Convert java options to env vars as a work around - // since we can't set env vars directly in sbt. - for { (envKey, propKey) <- Seq(("SPARK_TESTING", "spark.testing")) - value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} { - executorEnvs(envKey) = value - } - Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v => - executorEnvs("SPARK_PREPEND_CLASSES") = v - } - // The Mesos scheduler backend relies on this environment variable to set executor memory. - // TODO: Set this only in the Mesos scheduler. - executorEnvs("SPARK_EXECUTOR_MEMORY") = executorMemory + "m" - executorEnvs ++= conf.getExecutorEnv + if (_conf.getBoolean("spark.logConf", false)) { + logInfo("Spark configuration:\n" + _conf.toDebugString) + } - // Set SPARK_USER for user who is running SparkContext. - val sparkUser = Option { - Option(System.getenv("SPARK_USER")).getOrElse(System.getProperty("user.name")) - }.getOrElse { - SparkContext.SPARK_UNKNOWN_USER - } - executorEnvs("SPARK_USER") = sparkUser - - // Create and start the scheduler - private[spark] var (schedulerBackend, taskScheduler) = - SparkContext.createTaskScheduler(this, master) - private val heartbeatReceiver = env.actorSystem.actorOf( - Props(new HeartbeatReceiver(taskScheduler)), "HeartbeatReceiver") - @volatile private[spark] var dagScheduler: DAGScheduler = _ - try { - dagScheduler = new DAGScheduler(this) - } catch { - case e: Exception => { - try { - stop() - } finally { - throw new SparkException("Error while constructing DAGScheduler", e) + // Set Spark driver host and port system properties + _conf.setIfMissing("spark.driver.host", Utils.localHostName()) + _conf.setIfMissing("spark.driver.port", "0") + + _conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER) + + _jars =_conf.getOption("spark.jars").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten + _files = _conf.getOption("spark.files").map(_.split(",")).map(_.filter(_.size != 0)) + .toSeq.flatten + + _eventLogDir = + if (isEventLogEnabled) { + val unresolvedDir = conf.get("spark.eventLog.dir", EventLoggingListener.DEFAULT_LOG_DIR) + .stripSuffix("/") + Some(Utils.resolveURI(unresolvedDir)) + } else { + None + } + + _eventLogCodec = { + val compress = _conf.getBoolean("spark.eventLog.compress", false) + if (compress && isEventLogEnabled) { + Some(CompressionCodec.getCodecName(_conf)).map(CompressionCodec.getShortName) + } else { + None } } - } - // start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's - // constructor - taskScheduler.start() + _conf.set("spark.tachyonStore.folderName", tachyonFolderName) - val applicationId: String = taskScheduler.applicationId() - conf.set("spark.app.id", applicationId) + if (master == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") - env.blockManager.initialize(applicationId) + // Create the Spark execution environment (cache, map output tracker, etc) + _env = createSparkEnv(_conf, isLocal, listenerBus) + SparkEnv.set(_env) - val metricsSystem = env.metricsSystem + _metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, _conf) - // The metrics system for Driver need to be set spark.app.id to app ID. - // So it should start after we get app ID from the task scheduler and set spark.app.id. - metricsSystem.start() - // Attach the driver metrics servlet handler to the web ui after the metrics system is started. - metricsSystem.getServletHandlers.foreach(handler => ui.foreach(_.attachHandler(handler))) + _jobProgressListener = new JobProgressListener(_conf) + listenerBus.addListener(jobProgressListener) - // Optionally log Spark events - private[spark] val eventLogger: Option[EventLoggingListener] = { - if (isEventLogEnabled) { - val logger = - new EventLoggingListener(applicationId, eventLogDir.get, conf, hadoopConfiguration) - logger.start() - listenerBus.addListener(logger) - Some(logger) - } else None - } + _statusTracker = new SparkStatusTracker(this) - // Optionally scale number of executors dynamically based on workload. Exposed for testing. - private val dynamicAllocationEnabled = conf.getBoolean("spark.dynamicAllocation.enabled", false) - private val dynamicAllocationTesting = conf.getBoolean("spark.dynamicAllocation.testing", false) - private[spark] val executorAllocationManager: Option[ExecutorAllocationManager] = - if (dynamicAllocationEnabled) { - assert(master.contains("yarn") || dynamicAllocationTesting, - "Dynamic allocation of executors is currently only supported in YARN mode") - Some(new ExecutorAllocationManager(this, listenerBus, conf)) - } else { - None + _progressBar = + if (_conf.getBoolean("spark.ui.showConsoleProgress", true) && !log.isInfoEnabled) { + Some(new ConsoleProgressBar(this)) + } else { + None + } + + _ui = + if (conf.getBoolean("spark.ui.enabled", true)) { + Some(SparkUI.createLiveUI(this, _conf, listenerBus, _jobProgressListener, + _env.securityManager,appName)) + } else { + // For tests, do not enable the UI + None + } + // Bind the UI before starting the task scheduler to communicate + // the bound port to the cluster manager properly + _ui.foreach(_.bind()) + + _hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(_conf) + + // Add each JAR given through the constructor + if (jars != null) { + jars.foreach(addJar) } - executorAllocationManager.foreach(_.start()) - // At this point, all relevant SparkListeners have been registered, so begin releasing events - listenerBus.start() + if (files != null) { + files.foreach(addFile) + } - private[spark] val cleaner: Option[ContextCleaner] = { - if (conf.getBoolean("spark.cleaner.referenceTracking", true)) { - Some(new ContextCleaner(this)) - } else { - None + _executorMemory = _conf.getOption("spark.executor.memory") + .orElse(Option(System.getenv("SPARK_EXECUTOR_MEMORY"))) + .orElse(Option(System.getenv("SPARK_MEM")) + .map(warnSparkMem)) + .map(Utils.memoryStringToMb) + .getOrElse(512) + + // Convert java options to env vars as a work around + // since we can't set env vars directly in sbt. + for { (envKey, propKey) <- Seq(("SPARK_TESTING", "spark.testing")) + value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} { + executorEnvs(envKey) = value } - } - cleaner.foreach(_.start()) + Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v => + executorEnvs("SPARK_PREPEND_CLASSES") = v + } + // The Mesos scheduler backend relies on this environment variable to set executor memory. + // TODO: Set this only in the Mesos scheduler. + executorEnvs("SPARK_EXECUTOR_MEMORY") = executorMemory + "m" + executorEnvs ++= _conf.getExecutorEnv + executorEnvs("SPARK_USER") = sparkUser + + // We need to register "HeartbeatReceiver" before "createTaskScheduler" because Executor will + // retrieve "HeartbeatReceiver" in the constructor. (SPARK-6640) + _heartbeatReceiver = env.rpcEnv.setupEndpoint( + HeartbeatReceiver.ENDPOINT_NAME, new HeartbeatReceiver(this)) + + // Create and start the scheduler + val (sched, ts) = SparkContext.createTaskScheduler(this, master) + _schedulerBackend = sched + _taskScheduler = ts + _dagScheduler = new DAGScheduler(this) + _heartbeatReceiver.send(TaskSchedulerIsSet) + + // start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's + // constructor + _taskScheduler.start() + + _applicationId = _taskScheduler.applicationId() + _conf.set("spark.app.id", _applicationId) + _env.blockManager.initialize(_applicationId) + + // The metrics system for Driver need to be set spark.app.id to app ID. + // So it should start after we get app ID from the task scheduler and set spark.app.id. + metricsSystem.start() + // Attach the driver metrics servlet handler to the web ui after the metrics system is started. + metricsSystem.getServletHandlers.foreach(handler => ui.foreach(_.attachHandler(handler))) + + _eventLogger = + if (isEventLogEnabled) { + val logger = + new EventLoggingListener(_applicationId, _eventLogDir.get, _conf, _hadoopConfiguration) + logger.start() + listenerBus.addListener(logger) + Some(logger) + } else { + None + } - postEnvironmentUpdate() - postApplicationStart() + // Optionally scale number of executors dynamically based on workload. Exposed for testing. + val dynamicAllocationEnabled = _conf.getBoolean("spark.dynamicAllocation.enabled", false) + _executorAllocationManager = + if (dynamicAllocationEnabled) { + assert(supportDynamicAllocation, + "Dynamic allocation of executors is currently only supported in YARN mode") + Some(new ExecutorAllocationManager(this, listenerBus, _conf)) + } else { + None + } + _executorAllocationManager.foreach(_.start()) - private[spark] var checkpointDir: Option[String] = None + _cleaner = + if (_conf.getBoolean("spark.cleaner.referenceTracking", true)) { + Some(new ContextCleaner(this)) + } else { + None + } + _cleaner.foreach(_.start()) - // Thread Local variable that can be used by users to pass information down the stack - private val localProperties = new InheritableThreadLocal[Properties] { - override protected def childValue(parent: Properties): Properties = new Properties(parent) + setupAndStartListenerBus() + postEnvironmentUpdate() + postApplicationStart() + + // Post init + _taskScheduler.postStartHook() + _env.metricsSystem.registerSource(new DAGSchedulerSource(dagScheduler)) + _env.metricsSystem.registerSource(new BlockManagerSource(_env.blockManager)) + } catch { + case NonFatal(e) => + logError("Error initializing SparkContext.", e) + try { + stop() + } catch { + case NonFatal(inner) => + logError("Error stopping SparkContext after init error.", inner) + } finally { + throw e + } } /** @@ -412,10 +541,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli if (executorId == SparkContext.DRIVER_IDENTIFIER) { Some(Utils.getThreadDump()) } else { - val (host, port) = env.blockManager.master.getActorSystemHostPortForExecutor(executorId).get - val actorRef = AkkaUtils.makeExecutorRef("ExecutorActor", conf, host, port, env.actorSystem) - Some(AkkaUtils.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump, actorRef, - AkkaUtils.numRetries(conf), AkkaUtils.retryWaitMs(conf), AkkaUtils.askTimeout(conf))) + val (host, port) = env.blockManager.master.getRpcHostPortForExecutor(executorId).get + val endpointRef = env.rpcEnv.setupEndpointRef( + SparkEnv.executorActorSystemName, + RpcAddress(host, port), + ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME) + Some(endpointRef.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump)) } } catch { case e: Exception => @@ -440,9 +571,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Spark fair scheduler pool. */ def setLocalProperty(key: String, value: String) { - if (localProperties.get() == null) { - localProperties.set(new Properties()) - } if (value == null) { localProperties.get.remove(key) } else { @@ -503,19 +631,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, null) } - // Post init - taskScheduler.postStartHook() - - private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler) - private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager) - - private def initDriverMetrics() { - SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource) - SparkEnv.get.metricsSystem.registerSource(blockManagerSource) - } - - initDriverMetrics() - // Methods for creating RDDs /** Distribute a local Scala collection to form an RDD. @@ -523,8 +638,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @note Parallelize acts lazily. If `seq` is a mutable collection and is altered after the call * to parallelize and before the first action on the RDD, the resultant RDD will reflect the * modified collection. Pass a copy of the argument to avoid this. + * @note avoid using `parallelize(Seq())` to create an empty `RDD`. Consider `emptyRDD` for an + * RDD with no partitions, or `parallelize(Seq[T]())` for an RDD of `T` with empty partitions. */ def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { + assertNotStopped() new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]()) } @@ -540,6 +658,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * location preferences (hostnames of Spark nodes) for each object. * Create a new partition for each collection item. */ def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = { + assertNotStopped() val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs) } @@ -549,6 +668,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Hadoop-supported file system URI, and return it as an RDD of Strings. */ def textFile(path: String, minPartitions: Int = defaultMinPartitions): RDD[String] = { + assertNotStopped() hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text], minPartitions).map(pair => pair._2.toString).setName(path) } @@ -582,6 +702,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ def wholeTextFiles(path: String, minPartitions: Int = defaultMinPartitions): RDD[(String, String)] = { + assertNotStopped() val job = new NewHadoopJob(hadoopConfiguration) NewFileInputFormat.addInputPath(job, new Path(path)) val updateConf = job.getConfiguration @@ -627,6 +748,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @Experimental def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions): RDD[(String, PortableDataStream)] = { + assertNotStopped() val job = new NewHadoopJob(hadoopConfiguration) NewFileInputFormat.addInputPath(job, new Path(path)) val updateConf = job.getConfiguration @@ -644,6 +766,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * * Load data from a flat binary file, assuming the length of each record is constant. * + * '''Note:''' We ensure that the byte array for each record in the resulting RDD + * has the provided record length. + * * @param path Directory to the input data files * @param recordLength The length at which to split the records * @return An RDD of data with values, represented as byte arrays @@ -651,13 +776,18 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @Experimental def binaryRecords(path: String, recordLength: Int, conf: Configuration = hadoopConfiguration) : RDD[Array[Byte]] = { + assertNotStopped() conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength) val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path, classOf[FixedLengthBinaryInputFormat], classOf[LongWritable], classOf[BytesWritable], conf=conf) - val data = br.map{ case (k, v) => v.getBytes} + val data = br.map { case (k, v) => + val bytes = v.getBytes + assert(bytes.length == recordLength, "Byte array does not have correct length") + bytes + } data } @@ -666,16 +796,20 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable), * using the older MapReduce API (`org.apache.hadoop.mapred`). * - * @param conf JobConf for setting up the dataset + * @param conf JobConf for setting up the dataset. Note: This will be put into a Broadcast. + * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make + * sure you won't modify the conf. A safe approach is always creating a new conf for + * a new RDD. * @param inputFormatClass Class of the InputFormat * @param keyClass Class of the keys * @param valueClass Class of the values * @param minPartitions Minimum number of Hadoop Splits to generate. * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD will create many references to the same object. - * If you plan to directly cache Hadoop writable objects, you should first copy them using - * a `map` function. + * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle + * operation will create many references to the same object. + * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first + * copy them using a `map` function. */ def hadoopRDD[K, V]( conf: JobConf, @@ -684,18 +818,20 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli valueClass: Class[V], minPartitions: Int = defaultMinPartitions ): RDD[(K, V)] = { + assertNotStopped() // Add necessary security credentials to the JobConf before broadcasting it. SparkHadoopUtil.get.addCredentials(conf) new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minPartitions) } /** Get an RDD for a Hadoop file with an arbitrary InputFormat - * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD will create many references to the same object. - * If you plan to directly cache Hadoop writable objects, you should first copy them using - * a `map` function. - * */ + * + * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle + * operation will create many references to the same object. + * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first + * copy them using a `map` function. + */ def hadoopFile[K, V]( path: String, inputFormatClass: Class[_ <: InputFormat[K, V]], @@ -703,6 +839,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli valueClass: Class[V], minPartitions: Int = defaultMinPartitions ): RDD[(K, V)] = { + assertNotStopped() // A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it. val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration)) val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path) @@ -725,9 +862,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * }}} * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD will create many references to the same object. - * If you plan to directly cache Hadoop writable objects, you should first copy them using - * a `map` function. + * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle + * operation will create many references to the same object. + * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first + * copy them using a `map` function. */ def hadoopFile[K, V, F <: InputFormat[K, V]] (path: String, minPartitions: Int) @@ -748,9 +886,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * }}} * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD will create many references to the same object. - * If you plan to directly cache Hadoop writable objects, you should first copy them using - * a `map` function. + * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle + * operation will create many references to the same object. + * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first + * copy them using a `map` function. */ def hadoopFile[K, V, F <: InputFormat[K, V]](path: String) (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = @@ -772,9 +911,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * and extra configuration options to pass to the input format. * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD will create many references to the same object. - * If you plan to directly cache Hadoop writable objects, you should first copy them using - * a `map` function. + * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle + * operation will create many references to the same object. + * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first + * copy them using a `map` function. */ def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]]( path: String, @@ -782,6 +922,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli kClass: Class[K], vClass: Class[V], conf: Configuration = hadoopConfiguration): RDD[(K, V)] = { + assertNotStopped() + // The call to new NewHadoopJob automatically adds security credentials to conf, + // so we don't need to explicitly add them ourselves val job = new NewHadoopJob(conf) NewFileInputFormat.addInputPath(job, new Path(path)) val updatedConf = job.getConfiguration @@ -792,31 +935,46 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. * + * @param conf Configuration for setting up the dataset. Note: This will be put into a Broadcast. + * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make + * sure you won't modify the conf. A safe approach is always creating a new conf for + * a new RDD. + * @param fClass Class of the InputFormat + * @param kClass Class of the keys + * @param vClass Class of the values + * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD will create many references to the same object. - * If you plan to directly cache Hadoop writable objects, you should first copy them using - * a `map` function. + * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle + * operation will create many references to the same object. + * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first + * copy them using a `map` function. */ def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]]( conf: Configuration = hadoopConfiguration, fClass: Class[F], kClass: Class[K], vClass: Class[V]): RDD[(K, V)] = { - new NewHadoopRDD(this, fClass, kClass, vClass, conf) + assertNotStopped() + // Add necessary security credentials to the JobConf. Required to access secure HDFS. + val jconf = new JobConf(conf) + SparkHadoopUtil.get.addCredentials(jconf) + new NewHadoopRDD(this, fClass, kClass, vClass, jconf) } /** Get an RDD for a Hadoop SequenceFile with given key and value types. * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD will create many references to the same object. - * If you plan to directly cache Hadoop writable objects, you should first copy them using - * a `map` function. + * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle + * operation will create many references to the same object. + * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first + * copy them using a `map` function. */ def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V], minPartitions: Int ): RDD[(K, V)] = { + assertNotStopped() val inputFormatClass = classOf[SequenceFileInputFormat[K, V]] hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions) } @@ -824,13 +982,15 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** Get an RDD for a Hadoop SequenceFile with given key and value types. * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD will create many references to the same object. - * If you plan to directly cache Hadoop writable objects, you should first copy them using - * a `map` function. + * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle + * operation will create many references to the same object. + * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first + * copy them using a `map` function. * */ - def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V] - ): RDD[(K, V)] = + def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] = { + assertNotStopped() sequenceFile(path, keyClass, valueClass, defaultMinPartitions) + } /** * Version of sequenceFile() for types implicitly convertible to Writables through a @@ -849,15 +1009,17 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * allow it to figure out the Writable class to use in the subclass case. * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD will create many references to the same object. - * If you plan to directly cache Hadoop writable objects, you should first copy them using - * a `map` function. + * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle + * operation will create many references to the same object. + * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first + * copy them using a `map` function. */ def sequenceFile[K, V] (path: String, minPartitions: Int = defaultMinPartitions) (implicit km: ClassTag[K], vm: ClassTag[V], kcf: () => WritableConverter[K], vcf: () => WritableConverter[V]) : RDD[(K, V)] = { + assertNotStopped() val kc = kcf() val vc = vcf() val format = classOf[SequenceFileInputFormat[Writable, Writable]] @@ -879,6 +1041,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli path: String, minPartitions: Int = defaultMinPartitions ): RDD[T] = { + assertNotStopped() sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minPartitions) .flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes, Utils.getContextOrSparkClassLoader)) } @@ -890,14 +1053,21 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } /** Build the union of a list of RDDs. */ - def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds) + def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = { + val partitioners = rdds.flatMap(_.partitioner).toSet + if (partitioners.size == 1) { + new PartitionerAwareUnionRDD(this, rdds) + } else { + new UnionRDD(this, rdds) + } + } /** Build the union of a list of RDDs passed as variable-length arguments. */ def union[T: ClassTag](first: RDD[T], rest: RDD[T]*): RDD[T] = - new UnionRDD(this, Seq(first) ++ rest) + union(Seq(first) ++ rest) /** Get an RDD that has no partitions or elements. */ - def emptyRDD[T: ClassTag] = new EmptyRDD[T](this) + def emptyRDD[T: ClassTag]: EmptyRDD[T] = new EmptyRDD[T](this) // Methods for creating shared variables @@ -905,16 +1075,23 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" * values to using the `+=` method. Only the driver can access the accumulator's `value`. */ - def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = - new Accumulator(initialValue, param) + def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]): Accumulator[T] = + { + val acc = new Accumulator(initialValue, param) + cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc + } /** * Create an [[org.apache.spark.Accumulator]] variable of a given type, with a name for display * in the Spark UI. Tasks can "add" values to the accumulator using the `+=` method. Only the * driver can access the accumulator's `value`. */ - def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) = { - new Accumulator(initialValue, param, Some(name)) + def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) + : Accumulator[T] = { + val acc = new Accumulator(initialValue, param, Some(name)) + cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc } /** @@ -923,8 +1100,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @tparam R accumulator result type * @tparam T type that can be added to the accumulator */ - def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) = - new Accumulable(initialValue, param) + def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) + : Accumulable[R, T] = { + val acc = new Accumulable(initialValue, param) + cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc + } /** * Create an [[org.apache.spark.Accumulable]] shared variable, with a name for display in the @@ -933,8 +1114,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @tparam R accumulator result type * @tparam T type that can be added to the accumulator */ - def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) = - new Accumulable(initialValue, param, Some(name)) + def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) + : Accumulable[R, T] = { + val acc = new Accumulable(initialValue, param, Some(name)) + cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc + } /** * Create an accumulator from a "mutable collection" type. @@ -945,7 +1130,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T] (initialValue: R): Accumulable[R, T] = { val param = new GrowableAccumulableParam[R,T] - new Accumulable(initialValue, param) + val acc = new Accumulable(initialValue, param) + cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc } /** @@ -954,6 +1141,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * The variable will be sent to each cluster only once. */ def broadcast[T: ClassTag](value: T): Broadcast[T] = { + assertNotStopped() + if (classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass)) { + // This is a warning instead of an exception in order to avoid breaking user programs that + // might have created RDD broadcast variables but not used them: + logWarning("Can not directly broadcast RDDs; instead, call collect() and " + + "broadcast the result (see SPARK-5063)") + } val bc = env.broadcastManager.newBroadcast[T](value, isLocal) val callSite = getCallSite logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm) @@ -967,12 +1161,48 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, * use `SparkFiles.get(fileName)` to find its download location. */ - def addFile(path: String) { + def addFile(path: String): Unit = { + addFile(path, false) + } + + /** + * Add a file to be downloaded with this Spark job on every node. + * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported + * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, + * use `SparkFiles.get(fileName)` to find its download location. + * + * A directory can be given if the recursive option is set to true. Currently directories are only + * supported for Hadoop-supported filesystems. + */ + def addFile(path: String, recursive: Boolean): Unit = { val uri = new URI(path) - val key = uri.getScheme match { - case null | "file" => env.httpFileServer.addFile(new File(uri.getPath)) - case "local" => "file:" + uri.getPath - case _ => path + val schemeCorrectedPath = uri.getScheme match { + case null | "local" => new File(path).getCanonicalFile.toURI.toString + case _ => path + } + + val hadoopPath = new Path(schemeCorrectedPath) + val scheme = new URI(schemeCorrectedPath).getScheme + if (!Array("http", "https", "ftp").contains(scheme)) { + val fs = hadoopPath.getFileSystem(hadoopConfiguration) + if (!fs.exists(hadoopPath)) { + throw new FileNotFoundException(s"Added file $hadoopPath does not exist.") + } + val isDir = fs.getFileStatus(hadoopPath).isDir + if (!isLocal && scheme == "file" && isDir) { + throw new SparkException(s"addFile does not support local directories when not running " + + "local mode.") + } + if (!recursive && isDir) { + throw new SparkException(s"Added file $hadoopPath is a directory and recursive is not " + + "turned on.") + } + } + + val key = if (!isLocal && scheme == "file") { + env.httpFileServer.addFile(new File(uri.getPath)) + } else { + schemeCorrectedPath } val timestamp = System.currentTimeMillis addedFiles(key) = timestamp @@ -985,6 +1215,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli postEnvironmentUpdate() } + /** + * Return whether dynamically adjusting the amount of resources allocated to + * this application is supported. This is currently only available for YARN. + */ + private[spark] def supportDynamicAllocation = + master.contains("yarn") || _conf.getBoolean("spark.dynamicAllocation.testing", false) + /** * :: DeveloperApi :: * Register a listener to receive up-calls from events that happen during execution. @@ -994,14 +1231,31 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli listenerBus.addListener(listener) } + /** + * Express a preference to the cluster manager for a given total number of executors. + * This can result in canceling pending requests or filing additional requests. + * This is currently only supported in YARN mode. Return whether the request is received. + */ + private[spark] override def requestTotalExecutors(numExecutors: Int): Boolean = { + assert(supportDynamicAllocation, + "Requesting executors is currently only supported in YARN mode") + schedulerBackend match { + case b: CoarseGrainedSchedulerBackend => + b.requestTotalExecutors(numExecutors) + case _ => + logWarning("Requesting executors is only supported in coarse-grained mode") + false + } + } + /** * :: DeveloperApi :: * Request an additional number of executors from the cluster manager. - * This is currently only supported in Yarn mode. Return whether the request is received. + * This is currently only supported in YARN mode. Return whether the request is received. */ @DeveloperApi override def requestExecutors(numAdditionalExecutors: Int): Boolean = { - assert(master.contains("yarn") || dynamicAllocationTesting, + assert(supportDynamicAllocation, "Requesting executors is currently only supported in YARN mode") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => @@ -1015,11 +1269,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** * :: DeveloperApi :: * Request that the cluster manager kill the specified executors. - * This is currently only supported in Yarn mode. Return whether the request is received. + * This is currently only supported in YARN mode. Return whether the request is received. */ @DeveloperApi override def killExecutors(executorIds: Seq[String]): Boolean = { - assert(master.contains("yarn") || dynamicAllocationTesting, + assert(supportDynamicAllocation, "Killing executors is currently only supported in YARN mode") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => @@ -1039,13 +1293,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli override def killExecutor(executorId: String): Boolean = super.killExecutor(executorId) /** The version of Spark on which this application is running. */ - def version = SPARK_VERSION + def version: String = SPARK_VERSION /** * Return a map from the slave to the max memory available for caching and the remaining * memory available for caching. */ def getExecutorMemoryStatus: Map[String, (Long, Long)] = { + assertNotStopped() env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) => (blockManagerId.host + ":" + blockManagerId.port, mem) } @@ -1058,6 +1313,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getRDDStorageInfo: Array[RDDInfo] = { + assertNotStopped() val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus) rddInfos.filter(_.isCached) @@ -1075,6 +1331,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getExecutorStorageStatus: Array[StorageStatus] = { + assertNotStopped() env.blockManager.master.getStorageStatus } @@ -1084,6 +1341,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getAllPools: Seq[Schedulable] = { + assertNotStopped() // TODO(xiajunluan): We should take nested pools into account taskScheduler.rootPool.schedulableQueue.toSeq } @@ -1094,6 +1352,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getPoolForName(pool: String): Option[Schedulable] = { + assertNotStopped() Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool)) } @@ -1101,6 +1360,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Return current scheduling mode */ def getSchedulingMode: SchedulingMode.SchedulingMode = { + assertNotStopped() taskScheduler.schedulingMode } @@ -1175,7 +1435,19 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli null } } else { - env.httpFileServer.addJar(new File(uri.getPath)) + try { + env.httpFileServer.addJar(new File(uri.getPath)) + } catch { + case exc: FileNotFoundException => + logError(s"Jar not found at $path") + null + case e: Exception => + // For now just log an error but allow to go through so spark examples work. + // The spark examples don't really need the jar distributed since its also + // the app jar. + logError("Error adding jar (" + e + "), was the --addJars option used?") + null + } } // A JAR file which exists locally on every worker node case "local" => @@ -1201,33 +1473,46 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli addedJars.clear() } - /** Shut down the SparkContext. */ + // Shut down the SparkContext. def stop() { - SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { - postApplicationEnd() - ui.foreach(_.stop()) - // Do this only if not stopped already - best case effort. - // prevent NPE if stopped more than once. - val dagSchedulerCopy = dagScheduler - dagScheduler = null - if (dagSchedulerCopy != null) { - env.metricsSystem.report() - metadataCleaner.cancel() - env.actorSystem.stop(heartbeatReceiver) - cleaner.foreach(_.stop()) - dagSchedulerCopy.stop() - taskScheduler = null - // TODO: Cache.stop()? - env.stop() - SparkEnv.set(null) - listenerBus.stop() - eventLogger.foreach(_.stop()) - logInfo("Successfully stopped SparkContext") - SparkContext.clearActiveContext() - } else { - logInfo("SparkContext already stopped") - } + // Use the stopping variable to ensure no contention for the stop scenario. + // Still track the stopped variable for use elsewhere in the code. + if (!stopped.compareAndSet(false, true)) { + logInfo("SparkContext already stopped.") + return + } + + postApplicationEnd() + _ui.foreach(_.stop()) + if (env != null) { + env.metricsSystem.report() } + if (metadataCleaner != null) { + metadataCleaner.cancel() + } + _cleaner.foreach(_.stop()) + _executorAllocationManager.foreach(_.stop()) + if (_dagScheduler != null) { + _dagScheduler.stop() + _dagScheduler = null + } + if (_listenerBusStarted) { + listenerBus.stop() + _listenerBusStarted = false + } + _eventLogger.foreach(_.stop()) + if (env != null && _heartbeatReceiver != null) { + env.rpcEnv.stop(_heartbeatReceiver) + } + _progressBar.foreach(_.stop()) + _taskScheduler = null + // TODO: Cache.stop()? + if (_env != null) { + _env.stop() + SparkEnv.set(null) + } + SparkContext.clearActiveContext() + logInfo("Successfully stopped SparkContext") } @@ -1289,12 +1574,15 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli partitions: Seq[Int], allowLocal: Boolean, resultHandler: (Int, U) => Unit) { - if (dagScheduler == null) { - throw new SparkException("SparkContext has been shutdown") + if (stopped.get()) { + throw new IllegalStateException("SparkContext has been shutdown") } val callSite = getCallSite val cleanedFunc = clean(func) logInfo("Starting job: " + callSite.shortForm) + if (conf.getBoolean("spark.logLineage", false)) { + logInfo("RDD's recursive dependencies:\n" + rdd.toDebugString) + } dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, resultHandler, localProperties.get) progressBar.foreach(_.finishAll()) @@ -1377,6 +1665,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli func: (TaskContext, Iterator[T]) => U, evaluator: ApproximateEvaluator[U, R], timeout: Long): PartialResult[R] = { + assertNotStopped() val callSite = getCallSite logInfo("Starting job: " + callSite.shortForm) val start = System.nanoTime @@ -1399,6 +1688,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli resultHandler: (Int, U) => Unit, resultFunc: => R): SimpleFutureAction[R] = { + assertNotStopped() val cleanF = clean(processPartition) val callSite = getCallSite val waiter = dagScheduler.submitJob( @@ -1417,11 +1707,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * for more information. */ def cancelJobGroup(groupId: String) { + assertNotStopped() dagScheduler.cancelJobGroup(groupId) } /** Cancel all jobs that have been scheduled or are running. */ def cancelAllJobs() { + assertNotStopped() dagScheduler.cancelAllJobs() } @@ -1465,16 +1757,23 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } } - def getCheckpointDir = checkpointDir + def getCheckpointDir: Option[String] = checkpointDir /** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */ - def defaultParallelism: Int = taskScheduler.defaultParallelism + def defaultParallelism: Int = { + assertNotStopped() + taskScheduler.defaultParallelism + } /** Default min number of partitions for Hadoop RDDs when not given by user */ @deprecated("use defaultMinPartitions", "1.0.0") def defaultMinSplits: Int = math.min(defaultParallelism, 2) - /** Default min number of partitions for Hadoop RDDs when not given by user */ + /** + * Default min number of partitions for Hadoop RDDs when not given by user + * Notice that we use math.min so the "defaultMinPartitions" cannot be higher than 2. + * The reasons for this are discussed in https://github.com/mesos/spark/pull/718 + */ def defaultMinPartitions: Int = math.min(defaultParallelism, 2) private val nextShuffleId = new AtomicInteger(0) @@ -1486,6 +1785,59 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** Register a new RDD, returning its RDD ID */ private[spark] def newRddId(): Int = nextRddId.getAndIncrement() + /** + * Registers listeners specified in spark.extraListeners, then starts the listener bus. + * This should be called after all internal listeners have been registered with the listener bus + * (e.g. after the web UI and event logging listeners have been registered). + */ + private def setupAndStartListenerBus(): Unit = { + // Use reflection to instantiate listeners specified via `spark.extraListeners` + try { + val listenerClassNames: Seq[String] = + conf.get("spark.extraListeners", "").split(',').map(_.trim).filter(_ != "") + for (className <- listenerClassNames) { + // Use reflection to find the right constructor + val constructors = { + val listenerClass = Class.forName(className) + listenerClass.getConstructors.asInstanceOf[Array[Constructor[_ <: SparkListener]]] + } + val constructorTakingSparkConf = constructors.find { c => + c.getParameterTypes.sameElements(Array(classOf[SparkConf])) + } + lazy val zeroArgumentConstructor = constructors.find { c => + c.getParameterTypes.isEmpty + } + val listener: SparkListener = { + if (constructorTakingSparkConf.isDefined) { + constructorTakingSparkConf.get.newInstance(conf) + } else if (zeroArgumentConstructor.isDefined) { + zeroArgumentConstructor.get.newInstance() + } else { + throw new SparkException( + s"$className did not have a zero-argument constructor or a" + + " single-argument constructor that accepts SparkConf. Note: if the class is" + + " defined inside of another Scala class, then its constructors may accept an" + + " implicit parameter that references the enclosing class; in this case, you must" + + " define the listener as a top-level class in order to prevent this extra" + + " parameter from breaking Spark's ability to find a valid constructor.") + } + } + listenerBus.addListener(listener) + logInfo(s"Registered listener $className") + } + } catch { + case e: Exception => + try { + stop() + } finally { + throw new SparkException(s"Exception when registering SparkListener", e) + } + } + + listenerBus.start(this) + _listenerBusStarted = true + } + /** Post the application start event */ private def postApplicationStart() { // Note: this code assumes that the task scheduler has been initialized and has contacted @@ -1505,8 +1857,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val schedulingMode = getSchedulingMode.toString val addedJarPaths = addedJars.keys.toSeq val addedFilePaths = addedFiles.keys.toSeq - val environmentDetails = - SparkEnv.environmentDetails(conf, schedulingMode, addedJarPaths, addedFilePaths) + val environmentDetails = SparkEnv.environmentDetails(conf, schedulingMode, addedJarPaths, + addedFilePaths) val environmentUpdate = SparkListenerEnvironmentUpdate(environmentDetails) listenerBus.post(environmentUpdate) } @@ -1535,11 +1887,12 @@ object SparkContext extends Logging { private val SPARK_CONTEXT_CONSTRUCTOR_LOCK = new Object() /** - * The active, fully-constructed SparkContext. If no SparkContext is active, then this is `None`. + * The active, fully-constructed SparkContext. If no SparkContext is active, then this is `null`. * - * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK + * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK. */ - private var activeContext: Option[SparkContext] = None + private val activeContext: AtomicReference[SparkContext] = + new AtomicReference[SparkContext](null) /** * Points to a partially-constructed SparkContext if some thread is in the SparkContext @@ -1574,7 +1927,8 @@ object SparkContext extends Logging { logWarning(warnMsg) } - activeContext.foreach { ctx => + if (activeContext.get() != null) { + val ctx = activeContext.get() val errMsg = "Only one SparkContext may be running in this JVM (see SPARK-2243)." + " To ignore this error, set spark.driver.allowMultipleContexts = true. " + s"The currently running SparkContext was created at:\n${ctx.creationSite.longForm}" @@ -1589,6 +1943,39 @@ object SparkContext extends Logging { } } + /** + * This function may be used to get or instantiate a SparkContext and register it as a + * singleton object. Because we can only have one active SparkContext per JVM, + * this is useful when applications may wish to share a SparkContext. + * + * Note: This function cannot be used to create multiple SparkContext instances + * even if multiple contexts are allowed. + */ + def getOrCreate(config: SparkConf): SparkContext = { + // Synchronize to ensure that multiple create requests don't trigger an exception + // from assertNoOtherContextIsRunning within setActiveContext + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + if (activeContext.get() == null) { + setActiveContext(new SparkContext(config), allowMultipleContexts = false) + } + activeContext.get() + } + } + + /** + * This function may be used to get or instantiate a SparkContext and register it as a + * singleton object. Because we can only have one active SparkContext per JVM, + * this is useful when applications may wish to share a SparkContext. + * + * This method allows not passing a SparkConf (useful if just retrieving). + * + * Note: This function cannot be used to create multiple SparkContext instances + * even if multiple contexts are allowed. + */ + def getOrCreate(): SparkContext = { + getOrCreate(new SparkConf()) + } + /** * Called at the beginning of the SparkContext constructor to ensure that no SparkContext is * running. Throws an exception if a running context is detected and logs a warning if another @@ -1615,7 +2002,7 @@ object SparkContext extends Logging { SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { assertNoOtherContextIsRunning(sc, allowMultipleContexts) contextBeingConstructed = None - activeContext = Some(sc) + activeContext.set(sc) } } @@ -1626,7 +2013,7 @@ object SparkContext extends Logging { */ private[spark] def clearActiveContext(): Unit = { SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { - activeContext = None + activeContext.set(null) } } @@ -1636,9 +2023,17 @@ object SparkContext extends Logging { private[spark] val SPARK_JOB_INTERRUPT_ON_CANCEL = "spark.job.interruptOnCancel" - private[spark] val SPARK_UNKNOWN_USER = "" + /** + * Executor id for the driver. In earlier versions of Spark, this was ``, but this was + * changed to `driver` because the angle brackets caused escaping issues in URLs and XML (see + * SPARK-6716 for more details). + */ + private[spark] val DRIVER_IDENTIFIER = "driver" - private[spark] val DRIVER_IDENTIFIER = "" + /** + * Legacy version of DRIVER_IDENTIFIER, retained for backwards-compatibility. + */ + private[spark] val LEGACY_DRIVER_IDENTIFIER = "" // The following deprecated objects have already been copied to `object AccumulatorParam` to // make the compiler find them automatically. They are duplicate codes only for backward @@ -1649,28 +2044,28 @@ object SparkContext extends Logging { "backward compatibility.", "1.3.0") object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 - def zero(initialValue: Double) = 0.0 + def zero(initialValue: Double): Double = 0.0 } @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + "backward compatibility.", "1.3.0") object IntAccumulatorParam extends AccumulatorParam[Int] { def addInPlace(t1: Int, t2: Int): Int = t1 + t2 - def zero(initialValue: Int) = 0 + def zero(initialValue: Int): Int = 0 } @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + "backward compatibility.", "1.3.0") object LongAccumulatorParam extends AccumulatorParam[Long] { - def addInPlace(t1: Long, t2: Long) = t1 + t2 - def zero(initialValue: Long) = 0L + def addInPlace(t1: Long, t2: Long): Long = t1 + t2 + def zero(initialValue: Long): Long = 0L } @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + "backward compatibility.", "1.3.0") object FloatAccumulatorParam extends AccumulatorParam[Float] { - def addInPlace(t1: Float, t2: Float) = t1 + t2 - def zero(initialValue: Float) = 0f + def addInPlace(t1: Float, t2: Float): Float = t1 + t2 + def zero(initialValue: Float): Float = 0f } // The following deprecated functions have already been moved to `object RDD` to @@ -1680,49 +2075,71 @@ object SparkContext extends Logging { @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)]) - (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = { + (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null): PairRDDFunctions[K, V] = RDD.rddToPairRDDFunctions(rdd) - } @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") - def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = RDD.rddToAsyncRDDActions(rdd) + def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]): AsyncRDDActions[T] = + RDD.rddToAsyncRDDActions(rdd) @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( - rdd: RDD[(K, V)]) = + rdd: RDD[(K, V)]): SequenceFileRDDFunctions[K, V] = { + val kf = implicitly[K => Writable] + val vf = implicitly[V => Writable] + // Set the Writable class to null and `SequenceFileRDDFunctions` will use Reflection to get it + implicit val keyWritableFactory = new WritableFactory[K](_ => null, kf) + implicit val valueWritableFactory = new WritableFactory[V](_ => null, vf) RDD.rddToSequenceFileRDDFunctions(rdd) + } @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag]( - rdd: RDD[(K, V)]) = + rdd: RDD[(K, V)]): OrderedRDDFunctions[K, V, (K, V)] = RDD.rddToOrderedRDDFunctions(rdd) @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") - def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = RDD.doubleRDDToDoubleRDDFunctions(rdd) + def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]): DoubleRDDFunctions = + RDD.doubleRDDToDoubleRDDFunctions(rdd) @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") - def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) = + def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]): DoubleRDDFunctions = RDD.numericRDDToDoubleRDDFunctions(rdd) - // Implicit conversions to common Writable types, for saveAsSequenceFile + // The following deprecated functions have already been moved to `object WritableFactory` to + // make the compiler find them automatically. They are still kept here for backward compatibility. + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def intToIntWritable(i: Int): IntWritable = new IntWritable(i) + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def longToLongWritable(l: Long): LongWritable = new LongWritable(l) + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def floatToFloatWritable(f: Float): FloatWritable = new FloatWritable(f) + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def doubleToDoubleWritable(d: Double): DoubleWritable = new DoubleWritable(d) + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def boolToBoolWritable (b: Boolean): BooleanWritable = new BooleanWritable(b) + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def bytesToBytesWritable (aob: Array[Byte]): BytesWritable = new BytesWritable(aob) + @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + + "kept here only for backward compatibility.", "1.3.0") implicit def stringToText(s: String): Text = new Text(s) private implicit def arrayToArrayWritable[T <% Writable: ClassTag](arr: Traversable[T]) @@ -1857,29 +2274,29 @@ object SparkContext extends Logging { master match { case "local" => val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) - val backend = new LocalBackend(scheduler, 1) + val backend = new LocalBackend(sc.getConf, scheduler, 1) scheduler.initialize(backend) (backend, scheduler) case LOCAL_N_REGEX(threads) => - def localCpuCount = Runtime.getRuntime.availableProcessors() + def localCpuCount: Int = Runtime.getRuntime.availableProcessors() // local[*] estimates the number of cores on the machine; local[N] uses exactly N threads. val threadCount = if (threads == "*") localCpuCount else threads.toInt if (threadCount <= 0) { throw new SparkException(s"Asked to run locally with $threadCount threads") } val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) - val backend = new LocalBackend(scheduler, threadCount) + val backend = new LocalBackend(sc.getConf, scheduler, threadCount) scheduler.initialize(backend) (backend, scheduler) case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => - def localCpuCount = Runtime.getRuntime.availableProcessors() + def localCpuCount: Int = Runtime.getRuntime.availableProcessors() // local[*, M] means the number of cores on the computer with M failures // local[N, M] means exactly N threads with M failures val threadCount = if (threads == "*") localCpuCount else threads.toInt val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true) - val backend = new LocalBackend(scheduler, threadCount) + val backend = new LocalBackend(sc.getConf, scheduler, threadCount) scheduler.initialize(backend) (backend, scheduler) @@ -1901,7 +2318,7 @@ object SparkContext extends Logging { val scheduler = new TaskSchedulerImpl(sc) val localCluster = new LocalSparkCluster( - numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) + numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt, sc.conf) val masterUrls = localCluster.start() val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls) scheduler.initialize(backend) @@ -1942,7 +2359,7 @@ object SparkContext extends Logging { case "yarn-client" => val scheduler = try { val clazz = - Class.forName("org.apache.spark.scheduler.cluster.YarnClientClusterScheduler") + Class.forName("org.apache.spark.scheduler.cluster.YarnScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl] @@ -2012,7 +2429,7 @@ object WritableConverter { new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W])) } - // The following implicit functions were in SparkContext before 1.2 and users had to + // The following implicit functions were in SparkContext before 1.3 and users had to // `import SparkContext._` to enable them. Now we move them here to make the compiler find // them automatically. However, we still keep the old functions in SparkContext for backward // compatibility and forward to the following functions directly. @@ -2045,3 +2462,46 @@ object WritableConverter { implicit def writableWritableConverter[T <: Writable](): WritableConverter[T] = new WritableConverter[T](_.runtimeClass.asInstanceOf[Class[T]], _.asInstanceOf[T]) } + +/** + * A class encapsulating how to convert some type T to Writable. It stores both the Writable class + * corresponding to T (e.g. IntWritable for Int) and a function for doing the conversion. + * The Writable class will be used in `SequenceFileRDDFunctions`. + */ +private[spark] class WritableFactory[T]( + val writableClass: ClassTag[T] => Class[_ <: Writable], + val convert: T => Writable) extends Serializable + +object WritableFactory { + + private[spark] def simpleWritableFactory[T: ClassTag, W <: Writable : ClassTag](convert: T => W) + : WritableFactory[T] = { + val writableClass = implicitly[ClassTag[W]].runtimeClass.asInstanceOf[Class[W]] + new WritableFactory[T](_ => writableClass, convert) + } + + implicit def intWritableFactory: WritableFactory[Int] = + simpleWritableFactory(new IntWritable(_)) + + implicit def longWritableFactory: WritableFactory[Long] = + simpleWritableFactory(new LongWritable(_)) + + implicit def floatWritableFactory: WritableFactory[Float] = + simpleWritableFactory(new FloatWritable(_)) + + implicit def doubleWritableFactory: WritableFactory[Double] = + simpleWritableFactory(new DoubleWritable(_)) + + implicit def booleanWritableFactory: WritableFactory[Boolean] = + simpleWritableFactory(new BooleanWritable(_)) + + implicit def bytesWritableFactory: WritableFactory[Array[Byte]] = + simpleWritableFactory(new BytesWritable(_)) + + implicit def stringWritableFactory: WritableFactory[String] = + simpleWritableFactory(new Text(_)) + + implicit def writableWritableFactory[T <: Writable: ClassTag]: WritableFactory[T] = + simpleWritableFactory(w => w) + +} diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 4d418037bd33..959aefabd8de 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConversions._ import scala.collection.mutable import scala.util.Properties -import akka.actor._ import com.google.common.collect.MapMaker import org.apache.spark.annotation.DeveloperApi @@ -34,11 +33,14 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.nio.NioBlockTransferService -import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint, RpcEnv} +import org.apache.spark.rpc.akka.AkkaRpcEnv +import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus} +import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{RpcUtils, Utils} /** * :: DeveloperApi :: @@ -53,7 +55,7 @@ import org.apache.spark.util.{AkkaUtils, Utils} @DeveloperApi class SparkEnv ( val executorId: String, - val actorSystem: ActorSystem, + private[spark] val rpcEnv: RpcEnv, val serializer: Serializer, val closureSerializer: Serializer, val cacheManager: CacheManager, @@ -67,8 +69,12 @@ class SparkEnv ( val sparkFilesDir: String, val metricsSystem: MetricsSystem, val shuffleMemoryManager: ShuffleMemoryManager, + val outputCommitCoordinator: OutputCommitCoordinator, val conf: SparkConf) extends Logging { + // TODO Remove actorSystem + val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem + private[spark] var isStopped = false private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() @@ -76,6 +82,8 @@ class SparkEnv ( // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats). private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]() + private var driverTmpDirToDelete: Option[String] = None + private[spark] def stop() { isStopped = true pythonWorkers.foreach { case(key, worker) => worker.stop() } @@ -86,13 +94,31 @@ class SparkEnv ( blockManager.stop() blockManager.master.stop() metricsSystem.stop() - actorSystem.shutdown() + outputCommitCoordinator.stop() + rpcEnv.shutdown() + // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut // down, but let's call it anyway in case it gets fixed in a later release // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. // actorSystem.awaitTermination() // Note that blockTransferService is stopped by BlockManager since it is started by it. + + // If we only stop sc, but the driver process still run as a services then we need to delete + // the tmp dir, if not, it will create too many tmp dirs. + // We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the + // current working dir in executor which we do not need to delete. + driverTmpDirToDelete match { + case Some(path) => { + try { + Utils.deleteRecursively(new File(path)) + } catch { + case e: Exception => + logWarning(s"Exception while deleting Spark temp dir: $path", e) + } + } + case None => // We just need to delete tmp dir created by driver, so do nothing on executor + } } private[spark] @@ -151,7 +177,8 @@ object SparkEnv extends Logging { private[spark] def createDriverEnv( conf: SparkConf, isLocal: Boolean, - listenerBus: LiveListenerBus): SparkEnv = { + listenerBus: LiveListenerBus, + mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { assert(conf.contains("spark.driver.host"), "spark.driver.host is not set on the driver!") assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!") val hostname = conf.get("spark.driver.host") @@ -163,7 +190,8 @@ object SparkEnv extends Logging { port, isDriver = true, isLocal = isLocal, - listenerBus = listenerBus + listenerBus = listenerBus, + mockOutputCommitCoordinator = mockOutputCommitCoordinator ) } @@ -202,7 +230,8 @@ object SparkEnv extends Logging { isDriver: Boolean, isLocal: Boolean, listenerBus: LiveListenerBus = null, - numUsableCores: Int = 0): SparkEnv = { + numUsableCores: Int = 0, + mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { // Listener bus is only used on the driver if (isDriver) { @@ -212,16 +241,15 @@ object SparkEnv extends Logging { val securityManager = new SecurityManager(conf) // Create the ActorSystem for Akka and get the port it binds to. - val (actorSystem, boundPort) = { - val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName - AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager) - } + val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName + val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager) + val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem // Figure out which port Akka actually bound to in case the original port is 0 or occupied. if (isDriver) { - conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", rpcEnv.address.port.toString) } else { - conf.set("spark.executor.port", boundPort.toString) + conf.set("spark.executor.port", rpcEnv.address.port.toString) } // Create an instance of the class with the given name, possibly initializing it with our conf @@ -257,12 +285,14 @@ object SparkEnv extends Logging { val closureSerializer = instantiateClassFromConf[Serializer]( "spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer") - def registerOrLookup(name: String, newActor: => Actor): ActorRef = { + def registerOrLookupEndpoint( + name: String, endpointCreator: => RpcEndpoint): + RpcEndpointRef = { if (isDriver) { logInfo("Registering " + name) - actorSystem.actorOf(Props(newActor), name = name) + rpcEnv.setupEndpoint(name, endpointCreator) } else { - AkkaUtils.makeDriverRef(name, conf, actorSystem) + RpcUtils.makeDriverRef(name, conf, rpcEnv) } } @@ -274,9 +304,9 @@ object SparkEnv extends Logging { // Have to assign trackerActor after initialization as MapOutputTrackerActor // requires the MapOutputTracker itself - mapOutputTracker.trackerActor = registerOrLookup( - "MapOutputTracker", - new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) + mapOutputTracker.trackerEndpoint = registerOrLookupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint( + rpcEnv, mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) // Let the user specify short names for shuffle managers val shortShuffleMgrNames = Map( @@ -296,12 +326,13 @@ object SparkEnv extends Logging { new NioBlockTransferService(conf, securityManager) } - val blockManagerMaster = new BlockManagerMaster(registerOrLookup( - "BlockManagerMaster", - new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver) + val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint( + BlockManagerMaster.DRIVER_ENDPOINT_NAME, + new BlockManagerMasterEndpoint(rpcEnv, isLocal, conf, listenerBus)), + conf, isDriver) // NB: blockManager is not valid until initialize() is called later. - val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, + val blockManager = new BlockManager(executorId, rpcEnv, blockManagerMaster, serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) @@ -326,6 +357,10 @@ object SparkEnv extends Logging { // Then we can start the metrics system. MetricsSystem.createMetricsSystem("driver", conf, securityManager) } else { + // We need to set the executor ID before the MetricsSystem is created because sources and + // sinks specified in the metrics configuration file will want to incorporate this executor's + // ID into the metrics they report. + conf.set("spark.executor.id", executorId) val ms = MetricsSystem.createMetricsSystem("executor", conf, securityManager) ms.start() ms @@ -335,20 +370,21 @@ object SparkEnv extends Logging { // this is a temporary directory; in distributed mode, this is the executor's current working // directory. val sparkFilesDir: String = if (isDriver) { - Utils.createTempDir().getAbsolutePath + Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath } else { "." } - // Warn about deprecated spark.cache.class property - if (conf.contains("spark.cache.class")) { - logWarning("The spark.cache.class property is no longer being used! Specify storage " + - "levels using the RDD.persist() method instead.") + val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse { + new OutputCommitCoordinator(conf) } + val outputCommitCoordinatorRef = registerOrLookupEndpoint("OutputCommitCoordinator", + new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator)) + outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef) - new SparkEnv( + val envInstance = new SparkEnv( executorId, - actorSystem, + rpcEnv, serializer, closureSerializer, cacheManager, @@ -362,7 +398,17 @@ object SparkEnv extends Logging { sparkFilesDir, metricsSystem, shuffleMemoryManager, + outputCommitCoordinator, conf) + + // Add a reference to tmp dir created by driver, we will delete this tmp dir when stop() is + // called, and we only need to do it for driver. Because driver may run as a service, and if we + // don't delete this tmp dir when sc is stopped, then will create too many tmp dirs. + if (isDriver) { + envInstance.driverTmpDirToDelete = Some(sparkFilesDir) + } + + envInstance } /** diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 40237596570d..2ec42d3aea16 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -103,26 +103,11 @@ class SparkHadoopWriter(@transient jobConf: JobConf) } def commit() { - val taCtxt = getTaskContext() - val cmtr = getOutputCommitter() - if (cmtr.needsTaskCommit(taCtxt)) { - try { - cmtr.commitTask(taCtxt) - logInfo (taID + ": Committed") - } catch { - case e: IOException => { - logError("Error committing the output of task: " + taID.value, e) - cmtr.abortTask(taCtxt) - throw e - } - } - } else { - logInfo ("No need to commit output of task: " + taID.value) - } + SparkHadoopMapRedUtil.commitTask( + getOutputCommitter(), getTaskContext(), jobID, splitID, attemptID) } def commitJob() { - // always ? Or if cmtr.needsTaskCommit ? val cmtr = getOutputCommitter() cmtr.commitJob(getJobContext()) } diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala index edbdda8a0bcb..34ee3a48f8e7 100644 --- a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala @@ -45,8 +45,7 @@ class SparkStatusTracker private[spark] (sc: SparkContext) { */ def getJobIdsForGroup(jobGroup: String): Array[Int] = { jobProgressListener.synchronized { - val jobData = jobProgressListener.jobIdToData.valuesIterator - jobData.filter(_.jobGroup.orNull == jobGroup).map(_.jobId).toArray + jobProgressListener.jobGroupToJobIds.getOrElse(jobGroup, Seq.empty).toArray } } diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala new file mode 100644 index 000000000000..7d7fe1a44631 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.Serializable + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.util.TaskCompletionListener + + +object TaskContext { + /** + * Return the currently active TaskContext. This can be called inside of + * user functions to access contextual information about running tasks. + */ + def get(): TaskContext = taskContext.get + + private val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext] + + // Note: protected[spark] instead of private[spark] to prevent the following two from + // showing up in JavaDoc. + /** + * Set the thread local TaskContext. Internal to Spark. + */ + protected[spark] def setTaskContext(tc: TaskContext): Unit = taskContext.set(tc) + + /** + * Unset the thread local TaskContext. Internal to Spark. + */ + protected[spark] def unset(): Unit = taskContext.remove() +} + + +/** + * Contextual information about a task which can be read or mutated during + * execution. To access the TaskContext for a running task, use: + * {{{ + * org.apache.spark.TaskContext.get() + * }}} + */ +abstract class TaskContext extends Serializable { + // Note: TaskContext must NOT define a get method. Otherwise it will prevent the Scala compiler + // from generating a static get method (based on the companion object's get method). + + // Note: Update JavaTaskContextCompileCheck when new methods are added to this class. + + // Note: getters in this class are defined with parentheses to maintain backward compatibility. + + /** + * Returns true if the task has completed. + */ + def isCompleted(): Boolean + + /** + * Returns true if the task has been killed. + */ + def isInterrupted(): Boolean + + @deprecated("use isRunningLocally", "1.2.0") + def runningLocally(): Boolean + + /** + * Returns true if the task is running locally in the driver program. + * @return + */ + def isRunningLocally(): Boolean + + /** + * Adds a (Java friendly) listener to be executed on task completion. + * This will be called in all situation - success, failure, or cancellation. + * An example use is for HadoopRDD to register a callback to close the input stream. + */ + def addTaskCompletionListener(listener: TaskCompletionListener): TaskContext + + /** + * Adds a listener in the form of a Scala closure to be executed on task completion. + * This will be called in all situations - success, failure, or cancellation. + * An example use is for HadoopRDD to register a callback to close the input stream. + */ + def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext + + /** + * Adds a callback function to be executed on task completion. An example use + * is for HadoopRDD to register a callback to close the input stream. + * Will be called in any situation - success, failure, or cancellation. + * + * @param f Callback function. + */ + @deprecated("use addTaskCompletionListener", "1.2.0") + def addOnCompleteCallback(f: () => Unit) + + /** + * The ID of the stage that this task belong to. + */ + def stageId(): Int + + /** + * The ID of the RDD partition that is computed by this task. + */ + def partitionId(): Int + + /** + * How many times this task has been attempted. The first task attempt will be assigned + * attemptNumber = 0, and subsequent attempts will have increasing attempt numbers. + */ + def attemptNumber(): Int + + @deprecated("use attemptNumber", "1.3.0") + def attemptId(): Long + + /** + * An ID that is unique to this task attempt (within the same SparkContext, no two task attempts + * will share the same attempt ID). This is roughly equivalent to Hadoop's TaskAttemptID. + */ + def taskAttemptId(): Long + + /** ::DeveloperApi:: */ + @DeveloperApi + def taskMetrics(): TaskMetrics +} diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 9bb0c61e441f..337c8e4ebebc 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -33,7 +33,7 @@ private[spark] class TaskContextImpl( with Logging { // For backwards-compatibility; this method is now deprecated as of 1.3.0. - override def attemptId: Long = taskAttemptId + override def attemptId(): Long = taskAttemptId // List of callback functions to execute when the task completes. @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] @@ -87,10 +87,10 @@ private[spark] class TaskContextImpl( interrupted = true } - override def isCompleted: Boolean = completed + override def isCompleted(): Boolean = completed - override def isRunningLocally: Boolean = runningLocally + override def isRunningLocally(): Boolean = runningLocally - override def isInterrupted: Boolean = interrupted + override def isInterrupted(): Boolean = interrupted } diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index af5fd8e0ac00..48fd3e7e23d5 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -146,6 +146,16 @@ case object TaskKilled extends TaskFailedReason { override def toErrorString: String = "TaskKilled (killed intentionally)" } +/** + * :: DeveloperApi :: + * Task requested the driver to commit, but was denied. + */ +@DeveloperApi +case class TaskCommitDenied(jobID: Int, partitionID: Int, attemptID: Int) extends TaskFailedReason { + override def toErrorString: String = s"TaskCommitDenied (Driver denied task commit)" + + s" for job: $jobID, partition: $partitionID, attempt: $attemptID" +} + /** * :: DeveloperApi :: * The task failed because the executor that it was running on was lost. This may happen because diff --git a/core/src/main/scala/org/apache/spark/TaskState.scala b/core/src/main/scala/org/apache/spark/TaskState.scala index 0bf1e4a5e2cc..fe19f07e32d1 100644 --- a/core/src/main/scala/org/apache/spark/TaskState.scala +++ b/core/src/main/scala/org/apache/spark/TaskState.scala @@ -27,7 +27,9 @@ private[spark] object TaskState extends Enumeration { type TaskState = Value - def isFinished(state: TaskState) = FINISHED_STATES.contains(state) + def isFailed(state: TaskState): Boolean = (LOST == state) || (FAILED == state) + + def isFinished(state: TaskState): Boolean = FINISHED_STATES.contains(state) def toMesos(state: TaskState): MesosTaskState = state match { case LAUNCHING => MesosTaskState.TASK_STARTING @@ -46,5 +48,6 @@ private[spark] object TaskState extends Enumeration { case MesosTaskState.TASK_FAILED => FAILED case MesosTaskState.TASK_KILLED => KILLED case MesosTaskState.TASK_LOST => LOST + case MesosTaskState.TASK_ERROR => LOST } } diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 34078142f538..398ca41e1615 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -17,12 +17,13 @@ package org.apache.spark -import java.io.{File, FileInputStream, FileOutputStream} +import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream} import java.net.{URI, URL} import java.util.jar.{JarEntry, JarOutputStream} import scala.collection.JavaConversions._ +import com.google.common.base.Charsets.UTF_8 import com.google.common.io.{ByteStreams, Files} import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} @@ -43,13 +44,38 @@ private[spark] object TestUtils { * Note: if this is used during class loader tests, class names should be unique * in order to avoid interference between tests. */ - def createJarWithClasses(classNames: Seq[String], value: String = ""): URL = { + def createJarWithClasses( + classNames: Seq[String], + toStringValue: String = "", + classNamesWithBase: Seq[(String, String)] = Seq(), + classpathUrls: Seq[URL] = Seq()): URL = { val tempDir = Utils.createTempDir() - val files = for (name <- classNames) yield createCompiledClass(name, tempDir, value) + val files1 = for (name <- classNames) yield { + createCompiledClass(name, tempDir, toStringValue, classpathUrls = classpathUrls) + } + val files2 = for ((childName, baseName) <- classNamesWithBase) yield { + createCompiledClass(childName, tempDir, toStringValue, baseName, classpathUrls) + } val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis())) - createJar(files, jarFile) + createJar(files1 ++ files2, jarFile) } + /** + * Create a jar file containing multiple files. The `files` map contains a mapping of + * file names in the jar file to their contents. + */ + def createJarWithFiles(files: Map[String, String], dir: File = null): URL = { + val tempDir = Option(dir).getOrElse(Utils.createTempDir()) + val jarFile = File.createTempFile("testJar", ".jar", tempDir) + val jarStream = new JarOutputStream(new FileOutputStream(jarFile)) + files.foreach { case (k, v) => + val entry = new JarEntry(k) + jarStream.putNextEntry(entry) + ByteStreams.copy(new ByteArrayInputStream(v.getBytes(UTF_8)), jarStream) + } + jarStream.close() + jarFile.toURI.toURL + } /** * Create a jar file that contains this set of files. All files will be located at the root @@ -81,19 +107,30 @@ private[spark] object TestUtils { private class JavaSourceFromString(val name: String, val code: String) extends SimpleJavaFileObject(createURI(name), SOURCE) { - override def getCharContent(ignoreEncodingErrors: Boolean) = code + override def getCharContent(ignoreEncodingErrors: Boolean): String = code } /** Creates a compiled class with the given name. Class file will be placed in destDir. */ - def createCompiledClass(className: String, destDir: File, value: String = ""): File = { + def createCompiledClass( + className: String, + destDir: File, + toStringValue: String = "", + baseClass: String = null, + classpathUrls: Seq[URL] = Seq()): File = { val compiler = ToolProvider.getSystemJavaCompiler + val extendsText = Option(baseClass).map { c => s" extends ${c}" }.getOrElse("") val sourceFile = new JavaSourceFromString(className, - "public class " + className + " implements java.io.Serializable {" + - " @Override public String toString() { return \"" + value + "\"; }}") + "public class " + className + extendsText + " implements java.io.Serializable {" + + " @Override public String toString() { return \"" + toStringValue + "\"; }}") // Calling this outputs a class file in pwd. It's easier to just rename the file than // build a custom FileManager that controls the output location. - compiler.getTask(null, null, null, null, null, Seq(sourceFile)).call() + val options = if (classpathUrls.nonEmpty) { + Seq("-classpath", classpathUrls.map { _.getFile }.mkString(File.pathSeparator)) + } else { + Seq() + } + compiler.getTask(null, null, null, options, null, Seq(sourceFile)).call() val fileName = className + ".class" val result = new File(fileName) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 8e8f7f6c4fda..61af867b11b9 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -32,7 +32,8 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.StatCounter import org.apache.spark.util.Utils -class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, JavaDoubleRDD] { +class JavaDoubleRDD(val srdd: RDD[scala.Double]) + extends AbstractJavaRDDLike[JDouble, JavaDoubleRDD] { override val classTag: ClassTag[JDouble] = implicitly[ClassTag[JDouble]] @@ -162,6 +163,20 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja /** Add up the elements in this RDD. */ def sum(): JDouble = srdd.sum() + /** + * Returns the minimum element from this RDD as defined by + * the default comparator natural order. + * @return the minimum of the RDD + */ + def min(): JDouble = min(com.google.common.collect.Ordering.natural()) + + /** + * Returns the maximum element from this RDD as defined by + * the default comparator natural order. + * @return the maximum of the RDD + */ + def max(): JDouble = max(com.google.common.collect.Ordering.natural()) + /** * Return a [[org.apache.spark.util.StatCounter]] object that captures the mean, variance and * count of the RDD's elements in one operation. diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 7af3538262fd..8441bb3a3047 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -39,12 +39,13 @@ import org.apache.spark.api.java.function.{Function => JFunction, Function2 => J import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.{OrderedRDDFunctions, RDD} import org.apache.spark.rdd.RDD.rddToPairRDDFunctions +import org.apache.spark.serializer.Serializer import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) (implicit val kClassTag: ClassTag[K], implicit val vClassTag: ClassTag[V]) - extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] { + extends AbstractJavaRDDLike[(K, V), JavaPairRDD[K, V]] { override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd) @@ -227,24 +228,51 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) * - `mergeCombiners`, to combine two C's into a single one. * - * In addition, users can control the partitioning of the output RDD, and whether to perform - * map-side aggregation (if a mapper can produce multiple items with the same key). + * In addition, users can control the partitioning of the output RDD, the serializer that is use + * for the shuffle, and whether to perform map-side aggregation (if a mapper can produce multiple + * items with the same key). */ def combineByKey[C](createCombiner: JFunction[V, C], - mergeValue: JFunction2[C, V, C], - mergeCombiners: JFunction2[C, C, C], - partitioner: Partitioner): JavaPairRDD[K, C] = { - implicit val ctag: ClassTag[C] = fakeClassTag + mergeValue: JFunction2[C, V, C], + mergeCombiners: JFunction2[C, C, C], + partitioner: Partitioner, + mapSideCombine: Boolean, + serializer: Serializer): JavaPairRDD[K, C] = { + implicit val ctag: ClassTag[C] = fakeClassTag fromRDD(rdd.combineByKey( createCombiner, mergeValue, mergeCombiners, - partitioner + partitioner, + mapSideCombine, + serializer )) } /** - * Simplified version of combineByKey that hash-partitions the output RDD. + * Generic function to combine the elements for each key using a custom set of aggregation + * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a + * "combined type" C * Note that V and C can be different -- for example, one might group an + * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three + * functions: + * + * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) + * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) + * - `mergeCombiners`, to combine two C's into a single one. + * + * In addition, users can control the partitioning of the output RDD. This method automatically + * uses map-side aggregation in shuffling the RDD. + */ + def combineByKey[C](createCombiner: JFunction[V, C], + mergeValue: JFunction2[C, V, C], + mergeCombiners: JFunction2[C, C, C], + partitioner: Partitioner): JavaPairRDD[K, C] = { + combineByKey(createCombiner, mergeValue, mergeCombiners, partitioner, true, null) + } + + /** + * Simplified version of combineByKey that hash-partitions the output RDD and uses map-side + * aggregation. */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], @@ -488,7 +516,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Simplified version of combineByKey that hash-partitions the resulting RDD using the existing - * partitioner/parallelism level. + * partitioner/parallelism level and using map-side aggregation. */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], @@ -633,7 +661,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) */ def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = { import scala.collection.JavaConverters._ - def fn = (x: V) => f.call(x).asScala + def fn: (V) => Iterable[U] = (x: V) => f.call(x).asScala implicit val ctag: ClassTag[U] = fakeClassTag fromRDD(rdd.flatMapValues(fn)) } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 86fb374bef1e..db4e996feb31 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -30,7 +30,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) - extends JavaRDDLike[T, JavaRDD[T]] { + extends AbstractJavaRDDLike[T, JavaRDD[T]] { override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd) @@ -101,12 +101,23 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) /** * Return a sampled subset of this RDD. + * + * @param withReplacement can elements be sampled multiple times (replaced when sampled out) + * @param fraction expected size of the sample as a fraction of this RDD's size + * without replacement: probability that each element is chosen; fraction must be [0, 1] + * with replacement: expected number of times each element is chosen; fraction must be >= 0 */ def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] = sample(withReplacement, fraction, Utils.random.nextLong) /** * Return a sampled subset of this RDD. + * + * @param withReplacement can elements be sampled multiple times (replaced when sampled out) + * @param fraction expected size of the sample as a fraction of this RDD's size + * without replacement: probability that each element is chosen; fraction must be [0, 1] + * with replacement: expected number of times each element is chosen; fraction must be >= 0 + * @param seed seed for the random number generator */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaRDD[T] = wrapRDD(rdd.sample(withReplacement, fraction, seed)) @@ -168,7 +179,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] = wrapRDD(rdd.subtract(other, p)) - override def toString = rdd.toString + override def toString: String = rdd.toString /** Assign a name to this RDD */ def setName(name: String): JavaRDD[T] = { @@ -181,7 +192,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) */ def sortBy[S](f: JFunction[T, S], ascending: Boolean, numPartitions: Int): JavaRDD[T] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x) + def fn: (T) => S = (x: T) => f.call(x) import com.google.common.collect.Ordering // shadows scala.math.Ordering implicit val ordering = Ordering.natural().asInstanceOf[Ordering[S]] implicit val ctag: ClassTag[S] = fakeClassTag diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 62bf18d82d9b..8bf0627fc420 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -17,8 +17,9 @@ package org.apache.spark.api.java -import java.util.{Comparator, List => JList, Iterator => JIterator} +import java.{lang => jl} import java.lang.{Iterable => JIterable, Long => JLong} +import java.util.{Comparator, List => JList, Iterator => JIterator} import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ @@ -38,6 +39,14 @@ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils +/** + * As a workaround for https://issues.scala-lang.org/browse/SI-8905, implementations + * of JavaRDDLike should extend this dummy abstract class instead of directly inheriting + * from the trait. See SPARK-3266 for additional details. + */ +private[spark] abstract class AbstractJavaRDDLike[T, This <: JavaRDDLike[T, This]] + extends JavaRDDLike[T, This] + /** * Defines operations common to several Java RDD implementations. * Note that this trait is not intended to be implemented by user code. @@ -85,7 +94,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * of the original partition. */ def mapPartitionsWithIndex[R]( - f: JFunction2[java.lang.Integer, java.util.Iterator[T], java.util.Iterator[R]], + f: JFunction2[jl.Integer, java.util.Iterator[T], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))), preservesPartitioning)(fakeClassTag))(fakeClassTag) @@ -101,7 +110,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by applying a function to all elements of this RDD. */ def mapToPair[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { - def cm = implicitly[ClassTag[(K2, V2)]] + def cm: ClassTag[(K2, V2)] = implicitly[ClassTag[(K2, V2)]] new JavaPairRDD(rdd.map[(K2, V2)](f)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -111,7 +120,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def flatMap[U](f: FlatMapFunction[T, U]): JavaRDD[U] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala + def fn: (T) => Iterable[U] = (x: T) => f.call(x).asScala JavaRDD.fromRDD(rdd.flatMap(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -121,8 +130,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def flatMapToDouble(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala - new JavaDoubleRDD(rdd.flatMap(fn).map((x: java.lang.Double) => x.doubleValue())) + def fn: (T) => Iterable[jl.Double] = (x: T) => f.call(x).asScala + new JavaDoubleRDD(rdd.flatMap(fn).map((x: jl.Double) => x.doubleValue())) } /** @@ -131,8 +140,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def flatMapToPair[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala - def cm = implicitly[ClassTag[(K2, V2)]] + def fn: (T) => Iterable[(K2, V2)] = (x: T) => f.call(x).asScala + def cm: ClassTag[(K2, V2)] = implicitly[ClassTag[(K2, V2)]] JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -140,7 +149,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by applying a function to each partition of this RDD. */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[U] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } JavaRDD.fromRDD(rdd.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -149,7 +160,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U], preservesPartitioning: Boolean): JavaRDD[U] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[U] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } JavaRDD.fromRDD( rdd.mapPartitions(fn, preservesPartitioning)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -158,8 +171,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by applying a function to each partition of this RDD. */ def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) - new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: java.lang.Double) => x.doubleValue())) + def fn: (Iterator[T]) => Iterator[jl.Double] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } + new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: jl.Double) => x.doubleValue())) } /** @@ -167,7 +182,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]): JavaPairRDD[K2, V2] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[(K2, V2)] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -176,7 +193,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]], preservesPartitioning: Boolean): JavaDoubleRDD = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[jl.Double] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } new JavaDoubleRDD(rdd.mapPartitions(fn, preservesPartitioning) .map(x => x.doubleValue())) } @@ -186,7 +205,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2], preservesPartitioning: Boolean): JavaPairRDD[K2, V2] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[(K2, V2)] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } JavaPairRDD.fromRDD( rdd.mapPartitions(fn, preservesPartitioning))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -269,8 +290,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def zipPartitions[U, V]( other: JavaRDDLike[U, _], f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V]): JavaRDD[V] = { - def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator( - f.call(asJavaIterator(x), asJavaIterator(y)).iterator()) + def fn: (Iterator[T], Iterator[U]) => Iterator[V] = { + (x: Iterator[T], y: Iterator[U]) => asScalaIterator( + f.call(asJavaIterator(x), asJavaIterator(y)).iterator()) + } JavaRDD.fromRDD( rdd.zipPartitions(other.rdd)(fn)(other.classTag, fakeClassTag[V]))(fakeClassTag[V]) } @@ -348,6 +371,19 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f) + /** + * Reduces the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree + * @see [[org.apache.spark.api.java.JavaRDDLike#reduce]] + */ + def treeReduce(f: JFunction2[T, T, T], depth: Int): T = rdd.treeReduce(f, depth) + + /** + * [[org.apache.spark.api.java.JavaRDDLike#treeReduce]] with suggested depth 2. + */ + def treeReduce(f: JFunction2[T, T, T]): T = treeReduce(f, 2) + /** * Aggregate the elements of each partition, and then the results for all the partitions, using a * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to @@ -369,6 +405,30 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { combOp: JFunction2[U, U, U]): U = rdd.aggregate(zeroValue)(seqOp, combOp)(fakeClassTag[U]) + /** + * Aggregates the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree + * @see [[org.apache.spark.api.java.JavaRDDLike#aggregate]] + */ + def treeAggregate[U]( + zeroValue: U, + seqOp: JFunction2[U, T, U], + combOp: JFunction2[U, U, U], + depth: Int): U = { + rdd.treeAggregate(zeroValue)(seqOp, combOp, depth)(fakeClassTag[U]) + } + + /** + * [[org.apache.spark.api.java.JavaRDDLike#treeAggregate]] with suggested depth 2. + */ + def treeAggregate[U]( + zeroValue: U, + seqOp: JFunction2[U, T, U], + combOp: JFunction2[U, U, U]): U = { + treeAggregate(zeroValue, seqOp, combOp, 2) + } + /** * Return the number of elements in the RDD. */ @@ -396,8 +456,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final * combine step happens locally on the master, equivalent to running a single reduce task. */ - def countByValue(): java.util.Map[T, java.lang.Long] = - mapAsSerializableJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2))))) + def countByValue(): java.util.Map[T, jl.Long] = + mapAsSerializableJavaMap(rdd.countByValue().map((x => (x._1, new jl.Long(x._2))))) /** * (Experimental) Approximate version of countByValue(). diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 97f5c9f257e0..3be6783bba49 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -108,7 +108,7 @@ class JavaSparkContext(val sc: SparkContext) private[spark] val env = sc.env - def statusTracker = new JavaSparkStatusTracker(sc) + def statusTracker: JavaSparkStatusTracker = new JavaSparkStatusTracker(sc) def isLocal: java.lang.Boolean = sc.isLocal @@ -373,6 +373,15 @@ class JavaSparkContext(val sc: SparkContext) * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable, * etc). * + * @param conf JobConf for setting up the dataset. Note: This will be put into a Broadcast. + * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make + * sure you won't modify the conf. A safe approach is always creating a new conf for + * a new RDD. + * @param inputFormatClass Class of the InputFormat + * @param keyClass Class of the keys + * @param valueClass Class of the values + * @param minPartitions Minimum number of Hadoop Splits to generate. + * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using @@ -395,6 +404,14 @@ class JavaSparkContext(val sc: SparkContext) * Get an RDD for a Hadoop-readable dataset from a Hadooop JobConf giving its InputFormat and any * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable, * + * @param conf JobConf for setting up the dataset. Note: This will be put into a Broadcast. + * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make + * sure you won't modify the conf. A safe approach is always creating a new conf for + * a new RDD. + * @param inputFormatClass Class of the InputFormat + * @param keyClass Class of the keys + * @param valueClass Class of the values + * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using @@ -476,6 +493,14 @@ class JavaSparkContext(val sc: SparkContext) * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. * + * @param conf Configuration for setting up the dataset. Note: This will be put into a Broadcast. + * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make + * sure you won't modify the conf. A safe approach is always creating a new conf for + * a new RDD. + * @param fClass Class of the InputFormat + * @param kClass Class of the keys + * @param vClass Class of the values + * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using @@ -675,6 +700,9 @@ class JavaSparkContext(val sc: SparkContext) /** * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse. + * + * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you + * plan to set some global configurations for all Hadoop RDDs. */ def hadoopConfiguration(): Configuration = { sc.hadoopConfiguration diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala index 71b26737b8c0..8f9647eea9e2 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala @@ -17,6 +17,8 @@ package org.apache.spark.api.java +import java.util.Map.Entry + import com.google.common.base.Optional import java.{util => ju} @@ -30,8 +32,8 @@ private[spark] object JavaUtils { } // Workaround for SPARK-3926 / SI-8911 - def mapAsSerializableJavaMap[A, B](underlying: collection.Map[A, B]) = - new SerializableMapWrapper(underlying) + def mapAsSerializableJavaMap[A, B](underlying: collection.Map[A, B]): SerializableMapWrapper[A, B] + = new SerializableMapWrapper(underlying) // Implementation is copied from scala.collection.convert.Wrappers.MapWrapper, // but implements java.io.Serializable. It can't just be subclassed to make it @@ -40,36 +42,33 @@ private[spark] object JavaUtils { class SerializableMapWrapper[A, B](underlying: collection.Map[A, B]) extends ju.AbstractMap[A, B] with java.io.Serializable { self => - override def size = underlying.size + override def size: Int = underlying.size override def get(key: AnyRef): B = try { - underlying get key.asInstanceOf[A] match { - case None => null.asInstanceOf[B] - case Some(v) => v - } + underlying.getOrElse(key.asInstanceOf[A], null.asInstanceOf[B]) } catch { case ex: ClassCastException => null.asInstanceOf[B] } override def entrySet: ju.Set[ju.Map.Entry[A, B]] = new ju.AbstractSet[ju.Map.Entry[A, B]] { - def size = self.size + override def size: Int = self.size - def iterator = new ju.Iterator[ju.Map.Entry[A, B]] { + override def iterator: ju.Iterator[ju.Map.Entry[A, B]] = new ju.Iterator[ju.Map.Entry[A, B]] { val ui = underlying.iterator var prev : Option[A] = None - def hasNext = ui.hasNext + def hasNext: Boolean = ui.hasNext - def next() = { - val (k, v) = ui.next + def next(): Entry[A, B] = { + val (k, v) = ui.next() prev = Some(k) new ju.Map.Entry[A, B] { import scala.util.hashing.byteswap32 - def getKey = k - def getValue = v - def setValue(v1 : B) = self.put(k, v1) - override def hashCode = byteswap32(k.hashCode) + (byteswap32(v.hashCode) << 16) - override def equals(other: Any) = other match { + override def getKey: A = k + override def getValue: B = v + override def setValue(v1 : B): B = self.put(k, v1) + override def hashCode: Int = byteswap32(k.hashCode) + (byteswap32(v.hashCode) << 16) + override def equals(other: Any): Boolean = other match { case e: ju.Map.Entry[_, _] => k == e.getKey && v == e.getValue case _ => false } @@ -81,7 +80,7 @@ private[spark] object JavaUtils { case Some(k) => underlying match { case mm: mutable.Map[A, _] => - mm remove k + mm.remove(k) prev = None case _ => throw new UnsupportedOperationException("remove") diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala new file mode 100644 index 000000000000..164e95081583 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import java.io.DataOutputStream +import java.net.Socket + +import py4j.GatewayServer + +import org.apache.spark.Logging +import org.apache.spark.util.Utils + +/** + * Process that starts a Py4J GatewayServer on an ephemeral port and communicates the bound port + * back to its caller via a callback port specified by the caller. + * + * This process is launched (via SparkSubmit) by the PySpark driver (see java_gateway.py). + */ +private[spark] object PythonGatewayServer extends Logging { + def main(args: Array[String]): Unit = Utils.tryOrExit { + // Start a GatewayServer on an ephemeral port + val gatewayServer: GatewayServer = new GatewayServer(null, 0) + gatewayServer.start() + val boundPort: Int = gatewayServer.getListeningPort + if (boundPort == -1) { + logError("GatewayServer failed to bind; exiting") + System.exit(1) + } else { + logDebug(s"Started PythonGatewayServer on port $boundPort") + } + + // Communicate the bound port back to the caller via the caller-specified callback port + val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST") + val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt + logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort") + val callbackSocket = new Socket(callbackHost, callbackPort) + val dos = new DataOutputStream(callbackSocket.getOutputStream) + dos.writeInt(boundPort) + dos.close() + callbackSocket.close() + + // Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies: + while (System.in.read() != -1) { + // Do nothing + } + logDebug("Exiting due to broken pipe from Python driver") + System.exit(0) + } +} diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index 5ba66178e2b7..c9181a29d475 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -138,6 +138,11 @@ private[python] class JavaToWritableConverter extends Converter[Any, Writable] { mapWritable.put(convertToWritable(k), convertToWritable(v)) } mapWritable + case array: Array[Any] => { + val arrayWriteable = new ArrayWritable(classOf[Writable]) + arrayWriteable.set(array.map(convertToWritable(_))) + arrayWriteable + } case other => throw new SparkException( s"Data of type ${other.getClass.getName} cannot be used") } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 4ac666c54fbc..7409dc2d866f 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -19,26 +19,27 @@ package org.apache.spark.api.python import java.io._ import java.net._ -import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, UUID, Collections} - -import org.apache.spark.input.PortableDataStream +import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConversions._ import scala.collection.mutable import scala.language.existentials import com.google.common.base.Charsets.UTF_8 - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.compress.CompressionCodec -import org.apache.hadoop.mapred.{InputFormat, OutputFormat, JobConf} +import org.apache.hadoop.mapred.{InputFormat, JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat => NewOutputFormat} + import org.apache.spark._ -import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.broadcast.Broadcast +import org.apache.spark.input.PortableDataStream import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils +import scala.util.control.NonFatal + private[spark] class PythonRDD( @transient parent: RDD[_], command: Array[Byte], @@ -53,9 +54,11 @@ private[spark] class PythonRDD( val bufferSize = conf.getInt("spark.buffer.size", 65536) val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true) - override def getPartitions = firstParent.partitions + override def getPartitions: Array[Partition] = firstParent.partitions - override val partitioner = if (preservePartitoning) firstParent.partitioner else None + override val partitioner: Option[Partitioner] = { + if (preservePartitoning) firstParent.partitioner else None + } override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis @@ -67,17 +70,15 @@ private[spark] class PythonRDD( envVars += ("SPARK_REUSE_WORKER" -> "1") } val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap) + // Whether is the worker released into idle pool + @volatile var released = false // Start a thread to feed the process input from our parent's iterator val writerThread = new WriterThread(env, worker, split, context) - var complete_cleanly = false context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() - writerThread.join() - if (reuse_worker && complete_cleanly) { - env.releasePythonWorker(pythonExec, envVars.toMap, worker) - } else { + if (!reuse_worker || !released) { try { worker.close() } catch { @@ -93,7 +94,7 @@ private[spark] class PythonRDD( // Return an iterator that read lines from the process's stdout val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) val stdoutIterator = new Iterator[Array[Byte]] { - def next(): Array[Byte] = { + override def next(): Array[Byte] = { val obj = _nextObj if (hasNext) { _nextObj = read() @@ -145,8 +146,12 @@ private[spark] class PythonRDD( stream.readFully(update) accumulator += Collections.singletonList(update) } + // Check whether the worker is ready to be re-used. if (stream.readInt() == SpecialLengths.END_OF_STREAM) { - complete_cleanly = true + if (reuse_worker) { + env.releasePythonWorker(pythonExec, envVars.toMap, worker) + released = true + } } null } @@ -172,7 +177,7 @@ private[spark] class PythonRDD( var _nextObj = read() - def hasNext = _nextObj != null + override def hasNext: Boolean = _nextObj != null } new InterruptibleIterator(context, stdoutIterator) } @@ -216,14 +221,13 @@ private[spark] class PythonRDD( val oldBids = PythonRDD.getWorkerBroadcasts(worker) val newBids = broadcastVars.map(_.id).toSet // number of different broadcasts - val cnt = oldBids.diff(newBids).size + newBids.diff(oldBids).size + val toRemove = oldBids.diff(newBids) + val cnt = toRemove.size + newBids.diff(oldBids).size dataOut.writeInt(cnt) - for (bid <- oldBids) { - if (!newBids.contains(bid)) { - // remove the broadcast from worker - dataOut.writeLong(- bid - 1) // bid >= 0 - oldBids.remove(bid) - } + for (bid <- toRemove) { + // remove the broadcast from worker + dataOut.writeLong(- bid - 1) // bid >= 0 + oldBids.remove(bid) } for (broadcast <- broadcastVars) { if (!oldBids.contains(broadcast.id)) { @@ -245,13 +249,17 @@ private[spark] class PythonRDD( } catch { case e: Exception if context.isCompleted || context.isInterrupted => logDebug("Exception thrown after task completion (likely due to cleanup)", e) - worker.shutdownOutput() + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } case e: Exception => // We must avoid throwing exceptions here, because the thread uncaught exception handler // will kill the whole executor (see org.apache.spark.executor.Executor). _exception = e - worker.shutdownOutput() + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } } finally { // Release memory used by this thread for shuffles env.shuffleMemoryManager.releaseMemoryForThisThread() @@ -297,10 +305,10 @@ private class PythonException(msg: String, cause: Exception) extends RuntimeExce * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. * This is used by PySpark's shuffle operations. */ -private class PairwiseRDD(prev: RDD[Array[Byte]]) extends - RDD[(Long, Array[Byte])](prev) { - override def getPartitions = prev.partitions - override def compute(split: Partition, context: TaskContext) = +private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte])](prev) { + override def getPartitions: Array[Partition] = prev.partitions + override val partitioner: Option[Partitioner] = prev.partitioner + override def compute(split: Partition, context: TaskContext): Iterator[(Long, Array[Byte])] = prev.iterator(split, context).grouped(2).map { case Seq(a, b) => (Utils.deserializeLongValue(a), b) case x => throw new SparkException("PairwiseRDD: unexpected value: " + x) @@ -313,6 +321,7 @@ private object SpecialLengths { val PYTHON_EXCEPTION_THROWN = -2 val TIMING_DATA = -3 val END_OF_STREAM = -4 + val NULL = -5 } private[spark] object PythonRDD extends Logging { @@ -325,24 +334,45 @@ private[spark] object PythonRDD extends Logging { } } + /** + * Return an RDD of values from an RDD of (Long, Array[Byte]), with preservePartitions=true + * + * This is useful for PySpark to have the partitioner after partitionBy() + */ + def valueOfPair(pair: JavaPairRDD[Long, Array[Byte]]): JavaRDD[Array[Byte]] = { + pair.rdd.mapPartitions(it => it.map(_._2), true) + } + /** * Adapter for calling SparkContext#runJob from Python. * - * This method will return an iterator of an array that contains all elements in the RDD + * This method will serve an iterator of an array that contains all elements in the RDD * (effectively a collect()), but allows you to run on a certain subset of partitions, * or to enable local execution. + * + * @return the port number of a local socket which serves the data collected from this job. */ def runJob( sc: SparkContext, rdd: JavaRDD[Array[Byte]], partitions: JArrayList[Int], - allowLocal: Boolean): Iterator[Array[Byte]] = { + allowLocal: Boolean): Int = { type ByteArray = Array[Byte] type UnrolledPartition = Array[ByteArray] val allPartitions: Array[UnrolledPartition] = sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal) val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*) - flattenedPartition.iterator + serveIterator(flattenedPartition.iterator, + s"serve RDD ${rdd.id} with partitions ${partitions.mkString(",")}") + } + + /** + * A helper function to collect an RDD as an iterator, then serve it via socket. + * + * @return the port number of a local socket which serves the data collected from this job. + */ + def collectAndServe[T](rdd: RDD[T]): Int = { + serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}") } def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): @@ -371,54 +401,25 @@ private[spark] object PythonRDD extends Logging { } def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { - // The right way to implement this would be to use TypeTags to get the full - // type of T. Since I don't want to introduce breaking changes throughout the - // entire Spark API, I have to use this hacky approach: - if (iter.hasNext) { - val first = iter.next() - val newIter = Seq(first).iterator ++ iter - first match { - case arr: Array[Byte] => - newIter.asInstanceOf[Iterator[Array[Byte]]].foreach { bytes => - dataOut.writeInt(bytes.length) - dataOut.write(bytes) - } - case string: String => - newIter.asInstanceOf[Iterator[String]].foreach { str => - writeUTF(str, dataOut) - } - case stream: PortableDataStream => - newIter.asInstanceOf[Iterator[PortableDataStream]].foreach { stream => - val bytes = stream.toArray() - dataOut.writeInt(bytes.length) - dataOut.write(bytes) - } - case (key: String, stream: PortableDataStream) => - newIter.asInstanceOf[Iterator[(String, PortableDataStream)]].foreach { - case (key, stream) => - writeUTF(key, dataOut) - val bytes = stream.toArray() - dataOut.writeInt(bytes.length) - dataOut.write(bytes) - } - case (key: String, value: String) => - newIter.asInstanceOf[Iterator[(String, String)]].foreach { - case (key, value) => - writeUTF(key, dataOut) - writeUTF(value, dataOut) - } - case (key: Array[Byte], value: Array[Byte]) => - newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach { - case (key, value) => - dataOut.writeInt(key.length) - dataOut.write(key) - dataOut.writeInt(value.length) - dataOut.write(value) - } - case other => - throw new SparkException("Unexpected element type " + first.getClass) - } + + def write(obj: Any): Unit = obj match { + case null => + dataOut.writeInt(SpecialLengths.NULL) + case arr: Array[Byte] => + dataOut.writeInt(arr.length) + dataOut.write(arr) + case str: String => + writeUTF(str, dataOut) + case stream: PortableDataStream => + write(stream.toArray()) + case (key, value) => + write(key) + write(value) + case other => + throw new SparkException("Unexpected element type " + other.getClass) } + + iter.foreach(write) } /** @@ -435,7 +436,7 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, minSplits: Int, - batchSize: Int) = { + batchSize: Int): JavaRDD[Array[Byte]] = { val keyClass = Option(keyClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") val valueClass = Option(valueClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] @@ -462,7 +463,7 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, confAsMap: java.util.HashMap[String, String], - batchSize: Int) = { + batchSize: Int): JavaRDD[Array[Byte]] = { val mergedConf = getMergedConf(confAsMap, sc.hadoopConfiguration()) val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, @@ -488,7 +489,7 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, confAsMap: java.util.HashMap[String, String], - batchSize: Int) = { + batchSize: Int): JavaRDD[Array[Byte]] = { val conf = PythonHadoopUtil.mapToConf(confAsMap) val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, @@ -505,7 +506,7 @@ private[spark] object PythonRDD extends Logging { inputFormatClass: String, keyClass: String, valueClass: String, - conf: Configuration) = { + conf: Configuration): RDD[(K, V)] = { val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]] val fc = Utils.classForName(inputFormatClass).asInstanceOf[Class[F]] @@ -531,7 +532,7 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, confAsMap: java.util.HashMap[String, String], - batchSize: Int) = { + batchSize: Int): JavaRDD[Array[Byte]] = { val mergedConf = getMergedConf(confAsMap, sc.hadoopConfiguration()) val rdd = hadoopRDDFromClassNames[K, V, F](sc, @@ -557,7 +558,7 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, confAsMap: java.util.HashMap[String, String], - batchSize: Int) = { + batchSize: Int): JavaRDD[Array[Byte]] = { val conf = PythonHadoopUtil.mapToConf(confAsMap) val rdd = hadoopRDDFromClassNames[K, V, F](sc, @@ -591,15 +592,43 @@ private[spark] object PythonRDD extends Logging { dataOut.write(bytes) } - def writeToFile[T](items: java.util.Iterator[T], filename: String) { - import scala.collection.JavaConverters._ - writeToFile(items.asScala, filename) - } + /** + * Create a socket server and a background thread to serve the data in `items`, + * + * The socket server can only accept one connection, or close if no connection + * in 3 seconds. + * + * Once a connection comes in, it tries to serialize all the data in `items` + * and send them into this connection. + * + * The thread will terminate after all the data are sent or any exceptions happen. + */ + private def serveIterator[T](items: Iterator[T], threadName: String): Int = { + val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) + // Close the socket if no connection in 3 seconds + serverSocket.setSoTimeout(3000) + + new Thread(threadName) { + setDaemon(true) + override def run() { + try { + val sock = serverSocket.accept() + val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + Utils.tryWithSafeFinally { + writeIteratorToStream(items, out) + } { + out.close() + } + } catch { + case NonFatal(e) => + logError(s"Error while sending iterator", e) + } finally { + serverSocket.close() + } + } + }.start() - def writeToFile[T](items: Iterator[T], filename: String) { - val file = new DataOutputStream(new FileOutputStream(filename)) - writeIteratorToStream(items, file) - file.close() + serverSocket.getLocalPort } private def getMergedConf(confAsMap: java.util.HashMap[String, String], @@ -657,7 +686,7 @@ private[spark] object PythonRDD extends Logging { pyRDD: JavaRDD[Array[Byte]], batchSerialized: Boolean, path: String, - compressionCodecClass: String) = { + compressionCodecClass: String): Unit = { saveAsHadoopFile( pyRDD, batchSerialized, path, "org.apache.hadoop.mapred.SequenceFileOutputFormat", null, null, null, null, new java.util.HashMap(), compressionCodecClass) @@ -682,7 +711,7 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, confAsMap: java.util.HashMap[String, String], - compressionCodecClass: String) = { + compressionCodecClass: String): Unit = { val rdd = SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized) val (kc, vc) = getKeyValueTypes(keyClass, valueClass).getOrElse( inferKeyValueTypes(rdd, keyConverterClass, valueConverterClass)) @@ -712,7 +741,7 @@ private[spark] object PythonRDD extends Logging { valueClass: String, keyConverterClass: String, valueConverterClass: String, - confAsMap: java.util.HashMap[String, String]) = { + confAsMap: java.util.HashMap[String, String]): Unit = { val rdd = SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized) val (kc, vc) = getKeyValueTypes(keyClass, valueClass).getOrElse( inferKeyValueTypes(rdd, keyConverterClass, valueConverterClass)) @@ -737,7 +766,7 @@ private[spark] object PythonRDD extends Logging { confAsMap: java.util.HashMap[String, String], keyConverterClass: String, valueConverterClass: String, - useNewAPI: Boolean) = { + useNewAPI: Boolean): Unit = { val conf = PythonHadoopUtil.mapToConf(confAsMap) val converted = convertRDD(SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized), keyConverterClass, valueConverterClass, new JavaToWritableConverter) @@ -833,9 +862,9 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial val file = File.createTempFile("broadcast", "", dir) path = file.getAbsolutePath val out = new FileOutputStream(file) - try { + Utils.tryWithSafeFinally { Utils.copyStream(in, out) - } finally { + } { out.close() } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index be5ebfa9219d..acbaba679185 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -17,11 +17,14 @@ package org.apache.spark.api.python -import java.io.{File, InputStream, IOException, OutputStream} +import java.io.{File} +import java.util.{List => JList} +import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkContext +import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} private[spark] object PythonUtils { /** Get the PYTHONPATH for PySpark, either from SPARK_HOME, if it is set, or from our JAR */ @@ -39,4 +42,15 @@ private[spark] object PythonUtils { def mergePythonPaths(paths: String*): String = { paths.filter(_ != "").mkString(File.pathSeparator) } + + def generateRDDWithNull(sc: JavaSparkContext): JavaRDD[String] = { + sc.parallelize(List("a", null, "b")) + } + + /** + * Convert list of T into seq of T (for calling API with varargs) + */ + def toSeq[T](cols: JList[T]): Seq[T] = { + cols.toList.toSeq + } } diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index a4153aaa926f..257491e90dd6 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -84,7 +84,7 @@ private[spark] object SerDeUtil extends Logging { private var initialized = false // This should be called before trying to unpickle array.array from Python // In cluster mode, this should be put in closure - def initialize() = { + def initialize(): Unit = { synchronized{ if (!initialized) { Unpickler.registerConstructor("array", "array", new ArrayConstructor()) @@ -153,7 +153,10 @@ private[spark] object SerDeUtil extends Logging { iter.flatMap { row => val obj = unpickle.loads(row) if (batched) { - obj.asInstanceOf[JArrayList[_]].asScala + obj match { + case array: Array[Any] => array.toSeq + case _ => obj.asInstanceOf[JArrayList[_]].asScala + } } else { Seq(obj) } @@ -199,7 +202,10 @@ private[spark] object SerDeUtil extends Logging { * representation is serialized */ def pairRDDToPython(rdd: RDD[(Any, Any)], batchSize: Int): RDD[Array[Byte]] = { - val (keyFailed, valueFailed) = checkPickle(rdd.first()) + val (keyFailed, valueFailed) = rdd.take(1) match { + case Array() => (false, false) + case Array(first) => checkPickle(first) + } rdd.mapPartitions { iter => val cleaned = iter.map { case (k, v) => @@ -226,10 +232,12 @@ private[spark] object SerDeUtil extends Logging { } val rdd = pythonToJava(pyRDD, batched).rdd - rdd.first match { - case obj if isPair(obj) => + rdd.take(1) match { + case Array(obj) if isPair(obj) => // we only accept (K, V) - case other => throw new SparkException( + case Array() => + // we also accept empty collections + case Array(other) => throw new SparkException( s"RDD element of type ${other.getClass.getName} cannot be used") } rdd.map { obj => diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index c0cbd28a845b..8f30ff9202c8 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -18,38 +18,37 @@ package org.apache.spark.api.python import java.io.{DataOutput, DataInput} +import java.{util => ju} import com.google.common.base.Charsets.UTF_8 import org.apache.hadoop.io._ import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat + +import org.apache.spark.SparkException import org.apache.spark.api.java.JavaSparkContext -import org.apache.spark.{SparkContext, SparkException} /** * A class to test Pyrolite serialization on the Scala side, that will be deserialized * in Python - * @param str - * @param int - * @param double */ case class TestWritable(var str: String, var int: Int, var double: Double) extends Writable { def this() = this("", 0, 0.0) - def getStr = str + def getStr: String = str def setStr(str: String) { this.str = str } - def getInt = int + def getInt: Int = int def setInt(int: Int) { this.int = int } - def getDouble = double + def getDouble: Double = double def setDouble(double: Double) { this.double = double } - def write(out: DataOutput) = { + def write(out: DataOutput): Unit = { out.writeUTF(str) out.writeInt(int) out.writeDouble(double) } - def readFields(in: DataInput) = { + def readFields(in: DataInput): Unit = { str = in.readUTF() int = in.readInt() double = in.readDouble() @@ -57,28 +56,28 @@ case class TestWritable(var str: String, var int: Int, var double: Double) exten } private[python] class TestInputKeyConverter extends Converter[Any, Any] { - override def convert(obj: Any) = { + override def convert(obj: Any): Char = { obj.asInstanceOf[IntWritable].get().toChar } } private[python] class TestInputValueConverter extends Converter[Any, Any] { import collection.JavaConversions._ - override def convert(obj: Any) = { + override def convert(obj: Any): ju.List[Double] = { val m = obj.asInstanceOf[MapWritable] seqAsJavaList(m.keySet.map(w => w.asInstanceOf[DoubleWritable].get()).toSeq) } } private[python] class TestOutputKeyConverter extends Converter[Any, Any] { - override def convert(obj: Any) = { + override def convert(obj: Any): Text = { new Text(obj.asInstanceOf[Int].toString) } } private[python] class TestOutputValueConverter extends Converter[Any, Any] { import collection.JavaConversions._ - override def convert(obj: Any) = { + override def convert(obj: Any): DoubleWritable = { new DoubleWritable(obj.asInstanceOf[java.util.Map[Double, _]].keySet().head) } } @@ -86,7 +85,7 @@ private[python] class TestOutputValueConverter extends Converter[Any, Any] { private[python] class DoubleArrayWritable extends ArrayWritable(classOf[DoubleWritable]) private[python] class DoubleArrayToWritableConverter extends Converter[Any, Writable] { - override def convert(obj: Any) = obj match { + override def convert(obj: Any): DoubleArrayWritable = obj match { case arr if arr.getClass.isArray && arr.getClass.getComponentType == classOf[Double] => val daw = new DoubleArrayWritable daw.set(arr.asInstanceOf[Array[Double]].map(new DoubleWritable(_))) @@ -107,7 +106,6 @@ private[python] class WritableToDoubleArrayConverter extends Converter[Any, Arra * given directory (probably a temp directory) */ object WriteInputFormatTestDataGenerator { - import SparkContext._ def main(args: Array[String]) { val path = args(0) diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala new file mode 100644 index 000000000000..3a2c94bd9d87 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{DataOutputStream, File, FileOutputStream, IOException} +import java.net.{InetSocketAddress, ServerSocket} +import java.util.concurrent.TimeUnit + +import io.netty.bootstrap.ServerBootstrap +import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup} +import io.netty.channel.nio.NioEventLoopGroup +import io.netty.channel.socket.SocketChannel +import io.netty.channel.socket.nio.NioServerSocketChannel +import io.netty.handler.codec.LengthFieldBasedFrameDecoder +import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder} + +import org.apache.spark.Logging + +/** + * Netty-based backend server that is used to communicate between R and Java. + */ +private[spark] class RBackend { + + private[this] var channelFuture: ChannelFuture = null + private[this] var bootstrap: ServerBootstrap = null + private[this] var bossGroup: EventLoopGroup = null + + def init(): Int = { + bossGroup = new NioEventLoopGroup(2) + val workerGroup = bossGroup + val handler = new RBackendHandler(this) + + bootstrap = new ServerBootstrap() + .group(bossGroup, workerGroup) + .channel(classOf[NioServerSocketChannel]) + + bootstrap.childHandler(new ChannelInitializer[SocketChannel]() { + def initChannel(ch: SocketChannel): Unit = { + ch.pipeline() + .addLast("encoder", new ByteArrayEncoder()) + .addLast("frameDecoder", + // maxFrameLength = 2G + // lengthFieldOffset = 0 + // lengthFieldLength = 4 + // lengthAdjustment = 0 + // initialBytesToStrip = 4, i.e. strip out the length field itself + new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) + .addLast("decoder", new ByteArrayDecoder()) + .addLast("handler", handler) + } + }) + + channelFuture = bootstrap.bind(new InetSocketAddress(0)) + channelFuture.syncUninterruptibly() + channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort() + } + + def run(): Unit = { + channelFuture.channel.closeFuture().syncUninterruptibly() + } + + def close(): Unit = { + if (channelFuture != null) { + // close is a local operation and should finish within milliseconds; timeout just to be safe + channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS) + channelFuture = null + } + if (bootstrap != null && bootstrap.group() != null) { + bootstrap.group().shutdownGracefully() + } + if (bootstrap != null && bootstrap.childGroup() != null) { + bootstrap.childGroup().shutdownGracefully() + } + bootstrap = null + } + +} + +private[spark] object RBackend extends Logging { + def main(args: Array[String]): Unit = { + if (args.length < 1) { + System.err.println("Usage: RBackend ") + System.exit(-1) + } + val sparkRBackend = new RBackend() + try { + // bind to random port + val boundPort = sparkRBackend.init() + val serverSocket = new ServerSocket(0, 1) + val listenPort = serverSocket.getLocalPort() + + // tell the R process via temporary file + val path = args(0) + val f = new File(path + ".tmp") + val dos = new DataOutputStream(new FileOutputStream(f)) + dos.writeInt(boundPort) + dos.writeInt(listenPort) + dos.close() + f.renameTo(new File(path)) + + // wait for the end of stdin, then exit + new Thread("wait for socket to close") { + setDaemon(true) + override def run(): Unit = { + // any un-catched exception will also shutdown JVM + val buf = new Array[Byte](1024) + // shutdown JVM if R does not connect back in 10 seconds + serverSocket.setSoTimeout(10000) + try { + val inSocket = serverSocket.accept() + serverSocket.close() + // wait for the end of socket, closed if R process die + inSocket.getInputStream().read(buf) + } finally { + sparkRBackend.close() + System.exit(0) + } + } + }.start() + + sparkRBackend.run() + } catch { + case e: IOException => + logError("Server shutting down: failed with exception ", e) + sparkRBackend.close() + System.exit(1) + } + System.exit(0) + } +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala new file mode 100644 index 000000000000..0075d963711f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +import scala.collection.mutable.HashMap + +import io.netty.channel.ChannelHandler.Sharable +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} + +import org.apache.spark.Logging +import org.apache.spark.api.r.SerDe._ + +/** + * Handler for RBackend + * TODO: This is marked as sharable to get a handle to RBackend. Is it safe to re-use + * this across connections ? + */ +@Sharable +private[r] class RBackendHandler(server: RBackend) + extends SimpleChannelInboundHandler[Array[Byte]] with Logging { + + override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { + val bis = new ByteArrayInputStream(msg) + val dis = new DataInputStream(bis) + + val bos = new ByteArrayOutputStream() + val dos = new DataOutputStream(bos) + + // First bit is isStatic + val isStatic = readBoolean(dis) + val objId = readString(dis) + val methodName = readString(dis) + val numArgs = readInt(dis) + + if (objId == "SparkRHandler") { + methodName match { + case "stopBackend" => + writeInt(dos, 0) + writeType(dos, "void") + server.close() + case "rm" => + try { + val t = readObjectType(dis) + assert(t == 'c') + val objToRemove = readString(dis) + JVMObjectTracker.remove(objToRemove) + writeInt(dos, 0) + writeObject(dos, null) + } catch { + case e: Exception => + logError(s"Removing $objId failed", e) + writeInt(dos, -1) + } + case _ => dos.writeInt(-1) + } + } else { + handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) + } + + val reply = bos.toByteArray + ctx.write(reply) + } + + override def channelReadComplete(ctx: ChannelHandlerContext): Unit = { + ctx.flush() + } + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + // Close the connection when an exception is raised. + cause.printStackTrace() + ctx.close() + } + + def handleMethodCall( + isStatic: Boolean, + objId: String, + methodName: String, + numArgs: Int, + dis: DataInputStream, + dos: DataOutputStream): Unit = { + var obj: Object = null + try { + val cls = if (isStatic) { + Class.forName(objId) + } else { + JVMObjectTracker.get(objId) match { + case None => throw new IllegalArgumentException("Object not found " + objId) + case Some(o) => + obj = o + o.getClass + } + } + + val args = readArgs(numArgs, dis) + + val methods = cls.getMethods + val selectedMethods = methods.filter(m => m.getName == methodName) + if (selectedMethods.length > 0) { + val methods = selectedMethods.filter { x => + matchMethod(numArgs, args, x.getParameterTypes) + } + if (methods.isEmpty) { + logWarning(s"cannot find matching method ${cls}.$methodName. " + + s"Candidates are:") + selectedMethods.foreach { method => + logWarning(s"$methodName(${method.getParameterTypes.mkString(",")})") + } + throw new Exception(s"No matched method found for $cls.$methodName") + } + val ret = methods.head.invoke(obj, args:_*) + + // Write status bit + writeInt(dos, 0) + writeObject(dos, ret.asInstanceOf[AnyRef]) + } else if (methodName == "") { + // methodName should be "" for constructor + val ctor = cls.getConstructors.filter { x => + matchMethod(numArgs, args, x.getParameterTypes) + }.head + + val obj = ctor.newInstance(args:_*) + + writeInt(dos, 0) + writeObject(dos, obj.asInstanceOf[AnyRef]) + } else { + throw new IllegalArgumentException("invalid method " + methodName + " for object " + objId) + } + } catch { + case e: Exception => + logError(s"$methodName on $objId failed", e) + writeInt(dos, -1) + } + } + + // Read a number of arguments from the data input stream + def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { + (0 until numArgs).map { arg => + readObject(dis) + }.toArray + } + + // Checks if the arguments passed in args matches the parameter types. + // NOTE: Currently we do exact match. We may add type conversions later. + def matchMethod( + numArgs: Int, + args: Array[java.lang.Object], + parameterTypes: Array[Class[_]]): Boolean = { + if (parameterTypes.length != numArgs) { + return false + } + + for (i <- 0 to numArgs - 1) { + val parameterType = parameterTypes(i) + var parameterWrapperType = parameterType + + // Convert native parameters to Object types as args is Array[Object] here + if (parameterType.isPrimitive) { + parameterWrapperType = parameterType match { + case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Double.TYPE => classOf[java.lang.Double] + case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] + case _ => parameterType + } + } + if (!parameterWrapperType.isInstance(args(i))) { + return false + } + } + true + } +} + +/** + * Helper singleton that tracks JVM objects returned to R. + * This is useful for referencing these objects in RPC calls. + */ +private[r] object JVMObjectTracker { + + // TODO: This map should be thread-safe if we want to support multiple + // connections at the same time + private[this] val objMap = new HashMap[String, Object] + + // TODO: We support only one connection now, so an integer is fine. + // Investigate using use atomic integer in the future. + private[this] var objCounter: Int = 0 + + def getObject(id: String): Object = { + objMap(id) + } + + def get(id: String): Option[Object] = { + objMap.get(id) + } + + def put(obj: Object): String = { + val objId = objCounter.toString + objCounter = objCounter + 1 + objMap.put(objId, obj) + objId + } + + def remove(id: String): Option[Object] = { + objMap.remove(id) + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala new file mode 100644 index 000000000000..6fea5e1144f2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -0,0 +1,453 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io._ +import java.net.ServerSocket +import java.util.{Map => JMap} + +import scala.collection.JavaConversions._ +import scala.io.Source +import scala.reflect.ClassTag +import scala.util.Try + +import org.apache.spark._ +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils + +private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( + parent: RDD[T], + numPartitions: Int, + func: Array[Byte], + deserializer: String, + serializer: String, + packageNames: Array[Byte], + rLibDir: String, + broadcastVars: Array[Broadcast[Object]]) + extends RDD[U](parent) with Logging { + protected var dataStream: DataInputStream = _ + private var bootTime: Double = _ + override def getPartitions: Array[Partition] = parent.partitions + + override def compute(partition: Partition, context: TaskContext): Iterator[U] = { + + // Timing start + bootTime = System.currentTimeMillis / 1000.0 + + // The parent may be also an RRDD, so we should launch it first. + val parentIterator = firstParent[T].iterator(partition, context) + + // we expect two connections + val serverSocket = new ServerSocket(0, 2) + val listenPort = serverSocket.getLocalPort() + + // The stdout/stderr is shared by multiple tasks, because we use one daemon + // to launch child process as worker. + val errThread = RRDD.createRWorker(rLibDir, listenPort) + + // We use two sockets to separate input and output, then it's easy to manage + // the lifecycle of them to avoid deadlock. + // TODO: optimize it to use one socket + + // the socket used to send out the input of task + serverSocket.setSoTimeout(10000) + val inSocket = serverSocket.accept() + startStdinThread(inSocket.getOutputStream(), parentIterator, partition.index) + + // the socket used to receive the output of task + val outSocket = serverSocket.accept() + val inputStream = new BufferedInputStream(outSocket.getInputStream) + dataStream = new DataInputStream(inputStream) + serverSocket.close() + + try { + + return new Iterator[U] { + def next(): U = { + val obj = _nextObj + if (hasNext) { + _nextObj = read() + } + obj + } + + var _nextObj = read() + + def hasNext(): Boolean = { + val hasMore = (_nextObj != null) + if (!hasMore) { + dataStream.close() + } + hasMore + } + } + } catch { + case e: Exception => + throw new SparkException("R computation failed with\n " + errThread.getLines()) + } + } + + /** + * Start a thread to write RDD data to the R process. + */ + private def startStdinThread[T]( + output: OutputStream, + iter: Iterator[T], + partition: Int): Unit = { + + val env = SparkEnv.get + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + val stream = new BufferedOutputStream(output, bufferSize) + + new Thread("writer for R") { + override def run(): Unit = { + try { + SparkEnv.set(env) + val dataOut = new DataOutputStream(stream) + dataOut.writeInt(partition) + + SerDe.writeString(dataOut, deserializer) + SerDe.writeString(dataOut, serializer) + + dataOut.writeInt(packageNames.length) + dataOut.write(packageNames) + + dataOut.writeInt(func.length) + dataOut.write(func) + + dataOut.writeInt(broadcastVars.length) + broadcastVars.foreach { broadcast => + // TODO(shivaram): Read a Long in R to avoid this cast + dataOut.writeInt(broadcast.id.toInt) + // TODO: Pass a byte array from R to avoid this cast ? + val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]] + dataOut.writeInt(broadcastByteArr.length) + dataOut.write(broadcastByteArr) + } + + dataOut.writeInt(numPartitions) + + if (!iter.hasNext) { + dataOut.writeInt(0) + } else { + dataOut.writeInt(1) + } + + val printOut = new PrintStream(stream) + + def writeElem(elem: Any): Unit = { + if (deserializer == SerializationFormats.BYTE) { + val elemArr = elem.asInstanceOf[Array[Byte]] + dataOut.writeInt(elemArr.length) + dataOut.write(elemArr) + } else if (deserializer == SerializationFormats.ROW) { + dataOut.write(elem.asInstanceOf[Array[Byte]]) + } else if (deserializer == SerializationFormats.STRING) { + // write string(for StringRRDD) + printOut.println(elem) + } + } + + for (elem <- iter) { + elem match { + case (key, value) => + writeElem(key) + writeElem(value) + case _ => + writeElem(elem) + } + } + stream.flush() + } catch { + // TODO: We should propogate this error to the task thread + case e: Exception => + logError("R Writer thread got an exception", e) + } finally { + Try(output.close()) + } + } + }.start() + } + + protected def readData(length: Int): U + + protected def read(): U = { + try { + val length = dataStream.readInt() + + length match { + case SpecialLengths.TIMING_DATA => + // Timing data from R worker + val boot = dataStream.readDouble - bootTime + val init = dataStream.readDouble + val broadcast = dataStream.readDouble + val input = dataStream.readDouble + val compute = dataStream.readDouble + val output = dataStream.readDouble + logInfo( + ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " + + "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " + + "total = %.3f s").format( + boot, + init, + broadcast, + input, + compute, + output, + boot + init + broadcast + input + compute + output)) + read() + case length if length >= 0 => + readData(length) + } + } catch { + case eof: EOFException => + throw new SparkException("R worker exited unexpectedly (cranshed)", eof) + } + } +} + +/** + * Form an RDD[(Int, Array[Byte])] from key-value pairs returned from R. + * This is used by SparkR's shuffle operations. + */ +private class PairwiseRRDD[T: ClassTag]( + parent: RDD[T], + numPartitions: Int, + hashFunc: Array[Byte], + deserializer: String, + packageNames: Array[Byte], + rLibDir: String, + broadcastVars: Array[Object]) + extends BaseRRDD[T, (Int, Array[Byte])]( + parent, numPartitions, hashFunc, deserializer, + SerializationFormats.BYTE, packageNames, rLibDir, + broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + override protected def readData(length: Int): (Int, Array[Byte]) = { + length match { + case length if length == 2 => + val hashedKey = dataStream.readInt() + val contentPairsLength = dataStream.readInt() + val contentPairs = new Array[Byte](contentPairsLength) + dataStream.readFully(contentPairs) + (hashedKey, contentPairs) + case _ => null + } + } + + lazy val asJavaPairRDD : JavaPairRDD[Int, Array[Byte]] = JavaPairRDD.fromRDD(this) +} + +/** + * An RDD that stores serialized R objects as Array[Byte]. + */ +private class RRDD[T: ClassTag]( + parent: RDD[T], + func: Array[Byte], + deserializer: String, + serializer: String, + packageNames: Array[Byte], + rLibDir: String, + broadcastVars: Array[Object]) + extends BaseRRDD[T, Array[Byte]]( + parent, -1, func, deserializer, serializer, packageNames, rLibDir, + broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + override protected def readData(length: Int): Array[Byte] = { + length match { + case length if length > 0 => + val obj = new Array[Byte](length) + dataStream.readFully(obj) + obj + case _ => null + } + } + + lazy val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) +} + +/** + * An RDD that stores R objects as Array[String]. + */ +private class StringRRDD[T: ClassTag]( + parent: RDD[T], + func: Array[Byte], + deserializer: String, + packageNames: Array[Byte], + rLibDir: String, + broadcastVars: Array[Object]) + extends BaseRRDD[T, String]( + parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, rLibDir, + broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + override protected def readData(length: Int): String = { + length match { + case length if length > 0 => + SerDe.readStringBytes(dataStream, length) + case _ => null + } + } + + lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this) +} + +private object SpecialLengths { + val TIMING_DATA = -1 +} + +private[r] class BufferedStreamThread( + in: InputStream, + name: String, + errBufferSize: Int) extends Thread(name) with Logging { + val lines = new Array[String](errBufferSize) + var lineIdx = 0 + override def run() { + for (line <- Source.fromInputStream(in).getLines) { + synchronized { + lines(lineIdx) = line + lineIdx = (lineIdx + 1) % errBufferSize + } + logInfo(line) + } + } + + def getLines(): String = synchronized { + (0 until errBufferSize).filter { x => + lines((x + lineIdx) % errBufferSize) != null + }.map { x => + lines((x + lineIdx) % errBufferSize) + }.mkString("\n") + } +} + +private[r] object RRDD { + // Because forking processes from Java is expensive, we prefer to launch + // a single R daemon (daemon.R) and tell it to fork new workers for our tasks. + // This daemon currently only works on UNIX-based systems now, so we should + // also fall back to launching workers (worker.R) directly. + private[this] var errThread: BufferedStreamThread = _ + private[this] var daemonChannel: DataOutputStream = _ + + def createSparkContext( + master: String, + appName: String, + sparkHome: String, + jars: Array[String], + sparkEnvirMap: JMap[Object, Object], + sparkExecutorEnvMap: JMap[Object, Object]): JavaSparkContext = { + + val sparkConf = new SparkConf().setAppName(appName) + .setSparkHome(sparkHome) + .setJars(jars) + + // Override `master` if we have a user-specified value + if (master != "") { + sparkConf.setMaster(master) + } else { + // If conf has no master set it to "local" to maintain + // backwards compatibility + sparkConf.setIfMissing("spark.master", "local") + } + + for ((name, value) <- sparkEnvirMap) { + sparkConf.set(name.asInstanceOf[String], value.asInstanceOf[String]) + } + for ((name, value) <- sparkExecutorEnvMap) { + sparkConf.setExecutorEnv(name.asInstanceOf[String], value.asInstanceOf[String]) + } + + new JavaSparkContext(sparkConf) + } + + /** + * Start a thread to print the process's stderr to ours + */ + private def startStdoutThread(proc: Process): BufferedStreamThread = { + val BUFFER_SIZE = 100 + val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE) + thread.setDaemon(true) + thread.start() + thread + } + + private def createRProcess(rLibDir: String, port: Int, script: String): BufferedStreamThread = { + val rCommand = "Rscript" + val rOptions = "--vanilla" + val rExecScript = rLibDir + "/SparkR/worker/" + script + val pb = new ProcessBuilder(List(rCommand, rOptions, rExecScript)) + // Unset the R_TESTS environment variable for workers. + // This is set by R CMD check as startup.Rs + // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R) + // and confuses worker script which tries to load a non-existent file + pb.environment().put("R_TESTS", "") + pb.environment().put("SPARKR_RLIBDIR", rLibDir) + pb.environment().put("SPARKR_WORKER_PORT", port.toString) + pb.redirectErrorStream(true) // redirect stderr into stdout + val proc = pb.start() + val errThread = startStdoutThread(proc) + errThread + } + + /** + * ProcessBuilder used to launch worker R processes. + */ + def createRWorker(rLibDir: String, port: Int): BufferedStreamThread = { + val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true) + if (!Utils.isWindows && useDaemon) { + synchronized { + if (daemonChannel == null) { + // we expect one connections + val serverSocket = new ServerSocket(0, 1) + val daemonPort = serverSocket.getLocalPort + errThread = createRProcess(rLibDir, daemonPort, "daemon.R") + // the socket used to send out the input of task + serverSocket.setSoTimeout(10000) + val sock = serverSocket.accept() + daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + serverSocket.close() + } + try { + daemonChannel.writeInt(port) + daemonChannel.flush() + } catch { + case e: IOException => + // daemon process died + daemonChannel.close() + daemonChannel = null + errThread = null + // fail the current task, retry by scheduler + throw e + } + errThread + } + } else { + createRProcess(rLibDir, port, "worker.R") + } + } + + /** + * Create an RRDD given a sequence of byte arrays. Used to create RRDD when `parallelize` is + * called from R. + */ + def createRDDFromArray(jsc: JavaSparkContext, arr: Array[Array[Byte]]): JavaRDD[Array[Byte]] = { + JavaRDD.fromRDD(jsc.sc.parallelize(arr, arr.length)) + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala new file mode 100644 index 000000000000..371dfe454d1a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{DataInputStream, DataOutputStream} +import java.sql.{Date, Time} + +import scala.collection.JavaConversions._ + +/** + * Utility functions to serialize, deserialize objects to / from R + */ +private[spark] object SerDe { + + // Type mapping from R to Java + // + // NULL -> void + // integer -> Int + // character -> String + // logical -> Boolean + // double, numeric -> Double + // raw -> Array[Byte] + // Date -> Date + // POSIXlt/POSIXct -> Time + // + // list[T] -> Array[T], where T is one of above mentioned types + // environment -> Map[String, T], where T is a native type + // jobj -> Object, where jobj is an object created in the backend + + def readObjectType(dis: DataInputStream): Char = { + dis.readByte().toChar + } + + def readObject(dis: DataInputStream): Object = { + val dataType = readObjectType(dis) + readTypedObject(dis, dataType) + } + + def readTypedObject( + dis: DataInputStream, + dataType: Char): Object = { + dataType match { + case 'n' => null + case 'i' => new java.lang.Integer(readInt(dis)) + case 'd' => new java.lang.Double(readDouble(dis)) + case 'b' => new java.lang.Boolean(readBoolean(dis)) + case 'c' => readString(dis) + case 'e' => readMap(dis) + case 'r' => readBytes(dis) + case 'l' => readList(dis) + case 'D' => readDate(dis) + case 't' => readTime(dis) + case 'j' => JVMObjectTracker.getObject(readString(dis)) + case _ => throw new IllegalArgumentException(s"Invalid type $dataType") + } + } + + def readBytes(in: DataInputStream): Array[Byte] = { + val len = readInt(in) + val out = new Array[Byte](len) + val bytesRead = in.readFully(out) + out + } + + def readInt(in: DataInputStream): Int = { + in.readInt() + } + + def readDouble(in: DataInputStream): Double = { + in.readDouble() + } + + def readStringBytes(in: DataInputStream, len: Int): String = { + val bytes = new Array[Byte](len) + in.readFully(bytes) + assert(bytes(len - 1) == 0) + val str = new String(bytes.dropRight(1), "UTF-8") + str + } + + def readString(in: DataInputStream): String = { + val len = in.readInt() + readStringBytes(in, len) + } + + def readBoolean(in: DataInputStream): Boolean = { + val intVal = in.readInt() + if (intVal == 0) false else true + } + + def readDate(in: DataInputStream): Date = { + Date.valueOf(readString(in)) + } + + def readTime(in: DataInputStream): Time = { + val t = in.readDouble() + new Time((t * 1000L).toLong) + } + + def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { + val len = readInt(in) + (0 until len).map(_ => readBytes(in)).toArray + } + + def readIntArr(in: DataInputStream): Array[Int] = { + val len = readInt(in) + (0 until len).map(_ => readInt(in)).toArray + } + + def readDoubleArr(in: DataInputStream): Array[Double] = { + val len = readInt(in) + (0 until len).map(_ => readDouble(in)).toArray + } + + def readBooleanArr(in: DataInputStream): Array[Boolean] = { + val len = readInt(in) + (0 until len).map(_ => readBoolean(in)).toArray + } + + def readStringArr(in: DataInputStream): Array[String] = { + val len = readInt(in) + (0 until len).map(_ => readString(in)).toArray + } + + def readList(dis: DataInputStream): Array[_] = { + val arrType = readObjectType(dis) + arrType match { + case 'i' => readIntArr(dis) + case 'c' => readStringArr(dis) + case 'd' => readDoubleArr(dis) + case 'b' => readBooleanArr(dis) + case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) + case 'r' => readBytesArr(dis) + case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") + } + } + + def readMap(in: DataInputStream): java.util.Map[Object, Object] = { + val len = readInt(in) + if (len > 0) { + val keysType = readObjectType(in) + val keysLen = readInt(in) + val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType)) + + val valuesType = readObjectType(in) + val valuesLen = readInt(in) + val values = (0 until valuesLen).map(_ => readTypedObject(in, valuesType)) + mapAsJavaMap(keys.zip(values).toMap) + } else { + new java.util.HashMap[Object, Object]() + } + } + + // Methods to write out data from Java to R + // + // Type mapping from Java to R + // + // void -> NULL + // Int -> integer + // String -> character + // Boolean -> logical + // Double -> double + // Long -> double + // Array[Byte] -> raw + // Date -> Date + // Time -> POSIXct + // + // Array[T] -> list() + // Object -> jobj + + def writeType(dos: DataOutputStream, typeStr: String): Unit = { + typeStr match { + case "void" => dos.writeByte('n') + case "character" => dos.writeByte('c') + case "double" => dos.writeByte('d') + case "integer" => dos.writeByte('i') + case "logical" => dos.writeByte('b') + case "date" => dos.writeByte('D') + case "time" => dos.writeByte('t') + case "raw" => dos.writeByte('r') + case "list" => dos.writeByte('l') + case "jobj" => dos.writeByte('j') + case _ => throw new IllegalArgumentException(s"Invalid type $typeStr") + } + } + + def writeObject(dos: DataOutputStream, value: Object): Unit = { + if (value == null) { + writeType(dos, "void") + } else { + value.getClass.getName match { + case "java.lang.String" => + writeType(dos, "character") + writeString(dos, value.asInstanceOf[String]) + case "long" | "java.lang.Long" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Long].toDouble) + case "double" | "java.lang.Double" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Double]) + case "int" | "java.lang.Integer" => + writeType(dos, "integer") + writeInt(dos, value.asInstanceOf[Int]) + case "boolean" | "java.lang.Boolean" => + writeType(dos, "logical") + writeBoolean(dos, value.asInstanceOf[Boolean]) + case "java.sql.Date" => + writeType(dos, "date") + writeDate(dos, value.asInstanceOf[Date]) + case "java.sql.Time" => + writeType(dos, "time") + writeTime(dos, value.asInstanceOf[Time]) + case "[B" => + writeType(dos, "raw") + writeBytes(dos, value.asInstanceOf[Array[Byte]]) + // TODO: Types not handled right now include + // byte, char, short, float + + // Handle arrays + case "[Ljava.lang.String;" => + writeType(dos, "list") + writeStringArr(dos, value.asInstanceOf[Array[String]]) + case "[I" => + writeType(dos, "list") + writeIntArr(dos, value.asInstanceOf[Array[Int]]) + case "[J" => + writeType(dos, "list") + writeDoubleArr(dos, value.asInstanceOf[Array[Long]].map(_.toDouble)) + case "[D" => + writeType(dos, "list") + writeDoubleArr(dos, value.asInstanceOf[Array[Double]]) + case "[Z" => + writeType(dos, "list") + writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]]) + case "[[B" => + writeType(dos, "list") + writeBytesArr(dos, value.asInstanceOf[Array[Array[Byte]]]) + case otherName => + // Handle array of objects + if (otherName.startsWith("[L")) { + val objArr = value.asInstanceOf[Array[Object]] + writeType(dos, "list") + writeType(dos, "jobj") + dos.writeInt(objArr.length) + objArr.foreach(o => writeJObj(dos, o)) + } else { + writeType(dos, "jobj") + writeJObj(dos, value) + } + } + } + } + + def writeInt(out: DataOutputStream, value: Int): Unit = { + out.writeInt(value) + } + + def writeDouble(out: DataOutputStream, value: Double): Unit = { + out.writeDouble(value) + } + + def writeBoolean(out: DataOutputStream, value: Boolean): Unit = { + val intValue = if (value) 1 else 0 + out.writeInt(intValue) + } + + def writeDate(out: DataOutputStream, value: Date): Unit = { + writeString(out, value.toString) + } + + def writeTime(out: DataOutputStream, value: Time): Unit = { + out.writeDouble(value.getTime.toDouble / 1000.0) + } + + + // NOTE: Only works for ASCII right now + def writeString(out: DataOutputStream, value: String): Unit = { + val len = value.length + out.writeInt(len + 1) // For the \0 + out.writeBytes(value) + out.writeByte(0) + } + + def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { + out.writeInt(value.length) + out.write(value) + } + + def writeJObj(out: DataOutputStream, value: Object): Unit = { + val objId = JVMObjectTracker.put(value) + writeString(out, objId) + } + + def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { + writeType(out, "integer") + out.writeInt(value.length) + value.foreach(v => out.writeInt(v)) + } + + def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = { + writeType(out, "double") + out.writeInt(value.length) + value.foreach(v => out.writeDouble(v)) + } + + def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = { + writeType(out, "logical") + out.writeInt(value.length) + value.foreach(v => writeBoolean(out, v)) + } + + def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = { + writeType(out, "character") + out.writeInt(value.length) + value.foreach(v => writeString(out, v)) + } + + def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { + writeType(out, "raw") + out.writeInt(value.length) + value.foreach(v => writeBytes(out, v)) + } +} + +private[r] object SerializationFormats { + val BYTE = "byte" + val STRING = "string" + val ROW = "row" +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index a5ea478f231d..12d79f6ed311 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -146,5 +146,5 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable with Lo } } - override def toString = "Broadcast(" + id + ")" + override def toString: String = "Broadcast(" + id + ")" } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index 8f8a0b11f9f2..685313ac009b 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -58,7 +58,7 @@ private[spark] class BroadcastManager( private val nextBroadcastId = new AtomicLong(0) - def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean) = { + def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = { broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 31d6958c403b..4457c75e8b0f 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -151,7 +151,7 @@ private[broadcast] object HttpBroadcast extends Logging { } private def createServer(conf: SparkConf) { - broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf)) + broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf), "broadcast") val broadcastPort = conf.getInt("spark.broadcast.port", 0) server = new HttpServer(conf, broadcastDir, securityManager, broadcastPort, "HTTP broadcast server") @@ -160,12 +160,12 @@ private[broadcast] object HttpBroadcast extends Logging { logInfo("Broadcast server started at " + serverUri) } - def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name) + def getFile(id: Long): File = new File(broadcastDir, BroadcastBlockId(id).name) private def write(id: Long, value: Any) { val file = getFile(id) val fileOutputStream = new FileOutputStream(file) - try { + Utils.tryWithSafeFinally { val out: OutputStream = { if (compress) { compressionCodec.compressedOutputStream(fileOutputStream) @@ -175,10 +175,13 @@ private[broadcast] object HttpBroadcast extends Logging { } val ser = SparkEnv.get.serializer.newInstance() val serOut = ser.serializeStream(out) - serOut.writeObject(value) - serOut.close() + Utils.tryWithSafeFinally { + serOut.writeObject(value) + } { + serOut.close() + } files += file - } finally { + } { fileOutputStream.close() } } @@ -199,6 +202,7 @@ private[broadcast] object HttpBroadcast extends Logging { uc = new URL(url).openConnection() uc.setConnectTimeout(httpReadTimeout) } + Utils.setupSecureURLConnection(uc, securityManager) val in = { uc.setReadTimeout(httpReadTimeout) @@ -211,9 +215,11 @@ private[broadcast] object HttpBroadcast extends Logging { } val ser = SparkEnv.get.serializer.newInstance() val serIn = ser.deserializeStream(in) - val obj = serIn.readObject[T]() - serIn.close() - obj + Utils.tryWithSafeFinally { + serIn.readObject[T]() + } { + serIn.close() + } } /** @@ -221,7 +227,7 @@ private[broadcast] object HttpBroadcast extends Logging { * If removeFromDriver is true, also remove these persisted blocks on the driver * and delete the associated broadcast file. */ - def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized { + def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = synchronized { SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) if (removeFromDriver) { val file = getFile(id) diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala index c7ef02d572a1..cf3ae36f2794 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala @@ -31,7 +31,7 @@ class HttpBroadcastFactory extends BroadcastFactory { HttpBroadcast.initialize(isDriver, conf, securityMgr) } - override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = + override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = new HttpBroadcast[T](value_, isLocal, id) override def stop() { HttpBroadcast.stop() } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 94142d33369c..23b02e60338f 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -222,7 +222,7 @@ private object TorrentBroadcast extends Logging { * Remove all persisted blocks associated with this torrent broadcast on the executors. * If removeFromDriver is true, also remove these persisted blocks on the driver. */ - def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = { + def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = { logDebug(s"Unpersisting TorrentBroadcast $id") SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala index fb024c12094f..96d8dd79908c 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -30,7 +30,7 @@ class TorrentBroadcastFactory extends BroadcastFactory { override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { } - override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = { + override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = { new TorrentBroadcast[T](value_, id) } diff --git a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala index 65a1a8fd7e92..ae99432f5ce8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala @@ -17,16 +17,32 @@ package org.apache.spark.deploy +import java.net.URI + private[spark] class ApplicationDescription( val name: String, val maxCores: Option[Int], - val memoryPerSlave: Int, + val memoryPerExecutorMB: Int, val command: Command, var appUiUrl: String, - val eventLogDir: Option[String] = None) + val eventLogDir: Option[URI] = None, + // short name of compression codec used when writing event logs, if any (e.g. lzf) + val eventLogCodec: Option[String] = None, + val coresPerExecutor: Option[Int] = None) extends Serializable { val user = System.getProperty("user.name", "") + def copy( + name: String = name, + maxCores: Option[Int] = maxCores, + memoryPerExecutorMB: Int = memoryPerExecutorMB, + command: Command = command, + appUiUrl: String = appUiUrl, + eventLogDir: Option[URI] = eventLogDir, + eventLogCodec: Option[String] = eventLogCodec): ApplicationDescription = + new ApplicationDescription( + name, maxCores, memoryPerExecutorMB, command, appUiUrl, eventLogDir, eventLogCodec) + override def toString: String = "ApplicationDescription(" + name + ")" } diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 7c1c831c248f..c2c3e9a9e482 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -27,7 +27,7 @@ import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, Utils} /** * Proxy that relays messages to the driver. @@ -36,10 +36,11 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends Actor with ActorLogReceive with Logging { var masterActor: ActorSelection = _ - val timeout = AkkaUtils.askTimeout(conf) + val timeout = RpcUtils.askTimeout(conf) - override def preStart() = { - masterActor = context.actorSelection(Master.toAkkaUrl(driverArgs.master)) + override def preStart(): Unit = { + masterActor = context.actorSelection( + Master.toAkkaUrl(driverArgs.master, AkkaUtils.protocol(context.system))) context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) @@ -67,8 +68,9 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) .map(Utils.splitCommandString).getOrElse(Seq.empty) val sparkJavaOpts = Utils.sparkJavaOpts(conf) val javaOpts = sparkJavaOpts ++ extraJavaOpts - val command = new Command(mainClass, Seq("{{WORKER_URL}}", driverArgs.mainClass) ++ - driverArgs.driverOptions, sys.env, classPathEntries, libraryPathEntries, javaOpts) + val command = new Command(mainClass, + Seq("{{WORKER_URL}}", "{{USER_JAR}}", driverArgs.mainClass) ++ driverArgs.driverOptions, + sys.env, classPathEntries, libraryPathEntries, javaOpts) val driverDescription = new DriverDescription( driverArgs.jarUrl, @@ -87,7 +89,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) /* Find out driver status then exit the JVM */ def pollAndReportStatus(driverId: String) { - println(s"... waiting before polling master for driver state") + println("... waiting before polling master for driver state") Thread.sleep(5000) println("... polling master for driver state") val statusFuture = (masterActor ? RequestDriverStatus(driverId))(timeout) @@ -116,7 +118,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) } } - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case SubmitDriverResponse(success, driverId, message) => println(message) @@ -153,7 +155,7 @@ object Client { if (!driverArgs.logLevel.isGreaterOrEqual(Level.WARN)) { conf.set("spark.akka.logLifecycleEvents", "true") } - conf.set("spark.akka.askTimeout", "10") + conf.set("spark.rpc.askTimeout", "10") conf.set("akka.loglevel", driverArgs.logLevel.toString.replace("WARN", "WARNING")) Logger.getRootLogger.setLevel(driverArgs.logLevel) @@ -161,7 +163,7 @@ object Client { "driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) // Verify driverArgs.master is a valid url so that we can use it in ClientActor safely - Master.toAkkaUrl(driverArgs.master) + Master.toAkkaUrl(driverArgs.master, AkkaUtils.protocol(actorSystem)) actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf)) actorSystem.awaitTermination() diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index e5873ce724b9..5cbac787dcee 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -28,9 +28,8 @@ import org.apache.spark.util.{IntParam, MemoryParam} /** * Command-line parser for the driver client. */ -private[spark] class ClientArguments(args: Array[String]) { - val defaultCores = 1 - val defaultMemory = 512 +private[deploy] class ClientArguments(args: Array[String]) { + import ClientArguments._ var cmd: String = "" // 'launch' or 'kill' var logLevel = Level.WARN @@ -39,18 +38,18 @@ private[spark] class ClientArguments(args: Array[String]) { var master: String = "" var jarUrl: String = "" var mainClass: String = "" - var supervise: Boolean = false - var memory: Int = defaultMemory - var cores: Int = defaultCores + var supervise: Boolean = DEFAULT_SUPERVISE + var memory: Int = DEFAULT_MEMORY + var cores: Int = DEFAULT_CORES private var _driverOptions = ListBuffer[String]() - def driverOptions = _driverOptions.toSeq + def driverOptions: Seq[String] = _driverOptions.toSeq // kill parameters var driverId: String = "" parse(args.toList) - def parse(args: List[String]): Unit = args match { + private def parse(args: List[String]): Unit = args match { case ("--cores" | "-c") :: IntParam(value) :: tail => cores = value parse(tail) @@ -97,7 +96,7 @@ private[spark] class ClientArguments(args: Array[String]) { /** * Print usage and exit JVM with the given exit code. */ - def printUsageAndExit(exitCode: Int) { + private def printUsageAndExit(exitCode: Int) { // TODO: It wouldn't be too hard to allow users to submit their app and dependency jars // separately similar to in the YARN client. val usage = @@ -106,9 +105,10 @@ private[spark] class ClientArguments(args: Array[String]) { |Usage: DriverClient kill | |Options: - | -c CORES, --cores CORES Number of cores to request (default: $defaultCores) - | -m MEMORY, --memory MEMORY Megabytes of memory to request (default: $defaultMemory) + | -c CORES, --cores CORES Number of cores to request (default: $DEFAULT_CORES) + | -m MEMORY, --memory MEMORY Megabytes of memory to request (default: $DEFAULT_MEMORY) | -s, --supervise Whether to restart the driver on failure + | (default: $DEFAULT_SUPERVISE) | -v, --verbose Print more debugging output """.stripMargin System.err.println(usage) @@ -116,7 +116,11 @@ private[spark] class ClientArguments(args: Array[String]) { } } -object ClientArguments { +private[deploy] object ClientArguments { + val DEFAULT_CORES = 1 + val DEFAULT_MEMORY = 512 // MB + val DEFAULT_SUPERVISE = false + def isValidJarUrl(s: String): Boolean = { try { val uri = new URI(s) diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 243d8edb72ed..9db6fd1ac4db 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -101,6 +101,8 @@ private[deploy] object DeployMessages { case class RegisterApplication(appDescription: ApplicationDescription) extends DeployMessage + case class UnregisterApplication(appId: String) + case class MasterChangeAcknowledged(appId: String) // Master to AppClient @@ -148,15 +150,22 @@ private[deploy] object DeployMessages { // Master to MasterWebUI - case class MasterStateResponse(host: String, port: Int, workers: Array[WorkerInfo], - activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo], - activeDrivers: Array[DriverInfo], completedDrivers: Array[DriverInfo], - status: MasterState) { + case class MasterStateResponse( + host: String, + port: Int, + restPort: Option[Int], + workers: Array[WorkerInfo], + activeApps: Array[ApplicationInfo], + completedApps: Array[ApplicationInfo], + activeDrivers: Array[DriverInfo], + completedDrivers: Array[DriverInfo], + status: MasterState) { Utils.checkHost(host, "Required hostname") assert (port > 0) - def uri = "spark://" + host + ":" + port + def uri: String = "spark://" + host + ":" + port + def restUri: Option[String] = restPort.map { p => "spark://" + host + ":" + p } } // WorkerWebUI to Worker diff --git a/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala b/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala index 58c95dc4f911..659fb434a80f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy -private[spark] class DriverDescription( +private[deploy] class DriverDescription( val jarUrl: String, val mem: Int, val cores: Int, @@ -25,5 +25,13 @@ private[spark] class DriverDescription( val command: Command) extends Serializable { + def copy( + jarUrl: String = jarUrl, + mem: Int = mem, + cores: Int = cores, + supervise: Boolean = supervise, + command: Command = command): DriverDescription = + new DriverDescription(jarUrl, mem, cores, supervise, command) + override def toString: String = s"DriverDescription (${command.mainClass})" } diff --git a/core/src/main/scala/org/apache/spark/deploy/ExecutorDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ExecutorDescription.scala index 2abf0b69dddb..ec23371b52f9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExecutorDescription.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExecutorDescription.scala @@ -22,7 +22,7 @@ package org.apache.spark.deploy * This state is sufficient for the Master to reconstruct its internal data structures during * failover. */ -private[spark] class ExecutorDescription( +private[deploy] class ExecutorDescription( val appId: String, val execId: Int, val cores: Int, diff --git a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala index 9f34d01e6db4..efa88c62e1f5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy -private[spark] object ExecutorState extends Enumeration { +private[deploy] object ExecutorState extends Enumeration { val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST, EXITED = Value diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index 47dbcd87c35b..a7c89276a045 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -33,6 +33,7 @@ import org.json4s.jackson.JsonMethods import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.deploy.master.{RecoveryState, SparkCuratorUtil} +import org.apache.spark.util.Utils /** * This suite tests the fault tolerance of the Spark standalone scheduler, mainly the Master. @@ -55,29 +56,29 @@ import org.apache.spark.deploy.master.{RecoveryState, SparkCuratorUtil} * - The docker images tagged spark-test-master and spark-test-worker are built from the * docker/ directory. Run 'docker/spark-test/build' to generate these. */ -private[spark] object FaultToleranceTest extends App with Logging { +private object FaultToleranceTest extends App with Logging { - val conf = new SparkConf() - val ZK_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + private val conf = new SparkConf() + private val ZK_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") - val masters = ListBuffer[TestMasterInfo]() - val workers = ListBuffer[TestWorkerInfo]() - var sc: SparkContext = _ + private val masters = ListBuffer[TestMasterInfo]() + private val workers = ListBuffer[TestWorkerInfo]() + private var sc: SparkContext = _ - val zk = SparkCuratorUtil.newClient(conf) + private val zk = SparkCuratorUtil.newClient(conf) - var numPassed = 0 - var numFailed = 0 + private var numPassed = 0 + private var numFailed = 0 - val sparkHome = System.getenv("SPARK_HOME") + private val sparkHome = System.getenv("SPARK_HOME") assertTrue(sparkHome != null, "Run with a valid SPARK_HOME") - val containerSparkHome = "/opt/spark" - val dockerMountDir = "%s:%s".format(sparkHome, containerSparkHome) + private val containerSparkHome = "/opt/spark" + private val dockerMountDir = "%s:%s".format(sparkHome, containerSparkHome) System.setProperty("spark.driver.host", "172.17.42.1") // default docker host ip - def afterEach() { + private def afterEach() { if (sc != null) { sc.stop() sc = null @@ -179,7 +180,7 @@ private[spark] object FaultToleranceTest extends App with Logging { } } - def test(name: String)(fn: => Unit) { + private def test(name: String)(fn: => Unit) { try { fn numPassed += 1 @@ -197,19 +198,19 @@ private[spark] object FaultToleranceTest extends App with Logging { afterEach() } - def addMasters(num: Int) { + private def addMasters(num: Int) { logInfo(s">>>>> ADD MASTERS $num <<<<<") (1 to num).foreach { _ => masters += SparkDocker.startMaster(dockerMountDir) } } - def addWorkers(num: Int) { + private def addWorkers(num: Int) { logInfo(s">>>>> ADD WORKERS $num <<<<<") val masterUrls = getMasterUrls(masters) (1 to num).foreach { _ => workers += SparkDocker.startWorker(dockerMountDir, masterUrls) } } /** Creates a SparkContext, which constructs a Client to interact with our cluster. */ - def createClient() = { + private def createClient() = { logInfo(">>>>> CREATE CLIENT <<<<<") if (sc != null) { sc.stop() } // Counter-hack: Because of a hack in SparkEnv#create() that changes this @@ -218,17 +219,17 @@ private[spark] object FaultToleranceTest extends App with Logging { sc = new SparkContext(getMasterUrls(masters), "fault-tolerance", containerSparkHome) } - def getMasterUrls(masters: Seq[TestMasterInfo]): String = { + private def getMasterUrls(masters: Seq[TestMasterInfo]): String = { "spark://" + masters.map(master => master.ip + ":7077").mkString(",") } - def getLeader: TestMasterInfo = { + private def getLeader: TestMasterInfo = { val leaders = masters.filter(_.state == RecoveryState.ALIVE) assertTrue(leaders.size == 1) leaders(0) } - def killLeader(): Unit = { + private def killLeader(): Unit = { logInfo(">>>>> KILL LEADER <<<<<") masters.foreach(_.readState()) val leader = getLeader @@ -236,9 +237,9 @@ private[spark] object FaultToleranceTest extends App with Logging { leader.kill() } - def delay(secs: Duration = 5.seconds) = Thread.sleep(secs.toMillis) + private def delay(secs: Duration = 5.seconds) = Thread.sleep(secs.toMillis) - def terminateCluster() { + private def terminateCluster() { logInfo(">>>>> TERMINATE CLUSTER <<<<<") masters.foreach(_.kill()) workers.foreach(_.kill()) @@ -247,7 +248,7 @@ private[spark] object FaultToleranceTest extends App with Logging { } /** This includes Client retry logic, so it may take a while if the cluster is recovering. */ - def assertUsable() = { + private def assertUsable() = { val f = future { try { val res = sc.parallelize(0 until 10).collect() @@ -269,7 +270,7 @@ private[spark] object FaultToleranceTest extends App with Logging { * Asserts that the cluster is usable and that the expected masters and workers * are all alive in a proper configuration (e.g., only one leader). */ - def assertValidClusterState() = { + private def assertValidClusterState() = { logInfo(">>>>> ASSERT VALID CLUSTER STATE <<<<<") assertUsable() var numAlive = 0 @@ -325,7 +326,7 @@ private[spark] object FaultToleranceTest extends App with Logging { } } - def assertTrue(bool: Boolean, message: String = "") { + private def assertTrue(bool: Boolean, message: String = "") { if (!bool) { throw new IllegalStateException("Assertion failed: " + message) } @@ -335,7 +336,7 @@ private[spark] object FaultToleranceTest extends App with Logging { numFailed)) } -private[spark] class TestMasterInfo(val ip: String, val dockerId: DockerId, val logFile: File) +private class TestMasterInfo(val ip: String, val dockerId: DockerId, val logFile: File) extends Logging { implicit val formats = org.json4s.DefaultFormats @@ -377,7 +378,7 @@ private[spark] class TestMasterInfo(val ip: String, val dockerId: DockerId, val format(ip, dockerId.id, logFile.getAbsolutePath, state) } -private[spark] class TestWorkerInfo(val ip: String, val dockerId: DockerId, val logFile: File) +private class TestWorkerInfo(val ip: String, val dockerId: DockerId, val logFile: File) extends Logging { implicit val formats = org.json4s.DefaultFormats @@ -390,7 +391,7 @@ private[spark] class TestWorkerInfo(val ip: String, val dockerId: DockerId, val "[ip=%s, id=%s, logFile=%s]".format(ip, dockerId, logFile.getAbsolutePath) } -private[spark] object SparkDocker { +private object SparkDocker { def startMaster(mountDir: String): TestMasterInfo = { val cmd = Docker.makeRunCmd("spark-test-master", mountDir = mountDir) val (ip, id, outFile) = startNode(cmd) @@ -405,8 +406,7 @@ private[spark] object SparkDocker { private def startNode(dockerCmd: ProcessBuilder) : (String, DockerId, File) = { val ipPromise = promise[String]() - val outFile = File.createTempFile("fault-tolerance-test", "") - outFile.deleteOnExit() + val outFile = File.createTempFile("fault-tolerance-test", "", Utils.createTempDir()) val outStream: FileWriter = new FileWriter(outFile) def findIpAndLog(line: String): Unit = { if (line.startsWith("CONTAINER_IP=")) { @@ -425,11 +425,11 @@ private[spark] object SparkDocker { } } -private[spark] class DockerId(val id: String) { - override def toString = id +private class DockerId(val id: String) { + override def toString: String = id } -private[spark] object Docker extends Logging { +private object Docker extends Logging { def makeRunCmd(imageTag: String, args: String = "", mountDir: String = ""): ProcessBuilder = { val mountCmd = if (mountDir != "") { " -v " + mountDir } else "" diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index 696f32a6f573..2954f932b4f4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -17,14 +17,15 @@ package org.apache.spark.deploy +import org.json4s.JsonAST.JObject import org.json4s.JsonDSL._ import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} import org.apache.spark.deploy.worker.ExecutorRunner -private[spark] object JsonProtocol { - def writeWorkerInfo(obj: WorkerInfo) = { +private[deploy] object JsonProtocol { + def writeWorkerInfo(obj: WorkerInfo): JObject = { ("id" -> obj.id) ~ ("host" -> obj.host) ~ ("port" -> obj.port) ~ @@ -39,34 +40,34 @@ private[spark] object JsonProtocol { ("lastheartbeat" -> obj.lastHeartbeat) } - def writeApplicationInfo(obj: ApplicationInfo) = { + def writeApplicationInfo(obj: ApplicationInfo): JObject = { ("starttime" -> obj.startTime) ~ ("id" -> obj.id) ~ ("name" -> obj.desc.name) ~ ("cores" -> obj.desc.maxCores) ~ ("user" -> obj.desc.user) ~ - ("memoryperslave" -> obj.desc.memoryPerSlave) ~ + ("memoryperslave" -> obj.desc.memoryPerExecutorMB) ~ ("submitdate" -> obj.submitDate.toString) ~ ("state" -> obj.state.toString) ~ ("duration" -> obj.duration) } - def writeApplicationDescription(obj: ApplicationDescription) = { + def writeApplicationDescription(obj: ApplicationDescription): JObject = { ("name" -> obj.name) ~ ("cores" -> obj.maxCores) ~ - ("memoryperslave" -> obj.memoryPerSlave) ~ + ("memoryperslave" -> obj.memoryPerExecutorMB) ~ ("user" -> obj.user) ~ ("command" -> obj.command.toString) } - def writeExecutorRunner(obj: ExecutorRunner) = { + def writeExecutorRunner(obj: ExecutorRunner): JObject = { ("id" -> obj.execId) ~ ("memory" -> obj.memory) ~ ("appid" -> obj.appId) ~ ("appdesc" -> writeApplicationDescription(obj.appDesc)) } - def writeDriverInfo(obj: DriverInfo) = { + def writeDriverInfo(obj: DriverInfo): JObject = { ("id" -> obj.id) ~ ("starttime" -> obj.startTime.toString) ~ ("state" -> obj.state.toString) ~ @@ -74,7 +75,7 @@ private[spark] object JsonProtocol { ("memory" -> obj.desc.mem) } - def writeMasterState(obj: MasterStateResponse) = { + def writeMasterState(obj: MasterStateResponse): JObject = { ("url" -> obj.uri) ~ ("workers" -> obj.workers.toList.map(writeWorkerInfo)) ~ ("cores" -> obj.workers.map(_.cores).sum) ~ @@ -87,7 +88,7 @@ private[spark] object JsonProtocol { ("status" -> obj.status.toString) } - def writeWorkerState(obj: WorkerStateResponse) = { + def writeWorkerState(obj: WorkerStateResponse): JObject = { ("id" -> obj.workerId) ~ ("masterurl" -> obj.masterUrl) ~ ("masterwebuiurl" -> obj.masterWebUiUrl) ~ diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 9a7a113c9571..f0e77c2ba982 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -33,7 +33,11 @@ import org.apache.spark.util.Utils * fault recovery without spinning up a lot of processes. */ private[spark] -class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) +class LocalSparkCluster( + numWorkers: Int, + coresPerWorker: Int, + memoryPerWorker: Int, + conf: SparkConf) extends Logging { private val localHostname = Utils.localHostName() @@ -43,17 +47,19 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I def start(): Array[String] = { logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") + // Disable REST server on Master in this mode unless otherwise specified + val _conf = conf.clone().setIfMissing("spark.master.rest.enabled", "false") + /* Start the Master */ - val conf = new SparkConf(false) - val (masterSystem, masterPort, _) = Master.startSystemAndActor(localHostname, 0, 0, conf) + val (masterSystem, masterPort, _, _) = Master.startSystemAndActor(localHostname, 0, 0, _conf) masterActorSystems += masterSystem - val masterUrl = "spark://" + localHostname + ":" + masterPort + val masterUrl = "spark://" + Utils.localHostNameForURI() + ":" + masterPort val masters = Array(masterUrl) /* Start the Workers */ for (workerNum <- 1 to numWorkers) { val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker, - memoryPerWorker, masters, null, Some(workerNum)) + memoryPerWorker, masters, null, Some(workerNum), _conf) workerActorSystems += workerSystem } diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 039c8719e286..53e18c4bcec2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -26,7 +26,7 @@ import org.apache.spark.api.python.PythonUtils import org.apache.spark.util.{RedirectThread, Utils} /** - * A main class used by spark-submit to launch Python applications. It executes python as a + * A main class used to launch Python applications. It executes python as a * subprocess and then has it connect back to the JVM to access system properties, etc. */ object PythonRunner { diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala new file mode 100644 index 000000000000..e99779f29978 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io._ +import java.util.concurrent.{Semaphore, TimeUnit} + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.fs.Path + +import org.apache.spark.api.r.RBackend +import org.apache.spark.util.RedirectThread + +/** + * Main class used to launch SparkR applications using spark-submit. It executes R as a + * subprocess and then has it connect back to the JVM to access system properties etc. + */ +object RRunner { + def main(args: Array[String]): Unit = { + val rFile = PythonRunner.formatPath(args(0)) + + val otherArgs = args.slice(1, args.length) + + // Time to wait for SparkR backend to initialize in seconds + val backendTimeout = sys.env.getOrElse("SPARKR_BACKEND_TIMEOUT", "120").toInt + val rCommand = "Rscript" + + // Check if the file path exists. + // If not, change directory to current working directory for YARN cluster mode + val rF = new File(rFile) + val rFileNormalized = if (!rF.exists()) { + new Path(rFile).getName + } else { + rFile + } + + // Launch a SparkR backend server for the R process to connect to; this will let it see our + // Java system properties etc. + val sparkRBackend = new RBackend() + @volatile var sparkRBackendPort = 0 + val initialized = new Semaphore(0) + val sparkRBackendThread = new Thread("SparkR backend") { + override def run() { + sparkRBackendPort = sparkRBackend.init() + initialized.release() + sparkRBackend.run() + } + } + + sparkRBackendThread.start() + // Wait for RBackend initialization to finish + if (initialized.tryAcquire(backendTimeout, TimeUnit.SECONDS)) { + // Launch R + val returnCode = try { + val builder = new ProcessBuilder(Seq(rCommand, rFileNormalized) ++ otherArgs) + val env = builder.environment() + env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString) + val sparkHome = System.getenv("SPARK_HOME") + env.put("R_PROFILE_USER", + Seq(sparkHome, "R", "lib", "SparkR", "profile", "general.R").mkString(File.separator)) + builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize + val process = builder.start() + + new RedirectThread(process.getInputStream, System.out, "redirect R output").start() + + process.waitFor() + } finally { + sparkRBackend.close() + } + System.exit(returnCode) + } else { + System.err.println("SparkR backend did not initialize in " + backendTimeout + " seconds") + System.exit(-1) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 57f9faf5ddd1..cfaebf9ea505 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -21,14 +21,13 @@ import java.lang.reflect.Method import java.security.PrivilegedExceptionAction import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.fs.FileSystem.Statistics import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} -import org.apache.hadoop.security.Credentials -import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.mapreduce.JobContext +import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.apache.spark.{Logging, SparkContext, SparkConf, SparkException} +import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.Utils @@ -52,18 +51,13 @@ class SparkHadoopUtil extends Logging { * do a FileSystem.closeAllForUGI in order to avoid leaking Filesystems */ def runAsSparkUser(func: () => Unit) { - val user = Option(System.getenv("SPARK_USER")).getOrElse(SparkContext.SPARK_UNKNOWN_USER) - if (user != SparkContext.SPARK_UNKNOWN_USER) { - logDebug("running as user: " + user) - val ugi = UserGroupInformation.createRemoteUser(user) - transferCredentials(UserGroupInformation.getCurrentUser(), ugi) - ugi.doAs(new PrivilegedExceptionAction[Unit] { - def run: Unit = func() - }) - } else { - logDebug("running as SPARK_UNKNOWN_USER") - func() - } + val user = Utils.getCurrentUserName() + logDebug("running as user: " + user) + val ugi = UserGroupInformation.createRemoteUser(user) + transferCredentials(UserGroupInformation.getCurrentUser(), ugi) + ugi.doAs(new PrivilegedExceptionAction[Unit] { + def run: Unit = func() + }) } def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) { @@ -133,16 +127,15 @@ class SparkHadoopUtil extends Logging { * statistics are only available as of Hadoop 2.5 (see HADOOP-10688). * Returns None if the required method can't be found. */ - private[spark] def getFSBytesReadOnThreadCallback(path: Path, conf: Configuration) - : Option[() => Long] = { + private[spark] def getFSBytesReadOnThreadCallback(): Option[() => Long] = { try { - val threadStats = getFileSystemThreadStatistics(path, conf) + val threadStats = getFileSystemThreadStatistics() val getBytesReadMethod = getFileSystemThreadStatisticsMethod("getBytesRead") val f = () => threadStats.map(getBytesReadMethod.invoke(_).asInstanceOf[Long]).sum val baselineBytesRead = f() Some(() => f() - baselineBytesRead) } catch { - case e: NoSuchMethodException => { + case e @ (_: NoSuchMethodException | _: ClassNotFoundException) => { logDebug("Couldn't find method for retrieving thread-level FileSystem input data", e) None } @@ -156,26 +149,23 @@ class SparkHadoopUtil extends Logging { * statistics are only available as of Hadoop 2.5 (see HADOOP-10688). * Returns None if the required method can't be found. */ - private[spark] def getFSBytesWrittenOnThreadCallback(path: Path, conf: Configuration) - : Option[() => Long] = { + private[spark] def getFSBytesWrittenOnThreadCallback(): Option[() => Long] = { try { - val threadStats = getFileSystemThreadStatistics(path, conf) + val threadStats = getFileSystemThreadStatistics() val getBytesWrittenMethod = getFileSystemThreadStatisticsMethod("getBytesWritten") val f = () => threadStats.map(getBytesWrittenMethod.invoke(_).asInstanceOf[Long]).sum val baselineBytesWritten = f() Some(() => f() - baselineBytesWritten) } catch { - case e: NoSuchMethodException => { + case e @ (_: NoSuchMethodException | _: ClassNotFoundException) => { logDebug("Couldn't find method for retrieving thread-level FileSystem output data", e) None } } } - private def getFileSystemThreadStatistics(path: Path, conf: Configuration): Seq[AnyRef] = { - val qualifiedPath = path.getFileSystem(conf).makeQualified(path) - val scheme = qualifiedPath.toUri().getScheme() - val stats = FileSystem.getAllStatistics().filter(_.getScheme().equals(scheme)) + private def getFileSystemThreadStatistics(): Seq[AnyRef] = { + val stats = FileSystem.getAllStatistics() stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics")) } @@ -195,6 +185,52 @@ class SparkHadoopUtil extends Logging { val method = context.getClass.getMethod("getConfiguration") method.invoke(context).asInstanceOf[Configuration] } + + /** + * Get [[FileStatus]] objects for all leaf children (files) under the given base path. If the + * given path points to a file, return a single-element collection containing [[FileStatus]] of + * that file. + */ + def listLeafStatuses(fs: FileSystem, basePath: Path): Seq[FileStatus] = { + def recurse(path: Path): Array[FileStatus] = { + val (directories, leaves) = fs.listStatus(path).partition(_.isDir) + leaves ++ directories.flatMap(f => listLeafStatuses(fs, f.getPath)) + } + + val baseStatus = fs.getFileStatus(basePath) + if (baseStatus.isDir) recurse(basePath) else Array(baseStatus) + } + + private val HADOOP_CONF_PATTERN = "(\\$\\{hadoopconf-[^\\}\\$\\s]+\\})".r.unanchored + + /** + * Substitute variables by looking them up in Hadoop configs. Only variables that match the + * ${hadoopconf- .. } pattern are substituted. + */ + def substituteHadoopVariables(text: String, hadoopConf: Configuration): String = { + text match { + case HADOOP_CONF_PATTERN(matched) => { + logDebug(text + " matched " + HADOOP_CONF_PATTERN) + val key = matched.substring(13, matched.length() - 1) // remove ${hadoopconf- .. } + val eval = Option[String](hadoopConf.get(key)) + .map { value => + logDebug("Substituted " + matched + " with " + value) + text.replace(matched, value) + } + if (eval.isEmpty) { + // The variable was not found in Hadoop configs, so return text as is. + text + } else { + // Continue to substitute more variables. + substituteHadoopVariables(eval.get, hadoopConf) + } + } + case _ => { + logDebug(text + " didn't match " + HADOOP_CONF_PATTERN) + text + } + } + } } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 050ba91eb2bc..296a0764b8ba 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -18,13 +18,37 @@ package org.apache.spark.deploy import java.io.{File, PrintStream} -import java.lang.reflect.{Modifier, InvocationTargetException} +import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException} import java.net.URL +import java.security.PrivilegedExceptionAction import scala.collection.mutable.{ArrayBuffer, HashMap, Map} -import org.apache.spark.executor.ExecutorURLClassLoader -import org.apache.spark.util.Utils +import org.apache.hadoop.fs.Path +import org.apache.hadoop.security.UserGroupInformation +import org.apache.ivy.Ivy +import org.apache.ivy.core.LogOptions +import org.apache.ivy.core.module.descriptor._ +import org.apache.ivy.core.module.id.{ArtifactId, ModuleId, ModuleRevisionId} +import org.apache.ivy.core.report.ResolveReport +import org.apache.ivy.core.resolve.ResolveOptions +import org.apache.ivy.core.retrieve.RetrieveOptions +import org.apache.ivy.core.settings.IvySettings +import org.apache.ivy.plugins.matcher.GlobPatternMatcher +import org.apache.ivy.plugins.resolver.{ChainResolver, IBiblioResolver} + +import org.apache.spark.SPARK_VERSION +import org.apache.spark.deploy.rest._ +import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} + +/** + * Whether to submit, kill, or request the status of an application. + * The latter two operations are currently supported only for standalone cluster mode. + */ +private[deploy] object SparkSubmitAction extends Enumeration { + type SparkSubmitAction = Value + val SUBMIT, KILL, REQUEST_STATUS = Value +} /** * Main gateway of launching a Spark application. @@ -53,39 +77,132 @@ object SparkSubmit { // Special primary resource names that represent shells rather than application jars. private val SPARK_SHELL = "spark-shell" private val PYSPARK_SHELL = "pyspark-shell" + private val SPARKR_SHELL = "sparkr-shell" private val CLASS_NOT_FOUND_EXIT_STATUS = 101 // Exposed for testing - private[spark] var exitFn: () => Unit = () => System.exit(-1) + private[spark] var exitFn: () => Unit = () => System.exit(1) private[spark] var printStream: PrintStream = System.err - private[spark] def printWarning(str: String) = printStream.println("Warning: " + str) - private[spark] def printErrorAndExit(str: String) = { + private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str) + private[spark] def printErrorAndExit(str: String): Unit = { printStream.println("Error: " + str) printStream.println("Run with --help for usage help or --verbose for debug output") exitFn() } + private[spark] def printVersionAndExit(): Unit = { + printStream.println("""Welcome to + ____ __ + / __/__ ___ _____/ /__ + _\ \/ _ \/ _ `/ __/ '_/ + /___/ .__/\_,_/_/ /_/\_\ version %s + /_/ + """.format(SPARK_VERSION)) + printStream.println("Type --help for more information.") + exitFn() + } - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val appArgs = new SparkSubmitArguments(args) if (appArgs.verbose) { printStream.println(appArgs) } - val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) - launch(childArgs, classpath, sysProps, mainClass, appArgs.verbose) + appArgs.action match { + case SparkSubmitAction.SUBMIT => submit(appArgs) + case SparkSubmitAction.KILL => kill(appArgs) + case SparkSubmitAction.REQUEST_STATUS => requestStatus(appArgs) + } + } + + /** Kill an existing submission using the REST protocol. Standalone cluster mode only. */ + private def kill(args: SparkSubmitArguments): Unit = { + new StandaloneRestClient() + .killSubmission(args.master, args.submissionToKill) } /** - * @return a tuple containing - * (1) the arguments for the child process, - * (2) a list of classpath entries for the child, - * (3) a list of system properties and env vars, and - * (4) the main class for the child + * Request the status of an existing submission using the REST protocol. + * Standalone cluster mode only. */ - private[spark] def createLaunchEnv(args: SparkSubmitArguments) - : (ArrayBuffer[String], ArrayBuffer[String], Map[String, String], String) = { + private def requestStatus(args: SparkSubmitArguments): Unit = { + new StandaloneRestClient() + .requestSubmissionStatus(args.master, args.submissionToRequestStatusFor) + } - // Values to return + /** + * Submit the application using the provided parameters. + * + * This runs in two steps. First, we prepare the launch environment by setting up + * the appropriate classpath, system properties, and application arguments for + * running the child main class based on the cluster manager and the deploy mode. + * Second, we use this launch environment to invoke the main method of the child + * main class. + */ + private def submit(args: SparkSubmitArguments): Unit = { + val (childArgs, childClasspath, sysProps, childMainClass) = prepareSubmitEnvironment(args) + + def doRunMain(): Unit = { + if (args.proxyUser != null) { + val proxyUser = UserGroupInformation.createProxyUser(args.proxyUser, + UserGroupInformation.getCurrentUser()) + try { + proxyUser.doAs(new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + runMain(childArgs, childClasspath, sysProps, childMainClass, args.verbose) + } + }) + } catch { + case e: Exception => + // Hadoop's AuthorizationException suppresses the exception's stack trace, which + // makes the message printed to the output by the JVM not very helpful. Instead, + // detect exceptions with empty stack traces here, and treat them differently. + if (e.getStackTrace().length == 0) { + printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") + exitFn() + } else { + throw e + } + } + } else { + runMain(childArgs, childClasspath, sysProps, childMainClass, args.verbose) + } + } + + // In standalone cluster mode, there are two submission gateways: + // (1) The traditional Akka gateway using o.a.s.deploy.Client as a wrapper + // (2) The new REST-based gateway introduced in Spark 1.3 + // The latter is the default behavior as of Spark 1.3, but Spark submit will fail over + // to use the legacy gateway if the master endpoint turns out to be not a REST server. + if (args.isStandaloneCluster && args.useRest) { + try { + printStream.println("Running Spark using the REST application submission protocol.") + doRunMain() + } catch { + // Fail over to use the legacy submission gateway + case e: SubmitRestConnectionException => + printWarning(s"Master endpoint ${args.master} was not a REST server. " + + "Falling back to legacy submission gateway instead.") + args.useRest = false + submit(args) + } + // In all other modes, just run the main class as prepared + } else { + doRunMain() + } + } + + /** + * Prepare the environment for submitting an application. + * This returns a 4-tuple: + * (1) the arguments for the child process, + * (2) a list of classpath entries for the child, + * (3) a map of system properties, and + * (4) the main class for the child + * Exposed for testing. + */ + private[deploy] def prepareSubmitEnvironment(args: SparkSubmitArguments) + : (Seq[String], Seq[String], Map[String, String], String) = { + // Return values val childArgs = new ArrayBuffer[String]() val childClasspath = new ArrayBuffer[String]() val sysProps = new HashMap[String, String]() @@ -134,24 +251,70 @@ object SparkSubmit { } } + val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER + + // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files + // too for packages that include Python code + val resolvedMavenCoordinates = + SparkSubmitUtils.resolveMavenCoordinates( + args.packages, Option(args.repositories), Option(args.ivyRepoPath)) + if (!resolvedMavenCoordinates.trim.isEmpty) { + if (args.jars == null || args.jars.trim.isEmpty) { + args.jars = resolvedMavenCoordinates + } else { + args.jars += s",$resolvedMavenCoordinates" + } + if (args.isPython) { + if (args.pyFiles == null || args.pyFiles.trim.isEmpty) { + args.pyFiles = resolvedMavenCoordinates + } else { + args.pyFiles += s",$resolvedMavenCoordinates" + } + } + } + + // Require all python files to be local, so we can add them to the PYTHONPATH + // In YARN cluster mode, python files are distributed as regular files, which can be non-local + if (args.isPython && !isYarnCluster) { + if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { + printErrorAndExit(s"Only local python files are supported: $args.primaryResource") + } + val nonLocalPyFiles = Utils.nonLocalPaths(args.pyFiles).mkString(",") + if (nonLocalPyFiles.nonEmpty) { + printErrorAndExit(s"Only local additional python files are supported: $nonLocalPyFiles") + } + } + + // Require all R files to be local + if (args.isR && !isYarnCluster) { + if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { + printErrorAndExit(s"Only local R files are supported: $args.primaryResource") + } + } + // The following modes are not supported or applicable (clusterManager, deployMode) match { case (MESOS, CLUSTER) => printErrorAndExit("Cluster deploy mode is currently not supported for Mesos clusters.") - case (_, CLUSTER) if args.isPython => - printErrorAndExit("Cluster deploy mode is currently not supported for python applications.") + case (STANDALONE, CLUSTER) if args.isPython => + printErrorAndExit("Cluster deploy mode is currently not supported for python " + + "applications on standalone clusters.") + case (STANDALONE, CLUSTER) if args.isR => + printErrorAndExit("Cluster deploy mode is currently not supported for R " + + "applications on standalone clusters.") case (_, CLUSTER) if isShell(args.primaryResource) => printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.") case (_, CLUSTER) if isSqlShell(args.mainClass) => printErrorAndExit("Cluster deploy mode is not applicable to Spark SQL shell.") + case (_, CLUSTER) if isThriftServer(args.mainClass) => + printErrorAndExit("Cluster deploy mode is not applicable to Spark Thrift server.") case _ => } // If we're running a python app, set the main class to our specific python runner - if (args.isPython) { + if (args.isPython && deployMode == CLIENT) { if (args.primaryResource == PYSPARK_SHELL) { - args.mainClass = "py4j.GatewayServer" - args.childArgs = ArrayBuffer("--die-on-broken-pipe", "0") + args.mainClass = "org.apache.spark.api.python.PythonGatewayServer" } else { // If a python file is provided, add it to the child arguments and list of files to deploy. // Usage: PythonAppRunner
[app arguments] @@ -165,6 +328,34 @@ object SparkSubmit { } } + // If we're running a R app, set the main class to our specific R runner + if (args.isR && deployMode == CLIENT) { + if (args.primaryResource == SPARKR_SHELL) { + args.mainClass = "org.apache.spark.api.r.RBackend" + } else { + // If a R file is provided, add it to the child arguments and list of files to deploy. + // Usage: RRunner
[app arguments] + args.mainClass = "org.apache.spark.deploy.RRunner" + args.childArgs = ArrayBuffer(args.primaryResource) ++ args.childArgs + args.files = mergeFileLists(args.files, args.primaryResource) + } + } + + if (isYarnCluster) { + // In yarn-cluster mode for a python app, add primary resource and pyFiles to files + // that can be distributed with the job + if (args.isPython) { + args.files = mergeFileLists(args.files, args.primaryResource) + args.files = mergeFileLists(args.files, args.pyFiles) + } + + // In yarn-cluster mode for a R app, add primary resource to files + // that can be distributed with the job + if (args.isR) { + args.files = mergeFileLists(args.files, args.primaryResource) + } + } + // Special flag to avoid deprecation warnings at the client sysProps("SPARK_SUBMIT") = "true" @@ -176,6 +367,7 @@ object SparkSubmit { OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.master"), OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"), OptionAssigner(args.jars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars"), + OptionAssigner(args.ivyRepoPath, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars.ivy"), OptionAssigner(args.driverMemory, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.driver.memory"), OptionAssigner(args.driverExtraClassPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, @@ -186,9 +378,13 @@ object SparkSubmit { sysProp = "spark.driver.extraLibraryPath"), // Standalone cluster only + // Do not set CL arguments here because there are multiple possibilities for the main class OptionAssigner(args.jars, STANDALONE, CLUSTER, sysProp = "spark.jars"), - OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, clOption = "--memory"), - OptionAssigner(args.driverCores, STANDALONE, CLUSTER, clOption = "--cores"), + OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy"), + OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, sysProp = "spark.driver.memory"), + OptionAssigner(args.driverCores, STANDALONE, CLUSTER, sysProp = "spark.driver.cores"), + OptionAssigner(args.supervise.toString, STANDALONE, CLUSTER, + sysProp = "spark.driver.supervise"), // Yarn client only OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"), @@ -210,6 +406,8 @@ object SparkSubmit { OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"), // Other options + OptionAssigner(args.executorCores, STANDALONE, ALL_DEPLOY_MODES, + sysProp = "spark.executor.cores"), OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES, sysProp = "spark.executor.memory"), OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, ALL_DEPLOY_MODES, @@ -229,7 +427,6 @@ object SparkSubmit { if (args.childArgs != null) { childArgs ++= args.childArgs } } - // Map all arguments to command-line options or system properties for our chosen mode for (opt <- options) { if (opt.value != null && @@ -242,9 +439,8 @@ object SparkSubmit { // Add the application jar automatically so the user doesn't have to call sc.addJar // For YARN cluster mode, the jar is already distributed on each node as "app.jar" - // For python files, the primary resource is already distributed as a regular file - val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER - if (!isYarnCluster && !args.isPython) { + // For python and R files, the primary resource is already distributed as a regular file + if (!isYarnCluster && !args.isPython && !args.isR) { var jars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq.empty) if (isUserJar(args.primaryResource)) { jars = jars ++ Seq(args.primaryResource) @@ -252,14 +448,21 @@ object SparkSubmit { sysProps.put("spark.jars", jars.mkString(",")) } - // In standalone-cluster mode, use Client as a wrapper around the user class - if (clusterManager == STANDALONE && deployMode == CLUSTER) { - childMainClass = "org.apache.spark.deploy.Client" - if (args.supervise) { - childArgs += "--supervise" + // In standalone cluster mode, use the REST client to submit the application (Spark 1.3+). + // All Spark parameters are expected to be passed to the client through system properties. + if (args.isStandaloneCluster) { + if (args.useRest) { + childMainClass = "org.apache.spark.deploy.rest.StandaloneRestClient" + childArgs += (args.primaryResource, args.mainClass) + } else { + // In legacy standalone cluster mode, use Client as a wrapper around the user class + childMainClass = "org.apache.spark.deploy.Client" + if (args.supervise) { childArgs += "--supervise" } + Option(args.driverMemory).foreach { m => childArgs += ("--memory", m) } + Option(args.driverCores).foreach { c => childArgs += ("--cores", c) } + childArgs += "launch" + childArgs += (args.master, args.primaryResource, args.mainClass) } - childArgs += "launch" - childArgs += (args.master, args.primaryResource, args.mainClass) if (args.childArgs != null) { childArgs ++= args.childArgs } @@ -268,10 +471,26 @@ object SparkSubmit { // In yarn-cluster mode, use yarn.Client as a wrapper around the user class if (isYarnCluster) { childMainClass = "org.apache.spark.deploy.yarn.Client" - if (args.primaryResource != SPARK_INTERNAL) { - childArgs += ("--jar", args.primaryResource) + if (args.isPython) { + val mainPyFile = new Path(args.primaryResource).getName + childArgs += ("--primary-py-file", mainPyFile) + if (args.pyFiles != null) { + // These files will be distributed to each machine's working directory, so strip the + // path prefix + val pyFilesNames = args.pyFiles.split(",").map(p => (new Path(p)).getName).mkString(",") + childArgs += ("--py-files", pyFilesNames) + } + childArgs += ("--class", "org.apache.spark.deploy.PythonRunner") + } else if (args.isR) { + val mainFile = new Path(args.primaryResource).getName + childArgs += ("--primary-r-file", mainFile) + childArgs += ("--class", "org.apache.spark.deploy.RRunner") + } else { + if (args.primaryResource != SPARK_INTERNAL) { + childArgs += ("--jar", args.primaryResource) + } + childArgs += ("--class", args.mainClass) } - childArgs += ("--class", args.mainClass) if (args.childArgs != null) { args.childArgs.foreach { arg => childArgs += ("--arg", arg) } } @@ -284,7 +503,7 @@ object SparkSubmit { // Ignore invalid spark.driver.host in cluster modes. if (deployMode == CLUSTER) { - sysProps -= ("spark.driver.host") + sysProps -= "spark.driver.host" } // Resolve paths in certain spark properties @@ -313,12 +532,18 @@ object SparkSubmit { (childArgs, childClasspath, sysProps, childMainClass) } - private def launch( - childArgs: ArrayBuffer[String], - childClasspath: ArrayBuffer[String], + /** + * Run the main method of the child class using the provided launch environment. + * + * Note that this main class will not be the one provided by the user if we're + * running cluster deploy mode or python applications. + */ + private def runMain( + childArgs: Seq[String], + childClasspath: Seq[String], sysProps: Map[String, String], childMainClass: String, - verbose: Boolean = false) { + verbose: Boolean): Unit = { if (verbose) { printStream.println(s"Main class:\n$childMainClass") printStream.println(s"Arguments:\n${childArgs.mkString("\n")}") @@ -327,8 +552,14 @@ object SparkSubmit { printStream.println("\n") } - val loader = new ExecutorURLClassLoader(new Array[URL](0), - Thread.currentThread.getContextClassLoader) + val loader = + if (sysProps.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) { + new ChildFirstURLClassLoader(new Array[URL](0), + Thread.currentThread.getContextClassLoader) + } else { + new MutableURLClassLoader(new Array[URL](0), + Thread.currentThread.getContextClassLoader) + } Thread.currentThread.setContextClassLoader(loader) for (jar <- childClasspath) { @@ -347,8 +578,8 @@ object SparkSubmit { case e: ClassNotFoundException => e.printStackTrace(printStream) if (childMainClass.contains("thriftserver")) { - println(s"Failed to load main class $childMainClass.") - println("You need to build Spark with -Phive and -Phive-thriftserver.") + printStream.println(s"Failed to load main class $childMainClass.") + printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.") } System.exit(CLASS_NOT_FOUND_EXIT_STATUS) } @@ -362,17 +593,25 @@ object SparkSubmit { if (!Modifier.isStatic(mainMethod.getModifiers)) { throw new IllegalStateException("The main method in the given main class must be static") } + + def findCause(t: Throwable): Throwable = t match { + case e: UndeclaredThrowableException => + if (e.getCause() != null) findCause(e.getCause()) else e + case e: InvocationTargetException => + if (e.getCause() != null) findCause(e.getCause()) else e + case e: Throwable => + e + } + try { mainMethod.invoke(null, childArgs.toArray) } catch { - case e: InvocationTargetException => e.getCause match { - case cause: Throwable => throw cause - case null => throw e - } + case t: Throwable => + throw findCause(t) } } - private def addJarToClasspath(localJar: String, loader: ExecutorURLClassLoader) { + private def addJarToClasspath(localJar: String, loader: MutableURLClassLoader) { val uri = Utils.resolveURI(localJar) uri.getScheme match { case "file" | "local" => @@ -390,40 +629,54 @@ object SparkSubmit { /** * Return whether the given primary resource represents a user jar. */ - private def isUserJar(primaryResource: String): Boolean = { - !isShell(primaryResource) && !isPython(primaryResource) && !isInternal(primaryResource) + private[deploy] def isUserJar(res: String): Boolean = { + !isShell(res) && !isPython(res) && !isInternal(res) && !isR(res) } /** * Return whether the given primary resource represents a shell. */ - private[spark] def isShell(primaryResource: String): Boolean = { - primaryResource == SPARK_SHELL || primaryResource == PYSPARK_SHELL + private[deploy] def isShell(res: String): Boolean = { + (res == SPARK_SHELL || res == PYSPARK_SHELL || res == SPARKR_SHELL) } /** * Return whether the given main class represents a sql shell. */ - private[spark] def isSqlShell(mainClass: String): Boolean = { + private def isSqlShell(mainClass: String): Boolean = { mainClass == "org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" } + /** + * Return whether the given main class represents a thrift server. + */ + private def isThriftServer(mainClass: String): Boolean = { + mainClass == "org.apache.spark.sql.hive.thriftserver.HiveThriftServer2" + } + /** * Return whether the given primary resource requires running python. */ - private[spark] def isPython(primaryResource: String): Boolean = { - primaryResource.endsWith(".py") || primaryResource == PYSPARK_SHELL + private[deploy] def isPython(res: String): Boolean = { + res != null && res.endsWith(".py") || res == PYSPARK_SHELL } - private[spark] def isInternal(primaryResource: String): Boolean = { - primaryResource == SPARK_INTERNAL + /** + * Return whether the given primary resource requires running R. + */ + private[deploy] def isR(res: String): Boolean = { + res != null && res.endsWith(".R") || res == SPARKR_SHELL + } + + private[deploy] def isInternal(res: String): Boolean = { + res == SPARK_INTERNAL } /** * Merge a sequence of comma-separated file lists, some of which may be null to indicate * no files, into a single comma-separated string. */ - private[spark] def mergeFileLists(lists: String*): String = { + private def mergeFileLists(lists: String*): String = { val merged = lists.filter(_ != null) .flatMap(_.split(",")) .mkString(",") @@ -431,11 +684,234 @@ object SparkSubmit { } } +/** Provides utility functions to be used inside SparkSubmit. */ +private[deploy] object SparkSubmitUtils { + + // Exposed for testing + var printStream = SparkSubmit.printStream + + /** + * Represents a Maven Coordinate + * @param groupId the groupId of the coordinate + * @param artifactId the artifactId of the coordinate + * @param version the version of the coordinate + */ + private[deploy] case class MavenCoordinate(groupId: String, artifactId: String, version: String) + +/** + * Extracts maven coordinates from a comma-delimited string. Coordinates should be provided + * in the format `groupId:artifactId:version` or `groupId/artifactId:version`. + * @param coordinates Comma-delimited string of maven coordinates + * @return Sequence of Maven coordinates + */ + def extractMavenCoordinates(coordinates: String): Seq[MavenCoordinate] = { + coordinates.split(",").map { p => + val splits = p.replace("/", ":").split(":") + require(splits.length == 3, s"Provided Maven Coordinates must be in the form " + + s"'groupId:artifactId:version'. The coordinate provided is: $p") + require(splits(0) != null && splits(0).trim.nonEmpty, s"The groupId cannot be null or " + + s"be whitespace. The groupId provided is: ${splits(0)}") + require(splits(1) != null && splits(1).trim.nonEmpty, s"The artifactId cannot be null or " + + s"be whitespace. The artifactId provided is: ${splits(1)}") + require(splits(2) != null && splits(2).trim.nonEmpty, s"The version cannot be null or " + + s"be whitespace. The version provided is: ${splits(2)}") + new MavenCoordinate(splits(0), splits(1), splits(2)) + } + } + + /** + * Extracts maven coordinates from a comma-delimited string + * @param remoteRepos Comma-delimited string of remote repositories + * @return A ChainResolver used by Ivy to search for and resolve dependencies. + */ + def createRepoResolvers(remoteRepos: Option[String]): ChainResolver = { + // We need a chain resolver if we want to check multiple repositories + val cr = new ChainResolver + cr.setName("list") + + // the biblio resolver resolves POM declared dependencies + val br: IBiblioResolver = new IBiblioResolver + br.setM2compatible(true) + br.setUsepoms(true) + br.setName("central") + cr.add(br) + + val sp: IBiblioResolver = new IBiblioResolver + sp.setM2compatible(true) + sp.setUsepoms(true) + sp.setRoot("http://dl.bintray.com/spark-packages/maven") + sp.setName("spark-packages") + cr.add(sp) + + val repositoryList = remoteRepos.getOrElse("") + // add any other remote repositories other than maven central + if (repositoryList.trim.nonEmpty) { + repositoryList.split(",").zipWithIndex.foreach { case (repo, i) => + val brr: IBiblioResolver = new IBiblioResolver + brr.setM2compatible(true) + brr.setUsepoms(true) + brr.setRoot(repo) + brr.setName(s"repo-${i + 1}") + cr.add(brr) + printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}") + } + } + cr + } + + /** + * Output a comma-delimited list of paths for the downloaded jars to be added to the classpath + * (will append to jars in SparkSubmit). The name of the jar is given + * after a '!' by Ivy. It also sometimes contains '(bundle)' after '.jar'. Remove that as well. + * @param artifacts Sequence of dependencies that were resolved and retrieved + * @param cacheDirectory directory where jars are cached + * @return a comma-delimited list of paths for the dependencies + */ + def resolveDependencyPaths( + artifacts: Array[AnyRef], + cacheDirectory: File): String = { + artifacts.map { artifactInfo => + val artifactString = artifactInfo.toString + val jarName = artifactString.drop(artifactString.lastIndexOf("!") + 1) + cacheDirectory.getAbsolutePath + File.separator + + jarName.substring(0, jarName.lastIndexOf(".jar") + 4) + }.mkString(",") + } + + /** Adds the given maven coordinates to Ivy's module descriptor. */ + def addDependenciesToIvy( + md: DefaultModuleDescriptor, + artifacts: Seq[MavenCoordinate], + ivyConfName: String): Unit = { + artifacts.foreach { mvn => + val ri = ModuleRevisionId.newInstance(mvn.groupId, mvn.artifactId, mvn.version) + val dd = new DefaultDependencyDescriptor(ri, false, false) + dd.addDependencyConfiguration(ivyConfName, ivyConfName) + printStream.println(s"${dd.getDependencyId} added as a dependency") + md.addDependency(dd) + } + } + + /** Add exclusion rules for dependencies already included in the spark-assembly */ + def addExclusionRules( + ivySettings: IvySettings, + ivyConfName: String, + md: DefaultModuleDescriptor): Unit = { + // Add scala exclusion rule + val scalaArtifacts = new ArtifactId(new ModuleId("*", "scala-library"), "*", "*", "*") + val scalaDependencyExcludeRule = + new DefaultExcludeRule(scalaArtifacts, ivySettings.getMatcher("glob"), null) + scalaDependencyExcludeRule.addConfiguration(ivyConfName) + md.addExcludeRule(scalaDependencyExcludeRule) + + // We need to specify each component explicitly, otherwise we miss spark-streaming-kafka and + // other spark-streaming utility components. Underscore is there to differentiate between + // spark-streaming_2.1x and spark-streaming-kafka-assembly_2.1x + val components = Seq("bagel_", "catalyst_", "core_", "graphx_", "hive_", "mllib_", "repl_", + "sql_", "streaming_", "yarn_", "network-common_", "network-shuffle_", "network-yarn_") + + components.foreach { comp => + val sparkArtifacts = + new ArtifactId(new ModuleId("org.apache.spark", s"spark-$comp*"), "*", "*", "*") + val sparkDependencyExcludeRule = + new DefaultExcludeRule(sparkArtifacts, ivySettings.getMatcher("glob"), null) + sparkDependencyExcludeRule.addConfiguration(ivyConfName) + + md.addExcludeRule(sparkDependencyExcludeRule) + } + } + + /** A nice function to use in tests as well. Values are dummy strings. */ + def getModuleDescriptor: DefaultModuleDescriptor = DefaultModuleDescriptor.newDefaultInstance( + ModuleRevisionId.newInstance("org.apache.spark", "spark-submit-parent", "1.0")) + + /** + * Resolves any dependencies that were supplied through maven coordinates + * @param coordinates Comma-delimited string of maven coordinates + * @param remoteRepos Comma-delimited string of remote repositories other than maven central + * @param ivyPath The path to the local ivy repository + * @return The comma-delimited path to the jars of the given maven artifacts including their + * transitive dependencies + */ + def resolveMavenCoordinates( + coordinates: String, + remoteRepos: Option[String], + ivyPath: Option[String], + isTest: Boolean = false): String = { + if (coordinates == null || coordinates.trim.isEmpty) { + "" + } else { + val sysOut = System.out + // To prevent ivy from logging to system out + System.setOut(printStream) + val artifacts = extractMavenCoordinates(coordinates) + // Default configuration name for ivy + val ivyConfName = "default" + // set ivy settings for location of cache + val ivySettings: IvySettings = new IvySettings + // Directories for caching downloads through ivy and storing the jars when maven coordinates + // are supplied to spark-submit + val alternateIvyCache = ivyPath.getOrElse("") + val packagesDirectory: File = + if (alternateIvyCache.trim.isEmpty) { + new File(ivySettings.getDefaultIvyUserDir, "jars") + } else { + ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) + new File(alternateIvyCache, "jars") + } + printStream.println( + s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}") + printStream.println(s"The jars for the packages stored in: $packagesDirectory") + // create a pattern matcher + ivySettings.addMatcher(new GlobPatternMatcher) + // create the dependency resolvers + val repoResolver = createRepoResolvers(remoteRepos) + ivySettings.addResolver(repoResolver) + ivySettings.setDefaultResolver(repoResolver.getName) + + val ivy = Ivy.newInstance(ivySettings) + // Set resolve options to download transitive dependencies as well + val resolveOptions = new ResolveOptions + resolveOptions.setTransitive(true) + val retrieveOptions = new RetrieveOptions + // Turn downloading and logging off for testing + if (isTest) { + resolveOptions.setDownload(false) + resolveOptions.setLog(LogOptions.LOG_QUIET) + retrieveOptions.setLog(LogOptions.LOG_QUIET) + } else { + resolveOptions.setDownload(true) + } + + // A Module descriptor must be specified. Entries are dummy strings + val md = getModuleDescriptor + md.setDefaultConf(ivyConfName) + + // Add exclusion rules for Spark and Scala Library + addExclusionRules(ivySettings, ivyConfName, md) + // add all supplied maven artifacts as dependencies + addDependenciesToIvy(md, artifacts, ivyConfName) + + // resolve dependencies + val rr: ResolveReport = ivy.resolve(md, resolveOptions) + if (rr.hasError) { + throw new RuntimeException(rr.getAllProblemMessages.toString) + } + // retrieve all resolved dependencies + ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId, + packagesDirectory.getAbsolutePath + File.separator + "[artifact](-[classifier]).[ext]", + retrieveOptions.setConfs(Array(ivyConfName))) + System.setOut(sysOut) + resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) + } + } +} + /** * Provides an indirection layer for passing arguments as system properties or flags to * the user's driver program or to downstream launcher tools. */ -private[spark] case class OptionAssigner( +private case class OptionAssigner( value: String, clusterManager: Int, deployMode: Int, diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 81ec08cb6d50..c896842943f2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -18,17 +18,22 @@ package org.apache.spark.deploy import java.net.URI +import java.util.{List => JList} import java.util.jar.JarFile +import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} +import org.apache.spark.deploy.SparkSubmitAction._ +import org.apache.spark.launcher.SparkSubmitArgumentsParser import org.apache.spark.util.Utils /** * Parses and encapsulates arguments from the spark-submit script. * The env argument is used for testing. */ -private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, String] = sys.env) { +private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, String] = sys.env) + extends SparkSubmitArgumentsParser { var master: String = null var deployMode: String = null var executorMemory: String = null @@ -39,8 +44,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St var driverExtraClassPath: String = null var driverExtraLibraryPath: String = null var driverExtraJavaOptions: String = null - var driverCores: String = null - var supervise: Boolean = false var queue: String = null var numExecutors: String = null var files: String = null @@ -50,10 +53,23 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St var name: String = null var childArgs: ArrayBuffer[String] = new ArrayBuffer[String]() var jars: String = null + var packages: String = null + var repositories: String = null + var ivyRepoPath: String = null var verbose: Boolean = false var isPython: Boolean = false var pyFiles: String = null + var isR: Boolean = false + var action: SparkSubmitAction = null val sparkProperties: HashMap[String, String] = new HashMap[String, String]() + var proxyUser: String = null + + // Standalone cluster mode only + var supervise: Boolean = false + var driverCores: String = null + var submissionToKill: String = null + var submissionToRequestStatusFor: String = null + var useRest: Boolean = true // used internally /** Default properties present in the currently defined defaults file. */ lazy val defaultSparkProperties: HashMap[String, String] = { @@ -61,25 +77,28 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile") Option(propertiesFile).foreach { filename => Utils.getPropertiesFromFile(filename).foreach { case (k, v) => - if (k.startsWith("spark.")) { - defaultProperties(k) = v - if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v") - } else { - SparkSubmit.printWarning(s"Ignoring non-spark config property: $k=$v") - } + defaultProperties(k) = v + if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v") } } defaultProperties } // Set parameters from command line arguments - parseOpts(args.toList) + try { + parse(args.toList) + } catch { + case e: IllegalArgumentException => + SparkSubmit.printErrorAndExit(e.getMessage()) + } // Populate `sparkProperties` map from properties file mergeDefaultSparkProperties() + // Remove keys that don't start with "spark." from `sparkProperties`. + ignoreNonSparkProperties() // Use `sparkProperties` map along with env vars to fill in any missing parameters loadEnvironmentArguments() - checkRequiredArguments() + validateArguments() /** * Merge values from the default properties file with those specified through --conf. @@ -96,6 +115,18 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St } } + /** + * Remove keys that don't start with "spark." from `sparkProperties`. + */ + private def ignoreNonSparkProperties(): Unit = { + sparkProperties.foreach { case (k, v) => + if (!k.startsWith("spark.")) { + sparkProperties -= k + SparkSubmit.printWarning(s"Ignoring non-spark config property: $k=$v") + } + } + } + /** * Load arguments from environment variables, Spark properties etc. */ @@ -104,6 +135,15 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St .orElse(sparkProperties.get("spark.master")) .orElse(env.get("MASTER")) .orNull + driverExtraClassPath = Option(driverExtraClassPath) + .orElse(sparkProperties.get("spark.driver.extraClassPath")) + .orNull + driverExtraJavaOptions = Option(driverExtraJavaOptions) + .orElse(sparkProperties.get("spark.driver.extraJavaOptions")) + .orNull + driverExtraLibraryPath = Option(driverExtraLibraryPath) + .orElse(sparkProperties.get("spark.driver.extraLibraryPath")) + .orNull driverMemory = Option(driverMemory) .orElse(sparkProperties.get("spark.driver.memory")) .orElse(env.get("SPARK_DRIVER_MEMORY")) @@ -123,12 +163,13 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St .orNull name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull + ivyRepoPath = sparkProperties.get("spark.jars.ivy").orNull deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull numExecutors = Option(numExecutors) .getOrElse(sparkProperties.get("spark.executor.instances").orNull) // Try to set main class from JAR if no --class argument is given - if (mainClass == null && !isPython && primaryResource != null) { + if (mainClass == null && !isPython && !isR && primaryResource != null) { val uri = new URI(primaryResource) val uriScheme = uri.getScheme() @@ -162,35 +203,34 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St if (name == null && primaryResource != null) { name = Utils.stripDirectory(primaryResource) } + + // Action should be SUBMIT unless otherwise specified + action = Option(action).getOrElse(SUBMIT) } /** Ensure that required fields exists. Call this only once all defaults are loaded. */ - private def checkRequiredArguments(): Unit = { + private def validateArguments(): Unit = { + action match { + case SUBMIT => validateSubmitArguments() + case KILL => validateKillArguments() + case REQUEST_STATUS => validateStatusRequestArguments() + } + } + + private def validateSubmitArguments(): Unit = { if (args.length == 0) { printUsageAndExit(-1) } if (primaryResource == null) { - SparkSubmit.printErrorAndExit("Must specify a primary resource (JAR or Python file)") + SparkSubmit.printErrorAndExit("Must specify a primary resource (JAR or Python or R file)") } - if (mainClass == null && !isPython) { + if (mainClass == null && SparkSubmit.isUserJar(primaryResource)) { SparkSubmit.printErrorAndExit("No main class set in JAR; please specify one with --class") } if (pyFiles != null && !isPython) { SparkSubmit.printErrorAndExit("--py-files given but primary resource is not a Python script") } - // Require all python files to be local, so we can add them to the PYTHONPATH - if (isPython) { - if (Utils.nonLocalPaths(primaryResource).nonEmpty) { - SparkSubmit.printErrorAndExit(s"Only local python files are supported: $primaryResource") - } - val nonLocalPyFiles = Utils.nonLocalPaths(pyFiles).mkString(",") - if (nonLocalPyFiles.nonEmpty) { - SparkSubmit.printErrorAndExit( - s"Only local additional python files are supported: $nonLocalPyFiles") - } - } - if (master.startsWith("yarn")) { val hasHadoopEnv = env.contains("HADOOP_CONF_DIR") || env.contains("YARN_CONF_DIR") if (!hasHadoopEnv && !Utils.isTesting) { @@ -200,7 +240,30 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St } } - override def toString = { + private def validateKillArguments(): Unit = { + if (!master.startsWith("spark://")) { + SparkSubmit.printErrorAndExit("Killing submissions is only supported in standalone mode!") + } + if (submissionToKill == null) { + SparkSubmit.printErrorAndExit("Please specify a submission to kill.") + } + } + + private def validateStatusRequestArguments(): Unit = { + if (!master.startsWith("spark://")) { + SparkSubmit.printErrorAndExit( + "Requesting submission statuses is only supported in standalone mode!") + } + if (submissionToRequestStatusFor == null) { + SparkSubmit.printErrorAndExit("Please specify a submission to request status for.") + } + } + + def isStandaloneCluster: Boolean = { + master.startsWith("spark://") && deployMode == "cluster" + } + + override def toString: String = { s"""Parsed arguments: | master $master | deployMode $deployMode @@ -224,6 +287,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St | name $name | childArgs [${childArgs.mkString(" ")}] | jars $jars + | packages $packages + | repositories $repositories | verbose $verbose | |Spark properties used, including those specified through @@ -232,136 +297,140 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St """.stripMargin } - /** - * Fill in values by parsing user options. - * NOTE: Any changes here must be reflected in YarnClientSchedulerBackend. - */ - private def parseOpts(opts: Seq[String]): Unit = { - val EQ_SEPARATED_OPT="""(--[^=]+)=(.+)""".r - - // Delineates parsing of Spark options from parsing of user options. - parse(opts) - - /** - * NOTE: If you add or remove spark-submit options, - * modify NOT ONLY this file but also utils.sh - */ - def parse(opts: Seq[String]): Unit = opts match { - case ("--name") :: value :: tail => + /** Fill in values by parsing user options. */ + override protected def handle(opt: String, value: String): Boolean = { + opt match { + case NAME => name = value - parse(tail) - case ("--master") :: value :: tail => + case MASTER => master = value - parse(tail) - case ("--class") :: value :: tail => + case CLASS => mainClass = value - parse(tail) - case ("--deploy-mode") :: value :: tail => + case DEPLOY_MODE => if (value != "client" && value != "cluster") { SparkSubmit.printErrorAndExit("--deploy-mode must be either \"client\" or \"cluster\"") } deployMode = value - parse(tail) - case ("--num-executors") :: value :: tail => + case NUM_EXECUTORS => numExecutors = value - parse(tail) - case ("--total-executor-cores") :: value :: tail => + case TOTAL_EXECUTOR_CORES => totalExecutorCores = value - parse(tail) - case ("--executor-cores") :: value :: tail => + case EXECUTOR_CORES => executorCores = value - parse(tail) - case ("--executor-memory") :: value :: tail => + case EXECUTOR_MEMORY => executorMemory = value - parse(tail) - case ("--driver-memory") :: value :: tail => + case DRIVER_MEMORY => driverMemory = value - parse(tail) - case ("--driver-cores") :: value :: tail => + case DRIVER_CORES => driverCores = value - parse(tail) - case ("--driver-class-path") :: value :: tail => + case DRIVER_CLASS_PATH => driverExtraClassPath = value - parse(tail) - case ("--driver-java-options") :: value :: tail => + case DRIVER_JAVA_OPTIONS => driverExtraJavaOptions = value - parse(tail) - case ("--driver-library-path") :: value :: tail => + case DRIVER_LIBRARY_PATH => driverExtraLibraryPath = value - parse(tail) - case ("--properties-file") :: value :: tail => + case PROPERTIES_FILE => propertiesFile = value - parse(tail) - case ("--supervise") :: tail => + case KILL_SUBMISSION => + submissionToKill = value + if (action != null) { + SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $KILL.") + } + action = KILL + + case STATUS => + submissionToRequestStatusFor = value + if (action != null) { + SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $REQUEST_STATUS.") + } + action = REQUEST_STATUS + + case SUPERVISE => supervise = true - parse(tail) - case ("--queue") :: value :: tail => + case QUEUE => queue = value - parse(tail) - case ("--files") :: value :: tail => + case FILES => files = Utils.resolveURIs(value) - parse(tail) - case ("--py-files") :: value :: tail => + case PY_FILES => pyFiles = Utils.resolveURIs(value) - parse(tail) - case ("--archives") :: value :: tail => + case ARCHIVES => archives = Utils.resolveURIs(value) - parse(tail) - case ("--jars") :: value :: tail => + case JARS => jars = Utils.resolveURIs(value) - parse(tail) - case ("--conf" | "-c") :: value :: tail => + case PACKAGES => + packages = value + + case REPOSITORIES => + repositories = value + + case CONF => value.split("=", 2).toSeq match { case Seq(k, v) => sparkProperties(k) = v case _ => SparkSubmit.printErrorAndExit(s"Spark config without '=': $value") } - parse(tail) - case ("--help" | "-h") :: tail => + case PROXY_USER => + proxyUser = value + + case HELP => printUsageAndExit(0) - case ("--verbose" | "-v") :: tail => + case VERBOSE => verbose = true - parse(tail) - - case EQ_SEPARATED_OPT(opt, value) :: tail => - parse(opt :: value :: tail) - case value :: tail if value.startsWith("-") => - SparkSubmit.printErrorAndExit(s"Unrecognized option '$value'.") + case VERSION => + SparkSubmit.printVersionAndExit() - case value :: tail => - primaryResource = - if (!SparkSubmit.isShell(value) && !SparkSubmit.isInternal(value)) { - Utils.resolveURI(value).toString - } else { - value - } - isPython = SparkSubmit.isPython(value) - childArgs ++= tail + case _ => + throw new IllegalArgumentException(s"Unexpected argument '$opt'.") + } + true + } - case Nil => + /** + * Handle unrecognized command line options. + * + * The first unrecognized option is treated as the "primary resource". Everything else is + * treated as application arguments. + */ + override protected def handleUnknown(opt: String): Boolean = { + if (opt.startsWith("-")) { + SparkSubmit.printErrorAndExit(s"Unrecognized option '$opt'.") } + + primaryResource = + if (!SparkSubmit.isShell(opt) && !SparkSubmit.isInternal(opt)) { + Utils.resolveURI(opt).toString + } else { + opt + } + isPython = SparkSubmit.isPython(opt) + isR = SparkSubmit.isR(opt) + false + } + + override protected def handleExtraArgs(extra: JList[String]): Unit = { + childArgs ++= extra } private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = { @@ -370,7 +439,10 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St outStream.println("Unknown/unsupported param " + unknownParam) } outStream.println( - """Usage: spark-submit [options] [app options] + """Usage: spark-submit [options] [app arguments] + |Usage: spark-submit --kill [submission ID] --master [spark://...] + |Usage: spark-submit --status [submission ID] --master [spark://...] + | |Options: | --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local. | --deploy-mode DEPLOY_MODE Whether to launch the driver program locally ("client") or @@ -380,6 +452,13 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St | --name NAME A name of your application. | --jars JARS Comma-separated list of local jars to include on the driver | and executor classpaths. + | --packages Comma-separated list of maven coordinates of jars to include + | on the driver and executor classpaths. Will search the local + | maven repo, then maven central and any additional remote + | repositories given by --repositories. The format for the + | coordinates should be groupId:artifactId:version. + | --repositories Comma-separated list of additional remote repositories to + | search for the maven coordinates given with --packages. | --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to place | on the PYTHONPATH for Python apps. | --files FILES Comma-separated list of files to be placed in the working @@ -398,20 +477,28 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St | | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G). | + | --proxy-user NAME User to impersonate when submitting the application. + | | --help, -h Show this help message and exit | --verbose, -v Print additional debug output + | --version, Print the version of current Spark | | Spark standalone with cluster deploy mode only: | --driver-cores NUM Cores for driver (Default: 1). | --supervise If given, restarts the driver on failure. + | --kill SUBMISSION_ID If given, kills the driver specified. + | --status SUBMISSION_ID If given, requests the status of the driver specified. | | Spark standalone and Mesos only: | --total-executor-cores NUM Total cores for all executors. | + | Spark standalone and YARN only: + | --executor-cores NUM Number of cores per executor. (Default: 1 in YARN mode, + | or all available cores on the worker in standalone mode) + | | YARN-only: | --driver-cores NUM Number of cores used by the driver, only in cluster mode | (Default: 1). - | --executor-cores NUM Number of cores per executor (Default: 1). | --queue QUEUE_NAME The YARN queue to submit to (Default: "default"). | --num-executors NUM Number of executors to launch (Default: 2). | --archives ARCHIVES Comma separated list of archives to be extracted into the diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala deleted file mode 100644 index 2eab9981845e..000000000000 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala +++ /dev/null @@ -1,170 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy - -import java.io.File - -import scala.collection.JavaConversions._ - -import org.apache.spark.util.{RedirectThread, Utils} - -/** - * Launch an application through Spark submit in client mode with the appropriate classpath, - * library paths, java options and memory. These properties of the JVM must be set before the - * driver JVM is launched. The sole purpose of this class is to avoid handling the complexity - * of parsing the properties file for such relevant configs in Bash. - * - * Usage: org.apache.spark.deploy.SparkSubmitDriverBootstrapper - */ -private[spark] object SparkSubmitDriverBootstrapper { - - // Note: This class depends on the behavior of `bin/spark-class` and `bin/spark-submit`. - // Any changes made there must be reflected in this file. - - def main(args: Array[String]): Unit = { - - // This should be called only from `bin/spark-class` - if (!sys.env.contains("SPARK_CLASS")) { - System.err.println("SparkSubmitDriverBootstrapper must be called from `bin/spark-class`!") - System.exit(1) - } - - val submitArgs = args - val runner = sys.env("RUNNER") - val classpath = sys.env("CLASSPATH") - val javaOpts = sys.env("JAVA_OPTS") - val defaultDriverMemory = sys.env("OUR_JAVA_MEM") - - // Spark submit specific environment variables - val deployMode = sys.env("SPARK_SUBMIT_DEPLOY_MODE") - val propertiesFile = sys.env("SPARK_SUBMIT_PROPERTIES_FILE") - val bootstrapDriver = sys.env("SPARK_SUBMIT_BOOTSTRAP_DRIVER") - val submitDriverMemory = sys.env.get("SPARK_SUBMIT_DRIVER_MEMORY") - val submitLibraryPath = sys.env.get("SPARK_SUBMIT_LIBRARY_PATH") - val submitClasspath = sys.env.get("SPARK_SUBMIT_CLASSPATH") - val submitJavaOpts = sys.env.get("SPARK_SUBMIT_OPTS") - - assume(runner != null, "RUNNER must be set") - assume(classpath != null, "CLASSPATH must be set") - assume(javaOpts != null, "JAVA_OPTS must be set") - assume(defaultDriverMemory != null, "OUR_JAVA_MEM must be set") - assume(deployMode == "client", "SPARK_SUBMIT_DEPLOY_MODE must be \"client\"!") - assume(propertiesFile != null, "SPARK_SUBMIT_PROPERTIES_FILE must be set") - assume(bootstrapDriver != null, "SPARK_SUBMIT_BOOTSTRAP_DRIVER must be set") - - // Parse the properties file for the equivalent spark.driver.* configs - val properties = Utils.getPropertiesFromFile(propertiesFile) - val confDriverMemory = properties.get("spark.driver.memory") - val confLibraryPath = properties.get("spark.driver.extraLibraryPath") - val confClasspath = properties.get("spark.driver.extraClassPath") - val confJavaOpts = properties.get("spark.driver.extraJavaOptions") - - // Favor Spark submit arguments over the equivalent configs in the properties file. - // Note that we do not actually use the Spark submit values for library path, classpath, - // and Java opts here, because we have already captured them in Bash. - - val newDriverMemory = submitDriverMemory - .orElse(confDriverMemory) - .getOrElse(defaultDriverMemory) - - val newClasspath = - if (submitClasspath.isDefined) { - classpath - } else { - classpath + confClasspath.map(sys.props("path.separator") + _).getOrElse("") - } - - val newJavaOpts = - if (submitJavaOpts.isDefined) { - // SPARK_SUBMIT_OPTS is already captured in JAVA_OPTS - javaOpts - } else { - javaOpts + confJavaOpts.map(" " + _).getOrElse("") - } - - val filteredJavaOpts = Utils.splitCommandString(newJavaOpts) - .filterNot(_.startsWith("-Xms")) - .filterNot(_.startsWith("-Xmx")) - - // Build up command - val command: Seq[String] = - Seq(runner) ++ - Seq("-cp", newClasspath) ++ - filteredJavaOpts ++ - Seq(s"-Xms$newDriverMemory", s"-Xmx$newDriverMemory") ++ - Seq("org.apache.spark.deploy.SparkSubmit") ++ - submitArgs - - // Print the launch command. This follows closely the format used in `bin/spark-class`. - if (sys.env.contains("SPARK_PRINT_LAUNCH_COMMAND")) { - System.err.print("Spark Command: ") - System.err.println(command.mkString(" ")) - System.err.println("========================================\n") - } - - // Start the driver JVM - val filteredCommand = command.filter(_.nonEmpty) - val builder = new ProcessBuilder(filteredCommand) - val env = builder.environment() - - if (submitLibraryPath.isEmpty && confLibraryPath.nonEmpty) { - val libraryPaths = confLibraryPath ++ sys.env.get(Utils.libraryPathEnvName) - env.put(Utils.libraryPathEnvName, libraryPaths.mkString(sys.props("path.separator"))) - } - - val process = builder.start() - - // If we kill an app while it's running, its sub-process should be killed too. - Runtime.getRuntime().addShutdownHook(new Thread() { - override def run() = { - if (process != null) { - process.destroy() - process.waitFor() - } - } - }) - - // Redirect stdout and stderr from the child JVM - val stdoutThread = new RedirectThread(process.getInputStream, System.out, "redirect stdout") - val stderrThread = new RedirectThread(process.getErrorStream, System.err, "redirect stderr") - stdoutThread.start() - stderrThread.start() - - // Redirect stdin to child JVM only if we're not running Windows. This is because the - // subprocess there already reads directly from our stdin, so we should avoid spawning a - // thread that contends with the subprocess in reading from System.in. - val isWindows = Utils.isWindows - val isSubprocess = sys.env.contains("IS_SUBPROCESS") - if (!isWindows) { - val stdinThread = new RedirectThread(System.in, process.getOutputStream, "redirect stdin", - propagateEof = true) - stdinThread.start() - // Spark submit (JVM) may run as a subprocess, and so this JVM should terminate on - // broken pipe, signaling that the parent process has exited. This is the case if the - // application is launched directly from python, as in the PySpark shell. In Windows, - // the termination logic is handled in java_gateway.py - if (isSubprocess) { - stdinThread.join() - process.destroy() - } - } - val returnCode = process.waitFor() - sys.exit(returnCode) - } - -} diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 39a7b0319b6a..43c8a934c311 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -30,7 +30,7 @@ import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master -import org.apache.spark.util.{ActorLogReceive, Utils, AkkaUtils} +import org.apache.spark.util.{ActorLogReceive, RpcUtils, Utils, AkkaUtils} /** * Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL, @@ -47,18 +47,18 @@ private[spark] class AppClient( conf: SparkConf) extends Logging { - val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl) + private val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem))) - val REGISTRATION_TIMEOUT = 20.seconds - val REGISTRATION_RETRIES = 3 + private val REGISTRATION_TIMEOUT = 20.seconds + private val REGISTRATION_RETRIES = 3 - var masterAddress: Address = null - var actor: ActorRef = null - var appId: String = null - var registered = false - var activeMasterUrl: String = null + private var masterAddress: Address = null + private var actor: ActorRef = null + private var appId: String = null + private var registered = false + private var activeMasterUrl: String = null - class ClientActor extends Actor with ActorLogReceive with Logging { + private class ClientActor extends Actor with ActorLogReceive with Logging { var master: ActorSelection = null var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times var alreadyDead = false // To avoid calling listener.dead() multiple times @@ -107,15 +107,16 @@ private[spark] class AppClient( def changeMaster(url: String) { // activeMasterUrl is a valid Spark url since we receive it from master. activeMasterUrl = url - master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl)) - masterAddress = Master.toAkkaAddress(activeMasterUrl) + master = context.actorSelection( + Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(actorSystem))) + masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(actorSystem)) } private def isPossibleMaster(remoteUrl: Address) = { masterAkkaUrls.map(AddressFromURIString(_).hostPort).contains(remoteUrl.hostPort) } - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case RegisteredApplication(appId_, masterUrl) => appId = appId_ registered = true @@ -156,6 +157,7 @@ private[spark] class AppClient( case StopAppClient => markDead("Application has been stopped.") + master ! UnregisterApplication(appId) sender ! true context.stop(self) } @@ -191,7 +193,7 @@ private[spark] class AppClient( def stop() { if (actor != null) { try { - val timeout = AkkaUtils.askTimeout(conf) + val timeout = RpcUtils.askTimeout(conf) val future = actor.ask(StopAppClient)(timeout) Await.result(future, timeout) } catch { diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index 88a0862b96af..40835b955058 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -23,7 +23,7 @@ import org.apache.spark.util.{AkkaUtils, Utils} private[spark] object TestClient { - class TestListener extends AppClientListener with Logging { + private class TestListener extends AppClientListener with Logging { def connected(id: String) { logInfo("Connected to master, got app ID " + id) } @@ -46,7 +46,7 @@ private[spark] object TestClient { def main(args: Array[String]) { val url = args(0) val conf = new SparkConf - val (actorSystem, _) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0, + val (actorSystem, _) = AkkaUtils.createActorSystem("spark", Utils.localHostName(), 0, conf = conf, securityManager = new SecurityManager(conf)) val desc = new ApplicationDescription("TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(), Seq(), Seq()), "ignored") diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index 553bf3cb945a..ea6c85ee511d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.history import org.apache.spark.ui.SparkUI -private[spark] case class ApplicationHistoryInfo( +private[history] case class ApplicationHistoryInfo( id: String, name: String, startTime: Long, @@ -28,7 +28,7 @@ private[spark] case class ApplicationHistoryInfo( sparkUser: String, completed: Boolean = false) -private[spark] abstract class ApplicationHistoryProvider { +private[history] abstract class ApplicationHistoryProvider { /** * Returns a list of applications available for the history server to show. diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 2b084a2d73b7..a94ebf6e5375 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -17,19 +17,23 @@ package org.apache.spark.deploy.history -import java.io.{BufferedInputStream, FileNotFoundException, InputStream} +import java.io.{IOException, BufferedInputStream, FileNotFoundException, InputStream} +import java.util.concurrent.{ExecutorService, Executors, TimeUnit} import scala.collection.mutable +import scala.concurrent.duration.Duration -import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.fs.permission.AccessControlException +import com.google.common.util.concurrent.ThreadFactoryBuilder -import org.apache.spark.{Logging, SecurityManager, SparkConf} +import com.google.common.util.concurrent.MoreExecutors +import org.apache.hadoop.fs.permission.AccessControlException +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec import org.apache.spark.scheduler._ import org.apache.spark.ui.SparkUI -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.{Logging, SecurityManager, SparkConf} /** * A class that provides application history from event logs stored in the file system. @@ -44,8 +48,10 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis private val NOT_STARTED = "" // Interval between each check for event log updates - private val UPDATE_INTERVAL_MS = conf.getInt("spark.history.fs.updateInterval", - conf.getInt("spark.history.updateInterval", 10)) * 1000 + private val UPDATE_INTERVAL_S = conf.getTimeAsSeconds("spark.history.fs.update.interval", "10s") + + // Interval between each cleaner checks for event logs to delete + private val CLEAN_INTERVAL_S = conf.getTimeAsSeconds("spark.history.fs.cleaner.interval", "1d") private val logDir = conf.getOption("spark.history.fs.logDirectory") .map { d => Utils.resolveURI(d).toString } @@ -53,8 +59,11 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis private val fs = Utils.getHadoopFileSystem(logDir, SparkHadoopUtil.get.newConfiguration(conf)) - // A timestamp of when the disk was last accessed to check for log updates - private var lastLogCheckTimeMs = -1L + // Used by check event thread and clean log thread. + // Scheduled thread pool size must be one, otherwise it will have concurrent issues about fs + // and applications between check task and clean task. + private val pool = Executors.newScheduledThreadPool(1, new ThreadFactoryBuilder() + .setNameFormat("spark-history-task-%d").setDaemon(true).build()) // The modification time of the newest log detected during the last scan. This is used // to ignore logs that are older during subsequent scans, to avoid processing data that @@ -66,36 +75,38 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis @volatile private var applications: mutable.LinkedHashMap[String, FsApplicationHistoryInfo] = new mutable.LinkedHashMap() + // List of applications to be deleted by event log cleaner. + private var appsToClean = new mutable.ListBuffer[FsApplicationHistoryInfo] + // Constants used to parse Spark 1.0.0 log directories. private[history] val LOG_PREFIX = "EVENT_LOG_" - private[history] val SPARK_VERSION_PREFIX = "SPARK_VERSION_" - private[history] val COMPRESSION_CODEC_PREFIX = "COMPRESSION_CODEC_" + private[history] val SPARK_VERSION_PREFIX = EventLoggingListener.SPARK_VERSION_KEY + "_" + private[history] val COMPRESSION_CODEC_PREFIX = EventLoggingListener.COMPRESSION_CODEC_KEY + "_" private[history] val APPLICATION_COMPLETE = "APPLICATION_COMPLETE" /** - * A background thread that periodically checks for event log updates on disk. - * - * If a log check is invoked manually in the middle of a period, this thread re-adjusts the - * time at which it performs the next log check to maintain the same period as before. - * - * TODO: Add a mechanism to update manually. + * Return a runnable that performs the given operation on the event logs. + * This operation is expected to be executed periodically. */ - private val logCheckingThread = new Thread("LogCheckingThread") { - override def run() = Utils.logUncaughtExceptions { - while (true) { - val now = getMonotonicTimeMs() - if (now - lastLogCheckTimeMs > UPDATE_INTERVAL_MS) { - Thread.sleep(UPDATE_INTERVAL_MS) - } else { - // If the user has manually checked for logs recently, wait until - // UPDATE_INTERVAL_MS after the last check time - Thread.sleep(lastLogCheckTimeMs + UPDATE_INTERVAL_MS - now) - } - checkForLogs() + private def getRunner(operateFun: () => Unit): Runnable = { + new Runnable() { + override def run(): Unit = Utils.tryOrExit { + operateFun() } } } + /** + * An Executor to fetch and parse log files. + */ + private val replayExecutor: ExecutorService = { + if (!conf.contains("spark.testing")) { + ThreadUtils.newDaemonSingleThreadExecutor("log-replay-executor") + } else { + MoreExecutors.sameThreadExecutor() + } + } + initialize() private def initialize(): Unit = { @@ -104,7 +115,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis if (!fs.exists(path)) { var msg = s"Log directory specified does not exist: $logDir." if (logDir == DEFAULT_LOG_DIR) { - msg += " Did you configure the correct one through spark.fs.history.logDirectory?" + msg += " Did you configure the correct one through spark.history.fs.logDirectory?" } throw new IllegalArgumentException(msg) } @@ -113,16 +124,19 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis "Logging directory specified is not a directory: %s".format(logDir)) } - checkForLogs() - // Disable the background thread during tests. if (!conf.contains("spark.testing")) { - logCheckingThread.setDaemon(true) - logCheckingThread.start() + // A task that periodically checks for event log updates on disk. + pool.scheduleAtFixedRate(getRunner(checkForLogs), 0, UPDATE_INTERVAL_S, TimeUnit.SECONDS) + + if (conf.getBoolean("spark.history.fs.cleaner.enabled", false)) { + // A task that periodically cleans event logs on disk. + pool.scheduleAtFixedRate(getRunner(cleanLogs), 0, CLEAN_INTERVAL_S, TimeUnit.SECONDS) + } } } - override def getListing() = applications.values + override def getListing(): Iterable[FsApplicationHistoryInfo] = applications.values override def getAppUI(appId: String): Option[SparkUI] = { try { @@ -163,19 +177,17 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis * applications that haven't been updated since last time the logs were checked. */ private[history] def checkForLogs(): Unit = { - lastLogCheckTimeMs = getMonotonicTimeMs() - logDebug("Checking for logs. Time is now %d.".format(lastLogCheckTimeMs)) - try { - var newLastModifiedTime = lastModifiedTime val statusList = Option(fs.listStatus(new Path(logDir))).map(_.toSeq) .getOrElse(Seq[FileStatus]()) - val logInfos = statusList + var newLastModifiedTime = lastModifiedTime + val logInfos: Seq[FileStatus] = statusList .filter { entry => try { - val modTime = getModificationTime(entry) - newLastModifiedTime = math.max(newLastModifiedTime, modTime) - modTime >= lastModifiedTime + getModificationTime(entry).map { time => + newLastModifiedTime = math.max(newLastModifiedTime, time) + time >= lastModifiedTime + }.getOrElse(false) } catch { case e: AccessControlException => // Do not use "logInfo" since these messages can get pretty noisy if printed on @@ -184,56 +196,136 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis false } } - .flatMap { entry => - try { - Some(replay(entry, new ReplayListenerBus())) - } catch { - case e: Exception => - logError(s"Failed to load application log data from $entry.", e) - None - } - } - .sortBy { info => (-info.endTime, -info.startTime) } + .flatMap { entry => Some(entry) } + .sortWith { case (entry1, entry2) => + val mod1 = getModificationTime(entry1).getOrElse(-1L) + val mod2 = getModificationTime(entry2).getOrElse(-1L) + mod1 >= mod2 + } + + logInfos.sliding(20, 20).foreach { batch => + replayExecutor.submit(new Runnable { + override def run(): Unit = mergeApplicationListing(batch) + }) + } lastModifiedTime = newLastModifiedTime + } catch { + case e: Exception => logError("Exception in checking for event log updates", e) + } + } - // When there are new logs, merge the new list with the existing one, maintaining - // the expected ordering (descending end time). Maintaining the order is important - // to avoid having to sort the list every time there is a request for the log list. - if (!logInfos.isEmpty) { - val newApps = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]() - def addIfAbsent(info: FsApplicationHistoryInfo) = { - if (!newApps.contains(info.id)) { - newApps += (info.id -> info) - } + /** + * Replay the log files in the list and merge the list of old applications with new ones + */ + private def mergeApplicationListing(logs: Seq[FileStatus]): Unit = { + val bus = new ReplayListenerBus() + val newApps = logs.flatMap { fileStatus => + try { + val res = replay(fileStatus, bus) + logInfo(s"Application log ${res.logPath} loaded successfully.") + Some(res) + } catch { + case e: Exception => + logError( + s"Exception encountered when attempting to load application log ${fileStatus.getPath}", + e) + None + } + }.toSeq.sortWith(compareAppInfo) + + // When there are new logs, merge the new list with the existing one, maintaining + // the expected ordering (descending end time). Maintaining the order is important + // to avoid having to sort the list every time there is a request for the log list. + if (newApps.nonEmpty) { + val mergedApps = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]() + def addIfAbsent(info: FsApplicationHistoryInfo): Unit = { + if (!mergedApps.contains(info.id) || + mergedApps(info.id).logPath.endsWith(EventLoggingListener.IN_PROGRESS) && + !info.logPath.endsWith(EventLoggingListener.IN_PROGRESS)) { + mergedApps += (info.id -> info) } + } - val newIterator = logInfos.iterator.buffered - val oldIterator = applications.values.iterator.buffered - while (newIterator.hasNext && oldIterator.hasNext) { - if (newIterator.head.endTime > oldIterator.head.endTime) { - addIfAbsent(newIterator.next) - } else { - addIfAbsent(oldIterator.next) - } + val newIterator = newApps.iterator.buffered + val oldIterator = applications.values.iterator.buffered + while (newIterator.hasNext && oldIterator.hasNext) { + if (compareAppInfo(newIterator.head, oldIterator.head)) { + addIfAbsent(newIterator.next()) + } else { + addIfAbsent(oldIterator.next()) } - newIterator.foreach(addIfAbsent) - oldIterator.foreach(addIfAbsent) + } + newIterator.foreach(addIfAbsent) + oldIterator.foreach(addIfAbsent) - applications = newApps + applications = mergedApps + } + } + + /** + * Delete event logs from the log directory according to the clean policy defined by the user. + */ + private def cleanLogs(): Unit = { + try { + val maxAge = conf.getTimeAsSeconds("spark.history.fs.cleaner.maxAge", "7d") * 1000 + + val now = System.currentTimeMillis() + val appsToRetain = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]() + + // Scan all logs from the log directory. + // Only completed applications older than the specified max age will be deleted. + applications.values.foreach { info => + if (now - info.lastUpdated <= maxAge || !info.completed) { + appsToRetain += (info.id -> info) + } else { + appsToClean += info + } } + + applications = appsToRetain + + val leftToClean = new mutable.ListBuffer[FsApplicationHistoryInfo] + appsToClean.foreach { info => + try { + val path = new Path(logDir, info.logPath) + if (fs.exists(path)) { + fs.delete(path, true) + } + } catch { + case e: AccessControlException => + logInfo(s"No permission to delete ${info.logPath}, ignoring.") + case t: IOException => + logError(s"IOException in cleaning logs of ${info.logPath}", t) + leftToClean += info + } + } + + appsToClean = leftToClean } catch { - case e: Exception => logError("Exception in checking for event log updates", e) + case t: Exception => logError("Exception in cleaning logs", t) } } + /** + * Comparison function that defines the sort order for the application listing. + * + * @return Whether `i1` should precede `i2`. + */ + private def compareAppInfo( + i1: FsApplicationHistoryInfo, + i2: FsApplicationHistoryInfo): Boolean = { + if (i1.endTime != i2.endTime) i1.endTime >= i2.endTime else i1.startTime >= i2.startTime + } + /** * Replays the events in the specified log file and returns information about the associated * application. */ private def replay(eventLog: FileStatus, bus: ReplayListenerBus): FsApplicationHistoryInfo = { val logPath = eventLog.getPath() - val (logInput, sparkVersion) = + logInfo(s"Replaying log path: $logPath") + val logInput = if (isLegacyLogDirectory(eventLog)) { openLegacyEventLog(logPath) } else { @@ -242,14 +334,14 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis try { val appListener = new ApplicationEventListener bus.addListener(appListener) - bus.replay(logInput, sparkVersion) + bus.replay(logInput, logPath.toString) new FsApplicationHistoryInfo( logPath.getName(), appListener.appId.getOrElse(logPath.getName()), appListener.appName.getOrElse(NOT_STARTED), appListener.startTime.getOrElse(-1L), appListener.endTime.getOrElse(-1L), - getModificationTime(eventLog), + getModificationTime(eventLog).get, appListener.sparkUser.getOrElse(NOT_STARTED), isApplicationCompleted(eventLog)) } finally { @@ -262,30 +354,24 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis * log file (along with other metadata files), which is the case for directories generated by * the code in previous releases. * - * @return 2-tuple of (input stream of the events, version of Spark which wrote the log) + * @return input stream that holds one JSON record per line. */ - private[history] def openLegacyEventLog(dir: Path): (InputStream, String) = { + private[history] def openLegacyEventLog(dir: Path): InputStream = { val children = fs.listStatus(dir) var eventLogPath: Path = null var codecName: Option[String] = None - var sparkVersion: String = null children.foreach { child => child.getPath().getName() match { case name if name.startsWith(LOG_PREFIX) => eventLogPath = child.getPath() - case codec if codec.startsWith(COMPRESSION_CODEC_PREFIX) => codecName = Some(codec.substring(COMPRESSION_CODEC_PREFIX.length())) - - case version if version.startsWith(SPARK_VERSION_PREFIX) => - sparkVersion = version.substring(SPARK_VERSION_PREFIX.length()) - case _ => } } - if (eventLogPath == null || sparkVersion == null) { + if (eventLogPath == null) { throw new IllegalArgumentException(s"$dir is not a Spark application log directory.") } @@ -297,7 +383,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis } val in = new BufferedInputStream(fs.open(eventLogPath)) - (codec.map(_.compressedInputStream(in)).getOrElse(in), sparkVersion) + codec.map(_.compressedInputStream(in)).getOrElse(in) } /** @@ -308,17 +394,19 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis */ private def isLegacyLogDirectory(entry: FileStatus): Boolean = entry.isDir() - private def getModificationTime(fsEntry: FileStatus): Long = { - if (fsEntry.isDir) { - fs.listStatus(fsEntry.getPath).map(_.getModificationTime()).max + /** + * Returns the modification time of the given event log. If the status points at an empty + * directory, `None` is returned, indicating that there isn't an event log at that location. + */ + private def getModificationTime(fsEntry: FileStatus): Option[Long] = { + if (isLegacyLogDirectory(fsEntry)) { + val statusList = fs.listStatus(fsEntry.getPath) + if (!statusList.isEmpty) Some(statusList.map(_.getModificationTime()).max) else None } else { - fsEntry.getModificationTime() + Some(fsEntry.getModificationTime()) } } - /** Returns the system's mononotically increasing time. */ - private def getMonotonicTimeMs(): Long = System.nanoTime() / (1000 * 1000) - /** * Return true when the application has completed. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index e4e7bc221601..3781b4e8c12b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -23,7 +23,7 @@ import scala.xml.Node import org.apache.spark.ui.{WebUIPage, UIUtils} -private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { +private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { private val pageSize = 20 private val plusOrMinus = 2 @@ -61,9 +61,10 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { // page, `...` will be displayed. if (allApps.size > 0) { val leftSideIndices = - rangeIndices(actualPage - plusOrMinus until actualPage, 1 < _) + rangeIndices(actualPage - plusOrMinus until actualPage, 1 < _, requestedIncomplete) val rightSideIndices = - rangeIndices(actualPage + 1 to actualPage + plusOrMinus, _ < pageCount) + rangeIndices(actualPage + 1 to actualPage + plusOrMinus, _ < pageCount, + requestedIncomplete)

Showing {actualFirst + 1}-{last + 1} of {allApps.size} @@ -89,6 +90,8 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {

++ appTable + } else if (requestedIncomplete) { +

No incomplete applications found!

} else {

No completed applications found!

++

Did you specify the correct logging directory? @@ -122,8 +125,10 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { "Spark User", "Last Updated") - private def rangeIndices(range: Seq[Int], condition: Int => Boolean): Seq[Node] = { - range.filter(condition).map(nextPage => {nextPage} ) + private def rangeIndices(range: Seq[Int], condition: Int => Boolean, showIncomplete: Boolean): + Seq[Node] = { + range.filter(condition).map(nextPage => + {nextPage} ) } private def appRow(info: ApplicationHistoryInfo): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index fa9bfe5426b6..56bef57e5539 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -27,7 +27,7 @@ import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.ui.{SparkUI, UIUtils, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.SignalLogger +import org.apache.spark.util.{SignalLogger, Utils} /** * A web server that renders SparkUIs of completed applications. @@ -61,7 +61,7 @@ class HistoryServer( private val appCache = CacheBuilder.newBuilder() .maximumSize(retainedApplications) .removalListener(new RemovalListener[String, SparkUI] { - override def onRemoval(rm: RemovalNotification[String, SparkUI]) = { + override def onRemoval(rm: RemovalNotification[String, SparkUI]): Unit = { detachSparkUI(rm.getValue()) } }) @@ -96,6 +96,10 @@ class HistoryServer( } } } + // SPARK-5983 ensure TRACE is not supported + protected override def doTrace(req: HttpServletRequest, res: HttpServletResponse): Unit = { + res.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) + } } initialize() @@ -145,14 +149,14 @@ class HistoryServer( * * @return List of all known applications. */ - def getApplicationList() = provider.getListing() + def getApplicationList(): Iterable[ApplicationHistoryInfo] = provider.getListing() /** * Returns the provider configuration to show in the listing page. * * @return A map with the provider's configuration. */ - def getProviderConfig() = provider.getConfig() + def getProviderConfig(): Map[String, String] = provider.getConfig() } @@ -190,11 +194,7 @@ object HistoryServer extends Logging { val server = new HistoryServer(conf, provider, securityManager, port) server.bind() - Runtime.getRuntime().addShutdownHook(new Thread("HistoryServerStopper") { - override def run() = { - server.stop() - } - }) + Utils.addShutdownHook { () => server.stop() } // Wait until the end of the world... or if the HistoryServer process is manually stopped while(true) { Thread.sleep(Int.MaxValue) } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index b1270ade9f75..a2a97a7877ce 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -23,7 +23,8 @@ import org.apache.spark.util.Utils /** * Command-line parser for the master. */ -private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]) extends Logging { +private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String]) + extends Logging { private var propertiesFile: String = null parse(args.toList) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index ede0a9dbefb8..f59d550d4f3b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -28,7 +28,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.ApplicationDescription import org.apache.spark.util.Utils -private[spark] class ApplicationInfo( +private[deploy] class ApplicationInfo( val startTime: Long, val id: String, val desc: ApplicationDescription, @@ -75,14 +75,17 @@ private[spark] class ApplicationInfo( } } - def addExecutor(worker: WorkerInfo, cores: Int, useID: Option[Int] = None): ExecutorDesc = { - val exec = new ExecutorDesc(newExecutorId(useID), this, worker, cores, desc.memoryPerSlave) + private[master] def addExecutor( + worker: WorkerInfo, + cores: Int, + useID: Option[Int] = None): ExecutorDesc = { + val exec = new ExecutorDesc(newExecutorId(useID), this, worker, cores, desc.memoryPerExecutorMB) executors(exec.id) = exec coresGranted += cores exec } - def removeExecutor(exec: ExecutorDesc) { + private[master] def removeExecutor(exec: ExecutorDesc) { if (executors.contains(exec.id)) { removedExecutors += executors(exec.id) executors -= exec.id @@ -90,26 +93,30 @@ private[spark] class ApplicationInfo( } } - private val myMaxCores = desc.maxCores.getOrElse(defaultCores) + private val requestedCores = desc.maxCores.getOrElse(defaultCores) - def coresLeft: Int = myMaxCores - coresGranted + private[master] def coresLeft: Int = requestedCores - coresGranted private var _retryCount = 0 - def retryCount = _retryCount + private[master] def retryCount = _retryCount - def incrementRetryCount() = { + private[master] def incrementRetryCount() = { _retryCount += 1 _retryCount } - def resetRetryCount() = _retryCount = 0 + private[master] def resetRetryCount() = _retryCount = 0 - def markFinished(endState: ApplicationState.Value) { + private[master] def markFinished(endState: ApplicationState.Value) { state = endState endTime = System.currentTimeMillis() } + private[master] def isFinished: Boolean = { + state != ApplicationState.WAITING && state != ApplicationState.RUNNING + } + def duration: Long = { if (endTime != -1) { endTime - startTime diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala index 38db02cd2421..017e8b55cbe7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala @@ -21,7 +21,7 @@ import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.metrics.source.Source -class ApplicationSource(val application: ApplicationInfo) extends Source { +private[master] class ApplicationSource(val application: ApplicationInfo) extends Source { override val metricRegistry = new MetricRegistry() override val sourceName = "%s.%s.%s".format("application", application.desc.name, System.currentTimeMillis()) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala index 67e6c5d66af0..37bfcdfdf477 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala @@ -17,11 +17,11 @@ package org.apache.spark.deploy.master -private[spark] object ApplicationState extends Enumeration { +private[master] object ApplicationState extends Enumeration { type ApplicationState = Value - val WAITING, RUNNING, FINISHED, FAILED, UNKNOWN = Value + val WAITING, RUNNING, FINISHED, FAILED, KILLED, UNKNOWN = Value val MAX_NUM_RETRY = 10 } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala index 9d3d7938c6cc..b197dbcbfe29 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.DriverDescription import org.apache.spark.util.Utils -private[spark] class DriverInfo( +private[deploy] class DriverInfo( val startTime: Long, val id: String, val desc: DriverDescription, diff --git a/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala b/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala index 26a68bade3c6..35ff33a61653 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.master -private[spark] object DriverState extends Enumeration { +private[deploy] object DriverState extends Enumeration { type DriverState = Value diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala index 5d620dfcabad..fc62b094def6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.master import org.apache.spark.deploy.{ExecutorDescription, ExecutorState} -private[spark] class ExecutorDesc( +private[master] class ExecutorDesc( val id: Int, val application: ApplicationInfo, val worker: WorkerInfo, diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index 36a2e2c6a634..f459ed5b3a1a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -24,6 +24,7 @@ import scala.reflect.ClassTag import akka.serialization.Serialization import org.apache.spark.Logging +import org.apache.spark.util.Utils /** @@ -33,7 +34,7 @@ import org.apache.spark.Logging * @param dir Directory to store files. Created if non-existent (but not recursively). * @param serialization Used to serialize our objects. */ -private[spark] class FileSystemPersistenceEngine( +private[master] class FileSystemPersistenceEngine( val dir: String, val serialization: Serialization) extends PersistenceEngine with Logging { @@ -48,7 +49,7 @@ private[spark] class FileSystemPersistenceEngine( new File(dir + File.separator + name).delete() } - override def read[T: ClassTag](prefix: String) = { + override def read[T: ClassTag](prefix: String): Seq[T] = { val files = new File(dir).listFiles().filter(_.getName.startsWith(prefix)) files.map(deserializeFromFile[T]) } @@ -59,9 +60,9 @@ private[spark] class FileSystemPersistenceEngine( val serializer = serialization.findSerializerFor(value) val serialized = serializer.toBinary(value) val out = new FileOutputStream(file) - try { + Utils.tryWithSafeFinally { out.write(serialized) - } finally { + } { out.close() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index d92d99310a58..ff2eed6dee70 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -43,84 +43,99 @@ import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.MasterMessages._ import org.apache.spark.deploy.master.ui.MasterWebUI +import org.apache.spark.deploy.rest.StandaloneRestServer import org.apache.spark.metrics.MetricsSystem import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, SignalLogger, Utils} -private[spark] class Master( +private[master] class Master( host: String, port: Int, webUiPort: Int, - val securityMgr: SecurityManager) + val securityMgr: SecurityManager, + val conf: SparkConf) extends Actor with ActorLogReceive with Logging with LeaderElectable { import context.dispatcher // to use Akka's scheduler.schedule() - val conf = new SparkConf - val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) + private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs - val WORKER_TIMEOUT = conf.getLong("spark.worker.timeout", 60) * 1000 - val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200) - val RETAINED_DRIVERS = conf.getInt("spark.deploy.retainedDrivers", 200) - val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15) - val RECOVERY_MODE = conf.get("spark.deploy.recoveryMode", "NONE") + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs + + private val WORKER_TIMEOUT = conf.getLong("spark.worker.timeout", 60) * 1000 + private val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200) + private val RETAINED_DRIVERS = conf.getInt("spark.deploy.retainedDrivers", 200) + private val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15) + private val RECOVERY_MODE = conf.get("spark.deploy.recoveryMode", "NONE") val workers = new HashSet[WorkerInfo] - val idToWorker = new HashMap[String, WorkerInfo] - val addressToWorker = new HashMap[Address, WorkerInfo] - - val apps = new HashSet[ApplicationInfo] val idToApp = new HashMap[String, ApplicationInfo] - val actorToApp = new HashMap[ActorRef, ApplicationInfo] - val addressToApp = new HashMap[Address, ApplicationInfo] val waitingApps = new ArrayBuffer[ApplicationInfo] - val completedApps = new ArrayBuffer[ApplicationInfo] - var nextAppNumber = 0 - val appIdToUI = new HashMap[String, SparkUI] + val apps = new HashSet[ApplicationInfo] + + private val idToWorker = new HashMap[String, WorkerInfo] + private val addressToWorker = new HashMap[Address, WorkerInfo] - val drivers = new HashSet[DriverInfo] - val completedDrivers = new ArrayBuffer[DriverInfo] - val waitingDrivers = new ArrayBuffer[DriverInfo] // Drivers currently spooled for scheduling - var nextDriverNumber = 0 + private val actorToApp = new HashMap[ActorRef, ApplicationInfo] + private val addressToApp = new HashMap[Address, ApplicationInfo] + private val completedApps = new ArrayBuffer[ApplicationInfo] + private var nextAppNumber = 0 + private val appIdToUI = new HashMap[String, SparkUI] + + private val drivers = new HashSet[DriverInfo] + private val completedDrivers = new ArrayBuffer[DriverInfo] + // Drivers currently spooled for scheduling + private val waitingDrivers = new ArrayBuffer[DriverInfo] + private var nextDriverNumber = 0 Utils.checkHost(host, "Expected hostname") - val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr) - val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf, + private val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr) + private val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf, securityMgr) - val masterSource = new MasterSource(this) + private val masterSource = new MasterSource(this) - val webUi = new MasterWebUI(this, webUiPort) + private val webUi = new MasterWebUI(this, webUiPort) - val masterPublicAddress = { - val envVar = System.getenv("SPARK_PUBLIC_DNS") + private val masterPublicAddress = { + val envVar = conf.getenv("SPARK_PUBLIC_DNS") if (envVar != null) envVar else host } - val masterUrl = "spark://" + host + ":" + port - var masterWebUiUrl: String = _ + private val masterUrl = "spark://" + host + ":" + port + private var masterWebUiUrl: String = _ - var state = RecoveryState.STANDBY + private var state = RecoveryState.STANDBY - var persistenceEngine: PersistenceEngine = _ + private var persistenceEngine: PersistenceEngine = _ - var leaderElectionAgent: LeaderElectionAgent = _ + private var leaderElectionAgent: LeaderElectionAgent = _ private var recoveryCompletionTask: Cancellable = _ // As a temporary workaround before better ways of configuring memory, we allow users to set // a flag that will perform round-robin scheduling across the nodes (spreading out each app // among all the nodes) instead of trying to consolidate each app onto a small # of nodes. - val spreadOutApps = conf.getBoolean("spark.deploy.spreadOut", true) + private val spreadOutApps = conf.getBoolean("spark.deploy.spreadOut", true) // Default maxCores for applications that don't specify it (i.e. pass Int.MaxValue) - val defaultCores = conf.getInt("spark.deploy.defaultCores", Int.MaxValue) + private val defaultCores = conf.getInt("spark.deploy.defaultCores", Int.MaxValue) if (defaultCores < 1) { throw new SparkException("spark.deploy.defaultCores must be positive") } + // Alternative application submission gateway that is stable across Spark versions + private val restServerEnabled = conf.getBoolean("spark.master.rest.enabled", true) + private val restServer = + if (restServerEnabled) { + val port = conf.getInt("spark.master.rest.port", 6066) + Some(new StandaloneRestServer(host, port, self, masterUrl, conf)) + } else { + None + } + private val restServerBoundPort = restServer.map(_.start()) + override def preStart() { logInfo("Starting Spark master at " + masterUrl) logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") @@ -174,6 +189,7 @@ private[spark] class Master( recoveryCompletionTask.cancel() } webUi.stop() + restServer.foreach(_.stop()) masterMetricsSystem.stop() applicationMetricsSystem.stop() persistenceEngine.close() @@ -188,7 +204,7 @@ private[spark] class Master( self ! RevokedLeadership } - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case ElectedLeader => { val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData() state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) { @@ -323,7 +339,11 @@ private[spark] class Master( if (ExecutorState.isFinished(state)) { // Remove this executor from the worker and app logInfo(s"Removing executor ${exec.fullId} because it is $state") - appInfo.removeExecutor(exec) + // If an application has already finished, preserve its + // state to display its information properly on the UI + if (!appInfo.isFinished) { + appInfo.removeExecutor(exec) + } exec.worker.removeExecutor(exec) val normalExit = exitStatus == Some(0) @@ -412,6 +432,10 @@ private[spark] class Master( if (canCompleteRecovery) { completeRecovery() } } + case UnregisterApplication(applicationId) => + logInfo(s"Received unregister request from application $applicationId") + idToApp.get(applicationId).foreach(finishApplication) + case DisassociatedEvent(_, address, _) => { // The disconnected client could've been either a worker or an app; remove whichever it was logInfo(s"$address got disassociated, removing it.") @@ -421,7 +445,9 @@ private[spark] class Master( } case RequestMasterState => { - sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray, + sender ! MasterStateResponse( + host, port, restServerBoundPort, + workers.toArray, apps.toArray, completedApps.toArray, drivers.toArray, completedDrivers.toArray, state) } @@ -429,16 +455,16 @@ private[spark] class Master( timeOutDeadWorkers() } - case RequestWebUIPort => { - sender ! WebUIPortResponse(webUi.boundPort) + case BoundPortsRequest => { + sender ! BoundPortsResponse(port, webUi.boundPort, restServerBoundPort) } } - def canCompleteRecovery = + private def canCompleteRecovery = workers.count(_.state == WorkerState.UNKNOWN) == 0 && apps.count(_.state == ApplicationState.UNKNOWN) == 0 - def beginRecovery(storedApps: Seq[ApplicationInfo], storedDrivers: Seq[DriverInfo], + private def beginRecovery(storedApps: Seq[ApplicationInfo], storedDrivers: Seq[DriverInfo], storedWorkers: Seq[WorkerInfo]) { for (app <- storedApps) { logInfo("Trying to recover app: " + app.id) @@ -469,7 +495,7 @@ private[spark] class Master( } } - def completeRecovery() { + private def completeRecovery() { // Ensure "only-once" recovery semantics using a short synchronization period. synchronized { if (state != RecoveryState.RECOVERING) { return } @@ -498,52 +524,28 @@ private[spark] class Master( } /** - * Can an app use the given worker? True if the worker has enough memory and we haven't already - * launched an executor for the app on it (right now the standalone backend doesn't like having - * two executors on the same worker). - */ - def canUse(app: ApplicationInfo, worker: WorkerInfo): Boolean = { - worker.memoryFree >= app.desc.memoryPerSlave && !worker.hasExecutor(app) - } - - /** - * Schedule the currently available resources among waiting apps. This method will be called - * every time a new app joins or resource availability changes. + * Schedule executors to be launched on the workers. + * + * There are two modes of launching executors. The first attempts to spread out an application's + * executors on as many workers as possible, while the second does the opposite (i.e. launch them + * on as few workers as possible). The former is usually better for data locality purposes and is + * the default. + * + * The number of cores assigned to each executor is configurable. When this is explicitly set, + * multiple executors from the same application may be launched on the same worker if the worker + * has enough cores and memory. Otherwise, each executor grabs all the cores available on the + * worker by default, in which case only one executor may be launched on each worker. */ - private def schedule() { - if (state != RecoveryState.ALIVE) { return } - - // First schedule drivers, they take strict precedence over applications - // Randomization helps balance drivers - val shuffledAliveWorkers = Random.shuffle(workers.toSeq.filter(_.state == WorkerState.ALIVE)) - val numWorkersAlive = shuffledAliveWorkers.size - var curPos = 0 - - for (driver <- waitingDrivers.toList) { // iterate over a copy of waitingDrivers - // We assign workers to each waiting driver in a round-robin fashion. For each driver, we - // start from the last worker that was assigned a driver, and continue onwards until we have - // explored all alive workers. - var launched = false - var numWorkersVisited = 0 - while (numWorkersVisited < numWorkersAlive && !launched) { - val worker = shuffledAliveWorkers(curPos) - numWorkersVisited += 1 - if (worker.memoryFree >= driver.desc.mem && worker.coresFree >= driver.desc.cores) { - launchDriver(worker, driver) - waitingDrivers -= driver - launched = true - } - curPos = (curPos + 1) % numWorkersAlive - } - } - + private def startExecutorsOnWorkers(): Unit = { // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app // in the queue, then the second app, etc. if (spreadOutApps) { - // Try to spread out each app among all the nodes, until it has all its cores + // Try to spread out each app among all the workers, until it has all its cores for (app <- waitingApps if app.coresLeft > 0) { val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) - .filter(canUse(app, _)).sortBy(_.coresFree).reverse + .filter(worker => worker.memoryFree >= app.desc.memoryPerExecutorMB && + worker.coresFree >= app.desc.coresPerExecutor.getOrElse(1)) + .sortBy(_.coresFree).reverse val numUsable = usableWorkers.length val assigned = new Array[Int](numUsable) // Number of cores to give on each node var toAssign = math.min(app.coresLeft, usableWorkers.map(_.coresFree).sum) @@ -556,32 +558,61 @@ private[spark] class Master( pos = (pos + 1) % numUsable } // Now that we've decided how many cores to give on each node, let's actually give them - for (pos <- 0 until numUsable) { - if (assigned(pos) > 0) { - val exec = app.addExecutor(usableWorkers(pos), assigned(pos)) - launchExecutor(usableWorkers(pos), exec) - app.state = ApplicationState.RUNNING - } + for (pos <- 0 until numUsable if assigned(pos) > 0) { + allocateWorkerResourceToExecutors(app, assigned(pos), usableWorkers(pos)) } } } else { - // Pack each app into as few nodes as possible until we've assigned all its cores + // Pack each app into as few workers as possible until we've assigned all its cores for (worker <- workers if worker.coresFree > 0 && worker.state == WorkerState.ALIVE) { for (app <- waitingApps if app.coresLeft > 0) { - if (canUse(app, worker)) { - val coresToUse = math.min(worker.coresFree, app.coresLeft) - if (coresToUse > 0) { - val exec = app.addExecutor(worker, coresToUse) - launchExecutor(worker, exec) - app.state = ApplicationState.RUNNING - } - } + allocateWorkerResourceToExecutors(app, app.coresLeft, worker) } } } } - def launchExecutor(worker: WorkerInfo, exec: ExecutorDesc) { + /** + * Allocate a worker's resources to one or more executors. + * @param app the info of the application which the executors belong to + * @param coresToAllocate cores on this worker to be allocated to this application + * @param worker the worker info + */ + private def allocateWorkerResourceToExecutors( + app: ApplicationInfo, + coresToAllocate: Int, + worker: WorkerInfo): Unit = { + val memoryPerExecutor = app.desc.memoryPerExecutorMB + val coresPerExecutor = app.desc.coresPerExecutor.getOrElse(coresToAllocate) + var coresLeft = coresToAllocate + while (coresLeft >= coresPerExecutor && worker.memoryFree >= memoryPerExecutor) { + val exec = app.addExecutor(worker, coresPerExecutor) + coresLeft -= coresPerExecutor + launchExecutor(worker, exec) + app.state = ApplicationState.RUNNING + } + } + + /** + * Schedule the currently available resources among waiting apps. This method will be called + * every time a new app joins or resource availability changes. + */ + private def schedule(): Unit = { + if (state != RecoveryState.ALIVE) { return } + // Drivers take strict precedence over executors + val shuffledWorkers = Random.shuffle(workers) // Randomization helps balance drivers + for (worker <- shuffledWorkers if worker.state == WorkerState.ALIVE) { + for (driver <- waitingDrivers) { + if (worker.memoryFree >= driver.desc.mem && worker.coresFree >= driver.desc.cores) { + launchDriver(worker, driver) + waitingDrivers -= driver + } + } + } + startExecutorsOnWorkers() + } + + private def launchExecutor(worker: WorkerInfo, exec: ExecutorDesc): Unit = { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) worker.actor ! LaunchExecutor(masterUrl, @@ -590,7 +621,7 @@ private[spark] class Master( exec.id, worker.id, worker.hostPort, exec.cores, exec.memory) } - def registerWorker(worker: WorkerInfo): Boolean = { + private def registerWorker(worker: WorkerInfo): Boolean = { // There may be one or more refs to dead workers on this same node (w/ different ID's), // remove them. workers.filter { w => @@ -618,7 +649,7 @@ private[spark] class Master( true } - def removeWorker(worker: WorkerInfo) { + private def removeWorker(worker: WorkerInfo) { logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port) worker.setState(WorkerState.DEAD) idToWorker -= worker.id @@ -641,22 +672,22 @@ private[spark] class Master( persistenceEngine.removeWorker(worker) } - def relaunchDriver(driver: DriverInfo) { + private def relaunchDriver(driver: DriverInfo) { driver.worker = None driver.state = DriverState.RELAUNCHING waitingDrivers += driver schedule() } - def createApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = { + private def createApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = { val now = System.currentTimeMillis() val date = new Date(now) new ApplicationInfo(now, newApplicationId(date), desc, date, driver, defaultCores) } - def registerApplication(app: ApplicationInfo): Unit = { + private def registerApplication(app: ApplicationInfo): Unit = { val appAddress = app.driver.path.address - if (addressToWorker.contains(appAddress)) { + if (addressToApp.contains(appAddress)) { logInfo("Attempted to re-register application at same address: " + appAddress) return } @@ -669,7 +700,7 @@ private[spark] class Master( waitingApps += app } - def finishApplication(app: ApplicationInfo) { + private def finishApplication(app: ApplicationInfo) { removeApplication(app, ApplicationState.FINISHED) } @@ -717,36 +748,41 @@ private[spark] class Master( * Rebuild a new SparkUI from the given application's event logs. * Return whether this is successful. */ - def rebuildSparkUI(app: ApplicationInfo): Boolean = { + private def rebuildSparkUI(app: ApplicationInfo): Boolean = { val appName = app.desc.name val notFoundBasePath = HistoryServer.UI_PATH_PREFIX + "/not-found" try { - val eventLogFile = app.desc.eventLogDir - .map { dir => EventLoggingListener.getLogPath(dir, app.id) } + val eventLogDir = app.desc.eventLogDir .getOrElse { // Event logging is not enabled for this application app.desc.appUiUrl = notFoundBasePath return false } - - val fs = Utils.getHadoopFileSystem(eventLogFile, hadoopConf) - - if (fs.exists(new Path(eventLogFile + EventLoggingListener.IN_PROGRESS))) { + + val eventLogFilePrefix = EventLoggingListener.getLogPath( + eventLogDir, app.id, app.desc.eventLogCodec) + val fs = Utils.getHadoopFileSystem(eventLogDir, hadoopConf) + val inProgressExists = fs.exists(new Path(eventLogFilePrefix + + EventLoggingListener.IN_PROGRESS)) + + if (inProgressExists) { // Event logging is enabled for this application, but the application is still in progress - val title = s"Application history not found (${app.id})" - var msg = s"Application $appName is still in progress." - logWarning(msg) - msg = URLEncoder.encode(msg, "UTF-8") - app.desc.appUiUrl = notFoundBasePath + s"?msg=$msg&title=$title" - return false + logWarning(s"Application $appName is still in progress, it may be terminated abnormally.") } - - val (logInput, sparkVersion) = EventLoggingListener.openEventLog(new Path(eventLogFile), fs) + + val (eventLogFile, status) = if (inProgressExists) { + (eventLogFilePrefix + EventLoggingListener.IN_PROGRESS, " (in progress)") + } else { + (eventLogFilePrefix, " (completed)") + } + + val logInput = EventLoggingListener.openEventLog(new Path(eventLogFile), fs) val replayBus = new ReplayListenerBus() val ui = SparkUI.createHistoryUI(new SparkConf, replayBus, new SecurityManager(conf), - appName + " (completed)", HistoryServer.UI_PATH_PREFIX + s"/${app.id}") + appName + status, HistoryServer.UI_PATH_PREFIX + s"/${app.id}") + val maybeTruncated = eventLogFile.endsWith(EventLoggingListener.IN_PROGRESS) try { - replayBus.replay(logInput, sparkVersion) + replayBus.replay(logInput, eventLogFile, maybeTruncated) } finally { logInput.close() } @@ -759,7 +795,7 @@ private[spark] class Master( case fnf: FileNotFoundException => // Event logging is enabled for this application, but no event logs are found val title = s"Application history not found (${app.id})" - var msg = s"No event logs found for application $appName in ${app.desc.eventLogDir}." + var msg = s"No event logs found for application $appName in ${app.desc.eventLogDir.get}." logWarning(msg) msg += " Did you specify the correct logging directory?" msg = URLEncoder.encode(msg, "UTF-8") @@ -778,14 +814,14 @@ private[spark] class Master( } /** Generate a new app ID given a app's submission date */ - def newApplicationId(submitDate: Date): String = { + private def newApplicationId(submitDate: Date): String = { val appId = "app-%s-%04d".format(createDateFormat.format(submitDate), nextAppNumber) nextAppNumber += 1 appId } /** Check for, and remove, any timed-out workers */ - def timeOutDeadWorkers() { + private def timeOutDeadWorkers() { // Copy the workers into an array so we don't modify the hashset while iterating through it val currentTime = System.currentTimeMillis() val toRemove = workers.filter(_.lastHeartbeat < currentTime - WORKER_TIMEOUT).toArray @@ -802,19 +838,19 @@ private[spark] class Master( } } - def newDriverId(submitDate: Date): String = { + private def newDriverId(submitDate: Date): String = { val appId = "driver-%s-%04d".format(createDateFormat.format(submitDate), nextDriverNumber) nextDriverNumber += 1 appId } - def createDriver(desc: DriverDescription): DriverInfo = { + private def createDriver(desc: DriverDescription): DriverInfo = { val now = System.currentTimeMillis() val date = new Date(now) new DriverInfo(now, newDriverId(date), desc, date) } - def launchDriver(worker: WorkerInfo, driver: DriverInfo) { + private def launchDriver(worker: WorkerInfo, driver: DriverInfo) { logInfo("Launching driver " + driver.id + " on worker " + worker.id) worker.addDriver(driver) driver.worker = Some(worker) @@ -822,7 +858,10 @@ private[spark] class Master( driver.state = DriverState.RUNNING } - def removeDriver(driverId: String, finalState: DriverState, exception: Option[Exception]) { + private def removeDriver( + driverId: String, + finalState: DriverState, + exception: Option[Exception]) { drivers.find(d => d.id == driverId) match { case Some(driver) => logInfo(s"Removing driver: $driverId") @@ -843,7 +882,7 @@ private[spark] class Master( } } -private[spark] object Master extends Logging { +private[deploy] object Master extends Logging { val systemName = "sparkMaster" private val actorName = "Master" @@ -851,7 +890,7 @@ private[spark] object Master extends Logging { SignalLogger.register(log) val conf = new SparkConf val args = new MasterArguments(argStrings, conf) - val (actorSystem, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf) + val (actorSystem, _, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf) actorSystem.awaitTermination() } @@ -860,9 +899,9 @@ private[spark] object Master extends Logging { * * @throws SparkException if the url is invalid */ - def toAkkaUrl(sparkUrl: String): String = { + def toAkkaUrl(sparkUrl: String, protocol: String): String = { val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) - "akka.tcp://%s@%s:%s/user/%s".format(systemName, host, port, actorName) + AkkaUtils.address(protocol, systemName, host, port, actorName) } /** @@ -870,24 +909,31 @@ private[spark] object Master extends Logging { * * @throws SparkException if the url is invalid */ - def toAkkaAddress(sparkUrl: String): Address = { + def toAkkaAddress(sparkUrl: String, protocol: String): Address = { val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) - Address("akka.tcp", systemName, host, port) + Address(protocol, systemName, host, port) } + /** + * Start the Master and return a four tuple of: + * (1) The Master actor system + * (2) The bound port + * (3) The web UI bound port + * (4) The REST server bound port, if any + */ def startSystemAndActor( host: String, port: Int, webUiPort: Int, - conf: SparkConf): (ActorSystem, Int, Int) = { + conf: SparkConf): (ActorSystem, Int, Int, Option[Int]) = { val securityMgr = new SecurityManager(conf) val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf, securityManager = securityMgr) - val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort, - securityMgr), actorName) - val timeout = AkkaUtils.askTimeout(conf) - val respFuture = actor.ask(RequestWebUIPort)(timeout) - val resp = Await.result(respFuture, timeout).asInstanceOf[WebUIPortResponse] - (actorSystem, boundPort, resp.webUIBoundPort) + val actor = actorSystem.actorOf( + Props(classOf[Master], host, boundPort, webUiPort, securityMgr, conf), actorName) + val timeout = RpcUtils.askTimeout(conf) + val portsRequest = actor.ask(BoundPortsRequest)(timeout) + val portsResponse = Await.result(portsRequest, timeout).asInstanceOf[BoundPortsResponse] + (actorSystem, boundPort, portsResponse.webUIPort, portsResponse.restPort) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index e34bee785429..435b9b12f83b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -23,7 +23,7 @@ import org.apache.spark.util.{IntParam, Utils} /** * Command-line parser for the master. */ -private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { +private[master] class MasterArguments(args: Array[String], conf: SparkConf) { var host = Utils.localHostName() var port = 7077 var webUiPort = 8080 @@ -49,7 +49,7 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { webUiPort = conf.get("spark.master.ui.port").toInt } - def parse(args: List[String]): Unit = args match { + private def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => Utils.checkHost(value, "ip no longer supported, please use hostname " + value) host = value @@ -84,7 +84,7 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { /** * Print usage and exit JVM with the given exit code. */ - def printUsageAndExit(exitCode: Int) { + private def printUsageAndExit(exitCode: Int) { System.err.println( "Usage: Master [options]\n" + "\n" + diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala index db72d8ae9bda..15c6296888f7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala @@ -36,7 +36,7 @@ private[master] object MasterMessages { case object CompleteRecovery - case object RequestWebUIPort + case object BoundPortsRequest - case class WebUIPortResponse(webUIBoundPort: Int) + case class BoundPortsResponse(actorPort: Int, webUIPort: Int, restPort: Option[Int]) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala index 2e0e1e7036ac..da5060778ede 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala @@ -87,7 +87,7 @@ trait PersistenceEngine { def close() {} } -private[spark] class BlackHolePersistenceEngine extends PersistenceEngine { +private[master] class BlackHolePersistenceEngine extends PersistenceEngine { override def persist(name: String, obj: Object): Unit = {} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala index 1096eb036835..351db8fab204 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala @@ -49,22 +49,29 @@ abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serial * LeaderAgent in this case is a no-op. Since leader is forever leader as the actual * recovery is made by restoring from filesystem. */ -private[spark] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serialization) +private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serialization) extends StandaloneRecoveryModeFactory(conf, serializer) with Logging { + val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "") - def createPersistenceEngine() = { + def createPersistenceEngine(): PersistenceEngine = { logInfo("Persisting recovery state to directory: " + RECOVERY_DIR) new FileSystemPersistenceEngine(RECOVERY_DIR, serializer) } - def createLeaderElectionAgent(master: LeaderElectable) = new MonarchyLeaderAgent(master) + def createLeaderElectionAgent(master: LeaderElectable): LeaderElectionAgent = { + new MonarchyLeaderAgent(master) + } } -private[spark] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serialization) +private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serialization) extends StandaloneRecoveryModeFactory(conf, serializer) { - def createPersistenceEngine() = new ZooKeeperPersistenceEngine(conf, serializer) - def createLeaderElectionAgent(master: LeaderElectable) = + def createPersistenceEngine(): PersistenceEngine = { + new ZooKeeperPersistenceEngine(conf, serializer) + } + + def createLeaderElectionAgent(master: LeaderElectable): LeaderElectionAgent = { new ZooKeeperLeaderElectionAgent(master, conf) + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala index 256a5a7c28e4..aa0f02fa625c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.master -private[spark] object RecoveryState extends Enumeration { +private[deploy] object RecoveryState extends Enumeration { type MasterState = Value val STANDBY, ALIVE, RECOVERING, COMPLETING_RECOVERY = Value diff --git a/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala b/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala index 4781a80d470e..5b22481ea8c5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala @@ -25,12 +25,12 @@ import org.apache.zookeeper.KeeperException import org.apache.spark.{Logging, SparkConf} -object SparkCuratorUtil extends Logging { +private[deploy] object SparkCuratorUtil extends Logging { - val ZK_CONNECTION_TIMEOUT_MILLIS = 15000 - val ZK_SESSION_TIMEOUT_MILLIS = 60000 - val RETRY_WAIT_MILLIS = 5000 - val MAX_RECONNECT_ATTEMPTS = 3 + private val ZK_CONNECTION_TIMEOUT_MILLIS = 15000 + private val ZK_SESSION_TIMEOUT_MILLIS = 60000 + private val RETRY_WAIT_MILLIS = 5000 + private val MAX_RECONNECT_ATTEMPTS = 3 def newClient(conf: SparkConf): CuratorFramework = { val ZK_URL = conf.get("spark.deploy.zookeeper.url") diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index e94aae93e449..9b3d48c6edc8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -104,7 +104,7 @@ private[spark] class WorkerInfo( "http://" + this.publicAddress + ":" + this.webUiPort } - def setState(state: WorkerState.Value) = { + def setState(state: WorkerState.Value): Unit = { this.state = state } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala index 0b36ef60051f..b60baaadfb4b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.master -private[spark] object WorkerState extends Enumeration { +private[master] object WorkerState extends Enumeration { type WorkerState = Value val ALIVE, DEAD, DECOMMISSIONED, UNKNOWN = Value diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala index 8eaa0ad94851..4823fd7cac0c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala @@ -24,7 +24,7 @@ import org.apache.spark.deploy.master.MasterMessages._ import org.apache.curator.framework.CuratorFramework import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch} -private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: LeaderElectable, +private[master] class ZooKeeperLeaderElectionAgent(val masterActor: LeaderElectable, conf: SparkConf) extends LeaderLatchListener with LeaderElectionAgent with Logging { val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/leader_election" @@ -35,7 +35,7 @@ private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: LeaderElectab start() - def start() { + private def start() { logInfo("Starting ZooKeeper LeaderElection agent") zk = SparkCuratorUtil.newClient(conf) leaderLatch = new LeaderLatch(zk, WORKING_DIR) @@ -72,7 +72,7 @@ private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: LeaderElectab } } - def updateLeadershipStatus(isLeader: Boolean) { + private def updateLeadershipStatus(isLeader: Boolean) { if (isLeader && status == LeadershipStatus.NOT_LEADER) { status = LeadershipStatus.LEADER masterActor.electedLeader() diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index e11ac031fb9c..a285783f7200 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -28,12 +28,12 @@ import org.apache.zookeeper.CreateMode import org.apache.spark.{Logging, SparkConf} -private[spark] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization) +private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization) extends PersistenceEngine - with Logging -{ - val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/master_status" - val zk: CuratorFramework = SparkCuratorUtil.newClient(conf) + with Logging { + + private val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/master_status" + private val zk: CuratorFramework = SparkCuratorUtil.newClient(conf) SparkCuratorUtil.mkdir(zk, WORKING_DIR) @@ -46,7 +46,7 @@ private[spark] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializati zk.delete().forPath(WORKING_DIR + "/" + name) } - override def read[T: ClassTag](prefix: String) = { + override def read[T: ClassTag](prefix: String): Seq[T] = { val file = zk.getChildren.forPath(WORKING_DIR).filter(_.startsWith(prefix)) file.map(deserializeFromFile[T]).flatten } @@ -61,7 +61,7 @@ private[spark] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializati zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized) } - def deserializeFromFile[T](filename: String)(implicit m: ClassTag[T]): Option[T] = { + private def deserializeFromFile[T](filename: String)(implicit m: ClassTag[T]): Option[T] = { val fileData = zk.getData().forPath(WORKING_DIR + "/" + filename) val clazz = m.runtimeClass.asInstanceOf[Class[T]] val serializer = serialization.serializerFor(clazz) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 3aae2b95d739..273f077bd8f5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -24,6 +24,7 @@ import scala.xml.Node import akka.pattern.ask import org.json4s.JValue +import org.json4s.JsonAST.JNothing import org.apache.spark.deploy.{ExecutorState, JsonProtocol} import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} @@ -31,7 +32,7 @@ import org.apache.spark.deploy.master.ExecutorDesc import org.apache.spark.ui.{UIUtils, WebUIPage} import org.apache.spark.util.Utils -private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") { +private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") { private val master = parent.masterActorRef private val timeout = parent.timeout @@ -44,7 +45,11 @@ private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app val app = state.activeApps.find(_.id == appId).getOrElse({ state.completedApps.find(_.id == appId).getOrElse(null) }) - JsonProtocol.writeApplicationInfo(app) + if (app == null) { + JNothing + } else { + JsonProtocol.writeApplicationInfo(app) + } } /** Executor details for a particular application */ @@ -55,6 +60,10 @@ private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app val app = state.activeApps.find(_.id == appId).getOrElse({ state.completedApps.find(_.id == appId).getOrElse(null) }) + if (app == null) { + val msg =

No running application with ID {appId}
+ return UIUtils.basicSparkPage(msg, "Not Found") + } val executorHeaders = Seq("ExecutorID", "Worker", "Cores", "Memory", "State", "Logs") val allExecutors = (app.executors.values ++ app.removedExecutors).toSet.toSeq @@ -85,7 +94,7 @@ private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app
  • Executor Memory: - {Utils.megabytesToString(app.desc.memoryPerSlave)} + {Utils.megabytesToString(app.desc.memoryPerExecutorMB)}
  • Submit Date: {app.submitDate}
  • State: {app.state}
  • diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/HistoryNotFoundPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/HistoryNotFoundPage.scala index d8daff3e7fb9..e021f1eef794 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/HistoryNotFoundPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/HistoryNotFoundPage.scala @@ -24,7 +24,7 @@ import scala.xml.Node import org.apache.spark.ui.{UIUtils, WebUIPage} -private[spark] class HistoryNotFoundPage(parent: MasterWebUI) +private[ui] class HistoryNotFoundPage(parent: MasterWebUI) extends WebUIPage("history/not-found") { /** diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 7ca3b08a2872..1f2c3fdbfb2b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -26,12 +26,12 @@ import akka.pattern.ask import org.json4s.JValue import org.apache.spark.deploy.JsonProtocol -import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} -import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} +import org.apache.spark.deploy.DeployMessages.{RequestKillDriver, MasterStateResponse, RequestMasterState} +import org.apache.spark.deploy.master._ import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils -private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { +private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { private val master = parent.masterActorRef private val timeout = parent.timeout @@ -41,24 +41,49 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { JsonProtocol.writeMasterState(state) } + def handleAppKillRequest(request: HttpServletRequest): Unit = { + handleKillRequest(request, id => { + parent.master.idToApp.get(id).foreach { app => + parent.master.removeApplication(app, ApplicationState.KILLED) + } + }) + } + + def handleDriverKillRequest(request: HttpServletRequest): Unit = { + handleKillRequest(request, id => { master ! RequestKillDriver(id) }) + } + + private def handleKillRequest(request: HttpServletRequest, action: String => Unit): Unit = { + if (parent.killEnabled && + parent.master.securityMgr.checkModifyPermissions(request.getRemoteUser)) { + val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean + val id = Option(request.getParameter("id")) + if (id.isDefined && killFlag) { + action(id.get) + } + + Thread.sleep(100) + } + } + /** Index view listing applications and executors */ def render(request: HttpServletRequest): Seq[Node] = { val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] val state = Await.result(stateFuture, timeout) - val workerHeaders = Seq("Id", "Address", "State", "Cores", "Memory") + val workerHeaders = Seq("Worker Id", "Address", "State", "Cores", "Memory") val workers = state.workers.sortBy(_.id) val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers) - val appHeaders = Seq("ID", "Name", "Cores", "Memory per Node", "Submitted Time", "User", - "State", "Duration") + val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Node", "Submitted Time", + "User", "State", "Duration") val activeApps = state.activeApps.sortBy(_.startTime).reverse val activeAppsTable = UIUtils.listingTable(appHeaders, appRow, activeApps) val completedApps = state.completedApps.sortBy(_.endTime).reverse val completedAppsTable = UIUtils.listingTable(appHeaders, appRow, completedApps) - val driverHeaders = Seq("ID", "Submitted Time", "Worker", "State", "Cores", "Memory", - "Main Class") + val driverHeaders = Seq("Submission ID", "Submitted Time", "Worker", "State", "Cores", + "Memory", "Main Class") val activeDrivers = state.activeDrivers.sortBy(_.startTime).reverse val activeDriversTable = UIUtils.listingTable(driverHeaders, driverRow, activeDrivers) val completedDrivers = state.completedDrivers.sortBy(_.startTime).reverse @@ -66,13 +91,21 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { // For now we only show driver information if the user has submitted drivers to the cluster. // This is until we integrate the notion of drivers and applications in the UI. - def hasDrivers = activeDrivers.length > 0 || completedDrivers.length > 0 + def hasDrivers: Boolean = activeDrivers.length > 0 || completedDrivers.length > 0 val content =
    • URL: {state.uri}
    • + { + state.restUri.map { uri => +
    • + REST URL: {uri} + (cluster mode) +
    • + }.getOrElse { Seq.empty } + }
    • Workers: {state.workers.size}
    • Cores: {state.workers.map(_.cores).sum} Total, {state.workers.map(_.coresUsed).sum} Used
    • @@ -155,9 +188,21 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } private def appRow(app: ApplicationInfo): Seq[Node] = { + val killLink = if (parent.killEnabled && + (app.state == ApplicationState.RUNNING || app.state == ApplicationState.WAITING)) { + val confirm = + s"if (window.confirm('Are you sure you want to kill application ${app.id} ?')) " + + "{ this.parentNode.submit(); return true; } else { return false; }" +
      + + + (kill) +
      + } {app.id} + {killLink} {app.desc.name} @@ -165,8 +210,8 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {app.coresGranted} - - {Utils.megabytesToString(app.desc.memoryPerSlave)} + + {Utils.megabytesToString(app.desc.memoryPerExecutorMB)} {UIUtils.formatDate(app.submitDate)} {app.desc.user} @@ -176,8 +221,21 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } private def driverRow(driver: DriverInfo): Seq[Node] = { + val killLink = if (parent.killEnabled && + (driver.state == DriverState.RUNNING || + driver.state == DriverState.SUBMITTED || + driver.state == DriverState.RELAUNCHING)) { + val confirm = + s"if (window.confirm('Are you sure you want to kill driver ${driver.id} ?')) " + + "{ this.parentNode.submit(); return true; } else { return false; }" +
      + + + (kill) +
      + } - {driver.id} + {driver.id} {killLink} {driver.submitDate} {driver.worker.map(w => {w.id.toString}).getOrElse("None")} @@ -188,7 +246,7 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {Utils.megabytesToString(driver.desc.mem.toLong)} - {driver.desc.command.arguments(1)} + {driver.desc.command.arguments(2)} } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 73400c5affb5..aad9c87bdb98 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -21,26 +21,32 @@ import org.apache.spark.Logging import org.apache.spark.deploy.master.Master import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.AkkaUtils +import org.apache.spark.util.RpcUtils /** * Web UI server for the standalone master. */ -private[spark] +private[master] class MasterWebUI(val master: Master, requestedPort: Int) extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging { val masterActorRef = master.self - val timeout = AkkaUtils.askTimeout(master.conf) + val timeout = RpcUtils.askTimeout(master.conf) + val killEnabled = master.conf.getBoolean("spark.ui.killEnabled", true) initialize() /** Initialize all components of the server. */ def initialize() { + val masterPage = new MasterPage(this) attachPage(new ApplicationPage(this)) attachPage(new HistoryNotFoundPage(this)) - attachPage(new MasterPage(this)) + attachPage(masterPage) attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static")) + attachHandler(createRedirectHandler( + "/app/kill", "/", masterPage.handleAppKillRequest, httpMethod = "POST")) + attachHandler(createRedirectHandler( + "/driver/kill", "/", masterPage.handleDriverKillRequest, httpMethod = "POST")) } /** Attach a reconstructed UI to this Master UI. Only valid after bind(). */ @@ -56,6 +62,6 @@ class MasterWebUI(val master: Master, requestedPort: Int) } } -private[spark] object MasterWebUI { - val STATIC_RESOURCE_DIR = SparkUI.STATIC_RESOURCE_DIR +private[master] object MasterWebUI { + private val STATIC_RESOURCE_DIR = SparkUI.STATIC_RESOURCE_DIR } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala new file mode 100644 index 000000000000..b8fd406fb6f9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -0,0 +1,335 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.io.{DataOutputStream, FileNotFoundException} +import java.net.{HttpURLConnection, SocketException, URL} +import javax.servlet.http.HttpServletResponse + +import scala.io.Source + +import com.fasterxml.jackson.core.JsonProcessingException +import com.google.common.base.Charsets + +import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} +import org.apache.spark.util.Utils + +/** + * A client that submits applications to the standalone Master using a REST protocol. + * This client is intended to communicate with the [[StandaloneRestServer]] and is + * currently used for cluster mode only. + * + * In protocol version v1, the REST URL takes the form http://[host:port]/v1/submissions/[action], + * where [action] can be one of create, kill, or status. Each type of request is represented in + * an HTTP message sent to the following prefixes: + * (1) submit - POST to /submissions/create + * (2) kill - POST /submissions/kill/[submissionId] + * (3) status - GET /submissions/status/[submissionId] + * + * In the case of (1), parameters are posted in the HTTP body in the form of JSON fields. + * Otherwise, the URL fully specifies the intended action of the client. + * + * Since the protocol is expected to be stable across Spark versions, existing fields cannot be + * added or removed, though new optional fields can be added. In the rare event that forward or + * backward compatibility is broken, Spark must introduce a new protocol version (e.g. v2). + * + * The client and the server must communicate using the same version of the protocol. If there + * is a mismatch, the server will respond with the highest protocol version it supports. A future + * implementation of this client can use that information to retry using the version specified + * by the server. + */ +private[deploy] class StandaloneRestClient extends Logging { + import StandaloneRestClient._ + + /** + * Submit an application specified by the parameters in the provided request. + * + * If the submission was successful, poll the status of the submission and report + * it to the user. Otherwise, report the error message provided by the server. + */ + private[rest] def createSubmission( + master: String, + request: CreateSubmissionRequest): SubmitRestProtocolResponse = { + logInfo(s"Submitting a request to launch an application in $master.") + validateMaster(master) + val url = getSubmitUrl(master) + val response = postJson(url, request.toJson) + response match { + case s: CreateSubmissionResponse => + reportSubmissionStatus(master, s) + handleRestResponse(s) + case unexpected => + handleUnexpectedRestResponse(unexpected) + } + response + } + + /** Request that the server kill the specified submission. */ + def killSubmission(master: String, submissionId: String): SubmitRestProtocolResponse = { + logInfo(s"Submitting a request to kill submission $submissionId in $master.") + validateMaster(master) + val response = post(getKillUrl(master, submissionId)) + response match { + case k: KillSubmissionResponse => handleRestResponse(k) + case unexpected => handleUnexpectedRestResponse(unexpected) + } + response + } + + /** Request the status of a submission from the server. */ + def requestSubmissionStatus( + master: String, + submissionId: String, + quiet: Boolean = false): SubmitRestProtocolResponse = { + logInfo(s"Submitting a request for the status of submission $submissionId in $master.") + validateMaster(master) + val response = get(getStatusUrl(master, submissionId)) + response match { + case s: SubmissionStatusResponse => if (!quiet) { handleRestResponse(s) } + case unexpected => handleUnexpectedRestResponse(unexpected) + } + response + } + + /** Construct a message that captures the specified parameters for submitting an application. */ + private[rest] def constructSubmitRequest( + appResource: String, + mainClass: String, + appArgs: Array[String], + sparkProperties: Map[String, String], + environmentVariables: Map[String, String]): CreateSubmissionRequest = { + val message = new CreateSubmissionRequest + message.clientSparkVersion = sparkVersion + message.appResource = appResource + message.mainClass = mainClass + message.appArgs = appArgs + message.sparkProperties = sparkProperties + message.environmentVariables = environmentVariables + message.validate() + message + } + + /** Send a GET request to the specified URL. */ + private def get(url: URL): SubmitRestProtocolResponse = { + logDebug(s"Sending GET request to server at $url.") + val conn = url.openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod("GET") + readResponse(conn) + } + + /** Send a POST request to the specified URL. */ + private def post(url: URL): SubmitRestProtocolResponse = { + logDebug(s"Sending POST request to server at $url.") + val conn = url.openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod("POST") + readResponse(conn) + } + + /** Send a POST request with the given JSON as the body to the specified URL. */ + private def postJson(url: URL, json: String): SubmitRestProtocolResponse = { + logDebug(s"Sending POST request to server at $url:\n$json") + val conn = url.openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod("POST") + conn.setRequestProperty("Content-Type", "application/json") + conn.setRequestProperty("charset", "utf-8") + conn.setDoOutput(true) + val out = new DataOutputStream(conn.getOutputStream) + Utils.tryWithSafeFinally { + out.write(json.getBytes(Charsets.UTF_8)) + } { + out.close() + } + readResponse(conn) + } + + /** + * Read the response from the server and return it as a validated [[SubmitRestProtocolResponse]]. + * If the response represents an error, report the embedded message to the user. + * Exposed for testing. + */ + private[rest] def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = { + try { + val dataStream = + if (connection.getResponseCode == HttpServletResponse.SC_OK) { + connection.getInputStream + } else { + connection.getErrorStream + } + // If the server threw an exception while writing a response, it will not have a body + if (dataStream == null) { + throw new SubmitRestProtocolException("Server returned empty body") + } + val responseJson = Source.fromInputStream(dataStream).mkString + logDebug(s"Response from the server:\n$responseJson") + val response = SubmitRestProtocolMessage.fromJson(responseJson) + response.validate() + response match { + // If the response is an error, log the message + case error: ErrorResponse => + logError(s"Server responded with error:\n${error.message}") + error + // Otherwise, simply return the response + case response: SubmitRestProtocolResponse => response + case unexpected => + throw new SubmitRestProtocolException( + s"Message received from server was not a response:\n${unexpected.toJson}") + } + } catch { + case unreachable @ (_: FileNotFoundException | _: SocketException) => + throw new SubmitRestConnectionException( + s"Unable to connect to server ${connection.getURL}", unreachable) + case malformed @ (_: JsonProcessingException | _: SubmitRestProtocolException) => + throw new SubmitRestProtocolException( + "Malformed response received from server", malformed) + } + } + + /** Return the REST URL for creating a new submission. */ + private def getSubmitUrl(master: String): URL = { + val baseUrl = getBaseUrl(master) + new URL(s"$baseUrl/create") + } + + /** Return the REST URL for killing an existing submission. */ + private def getKillUrl(master: String, submissionId: String): URL = { + val baseUrl = getBaseUrl(master) + new URL(s"$baseUrl/kill/$submissionId") + } + + /** Return the REST URL for requesting the status of an existing submission. */ + private def getStatusUrl(master: String, submissionId: String): URL = { + val baseUrl = getBaseUrl(master) + new URL(s"$baseUrl/status/$submissionId") + } + + /** Return the base URL for communicating with the server, including the protocol version. */ + private def getBaseUrl(master: String): String = { + val masterUrl = master.stripPrefix("spark://").stripSuffix("/") + s"http://$masterUrl/$PROTOCOL_VERSION/submissions" + } + + /** Throw an exception if this is not standalone mode. */ + private def validateMaster(master: String): Unit = { + if (!master.startsWith("spark://")) { + throw new IllegalArgumentException("This REST client is only supported in standalone mode.") + } + } + + /** Report the status of a newly created submission. */ + private def reportSubmissionStatus( + master: String, + submitResponse: CreateSubmissionResponse): Unit = { + if (submitResponse.success) { + val submissionId = submitResponse.submissionId + if (submissionId != null) { + logInfo(s"Submission successfully created as $submissionId. Polling submission state...") + pollSubmissionStatus(master, submissionId) + } else { + // should never happen + logError("Application successfully submitted, but submission ID was not provided!") + } + } else { + val failMessage = Option(submitResponse.message).map { ": " + _ }.getOrElse("") + logError(s"Application submission failed$failMessage") + } + } + + /** + * Poll the status of the specified submission and log it. + * This retries up to a fixed number of times before giving up. + */ + private def pollSubmissionStatus(master: String, submissionId: String): Unit = { + (1 to REPORT_DRIVER_STATUS_MAX_TRIES).foreach { _ => + val response = requestSubmissionStatus(master, submissionId, quiet = true) + val statusResponse = response match { + case s: SubmissionStatusResponse => s + case _ => return // unexpected type, let upstream caller handle it + } + if (statusResponse.success) { + val driverState = Option(statusResponse.driverState) + val workerId = Option(statusResponse.workerId) + val workerHostPort = Option(statusResponse.workerHostPort) + val exception = Option(statusResponse.message) + // Log driver state, if present + driverState match { + case Some(state) => logInfo(s"State of driver $submissionId is now $state.") + case _ => logError(s"State of driver $submissionId was not found!") + } + // Log worker node, if present + (workerId, workerHostPort) match { + case (Some(id), Some(hp)) => logInfo(s"Driver is running on worker $id at $hp.") + case _ => + } + // Log exception stack trace, if present + exception.foreach { e => logError(e) } + return + } + Thread.sleep(REPORT_DRIVER_STATUS_INTERVAL) + } + logError(s"Error: Master did not recognize driver $submissionId.") + } + + /** Log the response sent by the server in the REST application submission protocol. */ + private def handleRestResponse(response: SubmitRestProtocolResponse): Unit = { + logInfo(s"Server responded with ${response.messageType}:\n${response.toJson}") + } + + /** Log an appropriate error if the response sent by the server is not of the expected type. */ + private def handleUnexpectedRestResponse(unexpected: SubmitRestProtocolResponse): Unit = { + logError(s"Error: Server responded with message of unexpected type ${unexpected.messageType}.") + } +} + +private[rest] object StandaloneRestClient { + private val REPORT_DRIVER_STATUS_INTERVAL = 1000 + private val REPORT_DRIVER_STATUS_MAX_TRIES = 10 + val PROTOCOL_VERSION = "v1" + + /** + * Submit an application, assuming Spark parameters are specified through the given config. + * This is abstracted to its own method for testing purposes. + */ + def run( + appResource: String, + mainClass: String, + appArgs: Array[String], + conf: SparkConf, + env: Map[String, String] = sys.env): SubmitRestProtocolResponse = { + val master = conf.getOption("spark.master").getOrElse { + throw new IllegalArgumentException("'spark.master' must be set.") + } + val sparkProperties = conf.getAll.toMap + val environmentVariables = env.filter { case (k, _) => k.startsWith("SPARK_") } + val client = new StandaloneRestClient + val submitRequest = client.constructSubmitRequest( + appResource, mainClass, appArgs, sparkProperties, environmentVariables) + client.createSubmission(master, submitRequest) + } + + def main(args: Array[String]): Unit = { + if (args.size < 2) { + sys.error("Usage: StandaloneRestClient [app resource] [main class] [app args*]") + sys.exit(1) + } + val appResource = args(0) + val mainClass = args(1) + val appArgs = args.slice(2, args.size) + val conf = new SparkConf + run(appResource, mainClass, appArgs, conf) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala new file mode 100644 index 000000000000..2d6b8d420479 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -0,0 +1,438 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.io.File +import java.net.InetSocketAddress +import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} + +import scala.io.Source + +import akka.actor.ActorRef +import com.fasterxml.jackson.core.JsonProcessingException +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler} +import org.eclipse.jetty.util.thread.QueuedThreadPool +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} +import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} +import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} +import org.apache.spark.deploy.ClientArguments._ + +/** + * A server that responds to requests submitted by the [[StandaloneRestClient]]. + * This is intended to be embedded in the standalone Master and used in cluster mode only. + * + * This server responds with different HTTP codes depending on the situation: + * 200 OK - Request was processed successfully + * 400 BAD REQUEST - Request was malformed, not successfully validated, or of unexpected type + * 468 UNKNOWN PROTOCOL VERSION - Request specified a protocol this server does not understand + * 500 INTERNAL SERVER ERROR - Server throws an exception internally while processing the request + * + * The server always includes a JSON representation of the relevant [[SubmitRestProtocolResponse]] + * in the HTTP body. If an error occurs, however, the server will include an [[ErrorResponse]] + * instead of the one expected by the client. If the construction of this error response itself + * fails, the response will consist of an empty body with a response code that indicates internal + * server error. + * + * @param host the address this server should bind to + * @param requestedPort the port this server will attempt to bind to + * @param masterActor reference to the Master actor to which requests can be sent + * @param masterUrl the URL of the Master new drivers will attempt to connect to + * @param masterConf the conf used by the Master + */ +private[deploy] class StandaloneRestServer( + host: String, + requestedPort: Int, + masterActor: ActorRef, + masterUrl: String, + masterConf: SparkConf) + extends Logging { + + import StandaloneRestServer._ + + private var _server: Option[Server] = None + + // A mapping from URL prefixes to servlets that serve them. Exposed for testing. + protected val baseContext = s"/$PROTOCOL_VERSION/submissions" + protected val contextToServlet = Map[String, StandaloneRestServlet]( + s"$baseContext/create/*" -> new SubmitRequestServlet(masterActor, masterUrl, masterConf), + s"$baseContext/kill/*" -> new KillRequestServlet(masterActor, masterConf), + s"$baseContext/status/*" -> new StatusRequestServlet(masterActor, masterConf), + "/*" -> new ErrorServlet // default handler + ) + + /** Start the server and return the bound port. */ + def start(): Int = { + val (server, boundPort) = Utils.startServiceOnPort[Server](requestedPort, doStart, masterConf) + _server = Some(server) + logInfo(s"Started REST server for submitting applications on port $boundPort") + boundPort + } + + /** + * Map the servlets to their corresponding contexts and attach them to a server. + * Return a 2-tuple of the started server and the bound port. + */ + private def doStart(startPort: Int): (Server, Int) = { + val server = new Server(new InetSocketAddress(host, startPort)) + val threadPool = new QueuedThreadPool + threadPool.setDaemon(true) + server.setThreadPool(threadPool) + val mainHandler = new ServletContextHandler + mainHandler.setContextPath("/") + contextToServlet.foreach { case (prefix, servlet) => + mainHandler.addServlet(new ServletHolder(servlet), prefix) + } + server.setHandler(mainHandler) + server.start() + val boundPort = server.getConnectors()(0).getLocalPort + (server, boundPort) + } + + def stop(): Unit = { + _server.foreach(_.stop()) + } +} + +private[rest] object StandaloneRestServer { + val PROTOCOL_VERSION = StandaloneRestClient.PROTOCOL_VERSION + val SC_UNKNOWN_PROTOCOL_VERSION = 468 +} + +/** + * An abstract servlet for handling requests passed to the [[StandaloneRestServer]]. + */ +private[rest] abstract class StandaloneRestServlet extends HttpServlet with Logging { + + /** + * Serialize the given response message to JSON and send it through the response servlet. + * This validates the response before sending it to ensure it is properly constructed. + */ + protected def sendResponse( + responseMessage: SubmitRestProtocolResponse, + responseServlet: HttpServletResponse): Unit = { + val message = validateResponse(responseMessage, responseServlet) + responseServlet.setContentType("application/json") + responseServlet.setCharacterEncoding("utf-8") + responseServlet.getWriter.write(message.toJson) + } + + /** + * Return any fields in the client request message that the server does not know about. + * + * The mechanism for this is to reconstruct the JSON on the server side and compare the + * diff between this JSON and the one generated on the client side. Any fields that are + * only in the client JSON are treated as unexpected. + */ + protected def findUnknownFields( + requestJson: String, + requestMessage: SubmitRestProtocolMessage): Array[String] = { + val clientSideJson = parse(requestJson) + val serverSideJson = parse(requestMessage.toJson) + val Diff(_, _, unknown) = clientSideJson.diff(serverSideJson) + unknown match { + case j: JObject => j.obj.map { case (k, _) => k }.toArray + case _ => Array.empty[String] // No difference + } + } + + /** Return a human readable String representation of the exception. */ + protected def formatException(e: Throwable): String = { + val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n") + s"$e\n$stackTraceString" + } + + /** Construct an error message to signal the fact that an exception has been thrown. */ + protected def handleError(message: String): ErrorResponse = { + val e = new ErrorResponse + e.serverSparkVersion = sparkVersion + e.message = message + e + } + + /** + * Parse a submission ID from the relative path, assuming it is the first part of the path. + * For instance, we expect the path to take the form /[submission ID]/maybe/something/else. + * The returned submission ID cannot be empty. If the path is unexpected, return None. + */ + protected def parseSubmissionId(path: String): Option[String] = { + if (path == null || path.isEmpty) { + None + } else { + path.stripPrefix("/").split("/").headOption.filter(_.nonEmpty) + } + } + + /** + * Validate the response to ensure that it is correctly constructed. + * + * If it is, simply return the message as is. Otherwise, return an error response instead + * to propagate the exception back to the client and set the appropriate error code. + */ + private def validateResponse( + responseMessage: SubmitRestProtocolResponse, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { + try { + responseMessage.validate() + responseMessage + } catch { + case e: Exception => + responseServlet.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR) + handleError("Internal server error: " + formatException(e)) + } + } +} + +/** + * A servlet for handling kill requests passed to the [[StandaloneRestServer]]. + */ +private[rest] class KillRequestServlet(masterActor: ActorRef, conf: SparkConf) + extends StandaloneRestServlet { + + /** + * If a submission ID is specified in the URL, have the Master kill the corresponding + * driver and return an appropriate response to the client. Otherwise, return error. + */ + protected override def doPost( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val submissionId = parseSubmissionId(request.getPathInfo) + val responseMessage = submissionId.map(handleKill).getOrElse { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Submission ID is missing in kill request.") + } + sendResponse(responseMessage, response) + } + + protected def handleKill(submissionId: String): KillSubmissionResponse = { + val askTimeout = RpcUtils.askTimeout(conf) + val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse]( + DeployMessages.RequestKillDriver(submissionId), masterActor, askTimeout) + val k = new KillSubmissionResponse + k.serverSparkVersion = sparkVersion + k.message = response.message + k.submissionId = submissionId + k.success = response.success + k + } +} + +/** + * A servlet for handling status requests passed to the [[StandaloneRestServer]]. + */ +private[rest] class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf) + extends StandaloneRestServlet { + + /** + * If a submission ID is specified in the URL, request the status of the corresponding + * driver from the Master and include it in the response. Otherwise, return error. + */ + protected override def doGet( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val submissionId = parseSubmissionId(request.getPathInfo) + val responseMessage = submissionId.map(handleStatus).getOrElse { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Submission ID is missing in status request.") + } + sendResponse(responseMessage, response) + } + + protected def handleStatus(submissionId: String): SubmissionStatusResponse = { + val askTimeout = RpcUtils.askTimeout(conf) + val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse]( + DeployMessages.RequestDriverStatus(submissionId), masterActor, askTimeout) + val message = response.exception.map { s"Exception from the cluster:\n" + formatException(_) } + val d = new SubmissionStatusResponse + d.serverSparkVersion = sparkVersion + d.submissionId = submissionId + d.success = response.found + d.driverState = response.state.map(_.toString).orNull + d.workerId = response.workerId.orNull + d.workerHostPort = response.workerHostPort.orNull + d.message = message.orNull + d + } +} + +/** + * A servlet for handling submit requests passed to the [[StandaloneRestServer]]. + */ +private[rest] class SubmitRequestServlet( + masterActor: ActorRef, + masterUrl: String, + conf: SparkConf) + extends StandaloneRestServlet { + + /** + * Submit an application to the Master with parameters specified in the request. + * + * The request is assumed to be a [[SubmitRestProtocolRequest]] in the form of JSON. + * If the request is successfully processed, return an appropriate response to the + * client indicating so. Otherwise, return error instead. + */ + protected override def doPost( + requestServlet: HttpServletRequest, + responseServlet: HttpServletResponse): Unit = { + val responseMessage = + try { + val requestMessageJson = Source.fromInputStream(requestServlet.getInputStream).mkString + val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson) + // The response should have already been validated on the client. + // In case this is not true, validate it ourselves to avoid potential NPEs. + requestMessage.validate() + handleSubmit(requestMessageJson, requestMessage, responseServlet) + } catch { + // The client failed to provide a valid JSON, so this is not our fault + case e @ (_: JsonProcessingException | _: SubmitRestProtocolException) => + responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Malformed request: " + formatException(e)) + } + sendResponse(responseMessage, responseServlet) + } + + /** + * Handle the submit request and construct an appropriate response to return to the client. + * + * This assumes that the request message is already successfully validated. + * If the request message is not of the expected type, return error to the client. + */ + private def handleSubmit( + requestMessageJson: String, + requestMessage: SubmitRestProtocolMessage, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { + requestMessage match { + case submitRequest: CreateSubmissionRequest => + val askTimeout = RpcUtils.askTimeout(conf) + val driverDescription = buildDriverDescription(submitRequest) + val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( + DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) + val submitResponse = new CreateSubmissionResponse + submitResponse.serverSparkVersion = sparkVersion + submitResponse.message = response.message + submitResponse.success = response.success + submitResponse.submissionId = response.driverId.orNull + val unknownFields = findUnknownFields(requestMessageJson, requestMessage) + if (unknownFields.nonEmpty) { + // If there are fields that the server does not know about, warn the client + submitResponse.unknownFields = unknownFields + } + submitResponse + case unexpected => + responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError(s"Received message of unexpected type ${unexpected.messageType}.") + } + } + + /** + * Build a driver description from the fields specified in the submit request. + * + * This involves constructing a command that takes into account memory, java options, + * classpath and other settings to launch the driver. This does not currently consider + * fields used by python applications since python is not supported in standalone + * cluster mode yet. + */ + private def buildDriverDescription(request: CreateSubmissionRequest): DriverDescription = { + // Required fields, including the main class because python is not yet supported + val appResource = Option(request.appResource).getOrElse { + throw new SubmitRestMissingFieldException("Application jar is missing.") + } + val mainClass = Option(request.mainClass).getOrElse { + throw new SubmitRestMissingFieldException("Main class is missing.") + } + + // Optional fields + val sparkProperties = request.sparkProperties + val driverMemory = sparkProperties.get("spark.driver.memory") + val driverCores = sparkProperties.get("spark.driver.cores") + val driverExtraJavaOptions = sparkProperties.get("spark.driver.extraJavaOptions") + val driverExtraClassPath = sparkProperties.get("spark.driver.extraClassPath") + val driverExtraLibraryPath = sparkProperties.get("spark.driver.extraLibraryPath") + val superviseDriver = sparkProperties.get("spark.driver.supervise") + val appArgs = request.appArgs + val environmentVariables = request.environmentVariables + + // Construct driver description + val conf = new SparkConf(false) + .setAll(sparkProperties) + .set("spark.master", masterUrl) + val extraClassPath = driverExtraClassPath.toSeq.flatMap(_.split(File.pathSeparator)) + val extraLibraryPath = driverExtraLibraryPath.toSeq.flatMap(_.split(File.pathSeparator)) + val extraJavaOpts = driverExtraJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty) + val sparkJavaOpts = Utils.sparkJavaOpts(conf) + val javaOpts = sparkJavaOpts ++ extraJavaOpts + val command = new Command( + "org.apache.spark.deploy.worker.DriverWrapper", + Seq("{{WORKER_URL}}", "{{USER_JAR}}", mainClass) ++ appArgs, // args to the DriverWrapper + environmentVariables, extraClassPath, extraLibraryPath, javaOpts) + val actualDriverMemory = driverMemory.map(Utils.memoryStringToMb).getOrElse(DEFAULT_MEMORY) + val actualDriverCores = driverCores.map(_.toInt).getOrElse(DEFAULT_CORES) + val actualSuperviseDriver = superviseDriver.map(_.toBoolean).getOrElse(DEFAULT_SUPERVISE) + new DriverDescription( + appResource, actualDriverMemory, actualDriverCores, actualSuperviseDriver, command) + } +} + +/** + * A default servlet that handles error cases that are not captured by other servlets. + */ +private class ErrorServlet extends StandaloneRestServlet { + private val serverVersion = StandaloneRestServer.PROTOCOL_VERSION + + /** Service a faulty request by returning an appropriate error message to the client. */ + protected override def service( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val path = request.getPathInfo + val parts = path.stripPrefix("/").split("/").filter(_.nonEmpty).toList + var versionMismatch = false + var msg = + parts match { + case Nil => + // http://host:port/ + "Missing protocol version." + case `serverVersion` :: Nil => + // http://host:port/correct-version + "Missing the /submissions prefix." + case `serverVersion` :: "submissions" :: tail => + // http://host:port/correct-version/submissions/* + "Missing an action: please specify one of /create, /kill, or /status." + case unknownVersion :: tail => + // http://host:port/unknown-version/* + versionMismatch = true + s"Unknown protocol version '$unknownVersion'." + case _ => + // never reached + s"Malformed path $path." + } + msg += s" Please submit requests through http://[host]:[port]/$serverVersion/submissions/..." + val error = handleError(msg) + // If there is a version mismatch, include the highest protocol version that + // this server supports in case the client wants to retry with our version + if (versionMismatch) { + error.highestProtocolVersion = serverVersion + response.setStatus(StandaloneRestServer.SC_UNKNOWN_PROTOCOL_VERSION) + } else { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + } + sendResponse(error, response) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolException.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolException.scala new file mode 100644 index 000000000000..b97921ec934a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolException.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +/** + * An exception thrown in the REST application submission protocol. + */ +private[rest] class SubmitRestProtocolException(message: String, cause: Throwable = null) + extends Exception(message, cause) + +/** + * An exception thrown if a field is missing from a [[SubmitRestProtocolMessage]]. + */ +private[rest] class SubmitRestMissingFieldException(message: String) + extends SubmitRestProtocolException(message) + +/** + * An exception thrown if the REST client cannot reach the REST server. + */ +private[deploy] class SubmitRestConnectionException(message: String, cause: Throwable) + extends SubmitRestProtocolException(message, cause) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala new file mode 100644 index 000000000000..e6615a3174ce --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import com.fasterxml.jackson.annotation._ +import com.fasterxml.jackson.annotation.JsonAutoDetect.Visibility +import com.fasterxml.jackson.annotation.JsonInclude.Include +import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper, SerializationFeature} +import com.fasterxml.jackson.module.scala.DefaultScalaModule +import org.json4s.JsonAST._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.util.Utils + +/** + * An abstract message exchanged in the REST application submission protocol. + * + * This message is intended to be serialized to and deserialized from JSON in the exchange. + * Each message can either be a request or a response and consists of three common fields: + * (1) the action, which fully specifies the type of the message + * (2) the Spark version of the client / server + * (3) an optional message + */ +@JsonInclude(Include.NON_NULL) +@JsonAutoDetect(getterVisibility = Visibility.ANY, setterVisibility = Visibility.ANY) +@JsonPropertyOrder(alphabetic = true) +private[rest] abstract class SubmitRestProtocolMessage { + @JsonIgnore + val messageType = Utils.getFormattedClassName(this) + + val action: String = messageType + var message: String = null + + // For JSON deserialization + private def setAction(a: String): Unit = { } + + /** + * Serialize the message to JSON. + * This also ensures that the message is valid and its fields are in the expected format. + */ + def toJson: String = { + validate() + SubmitRestProtocolMessage.mapper.writeValueAsString(this) + } + + /** + * Assert the validity of the message. + * If the validation fails, throw a [[SubmitRestProtocolException]]. + */ + final def validate(): Unit = { + try { + doValidate() + } catch { + case e: Exception => + throw new SubmitRestProtocolException(s"Validation of message $messageType failed!", e) + } + } + + /** Assert the validity of the message */ + protected def doValidate(): Unit = { + if (action == null) { + throw new SubmitRestMissingFieldException(s"The action field is missing in $messageType") + } + } + + /** Assert that the specified field is set in this message. */ + protected def assertFieldIsSet[T](value: T, name: String): Unit = { + if (value == null) { + throw new SubmitRestMissingFieldException(s"'$name' is missing in message $messageType.") + } + } + + /** + * Assert a condition when validating this message. + * If the assertion fails, throw a [[SubmitRestProtocolException]]. + */ + protected def assert(condition: Boolean, failMessage: String): Unit = { + if (!condition) { throw new SubmitRestProtocolException(failMessage) } + } +} + +/** + * Helper methods to process serialized [[SubmitRestProtocolMessage]]s. + */ +private[spark] object SubmitRestProtocolMessage { + private val packagePrefix = this.getClass.getPackage.getName + private val mapper = new ObjectMapper() + .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) + .enable(SerializationFeature.INDENT_OUTPUT) + .registerModule(DefaultScalaModule) + + /** + * Parse the value of the action field from the given JSON. + * If the action field is not found, throw a [[SubmitRestMissingFieldException]]. + */ + def parseAction(json: String): String = { + val value: Option[String] = parse(json) match { + case JObject(fields) => + fields.collectFirst { case ("action", v) => v }.collect { case JString(s) => s } + case _ => None + } + value.getOrElse { + throw new SubmitRestMissingFieldException(s"Action field not found in JSON:\n$json") + } + } + + /** + * Construct a [[SubmitRestProtocolMessage]] from its JSON representation. + * + * This method first parses the action from the JSON and uses it to infer the message type. + * Note that the action must represent one of the [[SubmitRestProtocolMessage]]s defined in + * this package. Otherwise, a [[ClassNotFoundException]] will be thrown. + */ + def fromJson(json: String): SubmitRestProtocolMessage = { + val className = parseAction(json) + val clazz = Class.forName(packagePrefix + "." + className) + .asSubclass[SubmitRestProtocolMessage](classOf[SubmitRestProtocolMessage]) + fromJson(json, clazz) + } + + /** + * Construct a [[SubmitRestProtocolMessage]] from its JSON representation. + * + * This method determines the type of the message from the class provided instead of + * inferring it from the action field. This is useful for deserializing JSON that + * represents custom user-defined messages. + */ + def fromJson[T <: SubmitRestProtocolMessage](json: String, clazz: Class[T]): T = { + mapper.readValue(json, clazz) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala new file mode 100644 index 000000000000..d80abdf15fb3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import scala.util.Try + +import org.apache.spark.util.Utils + +/** + * An abstract request sent from the client in the REST application submission protocol. + */ +private[rest] abstract class SubmitRestProtocolRequest extends SubmitRestProtocolMessage { + var clientSparkVersion: String = null + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(clientSparkVersion, "clientSparkVersion") + } +} + +/** + * A request to launch a new application in the REST application submission protocol. + */ +private[rest] class CreateSubmissionRequest extends SubmitRestProtocolRequest { + var appResource: String = null + var mainClass: String = null + var appArgs: Array[String] = null + var sparkProperties: Map[String, String] = null + var environmentVariables: Map[String, String] = null + + protected override def doValidate(): Unit = { + super.doValidate() + assert(sparkProperties != null, "No Spark properties set!") + assertFieldIsSet(appResource, "appResource") + assertPropertyIsSet("spark.app.name") + assertPropertyIsBoolean("spark.driver.supervise") + assertPropertyIsNumeric("spark.driver.cores") + assertPropertyIsNumeric("spark.cores.max") + assertPropertyIsMemory("spark.driver.memory") + assertPropertyIsMemory("spark.executor.memory") + } + + private def assertPropertyIsSet(key: String): Unit = + assertFieldIsSet(sparkProperties.getOrElse(key, null), key) + + private def assertPropertyIsBoolean(key: String): Unit = + assertProperty[Boolean](key, "boolean", _.toBoolean) + + private def assertPropertyIsNumeric(key: String): Unit = + assertProperty[Int](key, "numeric", _.toInt) + + private def assertPropertyIsMemory(key: String): Unit = + assertProperty[Int](key, "memory", Utils.memoryStringToMb) + + /** Assert that a Spark property can be converted to a certain type. */ + private def assertProperty[T](key: String, valueType: String, convert: (String => T)): Unit = { + sparkProperties.get(key).foreach { value => + Try(convert(value)).getOrElse { + throw new SubmitRestProtocolException( + s"Property '$key' expected $valueType value: actual was '$value'.") + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala new file mode 100644 index 000000000000..8fde8c142a4c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.lang.Boolean + +/** + * An abstract response sent from the server in the REST application submission protocol. + */ +private[rest] abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage { + var serverSparkVersion: String = null + var success: Boolean = null + var unknownFields: Array[String] = null + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(serverSparkVersion, "serverSparkVersion") + } +} + +/** + * A response to a [[CreateSubmissionRequest]] in the REST application submission protocol. + */ +private[rest] class CreateSubmissionResponse extends SubmitRestProtocolResponse { + var submissionId: String = null + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(success, "success") + } +} + +/** + * A response to a kill request in the REST application submission protocol. + */ +private[rest] class KillSubmissionResponse extends SubmitRestProtocolResponse { + var submissionId: String = null + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(submissionId, "submissionId") + assertFieldIsSet(success, "success") + } +} + +/** + * A response to a status request in the REST application submission protocol. + */ +private[rest] class SubmissionStatusResponse extends SubmitRestProtocolResponse { + var submissionId: String = null + var driverState: String = null + var workerId: String = null + var workerHostPort: String = null + + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(submissionId, "submissionId") + assertFieldIsSet(success, "success") + } +} + +/** + * An error response message used in the REST application submission protocol. + */ +private[rest] class ErrorResponse extends SubmitRestProtocolResponse { + // The highest protocol version that the server knows about + // This is set when the client specifies an unknown version + var highestProtocolVersion: String = null + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(message, "message") + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala index 28e9662db5da..0a1d60f58bc5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -20,16 +20,18 @@ package org.apache.spark.deploy.worker import java.io.{File, FileOutputStream, InputStream, IOException} import java.lang.System._ +import scala.collection.JavaConversions._ import scala.collection.Map import org.apache.spark.Logging import org.apache.spark.deploy.Command +import org.apache.spark.launcher.WorkerCommandBuilder import org.apache.spark.util.Utils /** ** Utilities for running commands with the spark classpath. */ -private[spark] +private[deploy] object CommandUtils extends Logging { /** @@ -54,12 +56,10 @@ object CommandUtils extends Logging { } private def buildCommandSeq(command: Command, memory: Int, sparkHome: String): Seq[String] = { - val runner = sys.env.get("JAVA_HOME").map(_ + "/bin/java").getOrElse("java") - // SPARK-698: do not call the run.cmd script, as process.destroy() // fails to kill a process tree on Windows - Seq(runner) ++ buildJavaOpts(command, memory, sparkHome) ++ Seq(command.mainClass) ++ - command.arguments + val cmd = new WorkerCommandBuilder(sparkHome, memory, command).buildCommand() + cmd.toSeq ++ Seq(command.mainClass) ++ command.arguments } /** @@ -92,34 +92,6 @@ object CommandUtils extends Logging { command.javaOpts) } - /** - * Attention: this must always be aligned with the environment variables in the run scripts and - * the way the JAVA_OPTS are assembled there. - */ - private def buildJavaOpts(command: Command, memory: Int, sparkHome: String): Seq[String] = { - val memoryOpts = Seq(s"-Xms${memory}M", s"-Xmx${memory}M") - - // Exists for backwards compatibility with older Spark versions - val workerLocalOpts = Option(getenv("SPARK_JAVA_OPTS")).map(Utils.splitCommandString) - .getOrElse(Nil) - if (workerLocalOpts.length > 0) { - logWarning("SPARK_JAVA_OPTS was set on the worker. It is deprecated in Spark 1.0.") - logWarning("Set SPARK_LOCAL_DIRS for node-specific storage locations.") - } - - // Figure out our classpath with the external compute-classpath script - val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh" - val classPath = Utils.executeAndGetOutput( - Seq(sparkHome + "/bin/compute-classpath" + ext), - extraEnvironment = command.environment) - val userClassPath = command.classPathEntries ++ Seq(classPath) - - val javaVersion = System.getProperty("java.version") - val permGenOpt = if (!javaVersion.startsWith("1.8")) Some("-XX:MaxPermSize=128m") else None - Seq("-cp", userClassPath.filterNot(_.isEmpty).mkString(File.pathSeparator)) ++ - permGenOpt ++ workerLocalOpts ++ command.javaOpts ++ memoryOpts - } - /** Spawn a thread that will redirect a given stream to a file */ def redirectStream(in: InputStream, file: File) { val out = new FileOutputStream(file, true) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 28cab36c7b9e..ef7a703bffe6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -20,64 +20,73 @@ package org.apache.spark.deploy.worker import java.io._ import scala.collection.JavaConversions._ -import scala.collection.Map import akka.actor.ActorRef import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileUtil, Path} +import org.apache.hadoop.fs.Path -import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.deploy.{Command, DriverDescription, SparkHadoopUtil} +import org.apache.spark.{Logging, SparkConf, SecurityManager} +import org.apache.spark.deploy.{DriverDescription, SparkHadoopUtil} import org.apache.spark.deploy.DeployMessages.DriverStateChanged import org.apache.spark.deploy.master.DriverState import org.apache.spark.deploy.master.DriverState.DriverState +import org.apache.spark.util.{Utils, Clock, SystemClock} /** * Manages the execution of one driver, including automatically restarting the driver on failure. * This is currently only used in standalone cluster deploy mode. */ -private[spark] class DriverRunner( - val conf: SparkConf, +private[deploy] class DriverRunner( + conf: SparkConf, val driverId: String, val workDir: File, val sparkHome: File, val driverDesc: DriverDescription, val worker: ActorRef, - val workerUrl: String) + val workerUrl: String, + val securityManager: SecurityManager) extends Logging { - @volatile var process: Option[Process] = None - @volatile var killed = false + @volatile private var process: Option[Process] = None + @volatile private var killed = false // Populated once finished - var finalState: Option[DriverState] = None - var finalException: Option[Exception] = None - var finalExitCode: Option[Int] = None + private[worker] var finalState: Option[DriverState] = None + private[worker] var finalException: Option[Exception] = None + private var finalExitCode: Option[Int] = None // Decoupled for testing - private[deploy] def setClock(_clock: Clock) = clock = _clock - private[deploy] def setSleeper(_sleeper: Sleeper) = sleeper = _sleeper - private var clock = new Clock { - def currentTimeMillis(): Long = System.currentTimeMillis() + def setClock(_clock: Clock): Unit = { + clock = _clock } + + def setSleeper(_sleeper: Sleeper): Unit = { + sleeper = _sleeper + } + + private var clock: Clock = new SystemClock() private var sleeper = new Sleeper { def sleep(seconds: Int): Unit = (0 until seconds).takeWhile(f => {Thread.sleep(1000); !killed}) } /** Starts a thread to run and manage the driver. */ - def start() = { + private[worker] def start() = { new Thread("DriverRunner for " + driverId) { override def run() { try { val driverDir = createWorkingDirectory() val localJarFilename = downloadUserJar(driverDir) - // Make sure user application jar is on the classpath + def substituteVariables(argument: String): String = argument match { + case "{{WORKER_URL}}" => workerUrl + case "{{USER_JAR}}" => localJarFilename + case other => other + } + // TODO: If we add ability to submit multiple jars they should also be added here val builder = CommandUtils.buildProcessBuilder(driverDesc.command, driverDesc.mem, - sparkHome.getAbsolutePath, substituteVariables, Seq(localJarFilename)) + sparkHome.getAbsolutePath, substituteVariables) launchDriver(builder, driverDir, driverDesc.supervise) } catch { @@ -104,19 +113,13 @@ private[spark] class DriverRunner( } /** Terminate this driver (or prevent it from ever starting if not yet started) */ - def kill() { + private[worker] def kill() { synchronized { process.foreach(p => p.destroy()) killed = true } } - /** Replace variables in a command argument passed to us */ - private def substituteVariables(argument: String): String = argument match { - case "{{WORKER_URL}}" => workerUrl - case other => other - } - /** * Creates the working directory for this driver. * Will throw an exception if there are errors preparing the directory. @@ -134,12 +137,9 @@ private[spark] class DriverRunner( * Will throw an exception if there are errors downloading the jar. */ private def downloadUserJar(driverDir: File): String = { - val jarPath = new Path(driverDesc.jarUrl) val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - val jarFileSystem = jarPath.getFileSystem(hadoopConf) - val destPath = new File(driverDir.getAbsolutePath, jarPath.getName) val jarFileName = jarPath.getName val localJarFile = new File(driverDir, jarFileName) @@ -147,7 +147,14 @@ private[spark] class DriverRunner( if (!localJarFile.exists()) { // May already exist if running multiple workers on one node logInfo(s"Copying user jar $jarPath to $destPath") - FileUtil.copy(jarFileSystem, jarPath, destPath, false, hadoopConf) + Utils.fetchFile( + driverDesc.jarUrl, + driverDir, + conf, + securityManager, + hadoopConf, + System.currentTimeMillis(), + useCache = false) } if (!localJarFile.exists()) { // Verify copy succeeded @@ -159,7 +166,7 @@ private[spark] class DriverRunner( private def launchDriver(builder: ProcessBuilder, baseDir: File, supervise: Boolean) { builder.directory(baseDir) - def initialize(process: Process) = { + def initialize(process: Process): Unit = { // Redirect stdout and stderr to files val stdout = new File(baseDir, "stdout") CommandUtils.redirectStream(process.getInputStream, stdout) @@ -173,8 +180,8 @@ private[spark] class DriverRunner( runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise) } - private[deploy] def runCommandWithRetry(command: ProcessBuilderLike, initialize: Process => Unit, - supervise: Boolean) { + def runCommandWithRetry( + command: ProcessBuilderLike, initialize: Process => Unit, supervise: Boolean): Unit = { // Time to wait between submission retries. var waitSeconds = 1 // A run of this many seconds resets the exponential back-off. @@ -191,9 +198,9 @@ private[spark] class DriverRunner( initialize(process.get) } - val processStart = clock.currentTimeMillis() + val processStart = clock.getTimeMillis() val exitCode = process.get.waitFor() - if (clock.currentTimeMillis() - processStart > successfulRunDuration * 1000) { + if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000) { waitSeconds = 1 } @@ -209,10 +216,6 @@ private[spark] class DriverRunner( } } -private[deploy] trait Clock { - def currentTimeMillis(): Long -} - private[deploy] trait Sleeper { def sleep(seconds: Int) } @@ -224,8 +227,8 @@ private[deploy] trait ProcessBuilderLike { } private[deploy] object ProcessBuilderLike { - def apply(processBuilder: ProcessBuilder) = new ProcessBuilderLike { - def start() = processBuilder.start() - def command = processBuilder.command() + def apply(processBuilder: ProcessBuilder): ProcessBuilderLike = new ProcessBuilderLike { + override def start(): Process = processBuilder.start() + override def command: Seq[String] = processBuilder.command() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index 05e242e6df70..d1a12b01e78f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -17,32 +17,50 @@ package org.apache.spark.deploy.worker -import akka.actor._ +import java.io.File import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.rpc.RpcEnv +import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} /** * Utility object for launching driver programs such that they share fate with the Worker process. + * This is used in standalone cluster mode only. */ object DriverWrapper { def main(args: Array[String]) { args.toList match { - case workerUrl :: mainClass :: extraArgs => + /* + * IMPORTANT: Spark 1.3 provides a stable application submission gateway that is both + * backward and forward compatible across future Spark versions. Because this gateway + * uses this class to launch the driver, the ordering and semantics of the arguments + * here must also remain consistent across versions. + */ + case workerUrl :: userJar :: mainClass :: extraArgs => val conf = new SparkConf() - val (actorSystem, _) = AkkaUtils.createActorSystem("Driver", + val rpcEnv = RpcEnv.create("Driver", Utils.localHostName(), 0, conf, new SecurityManager(conf)) - actorSystem.actorOf(Props(classOf[WorkerWatcher], workerUrl), name = "workerWatcher") + rpcEnv.setupEndpoint("workerWatcher", new WorkerWatcher(rpcEnv, workerUrl)) + + val currentLoader = Thread.currentThread.getContextClassLoader + val userJarUrl = new File(userJar).toURI().toURL() + val loader = + if (sys.props.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) { + new ChildFirstURLClassLoader(Array(userJarUrl), currentLoader) + } else { + new MutableURLClassLoader(Array(userJarUrl), currentLoader) + } + Thread.currentThread.setContextClassLoader(loader) // Delegate to supplied main class - val clazz = Class.forName(args(1)) + val clazz = Class.forName(mainClass, true, loader) val mainMethod = clazz.getMethod("main", classOf[Array[String]]) mainMethod.invoke(null, extraArgs.toArray[String]) - actorSystem.shutdown() + rpcEnv.shutdown() case _ => - System.err.println("Usage: DriverWrapper [options]") + System.err.println("Usage: DriverWrapper [options]") System.exit(-1) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index acbdf0d8bd7b..7aa85b732fc8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -26,15 +26,16 @@ import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.spark.{SparkConf, Logging} -import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} +import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged +import org.apache.spark.util.Utils import org.apache.spark.util.logging.FileAppender /** * Manages the execution of one executor process. * This is currently only used in standalone mode. */ -private[spark] class ExecutorRunner( +private[deploy] class ExecutorRunner( val appId: String, val execId: Int, val appDesc: ApplicationDescription, @@ -43,36 +44,33 @@ private[spark] class ExecutorRunner( val worker: ActorRef, val workerId: String, val host: String, + val webUiPort: Int, + val publicAddress: String, val sparkHome: File, val executorDir: File, val workerUrl: String, - val conf: SparkConf, + conf: SparkConf, val appLocalDirs: Seq[String], - var state: ExecutorState.Value) + @volatile var state: ExecutorState.Value) extends Logging { - val fullId = appId + "/" + execId - var workerThread: Thread = null - var process: Process = null - var stdoutAppender: FileAppender = null - var stderrAppender: FileAppender = null + private val fullId = appId + "/" + execId + private var workerThread: Thread = null + private var process: Process = null + private var stdoutAppender: FileAppender = null + private var stderrAppender: FileAppender = null // NOTE: This is now redundant with the automated shut-down enforced by the Executor. It might // make sense to remove this in the future. - var shutdownHook: Thread = null + private var shutdownHook: AnyRef = null - def start() { + private[worker] def start() { workerThread = new Thread("ExecutorRunner for " + fullId) { override def run() { fetchAndRunExecutor() } } workerThread.start() // Shutdown hook that kills actors on shutdown. - shutdownHook = new Thread() { - override def run() { - killProcess(Some("Worker shutting down")) - } - } - Runtime.getRuntime.addShutdownHook(shutdownHook) + shutdownHook = Utils.addShutdownHook { () => killProcess(Some("Worker shutting down")) } } /** @@ -84,32 +82,35 @@ private[spark] class ExecutorRunner( var exitCode: Option[Int] = None if (process != null) { logInfo("Killing process!") - process.destroy() - process.waitFor() if (stdoutAppender != null) { stdoutAppender.stop() } if (stderrAppender != null) { stderrAppender.stop() } + process.destroy() exitCode = Some(process.waitFor()) } worker ! ExecutorStateChanged(appId, execId, state, message, exitCode) } /** Stop this executor runner, including killing the process it launched */ - def kill() { + private[worker] def kill() { if (workerThread != null) { // the workerThread will kill the child process when interrupted workerThread.interrupt() workerThread = null state = ExecutorState.KILLED - Runtime.getRuntime.removeShutdownHook(shutdownHook) + try { + Utils.removeShutdownHook(shutdownHook) + } catch { + case e: IllegalStateException => None + } } } /** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */ - def substituteVariables(argument: String): String = argument match { + private[worker] def substituteVariables(argument: String): String = argument match { case "{{WORKER_URL}}" => workerUrl case "{{EXECUTOR_ID}}" => execId.toString case "{{HOSTNAME}}" => host @@ -121,7 +122,7 @@ private[spark] class ExecutorRunner( /** * Download and run the executor described in our ApplicationDescription */ - def fetchAndRunExecutor() { + private def fetchAndRunExecutor() { try { // Launch the process val builder = CommandUtils.buildProcessBuilder(appDesc.command, memory, @@ -130,10 +131,17 @@ private[spark] class ExecutorRunner( logInfo("Launch command: " + command.mkString("\"", "\" \"", "\"")) builder.directory(executorDir) - builder.environment.put("SPARK_LOCAL_DIRS", appLocalDirs.mkString(",")) + builder.environment.put("SPARK_EXECUTOR_DIRS", appLocalDirs.mkString(File.pathSeparator)) // In case we are running this from within the Spark Shell, avoid creating a "scala" // parent process for the executor command builder.environment.put("SPARK_LAUNCH_WITH_SCALA", "0") + + // Add webUI log urls + val baseUrl = + s"http://$publicAddress:$webUiPort/logPage/?appId=$appId&executorId=$execId&logType=" + builder.environment.put("SPARK_LOG_URL_STDERR", s"${baseUrl}stderr") + builder.environment.put("SPARK_LOG_URL_STDOUT", s"${baseUrl}stdout") + process = builder.start() val header = "Spark Executor Command: %s\n%s\n\n".format( command.mkString("\"", "\" \"", "\""), "=" * 40) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 13599830123d..3ee2eb69e8a4 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -31,8 +31,8 @@ import scala.util.Random import akka.actor._ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} -import org.apache.spark.deploy.{ExecutorDescription, ExecutorState} +import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.deploy.{Command, ExecutorDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI @@ -42,7 +42,7 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} /** * @param masterAkkaUrls Each url should be a valid akka url. */ -private[spark] class Worker( +private[worker] class Worker( host: String, port: Int, webUiPort: Int, @@ -60,80 +60,90 @@ private[spark] class Worker( Utils.checkHost(host, "Expected hostname") assert (port > 0) - def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs + // For worker and executor IDs + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // Send a heartbeat every (heartbeat timeout) / 4 milliseconds - val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 + private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 // Model retries to connect to the master, after Hadoop's model. // The first six attempts to reconnect are in shorter intervals (between 5 and 15 seconds) // Afterwards, the next 10 attempts are between 30 and 90 seconds. // A bit of randomness is introduced so that not all of the workers attempt to reconnect at // the same time. - val INITIAL_REGISTRATION_RETRIES = 6 - val TOTAL_REGISTRATION_RETRIES = INITIAL_REGISTRATION_RETRIES + 10 - val FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND = 0.500 - val REGISTRATION_RETRY_FUZZ_MULTIPLIER = { + private val INITIAL_REGISTRATION_RETRIES = 6 + private val TOTAL_REGISTRATION_RETRIES = INITIAL_REGISTRATION_RETRIES + 10 + private val FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND = 0.500 + private val REGISTRATION_RETRY_FUZZ_MULTIPLIER = { val randomNumberGenerator = new Random(UUID.randomUUID.getMostSignificantBits) randomNumberGenerator.nextDouble + FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND } - val INITIAL_REGISTRATION_RETRY_INTERVAL = (math.round(10 * + private val INITIAL_REGISTRATION_RETRY_INTERVAL = (math.round(10 * REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds - val PROLONGED_REGISTRATION_RETRY_INTERVAL = (math.round(60 + private val PROLONGED_REGISTRATION_RETRY_INTERVAL = (math.round(60 * REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds - val CLEANUP_ENABLED = conf.getBoolean("spark.worker.cleanup.enabled", false) + private val CLEANUP_ENABLED = conf.getBoolean("spark.worker.cleanup.enabled", false) // How often worker will clean up old app folders - val CLEANUP_INTERVAL_MILLIS = conf.getLong("spark.worker.cleanup.interval", 60 * 30) * 1000 + private val CLEANUP_INTERVAL_MILLIS = + conf.getLong("spark.worker.cleanup.interval", 60 * 30) * 1000 // TTL for app folders/data; after TTL expires it will be cleaned up - val APP_DATA_RETENTION_SECS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) - - val testing: Boolean = sys.props.contains("spark.testing") - var master: ActorSelection = null - var masterAddress: Address = null - var activeMasterUrl: String = "" - var activeMasterWebUiUrl : String = "" - val akkaUrl = "akka.tcp://%s@%s:%s/user/%s".format(actorSystemName, host, port, actorName) - @volatile var registered = false - @volatile var connected = false - val workerId = generateWorkerId() - val sparkHome = + private val APP_DATA_RETENTION_SECS = + conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) + + private val testing: Boolean = sys.props.contains("spark.testing") + private var master: ActorSelection = null + private var masterAddress: Address = null + private var activeMasterUrl: String = "" + private[worker] var activeMasterWebUiUrl : String = "" + private val akkaUrl = AkkaUtils.address( + AkkaUtils.protocol(context.system), + actorSystemName, + host, + port, + actorName) + @volatile private var registered = false + @volatile private var connected = false + private val workerId = generateWorkerId() + private val sparkHome = if (testing) { assert(sys.props.contains("spark.test.home"), "spark.test.home is not set!") new File(sys.props("spark.test.home")) } else { new File(sys.env.get("SPARK_HOME").getOrElse(".")) } + var workDir: File = null - val executors = new HashMap[String, ExecutorRunner] val finishedExecutors = new HashMap[String, ExecutorRunner] val drivers = new HashMap[String, DriverRunner] + val executors = new HashMap[String, ExecutorRunner] val finishedDrivers = new HashMap[String, DriverRunner] val appDirectories = new HashMap[String, Seq[String]] val finishedApps = new HashSet[String] // The shuffle service is not actually started unless configured. - val shuffleService = new StandaloneWorkerShuffleService(conf, securityMgr) + private val shuffleService = new StandaloneWorkerShuffleService(conf, securityMgr) - val publicAddress = { - val envVar = System.getenv("SPARK_PUBLIC_DNS") + private val publicAddress = { + val envVar = conf.getenv("SPARK_PUBLIC_DNS") if (envVar != null) envVar else host } - var webUi: WorkerWebUI = null + private var webUi: WorkerWebUI = null - var coresUsed = 0 - var memoryUsed = 0 - var connectionAttemptCount = 0 + private var connectionAttemptCount = 0 - val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr) - val workerSource = new WorkerSource(this) + private val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr) + private val workerSource = new WorkerSource(this) + + private var registrationRetryTimer: Option[Cancellable] = None - var registrationRetryTimer: Option[Cancellable] = None + var coresUsed = 0 + var memoryUsed = 0 def coresFree: Int = cores - coresUsed def memoryFree: Int = memory - memoryUsed - def createWorkDir() { + private def createWorkDir() { workDir = Option(workDirPath).map(new File(_)).getOrElse(new File(sparkHome, "work")) try { // This sporadically fails - not sure why ... !workDir.exists() && !workDir.mkdirs() @@ -170,12 +180,13 @@ private[spark] class Worker( metricsSystem.getServletHandlers.foreach(webUi.attachHandler) } - def changeMaster(url: String, uiUrl: String) { + private def changeMaster(url: String, uiUrl: String) { // activeMasterUrl it's a valid Spark url since we receive it from master. activeMasterUrl = url activeMasterWebUiUrl = uiUrl - master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl)) - masterAddress = Master.toAkkaAddress(activeMasterUrl) + master = context.actorSelection( + Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(context.system))) + masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(context.system)) connected = true // Cancel any outstanding re-registration attempts because we found a new master registrationRetryTimer.foreach(_.cancel()) @@ -246,7 +257,7 @@ private[spark] class Worker( } } - def registerWithMaster() { + private def registerWithMaster() { // DisassociatedEvent may be triggered multiple times, so don't attempt registration // if there are outstanding registration attempts scheduled. registrationRetryTimer match { @@ -264,7 +275,7 @@ private[spark] class Worker( } } - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case RegisteredWorker(masterUrl, masterWebUiUrl) => logInfo("Successfully registered with master " + masterUrl) registered = true @@ -339,18 +350,30 @@ private[spark] class Worker( } // Create local dirs for the executor. These are passed to the executor via the - // SPARK_LOCAL_DIRS environment variable, and deleted by the Worker when the + // SPARK_EXECUTOR_DIRS environment variable, and deleted by the Worker when the // application finishes. val appLocalDirs = appDirectories.get(appId).getOrElse { Utils.getOrCreateLocalRootDirs(conf).map { dir => - Utils.createDirectory(dir).getAbsolutePath() + Utils.createDirectory(dir, namePrefix = "executor").getAbsolutePath() }.toSeq } appDirectories(appId) = appLocalDirs - - val manager = new ExecutorRunner(appId, execId, appDesc, cores_, memory_, - self, workerId, host, sparkHome, executorDir, akkaUrl, conf, appLocalDirs, - ExecutorState.LOADING) + val manager = new ExecutorRunner( + appId, + execId, + appDesc.copy(command = Worker.maybeUpdateSSLSettings(appDesc.command, conf)), + cores_, + memory_, + self, + workerId, + host, + webUi.boundPort, + publicAddress, + sparkHome, + executorDir, + akkaUrl, + conf, + appLocalDirs, ExecutorState.LOADING) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ @@ -406,7 +429,15 @@ private[spark] class Worker( case LaunchDriver(driverId, driverDesc) => { logInfo(s"Asked to launch driver $driverId") - val driver = new DriverRunner(conf, driverId, workDir, sparkHome, driverDesc, self, akkaUrl) + val driver = new DriverRunner( + conf, + driverId, + workDir, + sparkHome, + driverDesc.copy(command = Worker.maybeUpdateSSLSettings(driverDesc.command, conf)), + self, + akkaUrl, + securityMgr) drivers(driverId) = driver driver.start() @@ -481,7 +512,7 @@ private[spark] class Worker( } } - def generateWorkerId(): String = { + private def generateWorkerId(): String = { "worker-%s-%s-%d".format(createDateFormat.format(new Date), host, port) } @@ -496,7 +527,7 @@ private[spark] class Worker( } } -private[spark] object Worker extends Logging { +private[deploy] object Worker extends Logging { def main(argStrings: Array[String]) { SignalLogger.register(log) val conf = new SparkConf @@ -514,19 +545,41 @@ private[spark] object Worker extends Logging { memory: Int, masterUrls: Array[String], workDir: String, - workerNumber: Option[Int] = None): (ActorSystem, Int) = { + workerNumber: Option[Int] = None, + conf: SparkConf = new SparkConf): (ActorSystem, Int) = { // The LocalSparkCluster runs multiple local sparkWorkerX actor systems - val conf = new SparkConf val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") val actorName = "Worker" val securityMgr = new SecurityManager(conf) val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf, securityManager = securityMgr) - val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl) + val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem))) actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory, masterAkkaUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName) (actorSystem, boundPort) } + def isUseLocalNodeSSLConfig(cmd: Command): Boolean = { + val pattern = """\-Dspark\.ssl\.useNodeLocalConf\=(.+)""".r + val result = cmd.javaOpts.collectFirst { + case pattern(_result) => _result.toBoolean + } + result.getOrElse(false) + } + + def maybeUpdateSSLSettings(cmd: Command, conf: SparkConf): Command = { + val prefix = "spark.ssl." + val useNLC = "spark.ssl.useNodeLocalConf" + if (isUseLocalNodeSSLConfig(cmd)) { + val newJavaOpts = cmd.javaOpts + .filter(opt => !opt.startsWith(s"-D$prefix")) ++ + conf.getAll.collect { case (key, value) if key.startsWith(prefix) => s"-D$key=$value" } :+ + s"-D$useNLC=true" + cmd.copy(javaOpts = newJavaOpts) + } else { + cmd + } + } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 019cd70f2a22..88f9d880ac20 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -25,7 +25,7 @@ import org.apache.spark.SparkConf /** * Command-line parser for the worker. */ -private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { +private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { var host = Utils.localHostName() var port = 0 var webUiPort = 8081 @@ -63,7 +63,7 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { checkWorkerMemory() - def parse(args: List[String]): Unit = args match { + private def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => Utils.checkHost(value, "ip no longer supported, please use hostname " + value) host = value diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala index df1e01b23b93..b36023bc40c3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala @@ -21,7 +21,7 @@ import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.metrics.source.Source -private[spark] class WorkerSource(val worker: Worker) extends Source { +private[worker] class WorkerSource(val worker: Worker) extends Source { override val sourceName = "worker" override val metricRegistry = new MetricRegistry() diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 63a8ac817b61..83fb991891a4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -17,58 +17,63 @@ package org.apache.spark.deploy.worker -import akka.actor.{Actor, Address, AddressFromURIString} -import akka.remote.{AssociatedEvent, AssociationErrorEvent, AssociationEvent, DisassociatedEvent, RemotingLifecycleEvent} - import org.apache.spark.Logging import org.apache.spark.deploy.DeployMessages.SendHeartbeat -import org.apache.spark.util.ActorLogReceive +import org.apache.spark.rpc._ /** * Actor which connects to a worker process and terminates the JVM if the connection is severed. * Provides fate sharing between a worker and its associated child processes. */ -private[spark] class WorkerWatcher(workerUrl: String) - extends Actor with ActorLogReceive with Logging { - - override def preStart() { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) +private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: String) + extends RpcEndpoint with Logging { + override def onStart() { logInfo(s"Connecting to worker $workerUrl") - val worker = context.actorSelection(workerUrl) - worker ! SendHeartbeat // need to send a message here to initiate connection + if (!isTesting) { + rpcEnv.asyncSetupEndpointRefByURI(workerUrl) + } } // Used to avoid shutting down JVM during tests + // In the normal case, exitNonZero will call `System.exit(-1)` to shutdown the JVM. In the unit + // test, the user should call `setTesting(true)` so that `exitNonZero` will set `isShutDown` to + // true rather than calling `System.exit`. The user can check `isShutDown` to know if + // `exitNonZero` is called. private[deploy] var isShutDown = false private[deploy] def setTesting(testing: Boolean) = isTesting = testing private var isTesting = false // Lets us filter events only from the worker's actor system - private val expectedHostPort = AddressFromURIString(workerUrl).hostPort - private def isWorker(address: Address) = address.hostPort == expectedHostPort + private val expectedAddress = RpcAddress.fromURIString(workerUrl) + private def isWorker(address: RpcAddress) = expectedAddress == address - def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1) + private def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1) - override def receiveWithLogging = { - case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => - logInfo(s"Successfully connected to $workerUrl") + override def receive: PartialFunction[Any, Unit] = { + case e => logWarning(s"Received unexpected message: $e") + } - case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) - if isWorker(remoteAddress) => - // These logs may not be seen if the worker (and associated pipe) has died - logError(s"Could not initialize connection to worker $workerUrl. Exiting.") - logError(s"Error was: $cause") - exitNonZero() + override def onConnected(remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { + logInfo(s"Successfully connected to $workerUrl") + } + } - case DisassociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { // This log message will never be seen logError(s"Lost connection to worker actor $workerUrl. Exiting.") exitNonZero() + } + } - case e: AssociationEvent => - // pass through association events relating to other remote actor systems - - case e => logWarning(s"Received unexpected actor system event: $e") + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { + // These logs may not be seen if the worker (and associated pipe) has died + logError(s"Could not initialize connection to worker $workerUrl. Exiting.") + logError(s"Error was: $cause") + exitNonZero() + } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index ecb358c39981..88170d4df305 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -26,7 +26,7 @@ import org.apache.spark.util.Utils import org.apache.spark.Logging import org.apache.spark.util.logging.RollingFileAppender -private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with Logging { +private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with Logging { private val worker = parent.worker private val workDir = parent.workDir diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 327b90503280..9f9f27d71e1a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -31,10 +31,9 @@ import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils -private[spark] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { - val workerActor = parent.worker.self - val worker = parent.worker - val timeout = parent.timeout +private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { + private val workerActor = parent.worker.self + private val timeout = parent.timeout override def renderJson(request: HttpServletRequest): JValue = { val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] @@ -134,7 +133,7 @@ private[spark] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { def driverRow(driver: DriverRunner): Seq[Node] = { {driver.driverId} - {driver.driverDesc.command.arguments(1)} + {driver.driverDesc.command.arguments(2)} {driver.finalState.getOrElse(DriverState.RUNNING)} {driver.driverDesc.cores.toString} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 7ac81a2d87ef..b3bb5f911dbd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -25,12 +25,12 @@ import org.apache.spark.deploy.worker.Worker import org.apache.spark.deploy.worker.ui.WorkerWebUI._ import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.AkkaUtils +import org.apache.spark.util.RpcUtils /** * Web UI server for the standalone worker. */ -private[spark] +private[worker] class WorkerWebUI( val worker: Worker, val workDir: File, @@ -38,7 +38,7 @@ class WorkerWebUI( extends WebUI(worker.securityMgr, requestedPort, worker.conf, name = "WorkerUI") with Logging { - val timeout = AkkaUtils.askTimeout(worker.conf) + private[ui] val timeout = RpcUtils.askTimeout(worker.conf) initialize() @@ -53,6 +53,6 @@ class WorkerWebUI( } } -private[spark] object WorkerWebUI { +private[ui] object WorkerWebUI { val STATIC_RESOURCE_BASE = SparkUI.STATIC_RESOURCE_DIR } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 9a4adfbbb3d7..8af46f3327ad 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -17,47 +17,67 @@ package org.apache.spark.executor +import java.net.URL import java.nio.ByteBuffer -import scala.concurrent.Await +import scala.collection.mutable +import scala.util.{Failure, Success} -import akka.actor.{Actor, ActorSelection, Props} -import akka.pattern.Patterns -import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent} - -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv} +import org.apache.spark.rpc._ +import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} +import org.apache.spark.serializer.SerializerInstance +import org.apache.spark.util.{SignalLogger, Utils} private[spark] class CoarseGrainedExecutorBackend( + override val rpcEnv: RpcEnv, driverUrl: String, executorId: String, hostPort: String, cores: Int, + userClassPath: Seq[URL], env: SparkEnv) - extends Actor with ActorLogReceive with ExecutorBackend with Logging { + extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging { Utils.checkHostPort(hostPort, "Expected hostport") var executor: Executor = null - var driver: ActorSelection = null + @volatile var driver: Option[RpcEndpointRef] = None + + // If this CoarseGrainedExecutorBackend is changed to support multiple threads, then this may need + // to be changed so that we don't share the serializer instance across threads + private[this] val ser: SerializerInstance = env.closureSerializer.newInstance() - override def preStart() { + override def onStart() { + import scala.concurrent.ExecutionContext.Implicits.global logInfo("Connecting to driver: " + driverUrl) - driver = context.actorSelection(driverUrl) - driver ! RegisterExecutor(executorId, hostPort, cores) - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref => + driver = Some(ref) + ref.sendWithReply[RegisteredExecutor.type]( + RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls)) + } onComplete { + case Success(msg) => Utils.tryLogNonFatalError { + Option(self).foreach(_.send(msg)) // msg must be RegisteredExecutor + } + case Failure(e) => logError(s"Cannot register with driver: $driverUrl", e) + } + } + + def extractLogUrls: Map[String, String] = { + val prefix = "SPARK_LOG_URL_" + sys.env.filterKeys(_.startsWith(prefix)) + .map(e => (e._1.substring(prefix.length).toLowerCase, e._2)) } - override def receiveWithLogging = { + override def receive: PartialFunction[Any, Unit] = { case RegisteredExecutor => logInfo("Successfully registered with driver") val (hostname, _) = Utils.parseHostPort(hostPort) - executor = new Executor(executorId, hostname, env, isLocal = false) + executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false) case RegisterExecutorFailed(message) => logError("Slave registration failed: " + message) @@ -68,7 +88,6 @@ private[spark] class CoarseGrainedExecutorBackend( logError("Received LaunchTask command but executor was null") System.exit(1) } else { - val ser = env.closureSerializer.newInstance() val taskDesc = ser.deserialize[TaskDescription](data.value) logInfo("Got assigned task " + taskDesc.taskId) executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber, @@ -83,19 +102,28 @@ private[spark] class CoarseGrainedExecutorBackend( executor.killTask(taskId, interruptThread) } - case x: DisassociatedEvent => - logError(s"Driver $x disassociated! Shutting down.") - System.exit(1) - case StopExecutor => logInfo("Driver commanded a shutdown") executor.stop() - context.stop(self) - context.system.shutdown() + stop() + rpcEnv.shutdown() + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (driver.exists(_.address == remoteAddress)) { + logError(s"Driver $remoteAddress disassociated! Shutting down.") + System.exit(1) + } else { + logWarning(s"An unknown ($remoteAddress) driver disconnected.") + } } override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { - driver ! StatusUpdate(executorId, taskId, state, data) + val msg = StatusUpdate(executorId, taskId, state, data) + driver match { + case Some(driverRef) => driverRef.send(msg) + case None => logWarning(s"Drop $msg because has not yet connected to driver") + } } } @@ -107,7 +135,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { hostname: String, cores: Int, appId: String, - workerUrl: Option[String]) { + workerUrl: Option[String], + userClassPath: Seq[URL]) { SignalLogger.register(log) @@ -118,17 +147,27 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Bootstrap to fetch the driver's Spark properties. val executorConf = new SparkConf val port = executorConf.getInt("spark.executor.port", 0) - val (fetcher, _) = AkkaUtils.createActorSystem( - "driverPropsFetcher", hostname, port, executorConf, new SecurityManager(executorConf)) - val driver = fetcher.actorSelection(driverUrl) - val timeout = AkkaUtils.askTimeout(executorConf) - val fut = Patterns.ask(driver, RetrieveSparkProps, timeout) - val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]] ++ + val fetcher = RpcEnv.create( + "driverPropsFetcher", + hostname, + port, + executorConf, + new SecurityManager(executorConf)) + val driver = fetcher.setupEndpointRefByURI(driverUrl) + val props = driver.askWithReply[Seq[(String, String)]](RetrieveSparkProps) ++ Seq[(String, String)](("spark.app.id", appId)) fetcher.shutdown() // Create SparkEnv using properties we fetched from the driver. - val driverConf = new SparkConf().setAll(props) + val driverConf = new SparkConf() + for ((key, value) <- props) { + // this is required for SSL in standalone mode + if (SparkConf.isExecutorStartupConf(key)) { + driverConf.setIfMissing(key, value) + } else { + driverConf.set(key, value) + } + } val env = SparkEnv.createExecutorEnv( driverConf, executorId, hostname, port, cores, isLocal = false) @@ -136,34 +175,81 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { val boundPort = env.conf.getInt("spark.executor.port", 0) assert(boundPort != 0) - // Start the CoarseGrainedExecutorBackend actor. + // Start the CoarseGrainedExecutorBackend endpoint. val sparkHostPort = hostname + ":" + boundPort - env.actorSystem.actorOf( - Props(classOf[CoarseGrainedExecutorBackend], - driverUrl, executorId, sparkHostPort, cores, env), - name = "Executor") + env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( + env.rpcEnv, driverUrl, executorId, sparkHostPort, cores, userClassPath, env)) workerUrl.foreach { url => - env.actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher") + env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url)) } - env.actorSystem.awaitTermination() + env.rpcEnv.awaitTermination() } } def main(args: Array[String]) { - args.length match { - case x if x < 5 => - System.err.println( + var driverUrl: String = null + var executorId: String = null + var hostname: String = null + var cores: Int = 0 + var appId: String = null + var workerUrl: Option[String] = None + val userClassPath = new mutable.ListBuffer[URL]() + + var argv = args.toList + while (!argv.isEmpty) { + argv match { + case ("--driver-url") :: value :: tail => + driverUrl = value + argv = tail + case ("--executor-id") :: value :: tail => + executorId = value + argv = tail + case ("--hostname") :: value :: tail => + hostname = value + argv = tail + case ("--cores") :: value :: tail => + cores = value.toInt + argv = tail + case ("--app-id") :: value :: tail => + appId = value + argv = tail + case ("--worker-url") :: value :: tail => // Worker url is used in spark standalone mode to enforce fate-sharing with worker - "Usage: CoarseGrainedExecutorBackend " + - " [] ") - System.exit(1) + workerUrl = Some(value) + argv = tail + case ("--user-class-path") :: value :: tail => + userClassPath += new URL(value) + argv = tail + case Nil => + case tail => + System.err.println(s"Unrecognized options: ${tail.mkString(" ")}") + printUsageAndExit() + } + } - // NB: These arguments are provided by SparkDeploySchedulerBackend (for standalone mode) - // and CoarseMesosSchedulerBackend (for mesos mode). - case 5 => - run(args(0), args(1), args(2), args(3).toInt, args(4), None) - case x if x > 5 => - run(args(0), args(1), args(2), args(3).toInt, args(4), Some(args(5))) + if (driverUrl == null || executorId == null || hostname == null || cores <= 0 || + appId == null) { + printUsageAndExit() } + + run(driverUrl, executorId, hostname, cores, appId, workerUrl, userClassPath) } + + private def printUsageAndExit() = { + System.err.println( + """ + |"Usage: CoarseGrainedExecutorBackend [options] + | + | Options are: + | --driver-url + | --executor-id + | --hostname + | --cores + | --app-id + | --worker-url + | --user-class-path + |""".stripMargin) + System.exit(1) + } + } diff --git a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala new file mode 100644 index 000000000000..f47d7ef511da --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import org.apache.spark.{TaskCommitDenied, TaskEndReason} + +/** + * Exception thrown when a task attempts to commit output to HDFS but is denied by the driver. + */ +private[spark] class CommitDeniedException( + msg: String, + jobID: Int, + splitID: Int, + attemptID: Int) + extends Exception(msg) { + + def toTaskEndReason: TaskEndReason = TaskCommitDenied(jobID, splitID, attemptID) +} diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 42566d1a1409..f57e215c3f2e 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -19,33 +19,38 @@ package org.apache.spark.executor import java.io.File import java.lang.management.ManagementFactory +import java.net.URL import java.nio.ByteBuffer -import java.util.concurrent._ +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal -import akka.actor.Props - import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} -import org.apache.spark.util.{SparkUncaughtExceptionHandler, AkkaUtils, Utils} +import org.apache.spark.util._ /** - * Spark executor used with Mesos, YARN, and the standalone scheduler. - * In coarse-grained mode, an existing actor system is provided. + * Spark executor, backed by a threadpool to run tasks. + * + * This can be used with Mesos, YARN, and the standalone scheduler. + * An internal RPC interface (at the moment Akka) is used for communication with the driver, + * except in the case of Mesos fine-grained mode. */ private[spark] class Executor( executorId: String, - slaveHostname: String, + executorHostname: String, env: SparkEnv, + userClassPath: Seq[URL] = Nil, isLocal: Boolean = false) - extends Logging -{ + extends Logging { + + logInfo(s"Starting executor ID $executorId on host $executorHostname") + // Application dependencies (added through SparkContext) that we've fetched so far on this node. // Each map holds the master's timestamp for the version of that file or JAR we got. private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() @@ -55,15 +60,13 @@ private[spark] class Executor( private val conf = env.conf - @volatile private var isStopped = false - // No ip or host:port - just hostname - Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname") + Utils.checkHost(executorHostname, "Expected executed slave to be a hostname") // must not have port specified. - assert (0 == Utils.parseHostPort(slaveHostname)._2) + assert (0 == Utils.parseHostPort(executorHostname)._2) // Make sure the local hostname we report matches the cluster scheduler's name for this host - Utils.setCustomHostname(slaveHostname) + Utils.setCustomHostname(executorHostname) if (!isLocal) { // Setup an uncaught exception handler for non-local mode. @@ -72,17 +75,21 @@ private[spark] class Executor( Thread.setDefaultUncaughtExceptionHandler(SparkUncaughtExceptionHandler) } - val executorSource = new ExecutorSource(this, executorId) - conf.set("spark.executor.id", executorId) + // Start worker thread pool + private val threadPool = ThreadUtils.newDaemonCachedThreadPool("Executor task launch worker") + private val executorSource = new ExecutorSource(threadPool, executorId) if (!isLocal) { env.metricsSystem.registerSource(executorSource) env.blockManager.initialize(conf.getAppId) } - // Create an actor for receiving RPCs from the driver - private val executorActor = env.actorSystem.actorOf( - Props(new ExecutorActor(executorId)), "ExecutorActor") + // Create an RpcEndpoint for receiving RPCs from the driver + private val executorEndpoint = env.rpcEnv.setupEndpoint( + ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME, new ExecutorEndpoint(env.rpcEnv, executorId)) + + // Whether to load classes in user jars before those in Spark jars + private val userClassPathFirst = conf.getBoolean("spark.executor.userClassPathFirst", false) // Create our ClassLoader // do this after SparkEnv creation so can access the SecurityManager @@ -90,7 +97,7 @@ private[spark] class Executor( private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) // Set the classloader for serializer - env.serializer.setDefaultClassLoader(urlClassLoader) + env.serializer.setDefaultClassLoader(replClassLoader) // Akka's message frame size. If task result is bigger than this, we use the block manager // to send the result back. @@ -99,12 +106,12 @@ private[spark] class Executor( // Limit of bytes for total size of results (default is 1GB) private val maxResultSize = Utils.getMaxResultSize(conf) - // Start worker thread pool - val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker") - // Maintains the list of running tasks. private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] + // Executor for the heartbeat task. + private val heartbeater = ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-heartbeater") + startDriverHeartbeater() def launchTask( @@ -112,31 +119,35 @@ private[spark] class Executor( taskId: Long, attemptNumber: Int, taskName: String, - serializedTask: ByteBuffer) { + serializedTask: ByteBuffer): Unit = { val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName, serializedTask) runningTasks.put(taskId, tr) threadPool.execute(tr) } - def killTask(taskId: Long, interruptThread: Boolean) { + def killTask(taskId: Long, interruptThread: Boolean): Unit = { val tr = runningTasks.get(taskId) if (tr != null) { tr.kill(interruptThread) } } - def stop() { + def stop(): Unit = { env.metricsSystem.report() - env.actorSystem.stop(executorActor) - isStopped = true + env.rpcEnv.stop(executorEndpoint) + heartbeater.shutdown() + heartbeater.awaitTermination(10, TimeUnit.SECONDS) threadPool.shutdown() if (!isLocal) { env.stop() } } - private def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum + /** Returns the total amount of time this JVM process has spent in garbage collection. */ + private def computeTotalGcTime(): Long = { + ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum + } class TaskRunner( execBackend: ExecutorBackend, @@ -146,12 +157,19 @@ private[spark] class Executor( serializedTask: ByteBuffer) extends Runnable { + /** Whether this task has been killed. */ @volatile private var killed = false - @volatile var task: Task[Any] = _ - @volatile var attemptedTask: Option[Task[Any]] = None + + /** How much the JVM process has spent in GC when the task starts to run. */ @volatile var startGCTime: Long = _ - def kill(interruptThread: Boolean) { + /** + * The task to run. This will be set in run() by deserializing the task binary coming + * from the driver. Once it is set, it will never be changed. + */ + @volatile var task: Task[Any] = _ + + def kill(interruptThread: Boolean): Unit = { logInfo(s"Executor is trying to kill $taskName (TID $taskId)") killed = true if (task != null) { @@ -159,14 +177,14 @@ private[spark] class Executor( } } - override def run() { + override def run(): Unit = { val deserializeStartTime = System.currentTimeMillis() Thread.currentThread.setContextClassLoader(replClassLoader) val ser = env.closureSerializer.newInstance() logInfo(s"Running $taskName (TID $taskId)") execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) var taskStart: Long = 0 - startGCTime = gcTime + startGCTime = computeTotalGcTime() try { val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) @@ -183,7 +201,6 @@ private[spark] class Executor( throw new TaskKilledException } - attemptedTask = Some(task) logDebug("Task " + taskId + "'s epoch is " + task.epoch) env.mapOutputTracker.updateEpoch(task.epoch) @@ -203,20 +220,23 @@ private[spark] class Executor( val afterSerialization = System.currentTimeMillis() for (m <- task.metrics) { - m.setExecutorDeserializeTime(taskStart - deserializeStartTime) - m.setExecutorRunTime(taskFinish - taskStart) - m.setJvmGCTime(gcTime - startGCTime) + // Deserialization happens in two parts: first, we deserialize a Task object, which + // includes the Partition. Second, Task.run() deserializes the RDD and function to be run. + m.setExecutorDeserializeTime( + (taskStart - deserializeStartTime) + task.executorDeserializeTime) + // We need to subtract Task.run()'s deserialization time to avoid double-counting + m.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime) + m.setJvmGCTime(computeTotalGcTime() - startGCTime) m.setResultSerializationTime(afterSerialization - beforeSerialization) } val accumUpdates = Accumulators.values - val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull) val serializedDirectResult = ser.serialize(directResult) val resultSize = serializedDirectResult.limit // directSend = sending directly back to the driver - val serializedResult = { + val serializedResult: ByteBuffer = { if (maxResultSize > 0 && resultSize > maxResultSize) { logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " + s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " + @@ -238,37 +258,40 @@ private[spark] class Executor( execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) } catch { - case ffe: FetchFailedException => { + case ffe: FetchFailedException => val reason = ffe.toTaskEndReason execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) - } - case _: TaskKilledException | _: InterruptedException if task.killed => { + case _: TaskKilledException | _: InterruptedException if task.killed => logInfo(s"Executor killed $taskName (TID $taskId)") execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) - } - case t: Throwable => { + case cDE: CommitDeniedException => + val reason = cDE.toTaskEndReason + execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) + + case t: Throwable => // Attempt to exit cleanly by informing the driver of our failure. // If anything goes wrong (or this was a fatal exception), we will delegate to // the default uncaught exception handler, which will terminate the Executor. logError(s"Exception in $taskName (TID $taskId)", t) - val serviceTime = System.currentTimeMillis() - taskStart - val metrics = attemptedTask.flatMap(t => t.metrics) - for (m <- metrics) { - m.setExecutorRunTime(serviceTime) - m.setJvmGCTime(gcTime - startGCTime) + val metrics: Option[TaskMetrics] = Option(task).flatMap { task => + task.metrics.map { m => + m.setExecutorRunTime(System.currentTimeMillis() - taskStart) + m.setJvmGCTime(computeTotalGcTime() - startGCTime) + m + } } - val reason = new ExceptionFailure(t, metrics) - execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) + val taskEndReason = new ExceptionFailure(t, metrics) + execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(taskEndReason)) // Don't forcibly exit unless the exception was inherently fatal, to avoid // stopping other tasks unnecessarily. if (Utils.isFatalError(t)) { SparkUncaughtExceptionHandler.uncaughtException(t) } - } + } finally { // Release memory used by this thread for shuffles env.shuffleMemoryManager.releaseMemoryForThisThread() @@ -286,17 +309,23 @@ private[spark] class Executor( * created by the interpreter to the search path */ private def createClassLoader(): MutableURLClassLoader = { + // Bootstrap the list of jars with the user class path. + val now = System.currentTimeMillis() + userClassPath.foreach { url => + currentJars(url.getPath().split("/").last) = now + } + val currentLoader = Utils.getContextOrSparkClassLoader // For each of the jars in the jarSet, add them to the class loader. // We assume each of the files has already been fetched. - val urls = currentJars.keySet.map { uri => + val urls = userClassPath.toArray ++ currentJars.keySet.map { uri => new File(uri.split("/").last).toURI.toURL - }.toArray - val userClassPathFirst = conf.getBoolean("spark.files.userClassPathFirst", false) - userClassPathFirst match { - case true => new ChildExecutorURLClassLoader(urls, currentLoader) - case false => new ExecutorURLClassLoader(urls, currentLoader) + } + if (userClassPathFirst) { + new ChildFirstURLClassLoader(urls, currentLoader) + } else { + new MutableURLClassLoader(urls, currentLoader) } } @@ -308,14 +337,13 @@ private[spark] class Executor( val classUri = conf.get("spark.repl.class.uri", null) if (classUri != null) { logInfo("Using REPL class URI: " + classUri) - val userClassPathFirst: java.lang.Boolean = - conf.getBoolean("spark.files.userClassPathFirst", false) try { + val _userClassPathFirst: java.lang.Boolean = userClassPathFirst val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader") .asInstanceOf[Class[_ <: ClassLoader]] val constructor = klass.getConstructor(classOf[SparkConf], classOf[String], classOf[ClassLoader], classOf[Boolean]) - constructor.newInstance(conf, classUri, parent, userClassPathFirst) + constructor.newInstance(conf, classUri, parent, _userClassPathFirst) } catch { case _: ClassNotFoundException => logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!") @@ -338,82 +366,86 @@ private[spark] class Executor( for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) // Fetch file with useCache mode, close cache for local mode. - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf, env.securityManager, hadoopConf, timestamp, useCache = !isLocal) currentFiles(name) = timestamp } - for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - // Fetch file with useCache mode, close cache for local mode. - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, - env.securityManager, hadoopConf, timestamp, useCache = !isLocal) - currentJars(name) = timestamp - // Add it to our class loader + for ((name, timestamp) <- newJars) { val localName = name.split("/").last - val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL - if (!urlClassLoader.getURLs.contains(url)) { - logInfo("Adding " + url + " to class loader") - urlClassLoader.addURL(url) + val currentTimeStamp = currentJars.get(name) + .orElse(currentJars.get(localName)) + .getOrElse(-1L) + if (currentTimeStamp < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + // Fetch file with useCache mode, close cache for local mode. + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf, + env.securityManager, hadoopConf, timestamp, useCache = !isLocal) + currentJars(name) = timestamp + // Add it to our class loader + val url = new File(SparkFiles.getRootDirectory(), localName).toURI.toURL + if (!urlClassLoader.getURLs().contains(url)) { + logInfo("Adding " + url + " to class loader") + urlClassLoader.addURL(url) + } } } } } - def startDriverHeartbeater() { - val interval = conf.getInt("spark.executor.heartbeatInterval", 10000) - val timeout = AkkaUtils.lookupTimeout(conf) - val retryAttempts = AkkaUtils.numRetries(conf) - val retryIntervalMs = AkkaUtils.retryWaitMs(conf) - val heartbeatReceiverRef = AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem) - - val t = new Thread() { - override def run() { - // Sleep a random interval so the heartbeats don't end up in sync - Thread.sleep(interval + (math.random * interval).asInstanceOf[Int]) - - while (!isStopped) { - val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]() - val curGCTime = gcTime - - for (taskRunner <- runningTasks.values()) { - if (taskRunner.attemptedTask.nonEmpty) { - Option(taskRunner.task).flatMap(_.metrics).foreach { metrics => - metrics.updateShuffleReadMetrics() - metrics.updateInputMetrics() - metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime) - - if (isLocal) { - // JobProgressListener will hold an reference of it during - // onExecutorMetricsUpdate(), then JobProgressListener can not see - // the changes of metrics any more, so make a deep copy of it - val copiedMetrics = Utils.deserialize[TaskMetrics](Utils.serialize(metrics)) - tasksMetrics += ((taskRunner.taskId, copiedMetrics)) - } else { - // It will be copied by serialization - tasksMetrics += ((taskRunner.taskId, metrics)) - } - } - } - } - - val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) - try { - val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef, - retryAttempts, retryIntervalMs, timeout) - if (response.reregisterBlockManager) { - logWarning("Told to re-register on heartbeat") - env.blockManager.reregister() - } - } catch { - case NonFatal(t) => logWarning("Issue communicating with driver in heartbeater", t) + private val heartbeatReceiverRef = + RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv) + + /** Reports heartbeat and metrics for active tasks to the driver. */ + private def reportHeartBeat(): Unit = { + // list of (task id, metrics) to send back to the driver + val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]() + val curGCTime = computeTotalGcTime() + + for (taskRunner <- runningTasks.values()) { + if (taskRunner.task != null) { + taskRunner.task.metrics.foreach { metrics => + metrics.updateShuffleReadMetrics() + metrics.updateInputMetrics() + metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime) + + if (isLocal) { + // JobProgressListener will hold an reference of it during + // onExecutorMetricsUpdate(), then JobProgressListener can not see + // the changes of metrics any more, so make a deep copy of it + val copiedMetrics = Utils.deserialize[TaskMetrics](Utils.serialize(metrics)) + tasksMetrics += ((taskRunner.taskId, copiedMetrics)) + } else { + // It will be copied by serialization + tasksMetrics += ((taskRunner.taskId, metrics)) } - - Thread.sleep(interval) } } } - t.setDaemon(true) - t.setName("Driver Heartbeater") - t.start() + + val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) + try { + val response = heartbeatReceiverRef.askWithReply[HeartbeatResponse](message) + if (response.reregisterBlockManager) { + logWarning("Told to re-register on heartbeat") + env.blockManager.reregister() + } + } catch { + case NonFatal(e) => logWarning("Issue communicating with driver in heartbeater", e) + } + } + + /** + * Schedules a task to report heartbeat and partial metrics for active tasks to driver. + */ + private def startDriverHeartbeater(): Unit = { + val intervalMs = conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s") + + // Wait a random interval so the heartbeats don't end up in sync + val initialDelay = intervalMs + (math.random * intervalMs).asInstanceOf[Int] + + val heartbeatTask = new Runnable() { + override def run(): Unit = Utils.logUncaughtExceptions(reportHeartBeat()) + } + heartbeater.scheduleAtFixedRate(heartbeatTask, initialDelay, intervalMs, TimeUnit.MILLISECONDS) } } diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala similarity index 67% rename from core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala rename to core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala index 41925f7e97e8..cf362f846473 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala @@ -17,10 +17,8 @@ package org.apache.spark.executor -import akka.actor.Actor -import org.apache.spark.Logging - -import org.apache.spark.util.{Utils, ActorLogReceive} +import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} +import org.apache.spark.util.Utils /** * Driver -> Executor message to trigger a thread dump. @@ -28,14 +26,18 @@ import org.apache.spark.util.{Utils, ActorLogReceive} private[spark] case object TriggerThreadDump /** - * Actor that runs inside of executors to enable driver -> executor RPC. + * [[RpcEndpoint]] that runs inside of executors to enable driver -> executor RPC. */ private[spark] -class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging { +class ExecutorEndpoint(override val rpcEnv: RpcEnv, executorId: String) extends RpcEndpoint { - override def receiveWithLogging = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case TriggerThreadDump => - sender ! Utils.getThreadDump() + context.reply(Utils.getThreadDump()) } } + +object ExecutorEndpoint { + val EXECUTOR_ENDPOINT_NAME = "ExecutorEndpoint" +} diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala index c4d73622c472..293c512f8b70 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala @@ -17,6 +17,8 @@ package org.apache.spark.executor +import java.util.concurrent.ThreadPoolExecutor + import scala.collection.JavaConversions._ import com.codahale.metrics.{Gauge, MetricRegistry} @@ -24,9 +26,11 @@ import org.apache.hadoop.fs.FileSystem import org.apache.spark.metrics.source.Source -private[spark] class ExecutorSource(val executor: Executor, executorId: String) extends Source { +private[spark] +class ExecutorSource(threadPool: ThreadPoolExecutor, executorId: String) extends Source { + private def fileStats(scheme: String) : Option[FileSystem.Statistics] = - FileSystem.getAllStatistics().filter(s => s.getScheme.equals(scheme)).headOption + FileSystem.getAllStatistics().find(s => s.getScheme.equals(scheme)) private def registerFileSystemStat[T]( scheme: String, name: String, f: FileSystem.Statistics => T, defaultValue: T) = { @@ -41,23 +45,23 @@ private[spark] class ExecutorSource(val executor: Executor, executorId: String) // Gauge for executor thread pool's actively executing task counts metricRegistry.register(MetricRegistry.name("threadpool", "activeTasks"), new Gauge[Int] { - override def getValue: Int = executor.threadPool.getActiveCount() + override def getValue: Int = threadPool.getActiveCount() }) // Gauge for executor thread pool's approximate total number of tasks that have been completed metricRegistry.register(MetricRegistry.name("threadpool", "completeTasks"), new Gauge[Long] { - override def getValue: Long = executor.threadPool.getCompletedTaskCount() + override def getValue: Long = threadPool.getCompletedTaskCount() }) // Gauge for executor thread pool's current number of threads metricRegistry.register(MetricRegistry.name("threadpool", "currentPool_size"), new Gauge[Int] { - override def getValue: Int = executor.threadPool.getPoolSize() + override def getValue: Int = threadPool.getPoolSize() }) // Gauge got executor thread pool's largest number of threads that have ever simultaneously // been in th pool metricRegistry.register(MetricRegistry.name("threadpool", "maxPool_size"), new Gauge[Int] { - override def getValue: Int = executor.threadPool.getMaximumPoolSize() + override def getValue: Int = threadPool.getMaximumPoolSize() }) // Gauge for file system stats of this executor diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala deleted file mode 100644 index 218ed7b5d2d3..000000000000 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.executor - -import java.net.{URLClassLoader, URL} - -import org.apache.spark.util.ParentClassLoader - -/** - * The addURL method in URLClassLoader is protected. We subclass it to make this accessible. - * We also make changes so user classes can come before the default classes. - */ - -private[spark] trait MutableURLClassLoader extends ClassLoader { - def addURL(url: URL) - def getURLs: Array[URL] -} - -private[spark] class ChildExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader) - extends MutableURLClassLoader { - - private object userClassLoader extends URLClassLoader(urls, null){ - override def addURL(url: URL) { - super.addURL(url) - } - override def findClass(name: String): Class[_] = { - super.findClass(name) - } - } - - private val parentClassLoader = new ParentClassLoader(parent) - - override def findClass(name: String): Class[_] = { - try { - userClassLoader.findClass(name) - } catch { - case e: ClassNotFoundException => { - parentClassLoader.loadClass(name) - } - } - } - - def addURL(url: URL) { - userClassLoader.addURL(url) - } - - def getURLs() = { - userClassLoader.getURLs() - } -} - -private[spark] class ExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader) - extends URLClassLoader(urls, parent) with MutableURLClassLoader { - - override def addURL(url: URL) { - super.addURL(url) - } -} - diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index ddb5903bf687..06152f16ae61 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,14 +17,10 @@ package org.apache.spark.executor -import java.util.concurrent.atomic.AtomicLong - -import org.apache.spark.executor.DataReadMethod -import org.apache.spark.executor.DataReadMethod.DataReadMethod - import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.executor.DataReadMethod.DataReadMethod import org.apache.spark.storage.{BlockId, BlockStatus} /** @@ -45,14 +41,14 @@ class TaskMetrics extends Serializable { * Host's name the task runs on */ private var _hostname: String = _ - def hostname = _hostname + def hostname: String = _hostname private[spark] def setHostname(value: String) = _hostname = value /** * Time taken on the executor to deserialize this task */ private var _executorDeserializeTime: Long = _ - def executorDeserializeTime = _executorDeserializeTime + def executorDeserializeTime: Long = _executorDeserializeTime private[spark] def setExecutorDeserializeTime(value: Long) = _executorDeserializeTime = value @@ -60,14 +56,14 @@ class TaskMetrics extends Serializable { * Time the executor spends actually running the task (including fetching shuffle data) */ private var _executorRunTime: Long = _ - def executorRunTime = _executorRunTime + def executorRunTime: Long = _executorRunTime private[spark] def setExecutorRunTime(value: Long) = _executorRunTime = value /** * The number of bytes this task transmitted back to the driver as the TaskResult */ private var _resultSize: Long = _ - def resultSize = _resultSize + def resultSize: Long = _resultSize private[spark] def setResultSize(value: Long) = _resultSize = value @@ -75,31 +71,31 @@ class TaskMetrics extends Serializable { * Amount of time the JVM spent in garbage collection while executing this task */ private var _jvmGCTime: Long = _ - def jvmGCTime = _jvmGCTime + def jvmGCTime: Long = _jvmGCTime private[spark] def setJvmGCTime(value: Long) = _jvmGCTime = value /** * Amount of time spent serializing the task result */ private var _resultSerializationTime: Long = _ - def resultSerializationTime = _resultSerializationTime + def resultSerializationTime: Long = _resultSerializationTime private[spark] def setResultSerializationTime(value: Long) = _resultSerializationTime = value /** * The number of in-memory bytes spilled by this task */ private var _memoryBytesSpilled: Long = _ - def memoryBytesSpilled = _memoryBytesSpilled - private[spark] def incMemoryBytesSpilled(value: Long) = _memoryBytesSpilled += value - private[spark] def decMemoryBytesSpilled(value: Long) = _memoryBytesSpilled -= value + def memoryBytesSpilled: Long = _memoryBytesSpilled + private[spark] def incMemoryBytesSpilled(value: Long): Unit = _memoryBytesSpilled += value + private[spark] def decMemoryBytesSpilled(value: Long): Unit = _memoryBytesSpilled -= value /** * The number of on-disk bytes spilled by this task */ private var _diskBytesSpilled: Long = _ - def diskBytesSpilled = _diskBytesSpilled - def incDiskBytesSpilled(value: Long) = _diskBytesSpilled += value - def decDiskBytesSpilled(value: Long) = _diskBytesSpilled -= value + def diskBytesSpilled: Long = _diskBytesSpilled + def incDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled += value + def decDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled -= value /** * If this task reads from a HadoopRDD or from persisted data, metrics on how much data was read @@ -107,7 +103,7 @@ class TaskMetrics extends Serializable { */ private var _inputMetrics: Option[InputMetrics] = None - def inputMetrics = _inputMetrics + def inputMetrics: Option[InputMetrics] = _inputMetrics /** * This should only be used when recreating TaskMetrics, not when updating input metrics in @@ -129,7 +125,7 @@ class TaskMetrics extends Serializable { */ private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None - def shuffleReadMetrics = _shuffleReadMetrics + def shuffleReadMetrics: Option[ShuffleReadMetrics] = _shuffleReadMetrics /** * This should only be used when recreating TaskMetrics, not when updating read metrics in @@ -178,35 +174,40 @@ class TaskMetrics extends Serializable { * Once https://issues.apache.org/jira/browse/SPARK-5225 is addressed, * we can store all the different inputMetrics (one per readMethod). */ - private[spark] def getInputMetricsForReadMethod(readMethod: DataReadMethod): - InputMetrics =synchronized { - _inputMetrics match { - case None => - val metrics = new InputMetrics(readMethod) - _inputMetrics = Some(metrics) - metrics - case Some(metrics @ InputMetrics(method)) if method == readMethod => - metrics - case Some(InputMetrics(method)) => - new InputMetrics(readMethod) + private[spark] def getInputMetricsForReadMethod(readMethod: DataReadMethod): InputMetrics = { + synchronized { + _inputMetrics match { + case None => + val metrics = new InputMetrics(readMethod) + _inputMetrics = Some(metrics) + metrics + case Some(metrics @ InputMetrics(method)) if method == readMethod => + metrics + case Some(InputMetrics(method)) => + new InputMetrics(readMethod) + } } } /** * Aggregates shuffle read metrics for all registered dependencies into shuffleReadMetrics. */ - private[spark] def updateShuffleReadMetrics() = synchronized { - val merged = new ShuffleReadMetrics() - for (depMetrics <- depsShuffleReadMetrics) { - merged.incFetchWaitTime(depMetrics.fetchWaitTime) - merged.incLocalBlocksFetched(depMetrics.localBlocksFetched) - merged.incRemoteBlocksFetched(depMetrics.remoteBlocksFetched) - merged.incRemoteBytesRead(depMetrics.remoteBytesRead) + private[spark] def updateShuffleReadMetrics(): Unit = synchronized { + if (!depsShuffleReadMetrics.isEmpty) { + val merged = new ShuffleReadMetrics() + for (depMetrics <- depsShuffleReadMetrics) { + merged.incFetchWaitTime(depMetrics.fetchWaitTime) + merged.incLocalBlocksFetched(depMetrics.localBlocksFetched) + merged.incRemoteBlocksFetched(depMetrics.remoteBlocksFetched) + merged.incRemoteBytesRead(depMetrics.remoteBytesRead) + merged.incLocalBytesRead(depMetrics.localBytesRead) + merged.incRecordsRead(depMetrics.recordsRead) + } + _shuffleReadMetrics = Some(merged) } - _shuffleReadMetrics = Some(merged) } - private[spark] def updateInputMetrics() = synchronized { + private[spark] def updateInputMetrics(): Unit = synchronized { inputMetrics.foreach(_.updateBytesRead()) } } @@ -243,27 +244,31 @@ object DataWriteMethod extends Enumeration with Serializable { @DeveloperApi case class InputMetrics(readMethod: DataReadMethod.Value) { - private val _bytesRead: AtomicLong = new AtomicLong() + /** + * This is volatile so that it is visible to the updater thread. + */ + @volatile @transient var bytesReadCallback: Option[() => Long] = None /** * Total bytes read. */ - def bytesRead: Long = _bytesRead.get() - @volatile @transient var bytesReadCallback: Option[() => Long] = None + private var _bytesRead: Long = _ + def bytesRead: Long = _bytesRead + def incBytesRead(bytes: Long): Unit = _bytesRead += bytes /** - * Adds additional bytes read for this read method. + * Total records read. */ - def addBytesRead(bytes: Long) = { - _bytesRead.addAndGet(bytes) - } + private var _recordsRead: Long = _ + def recordsRead: Long = _recordsRead + def incRecordsRead(records: Long): Unit = _recordsRead += records /** * Invoke the bytesReadCallback and mutate bytesRead. */ def updateBytesRead() { bytesReadCallback.foreach { c => - _bytesRead.set(c()) + _bytesRead = c() } } @@ -286,8 +291,15 @@ case class OutputMetrics(writeMethod: DataWriteMethod.Value) { * Total bytes written */ private var _bytesWritten: Long = _ - def bytesWritten = _bytesWritten - private[spark] def setBytesWritten(value : Long) = _bytesWritten = value + def bytesWritten: Long = _bytesWritten + private[spark] def setBytesWritten(value : Long): Unit = _bytesWritten = value + + /** + * Total records written + */ + private var _recordsWritten: Long = 0L + def recordsWritten: Long = _recordsWritten + private[spark] def setRecordsWritten(value: Long): Unit = _recordsWritten = value } /** @@ -300,18 +312,17 @@ class ShuffleReadMetrics extends Serializable { * Number of remote blocks fetched in this shuffle by this task */ private var _remoteBlocksFetched: Int = _ - def remoteBlocksFetched = _remoteBlocksFetched + def remoteBlocksFetched: Int = _remoteBlocksFetched private[spark] def incRemoteBlocksFetched(value: Int) = _remoteBlocksFetched += value - private[spark] def defRemoteBlocksFetched(value: Int) = _remoteBlocksFetched -= value + private[spark] def decRemoteBlocksFetched(value: Int) = _remoteBlocksFetched -= value /** * Number of local blocks fetched in this shuffle by this task */ private var _localBlocksFetched: Int = _ - def localBlocksFetched = _localBlocksFetched + def localBlocksFetched: Int = _localBlocksFetched private[spark] def incLocalBlocksFetched(value: Int) = _localBlocksFetched += value - private[spark] def defLocalBlocksFetched(value: Int) = _localBlocksFetched -= value - + private[spark] def decLocalBlocksFetched(value: Int) = _localBlocksFetched -= value /** * Time the task spent waiting for remote shuffle blocks. This only includes the time @@ -319,7 +330,7 @@ class ShuffleReadMetrics extends Serializable { * still not finished processing block A, it is not considered to be blocking on block B. */ private var _fetchWaitTime: Long = _ - def fetchWaitTime = _fetchWaitTime + def fetchWaitTime: Long = _fetchWaitTime private[spark] def incFetchWaitTime(value: Long) = _fetchWaitTime += value private[spark] def decFetchWaitTime(value: Long) = _fetchWaitTime -= value @@ -327,14 +338,34 @@ class ShuffleReadMetrics extends Serializable { * Total number of remote bytes read from the shuffle by this task */ private var _remoteBytesRead: Long = _ - def remoteBytesRead = _remoteBytesRead + def remoteBytesRead: Long = _remoteBytesRead private[spark] def incRemoteBytesRead(value: Long) = _remoteBytesRead += value private[spark] def decRemoteBytesRead(value: Long) = _remoteBytesRead -= value + /** + * Shuffle data that was read from the local disk (as opposed to from a remote executor). + */ + private var _localBytesRead: Long = _ + def localBytesRead: Long = _localBytesRead + private[spark] def incLocalBytesRead(value: Long) = _localBytesRead += value + + /** + * Total bytes fetched in the shuffle by this task (both remote and local). + */ + def totalBytesRead: Long = _remoteBytesRead + _localBytesRead + /** * Number of blocks fetched in this shuffle by this task (remote or local) */ - def totalBlocksFetched = _remoteBlocksFetched + _localBlocksFetched + def totalBlocksFetched: Int = _remoteBlocksFetched + _localBlocksFetched + + /** + * Total number of records read from the shuffle by this task + */ + private var _recordsRead: Long = _ + def recordsRead: Long = _recordsRead + private[spark] def incRecordsRead(value: Long) = _recordsRead += value + private[spark] def decRecordsRead(value: Long) = _recordsRead -= value } /** @@ -347,7 +378,7 @@ class ShuffleWriteMetrics extends Serializable { * Number of bytes written for the shuffle by this task */ @volatile private var _shuffleBytesWritten: Long = _ - def shuffleBytesWritten = _shuffleBytesWritten + def shuffleBytesWritten: Long = _shuffleBytesWritten private[spark] def incShuffleBytesWritten(value: Long) = _shuffleBytesWritten += value private[spark] def decShuffleBytesWritten(value: Long) = _shuffleBytesWritten -= value @@ -355,9 +386,16 @@ class ShuffleWriteMetrics extends Serializable { * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds */ @volatile private var _shuffleWriteTime: Long = _ - def shuffleWriteTime= _shuffleWriteTime + def shuffleWriteTime: Long = _shuffleWriteTime private[spark] def incShuffleWriteTime(value: Long) = _shuffleWriteTime += value private[spark] def decShuffleWriteTime(value: Long) = _shuffleWriteTime -= value - + /** + * Total number of records written to the shuffle by this task + */ + @volatile private var _shuffleRecordsWritten: Long = _ + def shuffleRecordsWritten: Long = _shuffleRecordsWritten + private[spark] def incShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten += value + private[spark] def decShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten -= value + private[spark] def setShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten = value } diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index 593a62b3e3b3..6cda7772f77b 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -73,16 +73,16 @@ private[spark] abstract class StreamBasedRecordReader[T]( private var key = "" private var value: T = null.asInstanceOf[T] - override def initialize(split: InputSplit, context: TaskAttemptContext) = {} - override def close() = {} + override def initialize(split: InputSplit, context: TaskAttemptContext): Unit = {} + override def close(): Unit = {} - override def getProgress = if (processed) 1.0f else 0.0f + override def getProgress: Float = if (processed) 1.0f else 0.0f - override def getCurrentKey = key + override def getCurrentKey: String = key - override def getCurrentValue = value + override def getCurrentValue: T = value - override def nextKeyValue = { + override def nextKeyValue: Boolean = { if (!processed) { val fileIn = new PortableDataStream(split, context, index) value = parseStream(fileIn) @@ -119,7 +119,8 @@ private[spark] class StreamRecordReader( * The format for the PortableDataStream files */ private[spark] class StreamInputFormat extends StreamFileInputFormat[PortableDataStream] { - override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext) = { + override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext) + : CombineFileRecordReader[String, PortableDataStream] = { new CombineFileRecordReader[String, PortableDataStream]( split.asInstanceOf[CombineFileSplit], taContext, classOf[StreamRecordReader]) } @@ -204,7 +205,7 @@ class PortableDataStream( /** * Close the file (if it is currently open) */ - def close() = { + def close(): Unit = { if (isOpen) { try { fileIn.close() diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index f856890d279f..0709b6d689e8 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -26,7 +26,6 @@ import org.xerial.snappy.{Snappy, SnappyInputStream, SnappyOutputStream} import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.Utils -import org.apache.spark.Logging /** * :: DeveloperApi :: @@ -53,8 +52,12 @@ private[spark] object CompressionCodec { "lzf" -> classOf[LZFCompressionCodec].getName, "snappy" -> classOf[SnappyCompressionCodec].getName) + def getCodecName(conf: SparkConf): String = { + conf.get(configKey, DEFAULT_COMPRESSION_CODEC) + } + def createCodec(conf: SparkConf): CompressionCodec = { - createCodec(conf, conf.get(configKey, DEFAULT_COMPRESSION_CODEC)) + createCodec(conf, getCodecName(conf)) } def createCodec(conf: SparkConf, codecName: String): CompressionCodec = { @@ -71,6 +74,20 @@ private[spark] object CompressionCodec { s"Consider setting $configKey=$FALLBACK_COMPRESSION_CODEC")) } + /** + * Return the short version of the given codec name. + * If it is already a short name, just return it. + */ + def getShortName(codecName: String): String = { + if (shortCompressionCodecNames.contains(codecName)) { + codecName + } else { + shortCompressionCodecNames + .collectFirst { case (k, v) if v == codecName => k } + .getOrElse { throw new IllegalArgumentException(s"No short name for codec $codecName.") } + } + } + val FALLBACK_COMPRESSION_CODEC = "lzf" val DEFAULT_COMPRESSION_CODEC = "snappy" val ALL_COMPRESSION_CODECS = shortCompressionCodecNames.values.toSeq diff --git a/core/src/main/scala/org/apache/spark/launcher/SparkSubmitArgumentsParser.scala b/core/src/main/scala/org/apache/spark/launcher/SparkSubmitArgumentsParser.scala new file mode 100644 index 000000000000..a83501253105 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/launcher/SparkSubmitArgumentsParser.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher + +/** + * This class makes SparkSubmitOptionParser visible for Spark code outside of the `launcher` + * package, since Java doesn't have a feature similar to `private[spark]`, and we don't want + * that class to be public. + */ +private[spark] abstract class SparkSubmitArgumentsParser extends SparkSubmitOptionParser diff --git a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala new file mode 100644 index 000000000000..9be98723aed1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher + +import java.io.File +import java.util.{HashMap => JHashMap, List => JList, Map => JMap} + +import scala.collection.JavaConversions._ + +import org.apache.spark.deploy.Command + +/** + * This class is used by CommandUtils. It uses some package-private APIs in SparkLauncher, and since + * Java doesn't have a feature similar to `private[spark]`, and we don't want that class to be + * public, needs to live in the same package as the rest of the library. + */ +private[spark] class WorkerCommandBuilder(sparkHome: String, memoryMb: Int, command: Command) + extends AbstractCommandBuilder { + + childEnv.putAll(command.environment) + childEnv.put(CommandBuilderUtils.ENV_SPARK_HOME, sparkHome) + + override def buildCommand(env: JMap[String, String]): JList[String] = { + val cmd = buildJavaCommand(command.classPathEntries.mkString(File.pathSeparator)) + cmd.add(s"-Xms${memoryMb}M") + cmd.add(s"-Xmx${memoryMb}M") + command.javaOpts.foreach(cmd.add) + addPermGenSizeOpt(cmd) + addOptionString(cmd, getenv("SPARK_JAVA_OPTS")) + cmd + } + + def buildCommand(): JList[String] = buildCommand(new JHashMap[String, String]()) + +} diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 21b782edd2a9..818f7a4c8d42 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -17,9 +17,15 @@ package org.apache.spark.mapred +import java.io.IOException import java.lang.reflect.Modifier -import org.apache.hadoop.mapred.{TaskAttemptID, JobID, JobConf, JobContext, TaskAttemptContext} +import org.apache.hadoop.mapred._ +import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext} +import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter} + +import org.apache.spark.executor.CommitDeniedException +import org.apache.spark.{Logging, SparkEnv, TaskContext} private[spark] trait SparkHadoopMapRedUtil { @@ -52,7 +58,7 @@ trait SparkHadoopMapRedUtil { jobId: Int, isMap: Boolean, taskId: Int, - attemptId: Int) = { + attemptId: Int): TaskAttemptID = { new TaskAttemptID(jtIdentifier, jobId, isMap, taskId, attemptId) } @@ -65,3 +71,86 @@ trait SparkHadoopMapRedUtil { } } } + +object SparkHadoopMapRedUtil extends Logging { + /** + * Commits a task output. Before committing the task output, we need to know whether some other + * task attempt might be racing to commit the same output partition. Therefore, coordinate with + * the driver in order to determine whether this attempt can commit (please see SPARK-4879 for + * details). + * + * Output commit coordinator is only contacted when the following two configurations are both set + * to `true`: + * + * - `spark.speculation` + * - `spark.hadoop.outputCommitCoordination.enabled` + */ + def commitTask( + committer: MapReduceOutputCommitter, + mrTaskContext: MapReduceTaskAttemptContext, + jobId: Int, + splitId: Int, + attemptId: Int): Unit = { + + val mrTaskAttemptID = mrTaskContext.getTaskAttemptID + + // Called after we have decided to commit + def performCommit(): Unit = { + try { + committer.commitTask(mrTaskContext) + logInfo(s"$mrTaskAttemptID: Committed") + } catch { + case cause: IOException => + logError(s"Error committing the output of task: $mrTaskAttemptID", cause) + committer.abortTask(mrTaskContext) + throw cause + } + } + + // First, check whether the task's output has already been committed by some other attempt + if (committer.needsTaskCommit(mrTaskContext)) { + val shouldCoordinateWithDriver: Boolean = { + val sparkConf = SparkEnv.get.conf + // We only need to coordinate with the driver if there are multiple concurrent task + // attempts, which should only occur if speculation is enabled + val speculationEnabled = sparkConf.getBoolean("spark.speculation", defaultValue = false) + // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs + sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", speculationEnabled) + } + + if (shouldCoordinateWithDriver) { + val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator + val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, attemptId) + + if (canCommit) { + performCommit() + } else { + val message = + s"$mrTaskAttemptID: Not committed because the driver did not authorize commit" + logInfo(message) + // We need to abort the task so that the driver can reschedule new attempts, if necessary + committer.abortTask(mrTaskContext) + throw new CommitDeniedException(message, jobId, splitId, attemptId) + } + } else { + // Speculation is disabled or a user has chosen to manually bypass the commit coordination + performCommit() + } + } else { + // Some other attempt committed the output, so we do nothing and signal success + logInfo(s"No need to commit output of task because needsTaskCommit=false: $mrTaskAttemptID") + } + } + + def commitTask( + committer: MapReduceOutputCommitter, + mrTaskContext: MapReduceTaskAttemptContext, + sparkTaskContext: TaskContext): Unit = { + commitTask( + committer, + mrTaskContext, + sparkTaskContext.stageId(), + sparkTaskContext.partitionId(), + sparkTaskContext.attemptNumber()) + } +} diff --git a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala index 3340673f9115..cfd20392d12f 100644 --- a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala @@ -45,7 +45,7 @@ trait SparkHadoopMapReduceUtil { jobId: Int, isMap: Boolean, taskId: Int, - attemptId: Int) = { + attemptId: Int): TaskAttemptID = { val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID") try { // First, attempt to use the old-style constructor that takes a boolean isMap diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala index 1b7a5d1f1980..8edf49378068 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala @@ -28,12 +28,12 @@ import org.apache.spark.util.Utils private[spark] class MetricsConfig(val configFile: Option[String]) extends Logging { - val DEFAULT_PREFIX = "*" - val INSTANCE_REGEX = "^(\\*|[a-zA-Z]+)\\.(.+)".r - val METRICS_CONF = "metrics.properties" + private val DEFAULT_PREFIX = "*" + private val INSTANCE_REGEX = "^(\\*|[a-zA-Z]+)\\.(.+)".r + private val DEFAULT_METRICS_CONF_FILENAME = "metrics.properties" - val properties = new Properties() - var propertyCategories: mutable.HashMap[String, Properties] = null + private[metrics] val properties = new Properties() + private[metrics] var propertyCategories: mutable.HashMap[String, Properties] = null private def setDefaultProperties(prop: Properties) { prop.setProperty("*.sink.servlet.class", "org.apache.spark.metrics.sink.MetricsServlet") @@ -47,20 +47,22 @@ private[spark] class MetricsConfig(val configFile: Option[String]) extends Loggi setDefaultProperties(properties) // If spark.metrics.conf is not set, try to get file in class path - var is: InputStream = null - try { - is = configFile match { - case Some(f) => new FileInputStream(f) - case None => Utils.getSparkClassLoader.getResourceAsStream(METRICS_CONF) + val isOpt: Option[InputStream] = configFile.map(new FileInputStream(_)).orElse { + try { + Option(Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_METRICS_CONF_FILENAME)) + } catch { + case e: Exception => + logError("Error loading default configuration file", e) + None } + } - if (is != null) { + isOpt.foreach { is => + try { properties.load(is) + } finally { + is.close() } - } catch { - case e: Exception => logError("Error loading configure file", e) - } finally { - if (is != null) is.close() } propertyCategories = subProperties(properties, INSTANCE_REGEX) diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 45633e3de01d..9150ad35712a 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -23,6 +23,7 @@ import java.util.concurrent.TimeUnit import scala.collection.mutable import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry} +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.metrics.sink.{MetricsServlet, Sink} @@ -84,7 +85,7 @@ private[spark] class MetricsSystem private ( /** * Get any UI handlers used by this metrics system; can only be called after start(). */ - def getServletHandlers = { + def getServletHandlers: Array[ServletContextHandler] = { require(running, "Can only call getServletHandlers on a running MetricsSystem") metricsServlet.map(_.getHandlers).getOrElse(Array()) } @@ -130,8 +131,8 @@ private[spark] class MetricsSystem private ( if (appId.isDefined && executorId.isDefined) { MetricRegistry.name(appId.get, executorId.get, source.sourceName) } else { - // Only Driver and Executor are set spark.app.id and spark.executor.id. - // For instance, Master and Worker are not related to a specific application. + // Only Driver and Executor set spark.app.id and spark.executor.id. + // Other instance types, e.g. Master and Worker, are not related to a specific application. val warningMsg = s"Using default name $defaultName for source because %s is not set." if (appId.isEmpty) { logWarning(warningMsg.format("spark.app.id")) } if (executorId.isEmpty) { logWarning(warningMsg.format("spark.executor.id")) } @@ -191,7 +192,10 @@ private[spark] class MetricsSystem private ( sinks += sink.asInstanceOf[Sink] } } catch { - case e: Exception => logError("Sink class " + classPath + " cannot be instantialized", e) + case e: Exception => { + logError("Sink class " + classPath + " cannot be instantialized") + throw e + } } } } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index d7b5f5c40efa..2d25ebd66159 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -22,7 +22,7 @@ import java.util.Properties import java.util.concurrent.TimeUnit import com.codahale.metrics.MetricRegistry -import com.codahale.metrics.graphite.{Graphite, GraphiteReporter} +import com.codahale.metrics.graphite.{GraphiteUDP, Graphite, GraphiteReporter} import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem @@ -38,6 +38,7 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric val GRAPHITE_KEY_PERIOD = "period" val GRAPHITE_KEY_UNIT = "unit" val GRAPHITE_KEY_PREFIX = "prefix" + val GRAPHITE_KEY_PROTOCOL = "protocol" def propertyToOption(prop: String): Option[String] = Option(property.getProperty(prop)) @@ -66,7 +67,11 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) - val graphite: Graphite = new Graphite(new InetSocketAddress(host, port)) + val graphite = propertyToOption(GRAPHITE_KEY_PROTOCOL).map(_.toLowerCase) match { + case Some("udp") => new GraphiteUDP(new InetSocketAddress(host, port)) + case Some("tcp") | None => new Graphite(new InetSocketAddress(host, port)) + case Some(p) => throw new Exception(s"Invalid Graphite protocol: $p") + } val reporter: GraphiteReporter = GraphiteReporter.forRegistry(registry) .convertDurationsTo(TimeUnit.MILLISECONDS) diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 2f65bc8b4660..0c2e212a3307 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -30,8 +30,12 @@ import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.SecurityManager import org.apache.spark.ui.JettyUtils._ -private[spark] class MetricsServlet(val property: Properties, val registry: MetricRegistry, - securityMgr: SecurityManager) extends Sink { +private[spark] class MetricsServlet( + val property: Properties, + val registry: MetricRegistry, + securityMgr: SecurityManager) + extends Sink { + val SERVLET_KEY_PATH = "path" val SERVLET_KEY_SAMPLE = "sample" @@ -45,10 +49,12 @@ private[spark] class MetricsServlet(val property: Properties, val registry: Metr val mapper = new ObjectMapper().registerModule( new MetricsModule(TimeUnit.SECONDS, TimeUnit.MILLISECONDS, servletShowSample)) - def getHandlers = Array[ServletContextHandler]( - createServletHandler(servletPath, - new ServletParams(request => getMetricsSnapshot(request), "text/json"), securityMgr) - ) + def getHandlers: Array[ServletContextHandler] = { + Array[ServletContextHandler]( + createServletHandler(servletPath, + new ServletParams(request => getMetricsSnapshot(request), "text/json"), securityMgr) + ) + } def getMetricsSnapshot(request: HttpServletRequest): String = { mapper.writeValueAsString(registry) diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala index 0d83d8c425ca..9fad4e7deacb 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala @@ -18,7 +18,7 @@ package org.apache.spark.metrics.sink private[spark] trait Sink { - def start: Unit - def stop: Unit + def start(): Unit + def stop(): Unit def report(): Unit } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala new file mode 100644 index 000000000000..e8b3074e8f1a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics.sink + +import java.util.Properties +import java.util.concurrent.TimeUnit + +import com.codahale.metrics.{Slf4jReporter, MetricRegistry} + +import org.apache.spark.SecurityManager +import org.apache.spark.metrics.MetricsSystem + +private[spark] class Slf4jSink( + val property: Properties, + val registry: MetricRegistry, + securityMgr: SecurityManager) + extends Sink { + val SLF4J_DEFAULT_PERIOD = 10 + val SLF4J_DEFAULT_UNIT = "SECONDS" + + val SLF4J_KEY_PERIOD = "period" + val SLF4J_KEY_UNIT = "unit" + + val pollPeriod = Option(property.getProperty(SLF4J_KEY_PERIOD)) match { + case Some(s) => s.toInt + case None => SLF4J_DEFAULT_PERIOD + } + + val pollUnit: TimeUnit = Option(property.getProperty(SLF4J_KEY_UNIT)) match { + case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case None => TimeUnit.valueOf(SLF4J_DEFAULT_UNIT) + } + + MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) + + val reporter: Slf4jReporter = Slf4jReporter.forRegistry(registry) + .convertDurationsTo(TimeUnit.MILLISECONDS) + .convertRatesTo(TimeUnit.SECONDS) + .build() + + override def start() { + reporter.start(pollPeriod, pollUnit) + } + + override def stop() { + reporter.stop() + } + + override def report() { + reporter.report() + } +} + diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala index a1a2c00ed154..1ba25aa74aa0 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala @@ -32,11 +32,11 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) def this() = this(null.asInstanceOf[Seq[BlockMessage]]) - def apply(i: Int) = blockMessages(i) + def apply(i: Int): BlockMessage = blockMessages(i) - def iterator = blockMessages.iterator + def iterator: Iterator[BlockMessage] = blockMessages.iterator - def length = blockMessages.length + def length: Int = blockMessages.length def set(bufferMessage: BufferMessage) { val startTime = System.currentTimeMillis diff --git a/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala index 3b245c5c7a4f..9a9e22b0c236 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala @@ -31,9 +31,9 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: val initialSize = currentSize() var gotChunkForSendingOnce = false - def size = initialSize + def size: Int = initialSize - def currentSize() = { + def currentSize(): Int = { if (buffers == null || buffers.isEmpty) { 0 } else { @@ -100,11 +100,11 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: buffers.foreach(_.flip) } - def hasAckId() = (ackId != 0) + def hasAckId(): Boolean = ackId != 0 - def isCompletelyReceived() = !buffers(0).hasRemaining + def isCompletelyReceived: Boolean = !buffers(0).hasRemaining - override def toString = { + override def toString: String = { if (hasAckId) { "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")" } else { diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index c2d9578be7eb..6b898bd4bfc1 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -101,9 +101,11 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, socketRemoteConnectionManagerId } - def key() = channel.keyFor(selector) + def key(): SelectionKey = channel.keyFor(selector) - def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] + def getRemoteAddress(): InetSocketAddress = { + channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] + } // Returns whether we have to register for further reads or not. def read(): Boolean = { @@ -179,7 +181,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, buffer.get(bytes) bytes.foreach(x => print(x + " ")) buffer.position(curPosition) - print(" (" + bytes.size + ")") + print(" (" + bytes.length + ")") } def printBuffer(buffer: ByteBuffer, position: Int, length: Int) { @@ -280,7 +282,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, /* channel.socket.setSendBufferSize(256 * 1024) */ - override def getRemoteAddress() = address + override def getRemoteAddress(): InetSocketAddress = address val DEFAULT_INTEREST = SelectionKey.OP_READ diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala index 764dc5e5503e..b3b281ff465f 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala @@ -18,7 +18,9 @@ package org.apache.spark.network.nio private[nio] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) { - override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId + override def toString: String = { + connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId + } } private[nio] object ConnectionId { diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 03c4137ca0a8..16e905982cf6 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -36,7 +36,7 @@ import io.netty.util.{Timeout, TimerTask, HashedWheelTimer} import org.apache.spark._ import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} import scala.util.Try import scala.util.control.NonFatal @@ -79,17 +79,18 @@ private[nio] class ConnectionManager( private val selector = SelectorProvider.provider.openSelector() private val ackTimeoutMonitor = - new HashedWheelTimer(Utils.namedThreadFactory("AckTimeoutMonitor")) + new HashedWheelTimer(ThreadUtils.namedThreadFactory("AckTimeoutMonitor")) private val ackTimeout = - conf.getInt("spark.core.connection.ack.wait.timeout", conf.getInt("spark.network.timeout", 120)) + conf.getTimeAsSeconds("spark.core.connection.ack.wait.timeout", + conf.get("spark.network.timeout", "120s")) // Get the thread counts from the Spark Configuration. - // + // // Even though the ThreadPoolExecutor constructor takes both a minimum and maximum value, // we only query for the minimum value because we are using LinkedBlockingDeque. - // - // The JavaDoc for ThreadPoolExecutor points out that when using a LinkedBlockingDeque (which is + // + // The JavaDoc for ThreadPoolExecutor points out that when using a LinkedBlockingDeque (which is // an unbounded queue) no more than corePoolSize threads will ever be created, so only the "min" // parameter is necessary. private val handlerThreadCount = conf.getInt("spark.core.connection.handler.threads.min", 20) @@ -101,7 +102,7 @@ private[nio] class ConnectionManager( handlerThreadCount, conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-message-executor")) { + ThreadUtils.namedThreadFactory("handle-message-executor")) { override def afterExecute(r: Runnable, t: Throwable): Unit = { super.afterExecute(r, t) @@ -116,7 +117,7 @@ private[nio] class ConnectionManager( ioThreadCount, conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-read-write-executor")) { + ThreadUtils.namedThreadFactory("handle-read-write-executor")) { override def afterExecute(r: Runnable, t: Throwable): Unit = { super.afterExecute(r, t) @@ -133,7 +134,7 @@ private[nio] class ConnectionManager( connectThreadCount, conf.getInt("spark.core.connection.connect.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-connect-executor")) { + ThreadUtils.namedThreadFactory("handle-connect-executor")) { override def afterExecute(r: Runnable, t: Throwable): Unit = { super.afterExecute(r, t) @@ -159,7 +160,7 @@ private[nio] class ConnectionManager( private val registerRequests = new SynchronizedQueue[SendingConnection] implicit val futureExecContext = ExecutionContext.fromExecutor( - Utils.newDaemonCachedThreadPool("Connection manager future execution context")) + ThreadUtils.newDaemonCachedThreadPool("Connection manager future execution context")) @volatile private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message] = null @@ -184,14 +185,17 @@ private[nio] class ConnectionManager( // to be able to track asynchronous messages private val idCount: AtomicInteger = new AtomicInteger(1) + private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() + private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() + + @volatile private var isActive = true private val selectorThread = new Thread("connection-manager-thread") { - override def run() = ConnectionManager.this.run() + override def run(): Unit = ConnectionManager.this.run() } selectorThread.setDaemon(true) + // start this thread last, since it invokes run(), which accesses members above selectorThread.start() - private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() - private def triggerWrite(key: SelectionKey) { val conn = connectionsByKey.getOrElse(key, null) if (conn == null) return @@ -232,7 +236,6 @@ private[nio] class ConnectionManager( } ) } - private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() private def triggerRead(key: SelectionKey) { val conn = connectionsByKey.getOrElse(key, null) @@ -340,7 +343,7 @@ private[nio] class ConnectionManager( def run() { try { - while(!selectorThread.isInterrupted) { + while (isActive) { while (!registerRequests.isEmpty) { val conn: SendingConnection = registerRequests.dequeue() addListeners(conn) @@ -396,7 +399,7 @@ private[nio] class ConnectionManager( } catch { // Explicitly only dealing with CancelledKeyException here since other exceptions // should be dealt with differently. - case e: CancelledKeyException => { + case e: CancelledKeyException => // Some keys within the selectors list are invalid/closed. clear them. val allKeys = selector.keys().iterator() @@ -418,8 +421,11 @@ private[nio] class ConnectionManager( } } } - } - 0 + 0 + + case e: ClosedSelectorException => + logDebug("Failed select() as selector is closed.", e) + return } if (selectedKeysCount == 0) { @@ -986,10 +992,11 @@ private[nio] class ConnectionManager( } def stop() { + isActive = false ackTimeoutMonitor.stop() + selector.close() selectorThread.interrupt() selectorThread.join() - selector.close() val connections = connectionsByKey.values connections.foreach(_.close()) if (connectionsByKey.size != 0) { diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala index cbb37ec5ced1..1cd13d887c6f 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala @@ -26,7 +26,7 @@ private[nio] case class ConnectionManagerId(host: String, port: Int) { Utils.checkHost(host) assert (port > 0) - def toSocketAddress() = new InetSocketAddress(host, port) + def toSocketAddress(): InetSocketAddress = new InetSocketAddress(host, port) } diff --git a/core/src/main/scala/org/apache/spark/network/nio/Message.scala b/core/src/main/scala/org/apache/spark/network/nio/Message.scala index fb4a979b824c..85d2fe2bf9c2 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Message.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Message.scala @@ -42,7 +42,9 @@ private[nio] abstract class Message(val typ: Long, val id: Int) { def timeTaken(): String = (finishTime - startTime).toString + " ms" - override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" + override def toString: String = { + this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" + } } @@ -51,7 +53,7 @@ private[nio] object Message { var lastId = 1 - def getNewId() = synchronized { + def getNewId(): Int = synchronized { lastId += 1 if (lastId == 0) { lastId += 1 diff --git a/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala b/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala index 278c5ac356ef..a4568e849fa1 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala @@ -24,9 +24,9 @@ import scala.collection.mutable.ArrayBuffer private[nio] class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { - val size = if (buffer == null) 0 else buffer.remaining + val size: Int = if (buffer == null) 0 else buffer.remaining - lazy val buffers = { + lazy val buffers: ArrayBuffer[ByteBuffer] = { val ab = new ArrayBuffer[ByteBuffer]() ab += header.buffer if (buffer != null) { @@ -35,7 +35,7 @@ class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { ab } - override def toString = { + override def toString: String = { "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")" } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala index 6e20f291c5ce..7b3da4bb9d5e 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala @@ -50,8 +50,10 @@ private[nio] class MessageChunkHeader( flip.asInstanceOf[ByteBuffer] } - override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + + override def toString: String = { + "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + " and sizes " + totalSize + " / " + chunkSize + " bytes, securityNeg: " + securityNeg + } } diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 5ad73c3d27f4..2ab41ba488ff 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -27,8 +27,7 @@ package org.apache * contains operations available only on RDDs of Doubles; and * [[org.apache.spark.rdd.SequenceFileRDDFunctions]] contains operations available on RDDs that can * be saved as SequenceFiles. These operations are automatically available on any RDD of the right - * type (e.g. RDD[(Int, Int)] through implicit conversions except `saveAsSequenceFile`. You need to - * `import org.apache.spark.SparkContext._` to make `saveAsSequenceFile` work. + * type (e.g. RDD[(Int, Int)] through implicit conversions. * * Java programmers should reference the [[org.apache.spark.api.java]] package * for Spark programming APIs in Java. @@ -44,5 +43,5 @@ package org.apache package object spark { // For package docs only - val SPARK_VERSION = "1.3.0-SNAPSHOT" + val SPARK_VERSION = "1.4.0-SNAPSHOT" } diff --git a/core/src/main/scala/org/apache/spark/partial/PartialResult.scala b/core/src/main/scala/org/apache/spark/partial/PartialResult.scala index cadd0c7ed19b..53c4b32c95ab 100644 --- a/core/src/main/scala/org/apache/spark/partial/PartialResult.scala +++ b/core/src/main/scala/org/apache/spark/partial/PartialResult.scala @@ -99,7 +99,7 @@ class PartialResult[R](initialVal: R, isFinal: Boolean) { case None => "(partial: " + initialValue + ")" } } - def getFinalValueInternal() = PartialResult.this.getFinalValueInternal().map(f) + def getFinalValueInternal(): Option[T] = PartialResult.this.getFinalValueInternal().map(f) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 646df283ac06..3406a7e97e36 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -45,7 +45,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi } result }, - Range(0, self.partitions.size), + Range(0, self.partitions.length), (index: Int, data: Long) => totalCount.addAndGet(data), totalCount.get()) } @@ -54,8 +54,8 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi * Returns a future for retrieving all elements of this RDD. */ def collectAsync(): FutureAction[Seq[T]] = { - val results = new Array[Array[T]](self.partitions.size) - self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size), + val results = new Array[Array[T]](self.partitions.length) + self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.length), (index, data) => results(index) = data, results.flatten.toSeq) } @@ -111,7 +111,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi */ def foreachAsync(f: T => Unit): FutureAction[Unit] = { val cleanF = self.context.clean(f) - self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.size), + self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.length), (index, data) => Unit, Unit) } @@ -119,7 +119,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi * Applies a function f to each partition of this RDD. */ def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = { - self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.size), + self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.length), (index, data) => Unit, Unit) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index fffa1911f5bc..71578d1210fd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -36,7 +36,7 @@ class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds override def getPartitions: Array[Partition] = { assertValid() - (0 until blockIds.size).map(i => { + (0 until blockIds.length).map(i => { new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition] }).toArray } diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala index 1cbd684224b7..c1d697178757 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -53,11 +53,11 @@ class CartesianRDD[T: ClassTag, U: ClassTag]( extends RDD[Pair[T, U]](sc, Nil) with Serializable { - val numPartitionsInRdd2 = rdd2.partitions.size + val numPartitionsInRdd2 = rdd2.partitions.length override def getPartitions: Array[Partition] = { // create the cross product split - val array = new Array[Partition](rdd1.partitions.size * rdd2.partitions.size) + val array = new Array[Partition](rdd1.partitions.length * rdd2.partitions.length) for (s1 <- rdd1.partitions; s2 <- rdd2.partitions) { val idx = s1.index * numPartitionsInRdd2 + s2.index array(idx) = new CartesianPartition(idx, rdd1, rdd2, s1.index, s2.index) @@ -70,7 +70,7 @@ class CartesianRDD[T: ClassTag, U: ClassTag]( (rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)).distinct } - override def compute(split: Partition, context: TaskContext) = { + override def compute(split: Partition, context: TaskContext): Iterator[(T, U)] = { val currSplit = split.asInstanceOf[CartesianPartition] for (x <- rdd1.iterator(currSplit.s1, context); y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 1c13e2c37284..0d130dd4c7a6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.util.Utils private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} @@ -48,7 +49,7 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) if (fs.exists(cpath)) { val dirContents = fs.listStatus(cpath).map(_.getPath) val partitionFiles = dirContents.filter(_.getName.startsWith("part-")).map(_.toString).sorted - val numPart = partitionFiles.size + val numPart = partitionFiles.length if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) { throw new SparkException("Invalid checkpoint directory: " + checkpointPath) @@ -112,8 +113,11 @@ private[spark] object CheckpointRDD extends Logging { } val serializer = env.serializer.newInstance() val serializeStream = serializer.serializeStream(fileOutputStream) - serializeStream.writeAll(iterator) - serializeStream.close() + Utils.tryWithSafeFinally { + serializeStream.writeAll(iterator) + } { + serializeStream.close() + } if (!fs.rename(tempOutputPath, finalOutputPath)) { if (!fs.exists(finalOutputPath)) { diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 07398a6fa62f..658e8c8b8931 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -29,15 +29,16 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer} import org.apache.spark.util.Utils import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.ShuffleHandle - -private[spark] sealed trait CoGroupSplitDep extends Serializable +/** The references to rdd and splitIndex are transient because redundant information is stored + * in the CoGroupedRDD object. Because CoGroupedRDD is serialized separately from + * CoGroupPartition, if rdd and splitIndex aren't transient, they'll be included twice in the + * task closure. */ private[spark] case class NarrowCoGroupSplitDep( - rdd: RDD[_], - splitIndex: Int, + @transient rdd: RDD[_], + @transient splitIndex: Int, var split: Partition - ) extends CoGroupSplitDep { + ) extends Serializable { @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { @@ -47,9 +48,16 @@ private[spark] case class NarrowCoGroupSplitDep( } } -private[spark] case class ShuffleCoGroupSplitDep(handle: ShuffleHandle) extends CoGroupSplitDep - -private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]) +/** + * Stores information about the narrow dependencies used by a CoGroupedRdd. + * + * @param narrowDeps maps to the dependencies variable in the parent RDD: for each one to one + * dependency in dependencies, narrowDeps has a NarrowCoGroupSplitDep (describing + * the partition for that dependency) at the corresponding index. The size of + * narrowDeps should always be equal to the number of parents. + */ +private[spark] class CoGroupPartition( + idx: Int, val narrowDeps: Array[Option[NarrowCoGroupSplitDep]]) extends Partition with Serializable { override val index: Int = idx override def hashCode(): Int = idx @@ -99,15 +107,15 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: override def getPartitions: Array[Partition] = { val array = new Array[Partition](part.numPartitions) - for (i <- 0 until array.size) { + for (i <- 0 until array.length) { // Each CoGroupPartition will have a dependency per contributing RDD array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) => // Assume each RDD contributed a single dependency, and get it dependencies(j) match { case s: ShuffleDependency[_, _, _] => - new ShuffleCoGroupSplitDep(s.shuffleHandle) + None case _ => - new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) + Some(new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))) } }.toArray) } @@ -120,20 +128,21 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: val sparkConf = SparkEnv.get.conf val externalSorting = sparkConf.getBoolean("spark.shuffle.spill", true) val split = s.asInstanceOf[CoGroupPartition] - val numRdds = split.deps.size + val numRdds = dependencies.length // A list of (rdd iterator, dependency number) pairs val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)] - for ((dep, depNum) <- split.deps.zipWithIndex) dep match { - case NarrowCoGroupSplitDep(rdd, _, itsSplit) => + for ((dep, depNum) <- dependencies.zipWithIndex) dep match { + case oneToOneDependency: OneToOneDependency[Product2[K, Any]] => + val dependencyPartition = split.narrowDeps(depNum).get.split // Read them from the parent - val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]] + val it = oneToOneDependency.rdd.iterator(dependencyPartition, context) rddIterators += ((it, depNum)) - case ShuffleCoGroupSplitDep(handle) => + case shuffleDependency: ShuffleDependency[_, _, _] => // Read map outputs of shuffle val it = SparkEnv.get.shuffleManager - .getReader(handle, split.index, split.index + 1, context) + .getReader(shuffleDependency.shuffleHandle, split.index, split.index + 1, context) .read() rddIterators += ((it, depNum)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index b073eba8a157..0c1b02c07d09 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -166,7 +166,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: // determines the tradeoff between load-balancing the partitions sizes and their locality // e.g. balanceSlack=0.10 means that it allows up to 10% imbalance in favor of locality - val slack = (balanceSlack * prev.partitions.size).toInt + val slack = (balanceSlack * prev.partitions.length).toInt var noLocality = true // if true if no preferredLocations exists for parent RDD @@ -186,7 +186,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: override val isEmpty = !it.hasNext // initializes/resets to start iterating from the beginning - def resetIterator() = { + def resetIterator(): Iterator[(String, Partition)] = { val iterators = (0 to 2).map( x => prev.partitions.iterator.flatMap(p => { if (currPrefLocs(p).size > x) Some((currPrefLocs(p)(x), p)) else None @@ -196,10 +196,10 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: } // hasNext() is false iff there are no preferredLocations for any of the partitions of the RDD - def hasNext(): Boolean = { !isEmpty } + override def hasNext: Boolean = { !isEmpty } // return the next preferredLocation of some partition of the RDD - def next(): (String, Partition) = { + override def next(): (String, Partition) = { if (it.hasNext) { it.next() } else { @@ -237,7 +237,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: val rotIt = new LocationIterator(prev) // deal with empty case, just create targetLen partition groups with no preferred location - if (!rotIt.hasNext()) { + if (!rotIt.hasNext) { (1 to targetLen).foreach(x => groupArr += PartitionGroup()) return } @@ -343,7 +343,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: private case class PartitionGroup(prefLoc: Option[String] = None) { var arr = mutable.ArrayBuffer[Partition]() - def size = arr.size + def size: Int = arr.size } private object PartitionGroup { diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index e66f83bb34e3..843a893235e5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.StatCounter class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { /** Add up the elements in this RDD. */ def sum(): Double = { - self.reduce(_ + _) + self.fold(0.0)(_ + _) } /** @@ -70,7 +70,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { @Experimental def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) - val evaluator = new MeanEvaluator(self.partitions.size, confidence) + val evaluator = new MeanEvaluator(self.partitions.length, confidence) self.context.runApproximateJob(self, processPartition, evaluator, timeout) } @@ -81,7 +81,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { @Experimental def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) - val evaluator = new SumEvaluator(self.partitions.size, confidence) + val evaluator = new SumEvaluator(self.partitions.length, confidence) self.context.runApproximateJob(self, processPartition, evaluator, timeout) } @@ -191,29 +191,34 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { } } // Determine the bucket function in constant time. Requires that buckets are evenly spaced - def fastBucketFunction(min: Double, increment: Double, count: Int)(e: Double): Option[Int] = { + def fastBucketFunction(min: Double, max: Double, count: Int)(e: Double): Option[Int] = { // If our input is not a number unless the increment is also NaN then we fail fast - if (e.isNaN()) { - return None - } - val bucketNumber = (e - min)/(increment) - // We do this rather than buckets.lengthCompare(bucketNumber) - // because Array[Double] fails to override it (for now). - if (bucketNumber > count || bucketNumber < 0) { + if (e.isNaN || e < min || e > max) { None } else { - Some(bucketNumber.toInt.min(count - 1)) + // Compute ratio of e's distance along range to total range first, for better precision + val bucketNumber = (((e - min) / (max - min)) * count).toInt + // should be less than count, but will equal count if e == max, in which case + // it's part of the last end-range-inclusive bucket, so return count-1 + Some(math.min(bucketNumber, count - 1)) } } // Decide which bucket function to pass to histogramPartition. We decide here - // rather than having a general function so that the decission need only be made + // rather than having a general function so that the decision need only be made // once rather than once per shard val bucketFunction = if (evenBuckets) { - fastBucketFunction(buckets(0), buckets(1)-buckets(0), buckets.length-1) _ + fastBucketFunction(buckets.head, buckets.last, buckets.length - 1) _ } else { basicBucketFunction _ } - self.mapPartitions(histogramPartition(bucketFunction)).reduce(mergeCounters) + if (self.partitions.length == 0) { + new Array[Long](buckets.length - 1) + } else { + // reduce() requires a non-empty RDD. This works because the mapPartitions will make + // non-empty partitions out of empty ones. But it doesn't handle the no-partitions case, + // which is below + self.mapPartitions(histogramPartition(bucketFunction)).reduce(mergeCounters) + } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 056aef0bc210..f77abac42b62 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -35,16 +35,18 @@ import org.apache.hadoop.mapred.Reporter import org.apache.hadoop.mapred.JobID import org.apache.hadoop.mapred.TaskAttemptID import org.apache.hadoop.mapred.TaskID +import org.apache.hadoop.mapred.lib.CombineFileSplit import org.apache.hadoop.util.ReflectionUtils import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.executor.{DataReadMethod, InputMetrics} +import org.apache.spark.executor.DataReadMethod import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD import org.apache.spark.util.{NextIterator, Utils} import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation} +import org.apache.spark.storage.StorageLevel /** * A Spark split class that wraps around a Hadoop InputSplit. @@ -213,18 +215,17 @@ class HadoopRDD[K, V]( logInfo("Input split: " + split.inputSplit) val jobConf = getJobConf() - val inputMetrics = context.taskMetrics - .getInputMetricsForReadMethod(DataReadMethod.Hadoop) + val inputMetrics = context.taskMetrics.getInputMetricsForReadMethod(DataReadMethod.Hadoop) // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = inputMetrics.bytesReadCallback.orElse( + val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { split.inputSplit.value match { - case split: FileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(split.getPath, jobConf) + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() case _ => None } - ) + } inputMetrics.setBytesReadCallback(bytesReadCallback) var reader: RecordReader[K, V] = null @@ -238,14 +239,16 @@ class HadoopRDD[K, V]( val key: K = reader.createKey() val value: V = reader.createValue() - override def getNext() = { + override def getNext(): (K, V) = { try { finished = !reader.next(key, value) } catch { case eof: EOFException => finished = true } - + if (!finished) { + inputMetrics.incRecordsRead(1) + } (key, value) } @@ -254,11 +257,12 @@ class HadoopRDD[K, V]( reader.close() if (bytesReadCallback.isDefined) { inputMetrics.updateBytesRead() - } else if (split.inputSplit.value.isInstanceOf[FileSplit]) { + } else if (split.inputSplit.value.isInstanceOf[FileSplit] || + split.inputSplit.value.isInstanceOf[CombineFileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { - inputMetrics.addBytesRead(split.inputSplit.value.getLength) + inputMetrics.incBytesRead(split.inputSplit.value.getLength) } catch { case e: java.io.IOException => logWarning("Unable to get input size to set InputMetrics for task", e) @@ -306,6 +310,15 @@ class HadoopRDD[K, V]( // Do nothing. Hadoop RDD should not be checkpointed. } + override def persist(storageLevel: StorageLevel): this.type = { + if (storageLevel.deserialized) { + logWarning("Caching NewHadoopRDDs as deserialized objects usually leads to undesired" + + " behavior because Hadoop's RecordReader reuses the same Writable object for all records." + + " Use a map transformation to make copies of the records.") + } + super.persist(storageLevel) + } + def getConf: Configuration = getJobConf() } @@ -323,11 +336,11 @@ private[spark] object HadoopRDD extends Logging { * The three methods below are helpers for accessing the local map, a property of the SparkEnv of * the local process. */ - def getCachedMetadata(key: String) = SparkEnv.get.hadoopJobMetadata.get(key) + def getCachedMetadata(key: String): Any = SparkEnv.get.hadoopJobMetadata.get(key) - def containsCachedMetadata(key: String) = SparkEnv.get.hadoopJobMetadata.containsKey(key) + def containsCachedMetadata(key: String): Boolean = SparkEnv.get.hadoopJobMetadata.containsKey(key) - def putCachedMetadata(key: String, value: Any) = + private def putCachedMetadata(key: String, value: Any): Unit = SparkEnv.get.hadoopJobMetadata.put(key, value) /** Add Hadoop configuration specific to a single partition and attempt. */ @@ -357,7 +370,7 @@ private[spark] object HadoopRDD extends Logging { override def getPartitions: Array[Partition] = firstParent[T].partitions - override def compute(split: Partition, context: TaskContext) = { + override def compute(split: Partition, context: TaskContext): Iterator[U] = { val partition = split.asInstanceOf[HadoopPartition] val inputSplit = partition.inputSplit.value f(inputSplit, firstParent[T].iterator(split, context)) diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index 642a12c1edf6..0c28f045e46e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import java.sql.{Connection, ResultSet} +import java.sql.{PreparedStatement, Connection, ResultSet} import scala.reflect.ClassTag @@ -28,8 +28,9 @@ import org.apache.spark.util.NextIterator import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition { - override def index = idx + override def index: Int = idx } + // TODO: Expose a jdbcRDD function in SparkContext and mark this as semi-private /** * An RDD that executes an SQL query on a JDBC connection and reads results. @@ -62,15 +63,16 @@ class JdbcRDD[T: ClassTag]( override def getPartitions: Array[Partition] = { // bounds are inclusive, hence the + 1 here and - 1 on end - val length = 1 + upperBound - lowerBound + val length = BigInt(1) + upperBound - lowerBound (0 until numPartitions).map(i => { - val start = lowerBound + ((i * length) / numPartitions).toLong - val end = lowerBound + (((i + 1) * length) / numPartitions).toLong - 1 - new JdbcPartition(i, start, end) + val start = lowerBound + ((i * length) / numPartitions) + val end = lowerBound + (((i + 1) * length) / numPartitions) - 1 + new JdbcPartition(i, start.toLong, end.toLong) }).toArray } - override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] { + override def compute(thePart: Partition, context: TaskContext): Iterator[T] = new NextIterator[T] + { context.addTaskCompletionListener{ context => closeIfNeeded() } val part = thePart.asInstanceOf[JdbcPartition] val conn = getConnection() @@ -88,7 +90,7 @@ class JdbcRDD[T: ClassTag]( stmt.setLong(2, part.upper) val rs = stmt.executeQuery() - override def getNext: T = { + override def getNext(): T = { if (rs.next()) { mapRow(rs) } else { @@ -99,21 +101,21 @@ class JdbcRDD[T: ClassTag]( override def close() { try { - if (null != rs && ! rs.isClosed()) { + if (null != rs) { rs.close() } } catch { case e: Exception => logWarning("Exception closing resultset", e) } try { - if (null != stmt && ! stmt.isClosed()) { + if (null != stmt) { stmt.close() } } catch { case e: Exception => logWarning("Exception closing statement", e) } try { - if (null != conn && ! conn.isClosed()) { + if (null != conn) { conn.close() } logInfo("closed connection") diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala index 4883fb828814..a838aac6e8d1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -31,6 +31,6 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( override def getPartitions: Array[Partition] = firstParent[T].partitions - override def compute(split: Partition, context: TaskContext) = + override def compute(split: Partition, context: TaskContext): Iterator[U] = f(context, split.index, firstParent[T].iterator(split, context)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 7b0e3c87ccff..2ab967f4bb31 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -25,20 +25,17 @@ import scala.reflect.ClassTag import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.FileSplit +import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.input.WholeTextFileInputFormat -import org.apache.spark.InterruptibleIterator -import org.apache.spark.Logging -import org.apache.spark.Partition -import org.apache.spark.SerializableWritable -import org.apache.spark.{SparkContext, TaskContext} -import org.apache.spark.executor.{DataReadMethod, InputMetrics} +import org.apache.spark._ +import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.util.Utils import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.storage.StorageLevel private[spark] class NewHadoopPartition( rddId: Int, @@ -114,13 +111,13 @@ class NewHadoopRDD[K, V]( // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = inputMetrics.bytesReadCallback.orElse( + val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { split.serializableHadoopSplit.value match { - case split: FileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(split.getPath, conf) + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() case _ => None } - ) + } inputMetrics.setBytesReadCallback(bytesReadCallback) val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) @@ -154,7 +151,9 @@ class NewHadoopRDD[K, V]( throw new java.util.NoSuchElementException("End of stream") } havePair = false - + if (!finished) { + inputMetrics.incRecordsRead(1) + } (reader.getCurrentKey, reader.getCurrentValue) } @@ -163,11 +162,12 @@ class NewHadoopRDD[K, V]( reader.close() if (bytesReadCallback.isDefined) { inputMetrics.updateBytesRead() - } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) { + } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { - inputMetrics.addBytesRead(split.serializableHadoopSplit.value.getLength) + inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) } catch { case e: java.io.IOException => logWarning("Unable to get input size to set InputMetrics for task", e) @@ -210,6 +210,16 @@ class NewHadoopRDD[K, V]( locs.getOrElse(split.getLocations.filter(_ != "localhost")) } + override def persist(storageLevel: StorageLevel): this.type = { + if (storageLevel.deserialized) { + logWarning("Caching NewHadoopRDDs as deserialized objects usually leads to undesired" + + " behavior because Hadoop's RecordReader reuses the same Writable object for all records." + + " Use a map transformation to make copies of the records.") + } + super.persist(storageLevel) + } + + def getConf: Configuration = confBroadcast.value.value } @@ -228,7 +238,7 @@ private[spark] object NewHadoopRDD { override def getPartitions: Array[Partition] = firstParent[T].partitions - override def compute(split: Partition, context: TaskContext) = { + override def compute(split: Partition, context: TaskContext): Iterator[U] = { val partition = split.asInstanceOf[NewHadoopPartition] val inputSplit = partition.serializableHadoopSplit.value f(inputSplit, firstParent[T].iterator(split, context)) diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index 144f679a5946..6afe50161dac 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -56,7 +56,7 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, * order of the keys). */ // TODO: this currently doesn't work on P other than Tuple2! - def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size) + def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.length) : RDD[(K, V)] = { val part = new RangePartitioner(numPartitions, self, ascending) @@ -75,4 +75,27 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, new ShuffledRDD[K, V, V](self, partitioner).setKeyOrdering(ordering) } + /** + * Returns an RDD containing only the elements in the the inclusive range `lower` to `upper`. + * If the RDD has been partitioned using a `RangePartitioner`, then this operation can be + * performed efficiently by only scanning the partitions that might contain matching elements. + * Otherwise, a standard `filter` is applied to all partitions. + */ + def filterByRange(lower: K, upper: K): RDD[P] = { + + def inRange(k: K): Boolean = ordering.gteq(k, lower) && ordering.lteq(k, upper) + + val rddToFilter: RDD[P] = self.partitioner match { + case Some(rp: RangePartitioner[K, V]) => { + val partitionIndicies = (rp.getPartition(lower), rp.getPartition(upper)) match { + case (l, u) => Math.min(l, u) to Math.max(l, u) + } + PartitionPruningRDD.create(self, partitionIndicies.contains) + } + case _ => + self + } + rddToFilter.filter { case (k, v) => inRange(k) } + } + } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 0f37d830ef34..1548fac960ec 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -34,7 +34,7 @@ import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, -RecordWriter => NewRecordWriter} + RecordWriter => NewRecordWriter} import org.apache.spark._ import org.apache.spark.Partitioner.defaultPartitioner @@ -823,7 +823,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * RDD will be <= us. */ def subtractByKey[W: ClassTag](other: RDD[(K, W)]): RDD[(K, V)] = - subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size))) + subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.length))) /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ def subtractByKey[W: ClassTag](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] = @@ -990,11 +990,12 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val committer = format.getOutputCommitter(hadoopContext) committer.setupTask(hadoopContext) - val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config) + val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context) val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] - try { - var recordsWritten = 0L + require(writer != null, "Unable to obtain RecordWriter") + var recordsWritten = 0L + Utils.tryWithSafeFinally { while (iter.hasNext) { val pair = iter.next() writer.write(pair._1, pair._2) @@ -1003,11 +1004,12 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten) recordsWritten += 1 } - } finally { + } { writer.close(hadoopContext) } committer.commitTask(hadoopContext) bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) } + outputMetrics.setRecordsWritten(recordsWritten) 1 } : Int @@ -1061,12 +1063,13 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) // around by taking a mod. We expect that no task will be attempted 2 billion times. val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt - val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config) + val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context) writer.setup(context.stageId, context.partitionId, taskAttemptId) writer.open() - try { - var recordsWritten = 0L + var recordsWritten = 0L + + Utils.tryWithSafeFinally { while (iter.hasNext) { val record = iter.next() writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) @@ -1075,22 +1078,20 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten) recordsWritten += 1 } - } finally { + } { writer.close() } writer.commit() bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) } + outputMetrics.setRecordsWritten(recordsWritten) } self.context.runJob(self, writeToFile) - writer.commitJob() + writer.commitJob() // hive insert clones here } - private def initHadoopOutputMetrics(context: TaskContext, config: Configuration) - : (OutputMetrics, Option[() => Long]) = { - val bytesWrittenCallback = Option(config.get("mapreduce.output.fileoutputformat.outputdir")) - .map(new Path(_)) - .flatMap(SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback(_, config)) + private def initHadoopOutputMetrics(context: TaskContext): (OutputMetrics, Option[() => Long]) = { + val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback() val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop) if (bytesWrittenCallback.isDefined) { context.taskMetrics.outputMetrics = Some(outputMetrics) @@ -1100,9 +1101,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) private def maybeUpdateOutputMetrics(bytesWrittenCallback: Option[() => Long], outputMetrics: OutputMetrics, recordsWritten: Long): Unit = { - if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0 - && bytesWrittenCallback.isDefined) { + if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0) { bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) } + outputMetrics.setRecordsWritten(recordsWritten) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index f12d0cffaba3..e2394e28f8d2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -98,7 +98,7 @@ private[spark] class ParallelCollectionRDD[T: ClassTag]( slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray } - override def compute(s: Partition, context: TaskContext) = { + override def compute(s: Partition, context: TaskContext): Iterator[T] = { new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator) } diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index f781a8d776f2..a00f4c1cdff9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -40,7 +40,7 @@ private[spark] class PruneDependency[T](rdd: RDD[T], @transient partitionFilterF .filter(s => partitionFilterFunc(s.index)).zipWithIndex .map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition } - override def getParents(partitionId: Int) = { + override def getParents(partitionId: Int): List[Int] = { List(partitions(partitionId).asInstanceOf[PartitionPruningRDDPartition].parentSplit.index) } } @@ -59,8 +59,10 @@ class PartitionPruningRDD[T: ClassTag]( @transient partitionFilterFunc: Int => Boolean) extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { - override def compute(split: Partition, context: TaskContext) = firstParent[T].iterator( - split.asInstanceOf[PartitionPruningRDDPartition].parentSplit, context) + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + firstParent[T].iterator( + split.asInstanceOf[PartitionPruningRDDPartition].parentSplit, context) + } override protected def getPartitions: Array[Partition] = getDependencies.head.asInstanceOf[PruneDependency[T]].partitions @@ -74,7 +76,7 @@ object PartitionPruningRDD { * Create a PartitionPruningRDD. This function can be used to create the PartitionPruningRDD * when its type T is not known at compile time. */ - def create[T](rdd: RDD[T], partitionFilterFunc: Int => Boolean) = { + def create[T](rdd: RDD[T], partitionFilterFunc: Int => Boolean): PartitionPruningRDD[T] = { new PartitionPruningRDD[T](rdd, partitionFilterFunc)(rdd.elementClassTag) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index ed79032893d3..dc60d4892762 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -149,10 +149,10 @@ private[spark] class PipedRDD[T: ClassTag]( }.start() // Return an iterator that read lines from the process's stdout - val lines = Source.fromInputStream(proc.getInputStream).getLines + val lines = Source.fromInputStream(proc.getInputStream).getLines() new Iterator[String] { - def next() = lines.next() - def hasNext = { + def next(): String = lines.next() + def hasNext: Boolean = { if (lines.hasNext) { true } else { @@ -162,7 +162,7 @@ private[spark] class PipedRDD[T: ClassTag]( } // cleanup task working directory if used - if (workInTaskDirectory == true) { + if (workInTaskDirectory) { scala.util.control.Exception.ignoring(classOf[IOException]) { Utils.deleteRecursively(new File(taskDirectory)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 97012c7033f9..d80d94a58834 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -25,11 +25,8 @@ import scala.language.implicitConversions import scala.reflect.{classTag, ClassTag} import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus -import org.apache.hadoop.io.BytesWritable +import org.apache.hadoop.io.{Writable, BytesWritable, NullWritable, Text} import org.apache.hadoop.io.compress.CompressionCodec -import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.io.Text -import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.TextOutputFormat import org.apache.spark._ @@ -57,8 +54,7 @@ import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, Bernoulli * [[org.apache.spark.rdd.SequenceFileRDDFunctions]] contains operations available on RDDs that * can be saved as SequenceFiles. * All operations are automatically available on any RDD of the right type (e.g. RDD[(Int, Int)] - * through implicit conversions except `saveAsSequenceFile`. You need to - * `import org.apache.spark.SparkContext._` to make `saveAsSequenceFile` work. + * through implicit. * * Internally, each RDD is characterized by five main properties: * @@ -76,10 +72,27 @@ import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, Bernoulli * on RDD internals. */ abstract class RDD[T: ClassTag]( - @transient private var sc: SparkContext, + @transient private var _sc: SparkContext, @transient private var deps: Seq[Dependency[_]] ) extends Serializable with Logging { + if (classOf[RDD[_]].isAssignableFrom(elementClassTag.runtimeClass)) { + // This is a warning instead of an exception in order to avoid breaking user programs that + // might have defined nested RDDs without running jobs with them. + logWarning("Spark does not support nested RDDs (see SPARK-5063)") + } + + private def sc: SparkContext = { + if (_sc == null) { + throw new SparkException( + "RDD transformations and actions can only be invoked by the driver, not inside of other " + + "transformations; for example, rdd1.map(x => rdd2.values.count() * x) is invalid because " + + "the values transformation and count action cannot be performed inside of the rdd1.map " + + "transformation. For more information, see SPARK-5063.") + } + _sc + } + /** Construct an RDD with just a one-to-one dependency on one parent */ def this(@transient oneParent: RDD[_]) = this(oneParent.context , List(new OneToOneDependency(oneParent))) @@ -173,7 +186,7 @@ abstract class RDD[T: ClassTag]( } /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ - def getStorageLevel = storageLevel + def getStorageLevel: StorageLevel = storageLevel // Our dependencies and partitions will be gotten by calling subclass's methods below, and will // be overwritten when we're checkpointed @@ -303,7 +316,7 @@ abstract class RDD[T: ClassTag]( /** * Return a new RDD containing the distinct elements in this RDD. */ - def distinct(): RDD[T] = distinct(partitions.size) + def distinct(): RDD[T] = distinct(partitions.length) /** * Return a new RDD that has exactly numPartitions partitions. @@ -364,6 +377,12 @@ abstract class RDD[T: ClassTag]( /** * Return a sampled subset of this RDD. + * + * @param withReplacement can elements be sampled multiple times (replaced when sampled out) + * @param fraction expected size of the sample as a fraction of this RDD's size + * without replacement: probability that each element is chosen; fraction must be [0, 1] + * with replacement: expected number of times each element is chosen; fraction must be >= 0 + * @param seed seed for the random number generator */ def sample(withReplacement: Boolean, fraction: Double, @@ -449,7 +468,13 @@ abstract class RDD[T: ClassTag]( * Return the union of this RDD and another one. Any identical elements will appear multiple * times (use `.distinct()` to eliminate them). */ - def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other)) + def union(other: RDD[T]): RDD[T] = { + if (partitioner.isDefined && other.partitioner == partitioner) { + new PartitionerAwareUnionRDD(sc, Array(this, other)) + } else { + new UnionRDD(sc, Array(this, other)) + } + } /** * Return the union of this RDD and another one. Any identical elements will appear multiple @@ -463,7 +488,7 @@ abstract class RDD[T: ClassTag]( def sortBy[K]( f: (T) => K, ascending: Boolean = true, - numPartitions: Int = this.partitions.size) + numPartitions: Int = this.partitions.length) (implicit ord: Ordering[K], ctag: ClassTag[K]): RDD[T] = this.keyBy[K](f) .sortByKey(ascending, numPartitions) @@ -587,8 +612,8 @@ abstract class RDD[T: ClassTag]( * print line function (like out.println()) as the 2nd parameter. * An example of pipe the RDD data of groupBy() in a streaming way, * instead of constructing a huge String to concat all the elements: - * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = - * for (e <- record._2){f(e)} + * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = + * for (e <- record._2){f(e)} * @param separateWorkingDir Use separate working directories for each task. * @return the result RDD */ @@ -721,13 +746,13 @@ abstract class RDD[T: ClassTag]( def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = { zipPartitions(other, preservesPartitioning = false) { (thisIter, otherIter) => new Iterator[(T, U)] { - def hasNext = (thisIter.hasNext, otherIter.hasNext) match { + def hasNext: Boolean = (thisIter.hasNext, otherIter.hasNext) match { case (true, true) => true case (false, false) => false case _ => throw new SparkException("Can only zip RDDs with " + "same number of elements in each partition") } - def next = (thisIter.next, otherIter.next) + def next(): (T, U) = (thisIter.next(), otherIter.next()) } } } @@ -824,10 +849,10 @@ abstract class RDD[T: ClassTag]( * Return an RDD with the elements from `this` that are not in `other`. * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. + * RDD will be <= us. */ def subtract(other: RDD[T]): RDD[T] = - subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size))) + subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.length))) /** * Return an RDD with the elements from `this` that are not in `other`. @@ -843,8 +868,8 @@ abstract class RDD[T: ClassTag]( // Our partitioner knows how to handle T (which, since we have a partitioner, is // really (K, V)) so make a new Partitioner that will de-tuple our fake tuples val p2 = new Partitioner() { - override def numPartitions = p.numPartitions - override def getPartition(k: Any) = p.getPartition(k.asInstanceOf[(Any, _)]._1) + override def numPartitions: Int = p.numPartitions + override def getPartition(k: Any): Int = p.getPartition(k.asInstanceOf[(Any, _)]._1) } // Unfortunately, since we're making a new p2, we'll get ShuffleDependencies // anyway, and when calling .keys, will not have a partitioner set, even though @@ -883,6 +908,38 @@ abstract class RDD[T: ClassTag]( jobResult.getOrElse(throw new UnsupportedOperationException("empty collection")) } + /** + * Reduces the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree (default: 2) + * @see [[org.apache.spark.rdd.RDD#reduce]] + */ + def treeReduce(f: (T, T) => T, depth: Int = 2): T = { + require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") + val cleanF = context.clean(f) + val reducePartition: Iterator[T] => Option[T] = iter => { + if (iter.hasNext) { + Some(iter.reduceLeft(cleanF)) + } else { + None + } + } + val partiallyReduced = mapPartitions(it => Iterator(reducePartition(it))) + val op: (Option[T], Option[T]) => Option[T] = (c, x) => { + if (c.isDefined && x.isDefined) { + Some(cleanF(c.get, x.get)) + } else if (c.isDefined) { + c + } else if (x.isDefined) { + x + } else { + None + } + } + partiallyReduced.treeAggregate(Option.empty[T])(op, op, depth) + .getOrElse(throw new UnsupportedOperationException("empty collection")) + } + /** * Aggregate the elements of each partition, and then the results for all the partitions, using a * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to @@ -909,7 +966,7 @@ abstract class RDD[T: ClassTag]( */ def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = { // Clone the zero value since we will also be serializing it as part of tasks - var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) + var jobResult = Utils.clone(zeroValue, sc.env.serializer.newInstance()) val cleanSeqOp = sc.clean(seqOp) val cleanCombOp = sc.clean(combOp) val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) @@ -918,6 +975,37 @@ abstract class RDD[T: ClassTag]( jobResult } + /** + * Aggregates the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree (default: 2) + * @see [[org.apache.spark.rdd.RDD#aggregate]] + */ + def treeAggregate[U: ClassTag](zeroValue: U)( + seqOp: (U, T) => U, + combOp: (U, U) => U, + depth: Int = 2): U = { + require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") + if (partitions.length == 0) { + return Utils.clone(zeroValue, context.env.closureSerializer.newInstance()) + } + val cleanSeqOp = context.clean(seqOp) + val cleanCombOp = context.clean(combOp) + val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) + var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it))) + var numPartitions = partiallyAggregated.partitions.length + val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) + // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation. + while (numPartitions > scale + numPartitions / scale) { + numPartitions /= scale + val curNumPartitions = numPartitions + partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) => + iter.map((i % curNumPartitions, _)) + }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values + } + partiallyAggregated.reduce(cleanCombOp) + } + /** * Return the number of elements in the RDD. */ @@ -938,7 +1026,7 @@ abstract class RDD[T: ClassTag]( } result } - val evaluator = new CountEvaluator(partitions.size, confidence) + val evaluator = new CountEvaluator(partitions.length, confidence) sc.runApproximateJob(this, countElements, evaluator, timeout) } @@ -947,7 +1035,7 @@ abstract class RDD[T: ClassTag]( * * Note that this method should only be used if the resulting map is expected to be small, as * the whole thing is loaded into the driver's memory. - * To handle very large results, consider using rdd.map(x => (x, 1L)).reduceByKey(_ + _), which + * To handle very large results, consider using rdd.map(x => (x, 1L)).reduceByKey(_ + _), which * returns an RDD[T, Long] instead of a map. */ def countByValue()(implicit ord: Ordering[T] = null): Map[T, Long] = { @@ -973,7 +1061,7 @@ abstract class RDD[T: ClassTag]( } map } - val evaluator = new GroupedCountEvaluator[T](partitions.size, confidence) + val evaluator = new GroupedCountEvaluator[T](partitions.length, confidence) sc.runApproximateJob(this, countPartition, evaluator, timeout) } @@ -985,7 +1073,7 @@ abstract class RDD[T: ClassTag]( * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available * here. * - * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero `sp > p` + * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero `sp > p` * would trigger sparse representation of registers, which may reduce the memory consumption * and increase accuracy when the cardinality is small. * @@ -1052,7 +1140,7 @@ abstract class RDD[T: ClassTag]( * the same index assignments, you should sort the RDD with sortByKey() or save it to a file. */ def zipWithUniqueId(): RDD[(T, Long)] = { - val n = this.partitions.size.toLong + val n = this.partitions.length.toLong this.mapPartitionsWithIndex { case (k, iter) => iter.zipWithIndex.map { case (item, i) => (item, i * n + k) @@ -1064,6 +1152,9 @@ abstract class RDD[T: ClassTag]( * Take the first num elements of the RDD. It works by first scanning one partition, and use the * results from that partition to estimate the number of additional partitions needed to satisfy * the limit. + * + * @note due to complications in the internal implementation, this method will raise + * an exception if called on an RDD of `Nothing` or `Null`. */ def take(num: Int): Array[T] = { if (num == 0) { @@ -1152,7 +1243,7 @@ abstract class RDD[T: ClassTag]( queue ++= util.collection.Utils.takeOrdered(items, num)(ord) Iterator.single(queue) } - if (mapRDDs.partitions.size == 0) { + if (mapRDDs.partitions.length == 0) { Array.empty } else { mapRDDs.reduce { (queue1, queue2) => @@ -1176,6 +1267,10 @@ abstract class RDD[T: ClassTag]( def min()(implicit ord: Ordering[T]): T = this.reduce(ord.min) /** + * @note due to complications in the internal implementation, this method will raise an + * exception if called on an RDD of `Nothing` or `Null`. This may be come up in practice + * because, for example, the type of `parallelize(Seq())` is `RDD[Nothing]`. + * (`parallelize(Seq())` should be avoided anyway in favor of `parallelize(Seq[T]())`.) * @return true if and only if the RDD contains no elements at all. Note that an RDD * may be empty even when it has at least 1 partition. */ @@ -1299,11 +1394,11 @@ abstract class RDD[T: ClassTag]( } /** The [[org.apache.spark.SparkContext]] that this RDD was created on. */ - def context = sc + def context: SparkContext = sc /** * Private API for changing an RDD's ClassTag. - * Used for internal Java <-> Scala API compatibility. + * Used for internal Java-Scala API compatibility. */ private[spark] def retag(cls: Class[T]): RDD[T] = { val classTag: ClassTag[T] = ClassTag.apply(cls) @@ -1312,7 +1407,7 @@ abstract class RDD[T: ClassTag]( /** * Private API for changing an RDD's ClassTag. - * Used for internal Java <-> Scala API compatibility. + * Used for internal Java-Scala API compatibility. */ private[spark] def retag(implicit classTag: ClassTag[T]): RDD[T] = { this.mapPartitions(identity, preservesPartitioning = true)(classTag) @@ -1394,7 +1489,7 @@ abstract class RDD[T: ClassTag]( } // The first RDD in the dependency stack has no parents, so no need for a +- def firstDebugString(rdd: RDD[_]): Seq[String] = { - val partitionStr = "(" + rdd.partitions.size + ")" + val partitionStr = "(" + rdd.partitions.length + ")" val leftOffset = (partitionStr.length - 1) / 2 val nextPrefix = (" " * leftOffset) + "|" + (" " * (partitionStr.length - leftOffset)) @@ -1404,7 +1499,7 @@ abstract class RDD[T: ClassTag]( } ++ debugChildren(rdd, nextPrefix) } def shuffleDebugString(rdd: RDD[_], prefix: String = "", isLastChild: Boolean): Seq[String] = { - val partitionStr = "(" + rdd.partitions.size + ")" + val partitionStr = "(" + rdd.partitions.length + ")" val leftOffset = (partitionStr.length - 1) / 2 val thisPrefix = prefix.replaceAll("\\|\\s+$", "") val nextPrefix = ( @@ -1447,7 +1542,7 @@ abstract class RDD[T: ClassTag]( */ object RDD { - // The following implicit functions were in SparkContext before 1.2 and users had to + // The following implicit functions were in SparkContext before 1.3 and users had to // `import SparkContext._` to enable them. Now we move them here to make the compiler find // them automatically. However, we still keep the old functions in SparkContext for backward // compatibility and forward to the following functions directly. @@ -1461,9 +1556,15 @@ object RDD { new AsyncRDDActions(rdd) } - implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( - rdd: RDD[(K, V)]): SequenceFileRDDFunctions[K, V] = { - new SequenceFileRDDFunctions(rdd) + implicit def rddToSequenceFileRDDFunctions[K, V](rdd: RDD[(K, V)]) + (implicit kt: ClassTag[K], vt: ClassTag[V], + keyWritableFactory: WritableFactory[K], + valueWritableFactory: WritableFactory[V]) + : SequenceFileRDDFunctions[K, V] = { + implicit val keyConverter = keyWritableFactory.convert + implicit val valueConverter = valueWritableFactory.convert + new SequenceFileRDDFunctions(rdd, + keyWritableFactory.writableClass(kt), valueWritableFactory.writableClass(vt)) } implicit def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag](rdd: RDD[(K, V)]) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index f67e5f185797..1722c27e5500 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -21,7 +21,7 @@ import scala.reflect.ClassTag import org.apache.hadoop.fs.Path -import org.apache.spark.{Logging, Partition, SerializableWritable, SparkException} +import org.apache.spark._ import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask} /** @@ -83,7 +83,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) } // Create the output path for the checkpoint - val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id) + val path = RDDCheckpointData.rddCheckpointDataPath(rdd.context, rdd.id).get val fs = path.getFileSystem(rdd.context.hadoopConfiguration) if (!fs.mkdirs(path)) { throw new SparkException("Failed to create checkpoint path " + path) @@ -92,12 +92,17 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) // Save to file, and reload it as an RDD val broadcastedConf = rdd.context.broadcast( new SerializableWritable(rdd.context.hadoopConfiguration)) - rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _) val newRDD = new CheckpointRDD[T](rdd.context, path.toString) - if (newRDD.partitions.size != rdd.partitions.size) { + if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) { + rdd.context.cleaner.foreach { cleaner => + cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id) + } + } + rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _) + if (newRDD.partitions.length != rdd.partitions.length) { throw new SparkException( - "Checkpoint RDD " + newRDD + "(" + newRDD.partitions.size + ") has different " + - "number of partitions than original RDD " + rdd + "(" + rdd.partitions.size + ")") + "Checkpoint RDD " + newRDD + "(" + newRDD.partitions.length + ") has different " + + "number of partitions than original RDD " + rdd + "(" + rdd.partitions.length + ")") } // Change the dependencies and partitions of the RDD @@ -130,5 +135,17 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) } } -// Used for synchronization -private[spark] object RDDCheckpointData +private[spark] object RDDCheckpointData { + def rddCheckpointDataPath(sc: SparkContext, rddId: Int): Option[Path] = { + sc.checkpointDir.map { dir => new Path(dir, "rdd-" + rddId) } + } + + def clearRDDCheckpointData(sc: SparkContext, rddId: Int): Unit = { + rddCheckpointDataPath(sc, rddId).foreach { path => + val fs = path.getFileSystem(sc.hadoopConfiguration) + if (fs.exists(path)) { + fs.delete(path, true) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala index 2b4891695143..059f8963691f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -30,13 +30,35 @@ import org.apache.spark.Logging * through an implicit conversion. Note that this can't be part of PairRDDFunctions because * we need more implicit parameters to convert our keys and values to Writable. * - * Import `org.apache.spark.SparkContext._` at the top of their program to use these functions. */ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag]( - self: RDD[(K, V)]) + self: RDD[(K, V)], + _keyWritableClass: Class[_ <: Writable], + _valueWritableClass: Class[_ <: Writable]) extends Logging with Serializable { + @deprecated("It's used to provide backward compatibility for pre 1.3.0.", "1.3.0") + def this(self: RDD[(K, V)]) { + this(self, null, null) + } + + private val keyWritableClass = + if (_keyWritableClass == null) { + // pre 1.3.0, we need to use Reflection to get the Writable class + getWritableClass[K]() + } else { + _keyWritableClass + } + + private val valueWritableClass = + if (_valueWritableClass == null) { + // pre 1.3.0, we need to use Reflection to get the Writable class + getWritableClass[V]() + } else { + _valueWritableClass + } + private def getWritableClass[T <% Writable: ClassTag](): Class[_ <: Writable] = { val c = { if (classOf[Writable].isAssignableFrom(classTag[T].runtimeClass)) { @@ -55,6 +77,7 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag c.asInstanceOf[Class[_ <: Writable]] } + /** * Output the RDD as a Hadoop SequenceFile using the Writable types we infer from the RDD's key * and value types. If the key or value are Writable, then we use their classes directly; @@ -65,26 +88,28 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag def saveAsSequenceFile(path: String, codec: Option[Class[_ <: CompressionCodec]] = None) { def anyToWritable[U <% Writable](u: U): Writable = u - val keyClass = getWritableClass[K] - val valueClass = getWritableClass[V] - val convertKey = !classOf[Writable].isAssignableFrom(self.keyClass) - val convertValue = !classOf[Writable].isAssignableFrom(self.valueClass) + // TODO We cannot force the return type of `anyToWritable` be same as keyWritableClass and + // valueWritableClass at the compile time. To implement that, we need to add type parameters to + // SequenceFileRDDFunctions. however, SequenceFileRDDFunctions is a public class so it will be a + // breaking change. + val convertKey = self.keyClass != keyWritableClass + val convertValue = self.valueClass != valueWritableClass - logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," + - valueClass.getSimpleName + ")" ) + logInfo("Saving as sequence file of type (" + keyWritableClass.getSimpleName + "," + + valueWritableClass.getSimpleName + ")" ) val format = classOf[SequenceFileOutputFormat[Writable, Writable]] val jobConf = new JobConf(self.context.hadoopConfiguration) if (!convertKey && !convertValue) { - self.saveAsHadoopFile(path, keyClass, valueClass, format, jobConf, codec) + self.saveAsHadoopFile(path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (!convertKey && convertValue) { self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile( - path, keyClass, valueClass, format, jobConf, codec) + path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (convertKey && !convertValue) { self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile( - path, keyClass, valueClass, format, jobConf, codec) + path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (convertKey && convertValue) { self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile( - path, keyClass, valueClass, format, jobConf, codec) + path, keyWritableClass, valueWritableClass, format, jobConf, codec) } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index d9fe6847254f..2dc47f95937c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -17,14 +17,12 @@ package org.apache.spark.rdd -import scala.reflect.ClassTag - import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.Serializer private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { - override val index = idx + override val index: Int = idx override def hashCode(): Int = idx } diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index ed24ea22a661..633aeba3bbae 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -76,14 +76,14 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( override def getPartitions: Array[Partition] = { val array = new Array[Partition](part.numPartitions) - for (i <- 0 until array.size) { + for (i <- 0 until array.length) { // Each CoGroupPartition will depend on rdd1 and rdd2 array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) => dependencies(j) match { case s: ShuffleDependency[_, _, _] => - new ShuffleCoGroupSplitDep(s.shuffleHandle) + None case _ => - new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) + Some(new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))) } }.toArray) } @@ -105,20 +105,26 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( seq } } - def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit) = dep match { - case NarrowCoGroupSplitDep(rdd, _, itsSplit) => - rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op) + def integrate(depNum: Int, op: Product2[K, V] => Unit) = { + dependencies(depNum) match { + case oneToOneDependency: OneToOneDependency[_] => + val dependencyPartition = partition.narrowDeps(depNum).get.split + oneToOneDependency.rdd.iterator(dependencyPartition, context) + .asInstanceOf[Iterator[Product2[K, V]]].foreach(op) - case ShuffleCoGroupSplitDep(handle) => - val iter = SparkEnv.get.shuffleManager - .getReader(handle, partition.index, partition.index + 1, context) - .read() - iter.foreach(op) + case shuffleDependency: ShuffleDependency[_, _, _] => + val iter = SparkEnv.get.shuffleManager + .getReader( + shuffleDependency.shuffleHandle, partition.index, partition.index + 1, context) + .read() + iter.foreach(op) + } } + // the first dep is rdd1; add all values to the map - integrate(partition.deps(0), t => getSeq(t._1) += t._2) + integrate(0, t => getSeq(t._1) += t._2) // the second dep is rdd2; remove all of its keys - integrate(partition.deps(1), t => map.remove(t._1)) + integrate(1, t => map.remove(t._1)) map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten } diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index aece683ff319..3986645350a8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -44,7 +44,7 @@ private[spark] class UnionPartition[T: ClassTag]( var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex) - def preferredLocations() = rdd.preferredLocations(parentPartition) + def preferredLocations(): Seq[String] = rdd.preferredLocations(parentPartition) override val index: Int = idx @@ -63,7 +63,7 @@ class UnionRDD[T: ClassTag]( extends RDD[T](sc, Nil) { // Nil since we implement getDependencies override def getPartitions: Array[Partition] = { - val array = new Array[Partition](rdds.map(_.partitions.size).sum) + val array = new Array[Partition](rdds.map(_.partitions.length).sum) var pos = 0 for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) { array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index) @@ -76,8 +76,8 @@ class UnionRDD[T: ClassTag]( val deps = new ArrayBuffer[Dependency[_]] var pos = 0 for (rdd <- rdds) { - deps += new RangeDependency(rdd, 0, pos, rdd.partitions.size) - pos += rdd.partitions.size + deps += new RangeDependency(rdd, 0, pos, rdd.partitions.length) + pos += rdd.partitions.length } deps } diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index 95b2dd954e9f..a96b6c3d2345 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -32,7 +32,7 @@ private[spark] class ZippedPartitionsPartition( override val index: Int = idx var partitionValues = rdds.map(rdd => rdd.partitions(idx)) - def partitions = partitionValues + def partitions: Seq[Partition] = partitionValues @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { @@ -52,8 +52,8 @@ private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag]( if (preservesPartitioning) firstParent[Any].partitioner else None override def getPartitions: Array[Partition] = { - val numParts = rdds.head.partitions.size - if (!rdds.forall(rdd => rdd.partitions.size == numParts)) { + val numParts = rdds.head.partitions.length + if (!rdds.forall(rdd => rdd.partitions.length == numParts)) { throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") } Array.tabulate[Partition](numParts) { i => diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index 8c43a559409f..523aaf2b860b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -41,7 +41,7 @@ class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, L /** The start index of each partition. */ @transient private val startIndices: Array[Long] = { - val n = prev.partitions.size + val n = prev.partitions.length if (n == 0) { Array[Long]() } else if (n == 1) { diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala new file mode 100644 index 000000000000..a5336b756380 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -0,0 +1,429 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import java.net.URI + +import scala.concurrent.{Await, Future} +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.reflect.ClassTag + +import org.apache.spark.{Logging, SparkException, SecurityManager, SparkConf} +import org.apache.spark.util.{RpcUtils, Utils} + +/** + * An RPC environment. [[RpcEndpoint]]s need to register itself with a name to [[RpcEnv]] to + * receives messages. Then [[RpcEnv]] will process messages sent from [[RpcEndpointRef]] or remote + * nodes, and deliver them to corresponding [[RpcEndpoint]]s. For uncaught exceptions caught by + * [[RpcEnv]], [[RpcEnv]] will use [[RpcCallContext.sendFailure]] to send exceptions back to the + * sender, or logging them if no such sender or `NotSerializableException`. + * + * [[RpcEnv]] also provides some methods to retrieve [[RpcEndpointRef]]s given name or uri. + */ +private[spark] abstract class RpcEnv(conf: SparkConf) { + + private[spark] val defaultLookupTimeout = RpcUtils.lookupTimeout(conf) + + /** + * Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement + * [[RpcEndpoint.self]]. Return `null` if the corresponding [[RpcEndpointRef]] does not exist. + */ + private[rpc] def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef + + /** + * Return the address that [[RpcEnv]] is listening to. + */ + def address: RpcAddress + + /** + * Register a [[RpcEndpoint]] with a name and return its [[RpcEndpointRef]]. [[RpcEnv]] does not + * guarantee thread-safety. + */ + def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef + + /** + * Retrieve the [[RpcEndpointRef]] represented by `uri` asynchronously. + */ + def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] + + /** + * Retrieve the [[RpcEndpointRef]] represented by `uri`. This is a blocking action. + */ + def setupEndpointRefByURI(uri: String): RpcEndpointRef = { + Await.result(asyncSetupEndpointRefByURI(uri), defaultLookupTimeout) + } + + /** + * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName` + * asynchronously. + */ + def asyncSetupEndpointRef( + systemName: String, address: RpcAddress, endpointName: String): Future[RpcEndpointRef] = { + asyncSetupEndpointRefByURI(uriOf(systemName, address, endpointName)) + } + + /** + * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`. + * This is a blocking action. + */ + def setupEndpointRef( + systemName: String, address: RpcAddress, endpointName: String): RpcEndpointRef = { + setupEndpointRefByURI(uriOf(systemName, address, endpointName)) + } + + /** + * Stop [[RpcEndpoint]] specified by `endpoint`. + */ + def stop(endpoint: RpcEndpointRef): Unit + + /** + * Shutdown this [[RpcEnv]] asynchronously. If need to make sure [[RpcEnv]] exits successfully, + * call [[awaitTermination()]] straight after [[shutdown()]]. + */ + def shutdown(): Unit + + /** + * Wait until [[RpcEnv]] exits. + * + * TODO do we need a timeout parameter? + */ + def awaitTermination(): Unit + + /** + * Create a URI used to create a [[RpcEndpointRef]]. Use this one to create the URI instead of + * creating it manually because different [[RpcEnv]] may have different formats. + */ + def uriOf(systemName: String, address: RpcAddress, endpointName: String): String +} + +private[spark] case class RpcEnvConfig( + conf: SparkConf, + name: String, + host: String, + port: Int, + securityManager: SecurityManager) + +/** + * A RpcEnv implementation must have a [[RpcEnvFactory]] implementation with an empty constructor + * so that it can be created via Reflection. + */ +private[spark] object RpcEnv { + + private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = { + // Add more RpcEnv implementations here + val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory") + val rpcEnvName = conf.get("spark.rpc", "akka") + val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) + Class.forName(rpcEnvFactoryClassName, true, Utils.getContextOrSparkClassLoader). + newInstance().asInstanceOf[RpcEnvFactory] + } + + def create( + name: String, + host: String, + port: Int, + conf: SparkConf, + securityManager: SecurityManager): RpcEnv = { + // Using Reflection to create the RpcEnv to avoid to depend on Akka directly + val config = RpcEnvConfig(conf, name, host, port, securityManager) + getRpcEnvFactory(conf).create(config) + } + +} + +/** + * A factory class to create the [[RpcEnv]]. It must have an empty constructor so that it can be + * created using Reflection. + */ +private[spark] trait RpcEnvFactory { + + def create(config: RpcEnvConfig): RpcEnv +} + +/** + * An end point for the RPC that defines what functions to trigger given a message. + * + * It is guaranteed that `onStart`, `receive` and `onStop` will be called in sequence. + * + * The lift-cycle will be: + * + * constructor onStart receive* onStop + * + * Note: `receive` can be called concurrently. If you want `receive` is thread-safe, please use + * [[ThreadSafeRpcEndpoint]] + * + * If any error is thrown from one of [[RpcEndpoint]] methods except `onError`, `onError` will be + * invoked with the cause. If `onError` throws an error, [[RpcEnv]] will ignore it. + */ +private[spark] trait RpcEndpoint { + + /** + * The [[RpcEnv]] that this [[RpcEndpoint]] is registered to. + */ + val rpcEnv: RpcEnv + + /** + * The [[RpcEndpointRef]] of this [[RpcEndpoint]]. `self` will become valid when `onStart` is + * called. And `self` will become `null` when `onStop` is called. + * + * Note: Because before `onStart`, [[RpcEndpoint]] has not yet been registered and there is not + * valid [[RpcEndpointRef]] for it. So don't call `self` before `onStart` is called. + */ + final def self: RpcEndpointRef = { + require(rpcEnv != null, "rpcEnv has not been initialized") + rpcEnv.endpointRef(this) + } + + /** + * Process messages from [[RpcEndpointRef.send]] or [[RpcCallContext.reply)]]. If receiving a + * unmatched message, [[SparkException]] will be thrown and sent to `onError`. + */ + def receive: PartialFunction[Any, Unit] = { + case _ => throw new SparkException(self + " does not implement 'receive'") + } + + /** + * Process messages from [[RpcEndpointRef.sendWithReply]]. If receiving a unmatched message, + * [[SparkException]] will be thrown and sent to `onError`. + */ + def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case _ => context.sendFailure(new SparkException(self + " won't reply anything")) + } + + /** + * Call onError when any exception is thrown during handling messages. + * + * @param cause + */ + def onError(cause: Throwable): Unit = { + // By default, throw e and let RpcEnv handle it + throw cause + } + + /** + * Invoked before [[RpcEndpoint]] starts to handle any message. + */ + def onStart(): Unit = { + // By default, do nothing. + } + + /** + * Invoked when [[RpcEndpoint]] is stopping. + */ + def onStop(): Unit = { + // By default, do nothing. + } + + /** + * Invoked when `remoteAddress` is connected to the current node. + */ + def onConnected(remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * Invoked when `remoteAddress` is lost. + */ + def onDisconnected(remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * Invoked when some network error happens in the connection between the current node and + * `remoteAddress`. + */ + def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * A convenient method to stop [[RpcEndpoint]]. + */ + final def stop(): Unit = { + val _self = self + if (_self != null) { + rpcEnv.stop(_self) + } + } +} + +/** + * A trait that requires RpcEnv thread-safely sending messages to it. + * + * Thread-safety means processing of one message happens before processing of the next message by + * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a + * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the + * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent. + * + * However, there is no guarantee that the same thread will be executing the same + * [[ThreadSafeRpcEndpoint]] for different messages. + */ +trait ThreadSafeRpcEndpoint extends RpcEndpoint + +/** + * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. + */ +private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) + extends Serializable with Logging { + + private[this] val maxRetries = RpcUtils.numRetries(conf) + private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf) + private[this] val defaultAskTimeout = RpcUtils.askTimeout(conf) + + /** + * return the address for the [[RpcEndpointRef]] + */ + def address: RpcAddress + + def name: String + + /** + * Sends a one-way asynchronous message. Fire-and-forget semantics. + */ + def send(message: Any): Unit + + /** + * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to + * receive the reply within a default timeout. + * + * This method only sends the message once and never retries. + */ + def sendWithReply[T: ClassTag](message: Any): Future[T] = + sendWithReply(message, defaultAskTimeout) + + /** + * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to + * receive the reply within the specified timeout. + * + * This method only sends the message once and never retries. + */ + def sendWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] + + /** + * Send a message to the corresponding [[RpcEndpoint]] and get its result within a default + * timeout, or throw a SparkException if this fails even after the default number of retries. + * The default `timeout` will be used in every trial of calling `sendWithReply`. Because this + * method retries, the message handling in the receiver side should be idempotent. + * + * Note: this is a blocking action which may cost a lot of time, so don't call it in an message + * loop of [[RpcEndpoint]]. + * + * @param message the message to send + * @tparam T type of the reply message + * @return the reply message from the corresponding [[RpcEndpoint]] + */ + def askWithReply[T: ClassTag](message: Any): T = askWithReply(message, defaultAskTimeout) + + /** + * Send a message to the corresponding [[RpcEndpoint.receive]] and get its result within a + * specified timeout, throw a SparkException if this fails even after the specified number of + * retries. `timeout` will be used in every trial of calling `sendWithReply`. Because this method + * retries, the message handling in the receiver side should be idempotent. + * + * Note: this is a blocking action which may cost a lot of time, so don't call it in an message + * loop of [[RpcEndpoint]]. + * + * @param message the message to send + * @param timeout the timeout duration + * @tparam T type of the reply message + * @return the reply message from the corresponding [[RpcEndpoint]] + */ + def askWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): T = { + // TODO: Consider removing multiple attempts + var attempts = 0 + var lastException: Exception = null + while (attempts < maxRetries) { + attempts += 1 + try { + val future = sendWithReply[T](message, timeout) + val result = Await.result(future, timeout) + if (result == null) { + throw new SparkException("Actor returned null") + } + return result + } catch { + case ie: InterruptedException => throw ie + case e: Exception => + lastException = e + logWarning(s"Error sending message [message = $message] in $attempts attempts", e) + } + Thread.sleep(retryWaitMs) + } + + throw new SparkException( + s"Error sending message [message = $message]", lastException) + } + +} + +/** + * Represent a host with a port + */ +private[spark] case class RpcAddress(host: String, port: Int) { + // TODO do we need to add the type of RpcEnv in the address? + + val hostPort: String = host + ":" + port + + override val toString: String = hostPort +} + +private[spark] object RpcAddress { + + /** + * Return the [[RpcAddress]] represented by `uri`. + */ + def fromURI(uri: URI): RpcAddress = { + RpcAddress(uri.getHost, uri.getPort) + } + + /** + * Return the [[RpcAddress]] represented by `uri`. + */ + def fromURIString(uri: String): RpcAddress = { + fromURI(new java.net.URI(uri)) + } + + def fromSparkURL(sparkUrl: String): RpcAddress = { + val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) + RpcAddress(host, port) + } +} + +/** + * A callback that [[RpcEndpoint]] can use it to send back a message or failure. It's thread-safe + * and can be called in any thread. + */ +private[spark] trait RpcCallContext { + + /** + * Reply a message to the sender. If the sender is [[RpcEndpoint]], its [[RpcEndpoint.receive]] + * will be called. + */ + def reply(response: Any): Unit + + /** + * Report a failure to the sender. + */ + def sendFailure(e: Throwable): Unit + + /** + * The sender of this message. + */ + def sender: RpcEndpointRef +} diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala new file mode 100644 index 000000000000..652e52f2b2e7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -0,0 +1,325 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.akka + +import java.util.concurrent.ConcurrentHashMap + +import scala.concurrent.Future +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.reflect.ClassTag +import scala.util.control.NonFatal + +import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Address} +import akka.event.Logging.Error +import akka.pattern.{ask => akkaAsk} +import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} +import org.apache.spark.{SparkException, Logging, SparkConf} +import org.apache.spark.rpc._ +import org.apache.spark.util.{ActorLogReceive, AkkaUtils} + +/** + * A RpcEnv implementation based on Akka. + * + * TODO Once we remove all usages of Akka in other place, we can move this file to a new project and + * remove Akka from the dependencies. + * + * @param actorSystem + * @param conf + * @param boundPort + */ +private[spark] class AkkaRpcEnv private[akka] ( + val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int) + extends RpcEnv(conf) with Logging { + + private val defaultAddress: RpcAddress = { + val address = actorSystem.asInstanceOf[ExtendedActorSystem].provider.getDefaultAddress + // In some test case, ActorSystem doesn't bind to any address. + // So just use some default value since they are only some unit tests + RpcAddress(address.host.getOrElse("localhost"), address.port.getOrElse(boundPort)) + } + + override val address: RpcAddress = defaultAddress + + /** + * A lookup table to search a [[RpcEndpointRef]] for a [[RpcEndpoint]]. We need it to make + * [[RpcEndpoint.self]] work. + */ + private val endpointToRef = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]() + + /** + * Need this map to remove `RpcEndpoint` from `endpointToRef` via a `RpcEndpointRef` + */ + private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, RpcEndpoint]() + + private def registerEndpoint(endpoint: RpcEndpoint, endpointRef: RpcEndpointRef): Unit = { + endpointToRef.put(endpoint, endpointRef) + refToEndpoint.put(endpointRef, endpoint) + } + + private def unregisterEndpoint(endpointRef: RpcEndpointRef): Unit = { + val endpoint = refToEndpoint.remove(endpointRef) + if (endpoint != null) { + endpointToRef.remove(endpoint) + } + } + + /** + * Retrieve the [[RpcEndpointRef]] of `endpoint`. + */ + override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointToRef.get(endpoint) + + override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { + @volatile var endpointRef: AkkaRpcEndpointRef = null + // Use lazy because the Actor needs to use `endpointRef`. + // So `actorRef` should be created after assigning `endpointRef`. + lazy val actorRef = actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { + + assert(endpointRef != null) + + override def preStart(): Unit = { + // Listen for remote client network events + context.system.eventStream.subscribe(self, classOf[AssociationEvent]) + safelyCall(endpoint) { + endpoint.onStart() + } + } + + override def receiveWithLogging: Receive = { + case AssociatedEvent(_, remoteAddress, _) => + safelyCall(endpoint) { + endpoint.onConnected(akkaAddressToRpcAddress(remoteAddress)) + } + + case DisassociatedEvent(_, remoteAddress, _) => + safelyCall(endpoint) { + endpoint.onDisconnected(akkaAddressToRpcAddress(remoteAddress)) + } + + case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) => + safelyCall(endpoint) { + endpoint.onNetworkError(cause, akkaAddressToRpcAddress(remoteAddress)) + } + + case e: AssociationEvent => + // TODO ignore? + + case m: AkkaMessage => + logDebug(s"Received RPC message: $m") + safelyCall(endpoint) { + processMessage(endpoint, m, sender) + } + + case AkkaFailure(e) => + safelyCall(endpoint) { + throw e + } + + case message: Any => { + logWarning(s"Unknown message: $message") + } + + } + + override def postStop(): Unit = { + unregisterEndpoint(endpoint.self) + safelyCall(endpoint) { + endpoint.onStop() + } + } + + }), name = name) + endpointRef = new AkkaRpcEndpointRef(defaultAddress, actorRef, conf, initInConstructor = false) + registerEndpoint(endpoint, endpointRef) + // Now actorRef can be created safely + endpointRef.init() + endpointRef + } + + private def processMessage(endpoint: RpcEndpoint, m: AkkaMessage, _sender: ActorRef): Unit = { + val message = m.message + val needReply = m.needReply + val pf: PartialFunction[Any, Unit] = + if (needReply) { + endpoint.receiveAndReply(new RpcCallContext { + override def sendFailure(e: Throwable): Unit = { + _sender ! AkkaFailure(e) + } + + override def reply(response: Any): Unit = { + _sender ! AkkaMessage(response, false) + } + + // Some RpcEndpoints need to know the sender's address + override val sender: RpcEndpointRef = + new AkkaRpcEndpointRef(defaultAddress, _sender, conf) + }) + } else { + endpoint.receive + } + try { + pf.applyOrElse[Any, Unit](message, { message => + throw new SparkException(s"Unmatched message $message from ${_sender}") + }) + } catch { + case NonFatal(e) => + if (needReply) { + // If the sender asks a reply, we should send the error back to the sender + _sender ! AkkaFailure(e) + } else { + throw e + } + } + } + + /** + * Run `action` safely to avoid to crash the thread. If any non-fatal exception happens, it will + * call `endpoint.onError`. If `endpoint.onError` throws any non-fatal exception, just log it. + */ + private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = { + try { + action + } catch { + case NonFatal(e) => { + try { + endpoint.onError(e) + } catch { + case NonFatal(e) => logError(s"Ignore error: ${e.getMessage}", e) + } + } + } + } + + private def akkaAddressToRpcAddress(address: Address): RpcAddress = { + RpcAddress(address.host.getOrElse(defaultAddress.host), + address.port.getOrElse(defaultAddress.port)) + } + + override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { + import actorSystem.dispatcher + actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout). + map(new AkkaRpcEndpointRef(defaultAddress, _, conf)) + } + + override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = { + AkkaUtils.address( + AkkaUtils.protocol(actorSystem), systemName, address.host, address.port, endpointName) + } + + override def shutdown(): Unit = { + actorSystem.shutdown() + } + + override def stop(endpoint: RpcEndpointRef): Unit = { + require(endpoint.isInstanceOf[AkkaRpcEndpointRef]) + actorSystem.stop(endpoint.asInstanceOf[AkkaRpcEndpointRef].actorRef) + } + + override def awaitTermination(): Unit = { + actorSystem.awaitTermination() + } + + override def toString: String = s"${getClass.getSimpleName}($actorSystem)" +} + +private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { + + def create(config: RpcEnvConfig): RpcEnv = { + val (actorSystem, boundPort) = AkkaUtils.createActorSystem( + config.name, config.host, config.port, config.conf, config.securityManager) + actorSystem.actorOf(Props(classOf[ErrorMonitor]), "ErrorMonitor") + new AkkaRpcEnv(actorSystem, config.conf, boundPort) + } +} + +/** + * Monitor errors reported by Akka and log them. + */ +private[akka] class ErrorMonitor extends Actor with ActorLogReceive with Logging { + + override def preStart(): Unit = { + context.system.eventStream.subscribe(self, classOf[Error]) + } + + override def receiveWithLogging: Actor.Receive = { + case Error(cause: Throwable, _, _, message: String) => logError(message, cause) + } +} + +private[akka] class AkkaRpcEndpointRef( + @transient defaultAddress: RpcAddress, + @transient _actorRef: => ActorRef, + @transient conf: SparkConf, + @transient initInConstructor: Boolean = true) + extends RpcEndpointRef(conf) with Logging { + + lazy val actorRef = _actorRef + + override lazy val address: RpcAddress = { + val akkaAddress = actorRef.path.address + RpcAddress(akkaAddress.host.getOrElse(defaultAddress.host), + akkaAddress.port.getOrElse(defaultAddress.port)) + } + + override lazy val name: String = actorRef.path.name + + private[akka] def init(): Unit = { + // Initialize the lazy vals + actorRef + address + name + } + + if (initInConstructor) { + init() + } + + override def send(message: Any): Unit = { + actorRef ! AkkaMessage(message, false) + } + + override def sendWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { + import scala.concurrent.ExecutionContext.Implicits.global + actorRef.ask(AkkaMessage(message, true))(timeout).flatMap { + case msg @ AkkaMessage(message, reply) => + if (reply) { + logError(s"Receive $msg but the sender cannot reply") + Future.failed(new SparkException(s"Receive $msg but the sender cannot reply")) + } else { + Future.successful(message) + } + case AkkaFailure(e) => + Future.failed(e) + }.mapTo[T] + } + + override def toString: String = s"${getClass.getSimpleName}($actorRef)" + +} + +/** + * A wrapper to `message` so that the receiver knows if the sender expects a reply. + * @param message + * @param needReply if the sender expects a reply message + */ +private[akka] case class AkkaMessage(message: Any, needReply: Boolean) + +/** + * A reply with the failure error from the receiver to the sender + */ +private[akka] case class AkkaFailure(e: Throwable) diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index fa83372bb4d1..e0edd7d4ae96 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -39,8 +39,11 @@ class AccumulableInfo ( } object AccumulableInfo { - def apply(id: Long, name: String, update: Option[String], value: String) = + def apply(id: Long, name: String, update: Option[String], value: String): AccumulableInfo = { new AccumulableInfo(id, name, update, value) + } - def apply(id: Long, name: String, value: String) = new AccumulableInfo(id, name, None, value) + def apply(id: Long, name: String, value: String): AccumulableInfo = { + new AccumulableInfo(id, name, None, value) + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala index b755d8fb1575..50a69379412d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala @@ -27,7 +27,7 @@ import org.apache.spark.util.CallSite */ private[spark] class ActiveJob( val jobId: Int, - val finalStage: Stage, + val finalStage: ResultStage, val func: (TaskContext, Iterator[_]) => _, val partitions: Array[Int], val callSite: CallSite, diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 1cfe98673773..8c4bff4e83af 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -19,26 +19,22 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.util.Properties -import java.util.concurrent.{TimeUnit, Executors} +import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack} -import scala.concurrent.Await import scala.concurrent.duration._ +import scala.language.existentials import scala.language.postfixOps -import scala.reflect.ClassTag import scala.util.control.NonFatal -import akka.pattern.ask -import akka.util.Timeout - import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.storage._ -import org.apache.spark.util.{CallSite, EventLoop, SystemClock, Clock, Utils} +import org.apache.spark.util._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat /** @@ -54,6 +50,10 @@ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task * a small number of times before cancelling the whole stage. * + * Here's a checklist to use when making or reviewing changes to this class: + * + * - When adding a new data structure, update `DAGSchedulerSuite.assertDataStructuresEmpty` to + * include the new structure. This will help to catch memory leaks. */ private[spark] class DAGScheduler( @@ -63,7 +63,7 @@ class DAGScheduler( mapOutputTracker: MapOutputTrackerMaster, blockManagerMaster: BlockManagerMaster, env: SparkEnv, - clock: Clock = SystemClock) + clock: Clock = new SystemClock()) extends Logging { def this(sc: SparkContext, taskScheduler: TaskScheduler) = { @@ -84,7 +84,7 @@ class DAGScheduler( private[scheduler] val jobIdToStageIds = new HashMap[Int, HashSet[Int]] private[scheduler] val stageIdToStage = new HashMap[Int, Stage] - private[scheduler] val shuffleToMapStage = new HashMap[Int, Stage] + private[scheduler] val shuffleToMapStage = new HashMap[Int, ShuffleMapStage] private[scheduler] val jobIdToActiveJob = new HashMap[Int, ActiveJob] // Stages we need to run whose parents aren't done @@ -98,8 +98,14 @@ class DAGScheduler( private[scheduler] val activeJobs = new HashSet[ActiveJob] - // Contains the locations that each RDD's partitions are cached on - private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]] + /** + * Contains the locations that each RDD's partitions are cached on. This map's keys are RDD ids + * and its values are arrays indexed by partition numbers. Each array value is the set of + * locations where that RDD partition is cached. + * + * All accesses to this map should be guarded by synchronizing on it (see SPARK-4454). + */ + private val cacheLocs = new HashMap[Int, Seq[Seq[TaskLocation]]] // For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with // every task. When we detect a node failing, we note the current epoch number and failed @@ -109,6 +115,8 @@ class DAGScheduler( // stray messages to detect. private val failedEpoch = new HashMap[String, Long] + private [scheduler] val outputCommitCoordinator = env.outputCommitCoordinator + // A closure serializer that we reuse. // This is only safe because DAGScheduler runs in a single thread. private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() @@ -121,7 +129,7 @@ class DAGScheduler( private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false) private val messageScheduler = - Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("dag-scheduler-message")) + ThreadUtils.newDaemonSingleThreadScheduledExecutor("dag-scheduler-message") private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) @@ -143,7 +151,7 @@ class DAGScheduler( result: Any, accumUpdates: Map[Long, Any], taskInfo: TaskInfo, - taskMetrics: TaskMetrics) { + taskMetrics: TaskMetrics): Unit = { eventProcessLoop.post( CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)) } @@ -158,41 +166,40 @@ class DAGScheduler( taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics) blockManagerId: BlockManagerId): Boolean = { listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) - implicit val timeout = Timeout(600 seconds) - - Await.result( - blockManagerMaster.driverActor ? BlockManagerHeartbeat(blockManagerId), - timeout.duration).asInstanceOf[Boolean] + blockManagerMaster.driverEndpoint.askWithReply[Boolean]( + BlockManagerHeartbeat(blockManagerId), 600 seconds) } // Called by TaskScheduler when an executor fails. - def executorLost(execId: String) { + def executorLost(execId: String): Unit = { eventProcessLoop.post(ExecutorLost(execId)) } // Called by TaskScheduler when a host is added - def executorAdded(execId: String, host: String) { + def executorAdded(execId: String, host: String): Unit = { eventProcessLoop.post(ExecutorAdded(execId, host)) } // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or // cancellation of the job itself. - def taskSetFailed(taskSet: TaskSet, reason: String) { + def taskSetFailed(taskSet: TaskSet, reason: String): Unit = { eventProcessLoop.post(TaskSetFailed(taskSet, reason)) } - private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = { + private[scheduler] + def getCacheLocs(rdd: RDD[_]): Seq[Seq[TaskLocation]] = cacheLocs.synchronized { + // Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times if (!cacheLocs.contains(rdd.id)) { val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId] - val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster) - cacheLocs(rdd.id) = blockIds.map { id => - locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId)) + val locs: Seq[Seq[TaskLocation]] = blockManagerMaster.getLocations(blockIds).map { bms => + bms.map(bm => TaskLocation(bm.host, bm.executorId)) } + cacheLocs(rdd.id) = locs } cacheLocs(rdd.id) } - private def clearCacheLocs() { + private def clearCacheLocs(): Unit = cacheLocs.synchronized { cacheLocs.clear() } @@ -201,40 +208,65 @@ class DAGScheduler( * The jobId value passed in will be used if the stage doesn't already exist with * a lower jobId (jobId always increases across jobs.) */ - private def getShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): Stage = { + private def getShuffleMapStage( + shuffleDep: ShuffleDependency[_, _, _], + jobId: Int): ShuffleMapStage = { shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => // We are going to register ancestor shuffle dependencies registerShuffleDependencies(shuffleDep, jobId) // Then register current shuffleDep - val stage = - newOrUsedStage( - shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId, - shuffleDep.rdd.creationSite) + val stage = newOrUsedShuffleStage(shuffleDep, jobId) shuffleToMapStage(shuffleDep.shuffleId) = stage - + stage } } /** - * Create a Stage -- either directly for use as a result stage, or as part of the (re)-creation - * of a shuffle map stage in newOrUsedStage. The stage will be associated with the provided - * jobId. Production of shuffle map stages should always use newOrUsedStage, not newStage - * directly. + * Helper function to eliminate some code re-use when creating new stages. */ - private def newStage( + private def getParentStagesAndId(rdd: RDD[_], jobId: Int): (List[Stage], Int) = { + val parentStages = getParentStages(rdd, jobId) + val id = nextStageId.getAndIncrement() + (parentStages, id) + } + + /** + * Create a ShuffleMapStage as part of the (re)-creation of a shuffle map stage in + * newOrUsedShuffleStage. The stage will be associated with the provided jobId. + * Production of shuffle map stages should always use newOrUsedShuffleStage, not + * newShuffleMapStage directly. + */ + private def newShuffleMapStage( rdd: RDD[_], numTasks: Int, - shuffleDep: Option[ShuffleDependency[_, _, _]], + shuffleDep: ShuffleDependency[_, _, _], jobId: Int, - callSite: CallSite) - : Stage = - { - val parentStages = getParentStages(rdd, jobId) - val id = nextStageId.getAndIncrement() - val stage = new Stage(id, rdd, numTasks, shuffleDep, parentStages, jobId, callSite) + callSite: CallSite): ShuffleMapStage = { + val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId) + val stage: ShuffleMapStage = new ShuffleMapStage(id, rdd, numTasks, parentStages, + jobId, callSite, shuffleDep) + + stageIdToStage(id) = stage + updateJobIdStageIdMaps(jobId, stage) + stage + } + + /** + * Create a ResultStage -- either directly for use as a result stage, or as part of the + * (re)-creation of a shuffle map stage in newOrUsedShuffleStage. The stage will be associated + * with the provided jobId. + */ + private def newResultStage( + rdd: RDD[_], + numTasks: Int, + jobId: Int, + callSite: CallSite): ResultStage = { + val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId) + val stage: ResultStage = new ResultStage(id, rdd, numTasks, parentStages, jobId, callSite) + stageIdToStage(id) = stage updateJobIdStageIdMaps(jobId, stage) stage @@ -246,20 +278,17 @@ class DAGScheduler( * present in the MapOutputTracker, then the number and location of available outputs are * recovered from the MapOutputTracker */ - private def newOrUsedStage( - rdd: RDD[_], - numTasks: Int, + private def newOrUsedShuffleStage( shuffleDep: ShuffleDependency[_, _, _], - jobId: Int, - callSite: CallSite) - : Stage = - { - val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite) + jobId: Int): ShuffleMapStage = { + val rdd = shuffleDep.rdd + val numTasks = rdd.partitions.size + val stage = newShuffleMapStage(rdd, numTasks, shuffleDep, jobId, rdd.creationSite) if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) val locs = MapOutputTracker.deserializeMapStatuses(serLocs) for (i <- 0 until locs.size) { - stage.outputLocs(i) = Option(locs(i)).toList // locs(i) will be null if missing + stage.outputLocs(i) = Option(locs(i)).toList // locs(i) will be null if missing } stage.numAvailableOutputs = locs.count(_ != null) } else { @@ -297,26 +326,23 @@ class DAGScheduler( } } waitingForVisit.push(rdd) - while (!waitingForVisit.isEmpty) { + while (waitingForVisit.nonEmpty) { visit(waitingForVisit.pop()) } parents.toList } - // Find ancestor missing shuffle dependencies and register into shuffleToMapStage - private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], jobId: Int) = { + /** Find ancestor missing shuffle dependencies and register into shuffleToMapStage */ + private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], jobId: Int) { val parentsWithNoMapStage = getAncestorShuffleDependencies(shuffleDep.rdd) - while (!parentsWithNoMapStage.isEmpty) { + while (parentsWithNoMapStage.nonEmpty) { val currentShufDep = parentsWithNoMapStage.pop() - val stage = - newOrUsedStage( - currentShufDep.rdd, currentShufDep.rdd.partitions.size, currentShufDep, jobId, - currentShufDep.rdd.creationSite) + val stage = newOrUsedShuffleStage(currentShufDep, jobId) shuffleToMapStage(currentShufDep.shuffleId) = stage } } - // Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet + /** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */ private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = { val parents = new Stack[ShuffleDependency[_, _, _]] val visited = new HashSet[RDD[_]] @@ -342,7 +368,7 @@ class DAGScheduler( } waitingForVisit.push(rdd) - while (!waitingForVisit.isEmpty) { + while (waitingForVisit.nonEmpty) { visit(waitingForVisit.pop()) } parents @@ -373,7 +399,7 @@ class DAGScheduler( } } waitingForVisit.push(stage.rdd) - while (!waitingForVisit.isEmpty) { + while (waitingForVisit.nonEmpty) { visit(waitingForVisit.pop()) } missing.toList @@ -383,7 +409,7 @@ class DAGScheduler( * Registers the given jobId among the jobs that need the given stage and * all of that stage's ancestors. */ - private def updateJobIdStageIdMaps(jobId: Int, stage: Stage) { + private def updateJobIdStageIdMaps(jobId: Int, stage: Stage): Unit = { def updateJobIdStageIdMapsList(stages: List[Stage]) { if (stages.nonEmpty) { val s = stages.head @@ -403,7 +429,7 @@ class DAGScheduler( * * @param job The job whose state to cleanup. */ - private def cleanupStateForJobAndIndependentStages(job: ActiveJob) { + private def cleanupStateForJobAndIndependentStages(job: ActiveJob): Unit = { val registeredStages = jobIdToStageIds.get(job.jobId) if (registeredStages.isEmpty || registeredStages.get.isEmpty) { logError("No stages registered for job " + job.jobId) @@ -465,8 +491,7 @@ class DAGScheduler( callSite: CallSite, allowLocal: Boolean, resultHandler: (Int, U) => Unit, - properties: Properties = null): JobWaiter[U] = - { + properties: Properties): JobWaiter[U] = { // Check to make sure we are not launching a task on a partition that does not exist. val maxPartitions = rdd.partitions.length partitions.find(p => p >= maxPartitions || p < 0).foreach { p => @@ -488,22 +513,20 @@ class DAGScheduler( waiter } - def runJob[T, U: ClassTag]( + def runJob[T, U]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], callSite: CallSite, allowLocal: Boolean, resultHandler: (Int, U) => Unit, - properties: Properties = null) - { + properties: Properties): Unit = { val start = System.nanoTime val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties) waiter.awaitResult() match { - case JobSucceeded => { + case JobSucceeded => logInfo("Job %d finished: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) - } case JobFailed(exception: Exception) => logInfo("Job %d failed: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) @@ -517,9 +540,7 @@ class DAGScheduler( evaluator: ApproximateEvaluator[U, R], callSite: CallSite, timeout: Long, - properties: Properties = null) - : PartialResult[R] = - { + properties: Properties): PartialResult[R] = { val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.partitions.size).toArray @@ -532,12 +553,12 @@ class DAGScheduler( /** * Cancel a job that is running or waiting in the queue. */ - def cancelJob(jobId: Int) { + def cancelJob(jobId: Int): Unit = { logInfo("Asked to cancel job " + jobId) eventProcessLoop.post(JobCancelled(jobId)) } - def cancelJobGroup(groupId: String) { + def cancelJobGroup(groupId: String): Unit = { logInfo("Asked to cancel job group " + groupId) eventProcessLoop.post(JobGroupCancelled(groupId)) } @@ -545,7 +566,7 @@ class DAGScheduler( /** * Cancel all jobs that are running or waiting in the queue. */ - def cancelAllJobs() { + def cancelAllJobs(): Unit = { eventProcessLoop.post(AllJobsCancelled) } @@ -624,13 +645,13 @@ class DAGScheduler( val split = rdd.partitions(job.partitions(0)) val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0, attemptNumber = 0, runningLocally = true) - TaskContextHelper.setTaskContext(taskContext) + TaskContext.setTaskContext(taskContext) try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) job.listener.taskSucceeded(0, result) } finally { taskContext.markTaskCompleted() - TaskContextHelper.unset() + TaskContext.unset() } } catch { case e: Exception => @@ -648,7 +669,7 @@ class DAGScheduler( // completion events or stage abort stageIdToStage -= s.id jobIdToStageIds -= job.jobId - listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), jobResult)) + listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), jobResult)) } } @@ -666,7 +687,7 @@ class DAGScheduler( // Cancel all jobs belonging to this job group. // First finds all active jobs with this group id, and then kill stages for them. val activeInGroup = activeJobs.filter(activeJob => - groupId == activeJob.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) + Option(activeJob.properties).exists(_.get(SparkContext.SPARK_JOB_GROUP_ID) == groupId)) val jobIds = activeInGroup.map(_.jobId) jobIds.foreach(handleJobCancellation(_, "part of cancelled job group %s".format(groupId))) submitWaitingStages() @@ -693,11 +714,12 @@ class DAGScheduler( // cancelling the stages because if the DAG scheduler is stopped, the entire application // is in the process of getting stopped. val stageFailedMessage = "Stage cancelled because SparkContext was shut down" - runningStages.foreach { stage => - stage.latestInfo.stageFailed(stageFailedMessage) - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + // The `toArray` here is necessary so that we don't iterate over `runningStages` while + // mutating it. + runningStages.toArray.foreach { stage => + markStageAsFinished(stage, Some(stageFailedMessage)) } - listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), JobFailed(error))) + listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobFailed(error))) } } @@ -713,13 +735,12 @@ class DAGScheduler( allowLocal: Boolean, callSite: CallSite, listener: JobListener, - properties: Properties = null) - { - var finalStage: Stage = null + properties: Properties) { + var finalStage: ResultStage = null try { // New stage creation may throw an exception if, for example, jobs are run on a // HadoopRDD whose underlying HDFS files have been deleted. - finalStage = newStage(finalRDD, partitions.size, None, jobId, callSite) + finalStage = newResultStage(finalRDD, partitions.size, jobId, callSite) } catch { case e: Exception => logWarning("Creating new stage failed due to exception - job: " + jobId, e) @@ -736,7 +757,7 @@ class DAGScheduler( logInfo("Missing parents: " + getMissingParentStages(finalStage)) val shouldRunLocally = localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1 - val jobSubmissionTime = clock.getTime() + val jobSubmissionTime = clock.getTimeMillis() if (shouldRunLocally) { // Compute very short actions like first() or take() with no parent stages locally. listenerBus.post( @@ -764,7 +785,7 @@ class DAGScheduler( if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) { val missing = getMissingParentStages(stage).sortBy(_.id) logDebug("missing: " + missing) - if (missing == Nil) { + if (missing.isEmpty) { logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") submitMissingTasks(stage, jobId.get) } else { @@ -785,22 +806,19 @@ class DAGScheduler( // Get our pending tasks and remember them in our pendingTasks entry stage.pendingTasks.clear() + // First figure out the indexes of partition ids to compute. val partitionsToCompute: Seq[Int] = { - if (stage.isShuffleMap) { - (0 until stage.numPartitions).filter(id => stage.outputLocs(id) == Nil) - } else { - val job = stage.resultOfJob.get - (0 until job.numPartitions).filter(id => !job.finished(id)) + stage match { + case stage: ShuffleMapStage => + (0 until stage.numPartitions).filter(id => stage.outputLocs(id).isEmpty) + case stage: ResultStage => + val job = stage.resultOfJob.get + (0 until job.numPartitions).filter(id => !job.finished(id)) } } - val properties = if (jobIdToActiveJob.contains(jobId)) { - jobIdToActiveJob(stage.jobId).properties - } else { - // this stage will be assigned to "default" pool - null - } + val properties = jobIdToActiveJob.get(stage.jobId).map(_.properties).orNull runningStages += stage // SparkListenerStageSubmitted should be posted before testing whether tasks are @@ -808,6 +826,7 @@ class DAGScheduler( // will be posted, which should always come after a corresponding SparkListenerStageSubmitted // event. stage.latestInfo = StageInfo.fromStage(stage, Some(partitionsToCompute.size)) + outputCommitCoordinator.stageStart(stage.id) listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times. @@ -820,18 +839,21 @@ class DAGScheduler( try { // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep). // For ResultTask, serialize and broadcast (rdd, func). - val taskBinaryBytes: Array[Byte] = - if (stage.isShuffleMap) { - closureSerializer.serialize((stage.rdd, stage.shuffleDep.get) : AnyRef).array() - } else { - closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func) : AnyRef).array() - } + val taskBinaryBytes: Array[Byte] = stage match { + case stage: ShuffleMapStage => + closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef).array() + case stage: ResultStage => + closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func): AnyRef).array() + } + taskBinary = sc.broadcast(taskBinaryBytes) } catch { // In the case of a failure during serialization, abort the stage. case e: NotSerializableException => abortStage(stage, "Task not serializable: " + e.toString) runningStages -= stage + + // Abort execution return case NonFatal(e) => abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}") @@ -839,20 +861,22 @@ class DAGScheduler( return } - val tasks: Seq[Task[_]] = if (stage.isShuffleMap) { - partitionsToCompute.map { id => - val locs = getPreferredLocs(stage.rdd, id) - val part = stage.rdd.partitions(id) - new ShuffleMapTask(stage.id, taskBinary, part, locs) - } - } else { - val job = stage.resultOfJob.get - partitionsToCompute.map { id => - val p: Int = job.partitions(id) - val part = stage.rdd.partitions(p) - val locs = getPreferredLocs(stage.rdd, p) - new ResultTask(stage.id, taskBinary, part, locs, id) - } + val tasks: Seq[Task[_]] = stage match { + case stage: ShuffleMapStage => + partitionsToCompute.map { id => + val locs = getPreferredLocs(stage.rdd, id) + val part = stage.rdd.partitions(id) + new ShuffleMapTask(stage.id, taskBinary, part, locs) + } + + case stage: ResultStage => + val job = stage.resultOfJob.get + partitionsToCompute.map { id => + val p: Int = job.partitions(id) + val part = stage.rdd.partitions(p) + val locs = getPreferredLocs(stage.rdd, p) + new ResultTask(stage.id, taskBinary, part, locs, id) + } } if (tasks.size > 0) { @@ -861,14 +885,22 @@ class DAGScheduler( logDebug("New pending tasks: " + stage.pendingTasks) taskScheduler.submitTasks( new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) - stage.latestInfo.submissionTime = Some(clock.getTime()) + stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } else { - // Because we posted SparkListenerStageSubmitted earlier, we should post - // SparkListenerStageCompleted here in case there are no tasks to run. - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) - logDebug("Stage " + stage + " is actually done; %b %d %d".format( - stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) - runningStages -= stage + // Because we posted SparkListenerStageSubmitted earlier, we should mark + // the stage as completed here in case there are no tasks to run + markStageAsFinished(stage, None) + + val debugString = stage match { + case stage: ShuffleMapStage => + s"Stage ${stage} is actually done; " + + s"(available: ${stage.isAvailable}," + + s"available outputs: ${stage.numAvailableOutputs}," + + s"partitions: ${stage.numPartitions})" + case stage : ResultStage => + s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})" + } + logDebug(debugString) } } @@ -879,8 +911,16 @@ class DAGScheduler( if (event.accumUpdates != null) { try { Accumulators.add(event.accumUpdates) + event.accumUpdates.foreach { case (id, partialValue) => - val acc = Accumulators.originals(id).asInstanceOf[Accumulable[Any, Any]] + // In this instance, although the reference in Accumulators.originals is a WeakRef, + // it's guaranteed to exist since the event.accumUpdates Map exists + + val acc = Accumulators.originals(id).get match { + case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]] + case None => throw new NullPointerException("Non-existent reference to Accumulator") + } + // To avoid UI cruft, ignore cases where value wasn't updated if (acc.name.isDefined && partialValue != acc.zero) { val name = acc.name.get @@ -909,6 +949,9 @@ class DAGScheduler( val stageId = task.stageId val taskType = Utils.getFormattedClassName(task) + outputCommitCoordinator.taskCompleted(stageId, task.partitionId, + event.taskInfo.attempt, event.reason) + // The success case is dealt with separately below, since we need to compute accumulator // updates before posting. if (event.reason != Success) { @@ -921,23 +964,8 @@ class DAGScheduler( // Skip all the actions if the stage has been cancelled. return } - val stage = stageIdToStage(task.stageId) - def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None) = { - val serviceTime = stage.latestInfo.submissionTime match { - case Some(t) => "%.03f".format((clock.getTime() - t) / 1000.0) - case _ => "Unknown" - } - if (errorMessage.isEmpty) { - logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) - stage.latestInfo.completionTime = Some(clock.getTime()) - } else { - stage.latestInfo.stageFailed(errorMessage.get) - logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime)) - } - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) - runningStages -= stage - } + val stage = stageIdToStage(task.stageId) event.reason match { case Success => listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType, @@ -945,7 +973,10 @@ class DAGScheduler( stage.pendingTasks -= task task match { case rt: ResultTask[_, _] => - stage.resultOfJob match { + // Cast to ResultStage here because it's part of the ResultTask + // TODO Refactor this out to a function that accepts a ResultStage + val resultStage = stage.asInstanceOf[ResultStage] + resultStage.resultOfJob match { case Some(job) => if (!job.finished(rt.outputId)) { updateAccumulators(event) @@ -953,10 +984,10 @@ class DAGScheduler( job.numFinished += 1 // If the whole job has finished, remove it if (job.numFinished == job.numPartitions) { - markStageAsFinished(stage) + markStageAsFinished(resultStage) cleanupStateForJobAndIndependentStages(job) listenerBus.post( - SparkListenerJobEnd(job.jobId, clock.getTime(), JobSucceeded)) + SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded)) } // taskSucceeded runs some user code that might throw an exception. Make sure @@ -965,7 +996,7 @@ class DAGScheduler( job.listener.taskSucceeded(rt.outputId, event.result) } catch { case e: Exception => - // TODO: Perhaps we want to mark the stage as failed? + // TODO: Perhaps we want to mark the resultStage as failed? job.listener.jobFailed(new SparkDriverExecutionException(e)) } } @@ -974,6 +1005,7 @@ class DAGScheduler( } case smt: ShuffleMapTask => + val shuffleStage = stage.asInstanceOf[ShuffleMapStage] updateAccumulators(event) val status = event.result.asInstanceOf[MapStatus] val execId = status.location.executorId @@ -981,50 +1013,54 @@ class DAGScheduler( if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) { logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId) } else { - stage.addOutputLoc(smt.partitionId, status) + shuffleStage.addOutputLoc(smt.partitionId, status) } - if (runningStages.contains(stage) && stage.pendingTasks.isEmpty) { - markStageAsFinished(stage) + if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) { + markStageAsFinished(shuffleStage) logInfo("looking for newly runnable stages") logInfo("running: " + runningStages) logInfo("waiting: " + waitingStages) logInfo("failed: " + failedStages) - if (stage.shuffleDep.isDefined) { - // We supply true to increment the epoch number here in case this is a - // recomputation of the map outputs. In that case, some nodes may have cached - // locations with holes (from when we detected the error) and will need the - // epoch incremented to refetch them. - // TODO: Only increment the epoch number if this is not the first time - // we registered these map outputs. - mapOutputTracker.registerMapOutputs( - stage.shuffleDep.get.shuffleId, - stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, - changeEpoch = true) - } + + // We supply true to increment the epoch number here in case this is a + // recomputation of the map outputs. In that case, some nodes may have cached + // locations with holes (from when we detected the error) and will need the + // epoch incremented to refetch them. + // TODO: Only increment the epoch number if this is not the first time + // we registered these map outputs. + mapOutputTracker.registerMapOutputs( + shuffleStage.shuffleDep.shuffleId, + shuffleStage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, + changeEpoch = true) + clearCacheLocs() - if (stage.outputLocs.exists(_ == Nil)) { - // Some tasks had failed; let's resubmit this stage + if (shuffleStage.outputLocs.contains(Nil)) { + // Some tasks had failed; let's resubmit this shuffleStage // TODO: Lower-level scheduler should also deal with this - logInfo("Resubmitting " + stage + " (" + stage.name + + logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name + ") because some of its tasks had failed: " + - stage.outputLocs.zipWithIndex.filter(_._1 == Nil).map(_._2).mkString(", ")) - submitStage(stage) + shuffleStage.outputLocs.zipWithIndex.filter(_._1.isEmpty) + .map(_._2).mkString(", ")) + submitStage(shuffleStage) } else { val newlyRunnable = new ArrayBuffer[Stage] - for (stage <- waitingStages) { - logInfo("Missing parents for " + stage + ": " + getMissingParentStages(stage)) + for (shuffleStage <- waitingStages) { + logInfo("Missing parents for " + shuffleStage + ": " + + getMissingParentStages(shuffleStage)) } - for (stage <- waitingStages if getMissingParentStages(stage) == Nil) { - newlyRunnable += stage + for (shuffleStage <- waitingStages if getMissingParentStages(shuffleStage).isEmpty) + { + newlyRunnable += shuffleStage } waitingStages --= newlyRunnable runningStages ++= newlyRunnable for { - stage <- newlyRunnable.sortBy(_.id) - jobId <- activeJobForStage(stage) + shuffleStage <- newlyRunnable.sortBy(_.id) + jobId <- activeJobForStage(shuffleStage) } { - logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable") - submitMissingTasks(stage, jobId) + logInfo("Submitting " + shuffleStage + " (" + + shuffleStage.rdd + "), which is now runnable") + submitMissingTasks(shuffleStage, jobId) } } } @@ -1045,7 +1081,6 @@ class DAGScheduler( logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + s"due to a fetch failure from $mapStage (${mapStage.name})") markStageAsFinished(failedStage, Some(failureMessage)) - runningStages -= failedStage } if (disallowStageRetryForTest) { @@ -1073,6 +1108,9 @@ class DAGScheduler( handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) } + case commitDenied: TaskCommitDenied => + // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits + case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) => // Do nothing here, left up to the TaskScheduler to decide how to handle user failures @@ -1158,6 +1196,26 @@ class DAGScheduler( submitWaitingStages() } + /** + * Marks a stage as finished and removes it from the list of running stages. + */ + private def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = { + val serviceTime = stage.latestInfo.submissionTime match { + case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) + case _ => "Unknown" + } + if (errorMessage.isEmpty) { + logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) + stage.latestInfo.completionTime = Some(clock.getTimeMillis()) + } else { + stage.latestInfo.stageFailed(errorMessage.get) + logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime)) + } + outputCommitCoordinator.stageEnd(stage.id) + listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + runningStages -= stage + } + /** * Aborts all jobs depending on a particular Stage. This is called in response to a task set * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. @@ -1169,7 +1227,7 @@ class DAGScheduler( } val dependentJobs: Seq[ActiveJob] = activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq - failedStage.latestInfo.completionTime = Some(clock.getTime()) + failedStage.latestInfo.completionTime = Some(clock.getTimeMillis()) for (job <- dependentJobs) { failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason") } @@ -1178,9 +1236,7 @@ class DAGScheduler( } } - /** - * Fails a job and all stages that are only used by that job, and cleans up relevant state. - */ + /** Fails a job and all stages that are only used by that job, and cleans up relevant state. */ private def failJobAndIndependentStages(job: ActiveJob, failureReason: String) { val error = new SparkException(failureReason) var ableToCancelStages = true @@ -1209,8 +1265,7 @@ class DAGScheduler( if (runningStages.contains(stage)) { try { // cancelTasks will fail if a SchedulerBackend does not implement killTask taskScheduler.cancelTasks(stageId, shouldInterruptThread) - stage.latestInfo.stageFailed(failureReason) - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + markStageAsFinished(stage, Some(failureReason)) } catch { case e: UnsupportedOperationException => logInfo(s"Could not cancel tasks for stage $stageId", e) @@ -1224,19 +1279,16 @@ class DAGScheduler( if (ableToCancelStages) { job.listener.jobFailed(error) cleanupStateForJobAndIndependentStages(job) - listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), JobFailed(error))) + listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobFailed(error))) } } - /** - * Return true if one of stage's ancestors is target. - */ + /** Return true if one of stage's ancestors is target. */ private def stageDependsOn(stage: Stage, target: Stage): Boolean = { if (stage == target) { return true } val visitedRdds = new HashSet[RDD[_]] - val visitedStages = new HashSet[Stage] // We are manually maintaining a stack here to prevent StackOverflowError // caused by recursively visiting val waitingForVisit = new Stack[RDD[_]] @@ -1248,7 +1300,6 @@ class DAGScheduler( case shufDep: ShuffleDependency[_, _, _] => val mapStage = getShuffleMapStage(shufDep, stage.jobId) if (!mapStage.isAvailable) { - visitedStages += mapStage waitingForVisit.push(mapStage.rdd) } // Otherwise there's no need to follow the dependency back case narrowDep: NarrowDependency[_] => @@ -1258,30 +1309,37 @@ class DAGScheduler( } } waitingForVisit.push(stage.rdd) - while (!waitingForVisit.isEmpty) { + while (waitingForVisit.nonEmpty) { visit(waitingForVisit.pop()) } visitedRdds.contains(target.rdd) } /** - * Synchronized method that might be called from other threads. + * Gets the locality information associated with a partition of a particular RDD. + * + * This method is thread-safe and is called from both DAGScheduler and SparkContext. + * * @param rdd whose partitions are to be looked at * @param partition to lookup locality information for * @return list of machines that are preferred by the partition */ private[spark] - def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = synchronized { + def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = { getPreferredLocsInternal(rdd, partition, new HashSet) } - /** Recursive implementation for getPreferredLocs. */ + /** + * Recursive implementation for getPreferredLocs. + * + * This method is thread-safe because it only accesses DAGScheduler state through thread-safe + * methods (getCacheLocs()); please be careful when modifying this method, because any new + * DAGScheduler state accessed by it may require additional synchronization. + */ private def getPreferredLocsInternal( rdd: RDD[_], partition: Int, - visited: HashSet[(RDD[_],Int)]) - : Seq[TaskLocation] = - { + visited: HashSet[(RDD[_],Int)]): Seq[TaskLocation] = { // If the partition has already been visited, no need to re-visit. // This avoids exponential path exploration. SPARK-695 if (!visited.add((rdd,partition))) { @@ -1290,12 +1348,12 @@ class DAGScheduler( } // If the partition is cached, return the cache locations val cached = getCacheLocs(rdd)(partition) - if (!cached.isEmpty) { + if (cached.nonEmpty) { return cached } // If the RDD has some placement preferences (as is the case for input RDDs), get those val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList - if (!rddPrefs.isEmpty) { + if (rddPrefs.nonEmpty) { return rddPrefs.map(TaskLocation(_)) } // If the RDD has narrow dependencies, pick the first partition of the first narrow dep @@ -1379,7 +1437,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler dagScheduler.sc.stop() } - override def onStop() { + override def onStop(): Unit = { // Cancel any active jobs in postStop hook dagScheduler.cleanUpAfterSchedulerStop() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 30075c172bdb..08e7727db2fd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -47,21 +47,30 @@ import org.apache.spark.util.{JsonProtocol, Utils} */ private[spark] class EventLoggingListener( appId: String, - logBaseDir: String, + logBaseDir: URI, sparkConf: SparkConf, hadoopConf: Configuration) extends SparkListener with Logging { import EventLoggingListener._ - def this(appId: String, logBaseDir: String, sparkConf: SparkConf) = + def this(appId: String, logBaseDir: URI, sparkConf: SparkConf) = this(appId, logBaseDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf)) private val shouldCompress = sparkConf.getBoolean("spark.eventLog.compress", false) private val shouldOverwrite = sparkConf.getBoolean("spark.eventLog.overwrite", false) private val testing = sparkConf.getBoolean("spark.eventLog.testing", false) private val outputBufferSize = sparkConf.getInt("spark.eventLog.buffer.kb", 100) * 1024 - private val fileSystem = Utils.getHadoopFileSystem(new URI(logBaseDir), hadoopConf) + private val fileSystem = Utils.getHadoopFileSystem(logBaseDir, hadoopConf) + private val compressionCodec = + if (shouldCompress) { + Some(CompressionCodec.createCodec(sparkConf)) + } else { + None + } + private val compressionCodecName = compressionCodec.map { c => + CompressionCodec.getShortName(c.getClass.getName) + } // Only defined if the file system scheme is not local private var hadoopDataStream: Option[FSDataOutputStream] = None @@ -80,13 +89,13 @@ private[spark] class EventLoggingListener( private[scheduler] val loggedEvents = new ArrayBuffer[JValue] // Visible for tests only. - private[scheduler] val logPath = getLogPath(logBaseDir, appId) + private[scheduler] val logPath = getLogPath(logBaseDir, appId, compressionCodecName) /** * Creates the log file in the configured log directory. */ def start() { - if (!fileSystem.isDirectory(new Path(logBaseDir))) { + if (!fileSystem.getFileStatus(new Path(logBaseDir)).isDir) { throw new IllegalArgumentException(s"Log directory $logBaseDir does not exist.") } @@ -111,19 +120,19 @@ private[spark] class EventLoggingListener( hadoopDataStream.get } - val compressionCodec = - if (shouldCompress) { - Some(CompressionCodec.createCodec(sparkConf)) - } else { - None - } - - fileSystem.setPermission(path, LOG_FILE_PERMISSIONS) - val logStream = initEventLog(new BufferedOutputStream(dstream, outputBufferSize), - compressionCodec) - writer = Some(new PrintWriter(logStream)) + try { + val cstream = compressionCodec.map(_.compressedOutputStream(dstream)).getOrElse(dstream) + val bstream = new BufferedOutputStream(cstream, outputBufferSize) - logInfo("Logging events to %s".format(logPath)) + EventLoggingListener.initEventLog(bstream) + fileSystem.setPermission(path, LOG_FILE_PERMISSIONS) + writer = Some(new PrintWriter(bstream)) + logInfo("Logging events to %s".format(logPath)) + } catch { + case e: Exception => + dstream.close() + throw e + } } /** Log the event as JSON. */ @@ -140,47 +149,60 @@ private[spark] class EventLoggingListener( } // Events that do not trigger a flush - override def onStageSubmitted(event: SparkListenerStageSubmitted) = - logEvent(event) - override def onTaskStart(event: SparkListenerTaskStart) = - logEvent(event) - override def onTaskGettingResult(event: SparkListenerTaskGettingResult) = - logEvent(event) - override def onTaskEnd(event: SparkListenerTaskEnd) = - logEvent(event) - override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate) = - logEvent(event) + override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = logEvent(event) + + override def onTaskStart(event: SparkListenerTaskStart): Unit = logEvent(event) + + override def onTaskGettingResult(event: SparkListenerTaskGettingResult): Unit = logEvent(event) + + override def onTaskEnd(event: SparkListenerTaskEnd): Unit = logEvent(event) + + override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate): Unit = logEvent(event) // Events that trigger a flush - override def onStageCompleted(event: SparkListenerStageCompleted) = - logEvent(event, flushLogger = true) - override def onJobStart(event: SparkListenerJobStart) = + override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { logEvent(event, flushLogger = true) - override def onJobEnd(event: SparkListenerJobEnd) = - logEvent(event, flushLogger = true) - override def onBlockManagerAdded(event: SparkListenerBlockManagerAdded) = + } + + override def onJobStart(event: SparkListenerJobStart): Unit = logEvent(event, flushLogger = true) + + override def onJobEnd(event: SparkListenerJobEnd): Unit = logEvent(event, flushLogger = true) + + override def onBlockManagerAdded(event: SparkListenerBlockManagerAdded): Unit = { logEvent(event, flushLogger = true) - override def onBlockManagerRemoved(event: SparkListenerBlockManagerRemoved) = + } + + override def onBlockManagerRemoved(event: SparkListenerBlockManagerRemoved): Unit = { logEvent(event, flushLogger = true) - override def onUnpersistRDD(event: SparkListenerUnpersistRDD) = + } + + override def onUnpersistRDD(event: SparkListenerUnpersistRDD): Unit = { logEvent(event, flushLogger = true) - override def onApplicationStart(event: SparkListenerApplicationStart) = + } + + override def onApplicationStart(event: SparkListenerApplicationStart): Unit = { logEvent(event, flushLogger = true) - override def onApplicationEnd(event: SparkListenerApplicationEnd) = + } + + override def onApplicationEnd(event: SparkListenerApplicationEnd): Unit = { logEvent(event, flushLogger = true) - override def onExecutorAdded(event: SparkListenerExecutorAdded) = + } + override def onExecutorAdded(event: SparkListenerExecutorAdded): Unit = { logEvent(event, flushLogger = true) - override def onExecutorRemoved(event: SparkListenerExecutorRemoved) = + } + + override def onExecutorRemoved(event: SparkListenerExecutorRemoved): Unit = { logEvent(event, flushLogger = true) + } // No-op because logging every update would be overkill - override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate) { } + override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } /** * Stop logging events. The event log file will be renamed so that it loses the * ".inprogress" suffix. */ - def stop() = { + def stop(): Unit = { writer.foreach(_.close()) val target = new Path(logPath) @@ -201,77 +223,57 @@ private[spark] object EventLoggingListener extends Logging { // Suffix applied to the names of files still being written by applications. val IN_PROGRESS = ".inprogress" val DEFAULT_LOG_DIR = "/tmp/spark-events" + val SPARK_VERSION_KEY = "SPARK_VERSION" + val COMPRESSION_CODEC_KEY = "COMPRESSION_CODEC" private val LOG_FILE_PERMISSIONS = new FsPermission(Integer.parseInt("770", 8).toShort) - // Marker for the end of header data in a log file. After this marker, log data, potentially - // compressed, will be found. - private val HEADER_END_MARKER = "=== LOG_HEADER_END ===" - - // To avoid corrupted files causing the heap to fill up. Value is arbitrary. - private val MAX_HEADER_LINE_LENGTH = 4096 - // A cache for compression codecs to avoid creating the same codec many times private val codecMap = new mutable.HashMap[String, CompressionCodec] /** - * Write metadata about the event log to the given stream. - * - * The header is a serialized version of a map, except it does not use Java serialization to - * avoid incompatibilities between different JDKs. It writes one map entry per line, in - * "key=value" format. - * - * The very last entry in the header is the `HEADER_END_MARKER` marker, so that the parsing code - * can know when to stop. + * Write metadata about an event log to the given stream. + * The metadata is encoded in the first line of the event log as JSON. * - * The format needs to be kept in sync with the openEventLog() method below. Also, it cannot - * change in new Spark versions without some other way of detecting the change (like some - * metadata encoded in the file name). - * - * @param logStream Raw output stream to the even log file. - * @param compressionCodec Optional compression codec to use. - * @return A stream where to write event log data. This may be a wrapper around the original - * stream (for example, when compression is enabled). + * @param logStream Raw output stream to the event log file. */ - def initEventLog( - logStream: OutputStream, - compressionCodec: Option[CompressionCodec]): OutputStream = { - val meta = mutable.HashMap(("version" -> SPARK_VERSION)) - compressionCodec.foreach { codec => - meta += ("compressionCodec" -> codec.getClass().getName()) - } - - def write(entry: String) = { - val bytes = entry.getBytes(Charsets.UTF_8) - if (bytes.length > MAX_HEADER_LINE_LENGTH) { - throw new IOException(s"Header entry too long: ${entry}") - } - logStream.write(bytes, 0, bytes.length) - } - - meta.foreach { case (k, v) => write(s"$k=$v\n") } - write(s"$HEADER_END_MARKER\n") - compressionCodec.map(_.compressedOutputStream(logStream)).getOrElse(logStream) + def initEventLog(logStream: OutputStream): Unit = { + val metadata = SparkListenerLogStart(SPARK_VERSION) + val metadataJson = compact(JsonProtocol.logStartToJson(metadata)) + "\n" + logStream.write(metadataJson.getBytes(Charsets.UTF_8)) } /** * Return a file-system-safe path to the log file for the given application. * + * Note that because we currently only create a single log file for each application, + * we must encode all the information needed to parse this event log in the file name + * instead of within the file itself. Otherwise, if the file is compressed, for instance, + * we won't know which codec to use to decompress the metadata needed to open the file in + * the first place. + * * @param logBaseDir Directory where the log file will be written. * @param appId A unique app ID. + * @param compressionCodecName Name to identify the codec used to compress the contents + * of the log, or None if compression is not enabled. * @return A path which consists of file-system-safe characters. */ - def getLogPath(logBaseDir: String, appId: String): String = { - val name = appId.replaceAll("[ :/]", "-").replaceAll("[${}'\"]", "_").toLowerCase - Utils.resolveURI(logBaseDir) + "/" + name.stripSuffix("/") + def getLogPath( + logBaseDir: URI, + appId: String, + compressionCodecName: Option[String] = None): String = { + val sanitizedAppId = appId.replaceAll("[ :/]", "-").replaceAll("[.${}'\"]", "_").toLowerCase + // e.g. app_123, app_123.lzf + val logName = sanitizedAppId + compressionCodecName.map { "." + _ }.getOrElse("") + logBaseDir.toString.stripSuffix("/") + "/" + logName } /** - * Opens an event log file and returns an input stream to the event data. + * Opens an event log file and returns an input stream that contains the event data. * - * @return 2-tuple (event input stream, Spark version of event data) + * @return input stream that holds one JSON record per line. */ - def openEventLog(log: Path, fs: FileSystem): (InputStream, String) = { + def openEventLog(log: Path, fs: FileSystem): InputStream = { // It's not clear whether FileSystem.open() throws FileNotFoundException or just plain // IOException when a file does not exist, so try our best to throw a proper exception. if (!fs.exists(log)) { @@ -279,52 +281,17 @@ private[spark] object EventLoggingListener extends Logging { } val in = new BufferedInputStream(fs.open(log)) - // Read a single line from the input stream without buffering. - // We cannot use BufferedReader because we must avoid reading - // beyond the end of the header, after which the content of the - // file may be compressed. - def readLine(): String = { - val bytes = new ByteArrayOutputStream() - var next = in.read() - var count = 0 - while (next != '\n') { - if (next == -1) { - throw new IOException("Unexpected end of file.") - } - bytes.write(next) - count = count + 1 - if (count > MAX_HEADER_LINE_LENGTH) { - throw new IOException("Maximum header line length exceeded.") - } - next = in.read() - } - new String(bytes.toByteArray(), Charsets.UTF_8) + + // Compression codec is encoded as an extension, e.g. app_123.lzf + // Since we sanitize the app ID to not include periods, it is safe to split on it + val logName = log.getName.stripSuffix(IN_PROGRESS) + val codecName: Option[String] = logName.split("\\.").tail.lastOption + val codec = codecName.map { c => + codecMap.getOrElseUpdate(c, CompressionCodec.createCodec(new SparkConf, c)) } - // Parse the header metadata in the form of k=v pairs - // This assumes that every line before the header end marker follows this format try { - val meta = new mutable.HashMap[String, String]() - var foundEndMarker = false - while (!foundEndMarker) { - readLine() match { - case HEADER_END_MARKER => - foundEndMarker = true - case entry => - val prop = entry.split("=", 2) - if (prop.length != 2) { - throw new IllegalArgumentException("Invalid metadata in log file.") - } - meta += (prop(0) -> prop(1)) - } - } - - val sparkVersion = meta.get("version").getOrElse( - throw new IllegalArgumentException("Missing Spark version in log metadata.")) - val codec = meta.get("compressionCodec").map { codecName => - codecMap.getOrElseUpdate(codecName, CompressionCodec.createCodec(new SparkConf, codecName)) - } - (codec.map(_.compressedInputStream(in)).getOrElse(in), sparkVersion) + codec.map(_.compressedInputStream(in)).getOrElse(in) } catch { case e: Exception => in.close() diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 3bb54855bae4..e55b76c36cc5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -57,7 +57,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener private val stageIdToJobId = new HashMap[Int, Int] private val jobIdToStageIds = new HashMap[Int, Seq[Int]] private val dateFormat = new ThreadLocal[SimpleDateFormat]() { - override def initialValue() = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") } createLogDir() @@ -169,7 +169,8 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched + " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched + " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime + - " REMOTE_BYTES_READ=" + metrics.remoteBytesRead + " REMOTE_BYTES_READ=" + metrics.remoteBytesRead + + " LOCAL_BYTES_READ=" + metrics.localBytesRead case None => "" } val writeMetrics = taskMetrics.shuffleWriteMetrics match { diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 29879b374b80..382b09422a4a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -34,7 +34,7 @@ private[spark] class JobWaiter[T]( @volatile private var _jobFinished = totalTasks == 0 - def jobFinished = _jobFinished + def jobFinished: Boolean = _jobFinished // If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero // partition RDDs), we set the jobResult directly to JobSucceeded. diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 36a6e6338faa..be23056e7d42 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -17,10 +17,9 @@ package org.apache.spark.scheduler -import java.util.concurrent.{LinkedBlockingQueue, Semaphore} +import java.util.concurrent.atomic.AtomicBoolean -import org.apache.spark.Logging -import org.apache.spark.util.Utils +import org.apache.spark.util.AsynchronousListenerBus /** * Asynchronously passes SparkListenerEvents to registered SparkListeners. @@ -29,113 +28,19 @@ import org.apache.spark.util.Utils * has started will events be actually propagated to all attached listeners. This listener bus * is stopped when it receives a SparkListenerShutdown event, which is posted using stop(). */ -private[spark] class LiveListenerBus extends SparkListenerBus with Logging { - - /* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than - * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */ - private val EVENT_QUEUE_CAPACITY = 10000 - private val eventQueue = new LinkedBlockingQueue[SparkListenerEvent](EVENT_QUEUE_CAPACITY) - private var queueFullErrorMessageLogged = false - private var started = false - - // A counter that represents the number of events produced and consumed in the queue - private val eventLock = new Semaphore(0) - - private val listenerThread = new Thread("SparkListenerBus") { - setDaemon(true) - override def run(): Unit = Utils.logUncaughtExceptions { - while (true) { - eventLock.acquire() - // Atomically remove and process this event - LiveListenerBus.this.synchronized { - val event = eventQueue.poll - if (event == SparkListenerShutdown) { - // Get out of the while loop and shutdown the daemon thread - return - } - Option(event).foreach(postToAll) - } - } - } - } - - /** - * Start sending events to attached listeners. - * - * This first sends out all buffered events posted before this listener bus has started, then - * listens for any additional events asynchronously while the listener bus is still running. - * This should only be called once. - */ - def start() { - if (started) { - throw new IllegalStateException("Listener bus already started!") +private[spark] class LiveListenerBus + extends AsynchronousListenerBus[SparkListener, SparkListenerEvent]("SparkListenerBus") + with SparkListenerBus { + + private val logDroppedEvent = new AtomicBoolean(false) + + override def onDropEvent(event: SparkListenerEvent): Unit = { + if (logDroppedEvent.compareAndSet(false, true)) { + // Only log the following message once to avoid duplicated annoying logs. + logError("Dropping SparkListenerEvent because no remaining room in event queue. " + + "This likely means one of the SparkListeners is too slow and cannot keep up with " + + "the rate at which tasks are being started by the scheduler.") } - listenerThread.start() - started = true } - def post(event: SparkListenerEvent) { - val eventAdded = eventQueue.offer(event) - if (eventAdded) { - eventLock.release() - } else { - logQueueFullErrorMessage() - } - } - - /** - * For testing only. Wait until there are no more events in the queue, or until the specified - * time has elapsed. Return true if the queue has emptied and false is the specified time - * elapsed before the queue emptied. - */ - def waitUntilEmpty(timeoutMillis: Int): Boolean = { - val finishTime = System.currentTimeMillis + timeoutMillis - while (!queueIsEmpty) { - if (System.currentTimeMillis > finishTime) { - return false - } - /* Sleep rather than using wait/notify, because this is used only for testing and - * wait/notify add overhead in the general case. */ - Thread.sleep(10) - } - true - } - - /** - * For testing only. Return whether the listener daemon thread is still alive. - */ - def listenerThreadIsAlive: Boolean = synchronized { listenerThread.isAlive } - - /** - * Return whether the event queue is empty. - * - * The use of synchronized here guarantees that all events that once belonged to this queue - * have already been processed by all attached listeners, if this returns true. - */ - def queueIsEmpty: Boolean = synchronized { eventQueue.isEmpty } - - /** - * Log an error message to indicate that the event queue is full. Do this only once. - */ - private def logQueueFullErrorMessage(): Unit = { - if (!queueFullErrorMessageLogged) { - if (listenerThread.isAlive) { - logError("Dropping SparkListenerEvent because no remaining room in event queue. " + - "This likely means one of the SparkListeners is too slow and cannot keep up with" + - "the rate at which tasks are being started by the scheduler.") - } else { - logError("SparkListenerBus thread is dead! This means SparkListenerEvents have not" + - "been (and will no longer be) propagated to listeners for some time.") - } - queueFullErrorMessageLogged = true - } - } - - def stop() { - if (!started) { - throw new IllegalStateException("Attempted to stop a listener bus that has not yet started!") - } - post(SparkListenerShutdown) - listenerThread.join() - } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala new file mode 100644 index 000000000000..7c184b1dcb30 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import scala.collection.mutable + +import org.apache.spark._ +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, RpcEndpoint} + +private sealed trait OutputCommitCoordinationMessage extends Serializable + +private case object StopCoordinator extends OutputCommitCoordinationMessage +private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttempt: Long) + +/** + * Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins" + * policy. + * + * OutputCommitCoordinator is instantiated in both the drivers and executors. On executors, it is + * configured with a reference to the driver's OutputCommitCoordinatorEndpoint, so requests to + * commit output will be forwarded to the driver's OutputCommitCoordinator. + * + * This class was introduced in SPARK-4879; see that JIRA issue (and the associated pull requests) + * for an extensive design discussion. + */ +private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { + + // Initialized by SparkEnv + var coordinatorRef: Option[RpcEndpointRef] = None + + private type StageId = Int + private type PartitionId = Long + private type TaskAttemptId = Long + + /** + * Map from active stages's id => partition id => task attempt with exclusive lock on committing + * output for that partition. + * + * Entries are added to the top-level map when stages start and are removed they finish + * (either successfully or unsuccessfully). + * + * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance. + */ + private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map() + private type CommittersByStageMap = mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptId]] + + /** + * Returns whether the OutputCommitCoordinator's internal data structures are all empty. + */ + def isEmpty: Boolean = { + authorizedCommittersByStage.isEmpty + } + + /** + * Called by tasks to ask whether they can commit their output to HDFS. + * + * If a task attempt has been authorized to commit, then all other attempts to commit the same + * task will be denied. If the authorized task attempt fails (e.g. due to its executor being + * lost), then a subsequent task attempt may be authorized to commit its output. + * + * @param stage the stage number + * @param partition the partition number + * @param attempt a unique identifier for this task attempt + * @return true if this task is authorized to commit, false otherwise + */ + def canCommit( + stage: StageId, + partition: PartitionId, + attempt: TaskAttemptId): Boolean = { + val msg = AskPermissionToCommitOutput(stage, partition, attempt) + coordinatorRef match { + case Some(endpointRef) => + endpointRef.askWithReply[Boolean](msg) + case None => + logError( + "canCommit called after coordinator was stopped (is SparkEnv shutdown in progress)?") + false + } + } + + // Called by DAGScheduler + private[scheduler] def stageStart(stage: StageId): Unit = synchronized { + authorizedCommittersByStage(stage) = mutable.HashMap[PartitionId, TaskAttemptId]() + } + + // Called by DAGScheduler + private[scheduler] def stageEnd(stage: StageId): Unit = synchronized { + authorizedCommittersByStage.remove(stage) + } + + // Called by DAGScheduler + private[scheduler] def taskCompleted( + stage: StageId, + partition: PartitionId, + attempt: TaskAttemptId, + reason: TaskEndReason): Unit = synchronized { + val authorizedCommitters = authorizedCommittersByStage.getOrElse(stage, { + logDebug(s"Ignoring task completion for completed stage") + return + }) + reason match { + case Success => + // The task output has been committed successfully + case denied: TaskCommitDenied => + logInfo( + s"Task was denied committing, stage: $stage, partition: $partition, attempt: $attempt") + case otherReason => + if (authorizedCommitters.get(partition).exists(_ == attempt)) { + logDebug(s"Authorized committer $attempt (stage=$stage, partition=$partition) failed;" + + s" clearing lock") + authorizedCommitters.remove(partition) + } + } + } + + def stop(): Unit = synchronized { + coordinatorRef.foreach(_ send StopCoordinator) + coordinatorRef = None + authorizedCommittersByStage.clear() + } + + // Marked private[scheduler] instead of private so this can be mocked in tests + private[scheduler] def handleAskPermissionToCommit( + stage: StageId, + partition: PartitionId, + attempt: TaskAttemptId): Boolean = synchronized { + authorizedCommittersByStage.get(stage) match { + case Some(authorizedCommitters) => + authorizedCommitters.get(partition) match { + case Some(existingCommitter) => + logDebug(s"Denying $attempt to commit for stage=$stage, partition=$partition; " + + s"existingCommitter = $existingCommitter") + false + case None => + logDebug(s"Authorizing $attempt to commit for stage=$stage, partition=$partition") + authorizedCommitters(partition) = attempt + true + } + case None => + logDebug(s"Stage $stage has completed, so not allowing task attempt $attempt to commit") + false + } + } +} + +private[spark] object OutputCommitCoordinator { + + // This actor is used only for RPC + private[spark] class OutputCommitCoordinatorEndpoint( + override val rpcEnv: RpcEnv, outputCommitCoordinator: OutputCommitCoordinator) + extends RpcEndpoint with Logging { + + override def receive: PartialFunction[Any, Unit] = { + case StopCoordinator => + logInfo("OutputCommitCoordinator stopped!") + stop() + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case AskPermissionToCommitOutput(stage, partition, taskAttempt) => + context.reply( + outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt)) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index 584f4e7789d1..86f357abb872 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -21,6 +21,7 @@ import java.io.{InputStream, IOException} import scala.io.Source +import com.fasterxml.jackson.core.JsonParseException import org.json4s.jackson.JsonMethods._ import org.apache.spark.Logging @@ -39,22 +40,40 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { * error is thrown by this method. * * @param logData Stream containing event log data. - * @param version Spark version that generated the events. + * @param sourceName Filename (or other source identifier) from whence @logData is being read + * @param maybeTruncated Indicate whether log file might be truncated (some abnormal situations + * encountered, log file might not finished writing) or not */ - def replay(logData: InputStream, version: String) { + def replay( + logData: InputStream, + sourceName: String, + maybeTruncated: Boolean = false): Unit = { var currentLine: String = null + var lineNumber: Int = 1 try { val lines = Source.fromInputStream(logData).getLines() - lines.foreach { line => - currentLine = line - postToAll(JsonProtocol.sparkEventFromJson(parse(line))) + while (lines.hasNext) { + currentLine = lines.next() + try { + postToAll(JsonProtocol.sparkEventFromJson(parse(currentLine))) + } catch { + case jpe: JsonParseException => + // We can only ignore exception from last line of the file that might be truncated + if (!maybeTruncated || lines.hasNext) { + throw jpe + } else { + logWarning(s"Got JsonParseException from log file $sourceName" + + s" at line $lineNumber, the file might not have finished writing cleanly.") + } + } + lineNumber += 1 } } catch { case ioe: IOException => throw ioe case e: Exception => - logError("Exception in parsing Spark event log.", e) - logError("Malformed line: %s\n".format(currentLine)) + logError(s"Exception parsing Spark event log: $sourceName", e) + logError(s"Malformed line #$lineNumber: $currentLine\n") } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala new file mode 100644 index 000000000000..c0f3d5a13d62 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.rdd.RDD +import org.apache.spark.util.CallSite + +/** + * The ResultStage represents the final stage in a job. + */ +private[spark] class ResultStage( + id: Int, + rdd: RDD[_], + numTasks: Int, + parents: List[Stage], + jobId: Int, + callSite: CallSite) + extends Stage(id, rdd, numTasks, parents, jobId, callSite) { + + // The active job for this result stage. Will be empty if the job has already finished + // (e.g., because the job was cancelled). + var resultOfJob: Option[ActiveJob] = None + + override def toString: String = "ResultStage " + id +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 4a9ff918afe2..c9a124113961 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -53,9 +53,11 @@ private[spark] class ResultTask[T, U]( override def runTask(context: TaskContext): U = { // Deserialize the RDD and the func using the broadcast variables. + val deserializeStartTime = System.currentTimeMillis() val ser = SparkEnv.get.closureSerializer.newInstance() val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)]( ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) + _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime metrics = Some(context.taskMetrics) func(context, rdd.iterator(partition, context)) @@ -64,5 +66,5 @@ private[spark] class ResultTask[T, U]( // This is only callable on the driver side. override def preferredLocations: Seq[TaskLocation] = preferredLocs - override def toString = "ResultTask(" + stageId + ", " + partitionId + ")" + override def toString: String = "ResultTask(" + stageId + ", " + partitionId + ")" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala new file mode 100644 index 000000000000..d02210743484 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.ShuffleDependency +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.CallSite + +/** + * The ShuffleMapStage represents the intermediate stages in a job. + */ +private[spark] class ShuffleMapStage( + id: Int, + rdd: RDD[_], + numTasks: Int, + parents: List[Stage], + jobId: Int, + callSite: CallSite, + val shuffleDep: ShuffleDependency[_, _, _]) + extends Stage(id, rdd, numTasks, parents, jobId, callSite) { + + override def toString: String = "ShuffleMapStage " + id + + var numAvailableOutputs: Long = 0 + + def isAvailable: Boolean = numAvailableOutputs == numPartitions + + val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) + + def addOutputLoc(partition: Int, status: MapStatus): Unit = { + val prevList = outputLocs(partition) + outputLocs(partition) = status :: prevList + if (prevList == Nil) { + numAvailableOutputs += 1 + } + } + + def removeOutputLoc(partition: Int, bmAddress: BlockManagerId): Unit = { + val prevList = outputLocs(partition) + val newList = prevList.filterNot(_.location == bmAddress) + outputLocs(partition) = newList + if (prevList != Nil && newList == Nil) { + numAvailableOutputs -= 1 + } + } + + /** + * Removes all shuffle outputs associated with this executor. Note that this will also remove + * outputs which are served by an external shuffle server (if one exists), as they are still + * registered with this execId. + */ + def removeOutputsOnExecutor(execId: String): Unit = { + var becameUnavailable = false + for (partition <- 0 until numPartitions) { + val prevList = outputLocs(partition) + val newList = prevList.filterNot(_.location.executorId == execId) + outputLocs(partition) = newList + if (prevList != Nil && newList == Nil) { + becameUnavailable = true + numAvailableOutputs -= 1 + } + } + if (becameUnavailable) { + logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( + this, execId, numAvailableOutputs, numPartitions, isAvailable)) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 79709089c0da..bd3dd23dfe1a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -33,7 +33,7 @@ import org.apache.spark.shuffle.ShuffleWriter * See [[org.apache.spark.scheduler.Task]] for more information. * * @param stageId id of the stage this task belongs to - * @param taskBinary broadcast version of of the RDD and the ShuffleDependency. Once deserialized, + * @param taskBinary broadcast version of the RDD and the ShuffleDependency. Once deserialized, * the type should be (RDD[_], ShuffleDependency[_, _, _]). * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling @@ -47,7 +47,7 @@ private[spark] class ShuffleMapTask( /** A constructor used only in test suites. This does not require passing in an RDD. */ def this(partitionId: Int) { - this(0, null, new Partition { override def index = 0 }, null) + this(0, null, new Partition { override def index: Int = 0 }, null) } @transient private val preferredLocs: Seq[TaskLocation] = { @@ -56,9 +56,11 @@ private[spark] class ShuffleMapTask( override def runTask(context: TaskContext): MapStatus = { // Deserialize the RDD using the broadcast variable. + val deserializeStartTime = System.currentTimeMillis() val ser = SparkEnv.get.closureSerializer.newInstance() val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])]( ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) + _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime metrics = Some(context.taskMetrics) var writer: ShuffleWriter[Any, Any] = null @@ -83,5 +85,5 @@ private[spark] class ShuffleMapTask( override def preferredLocations: Seq[TaskLocation] = preferredLocs - override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partitionId) + override def toString: String = "ShuffleMapTask(%d, %d)".format(stageId, partitionId) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index e5d1eb767e10..b711ff209af9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -91,11 +91,11 @@ case class SparkListenerBlockManagerRemoved(time: Long, blockManagerId: BlockMan case class SparkListenerUnpersistRDD(rddId: Int) extends SparkListenerEvent @DeveloperApi -case class SparkListenerExecutorAdded(executorId: String, executorInfo: ExecutorInfo) +case class SparkListenerExecutorAdded(time: Long, executorId: String, executorInfo: ExecutorInfo) extends SparkListenerEvent @DeveloperApi -case class SparkListenerExecutorRemoved(executorId: String) +case class SparkListenerExecutorRemoved(time: Long, executorId: String, reason: String) extends SparkListenerEvent /** @@ -116,9 +116,11 @@ case class SparkListenerApplicationStart(appName: String, appId: Option[String], @DeveloperApi case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent -/** An event used in the listener to shutdown the listener daemon thread. */ -private[spark] case object SparkListenerShutdown extends SparkListenerEvent - +/** + * An internal class that describes the metadata of an event log. + * This event is not meant to be posted to listeners downstream. + */ +private[spark] case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent /** * :: DeveloperApi :: @@ -298,7 +300,7 @@ private[spark] object StatsReportListener extends Logging { } def showDistribution(heading: String, dOpt: Option[Distribution], format:String) { - def f(d: Double) = format.format(d) + def f(d: Double): String = format.format(d) showDistribution(heading, dOpt, f _) } @@ -344,7 +346,7 @@ private[spark] object StatsReportListener extends Logging { /** * Reformat a time interval in milliseconds to a prettier format for output */ - def millisToString(ms: Long) = { + def millisToString(ms: Long): String = { val (size, units) = if (ms > hours) { (ms.toDouble / hours, "hours") diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index e700c6af542f..61e69ecc0838 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -17,78 +17,48 @@ package org.apache.spark.scheduler -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.Logging -import org.apache.spark.util.Utils +import org.apache.spark.util.ListenerBus /** - * A SparkListenerEvent bus that relays events to its listeners + * A [[SparkListenerEvent]] bus that relays [[SparkListenerEvent]]s to its listeners */ -private[spark] trait SparkListenerBus extends Logging { - - // SparkListeners attached to this event bus - protected val sparkListeners = new ArrayBuffer[SparkListener] - with mutable.SynchronizedBuffer[SparkListener] - - def addListener(listener: SparkListener) { - sparkListeners += listener - } +private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkListenerEvent] { - /** - * Post an event to all attached listeners. - * This does nothing if the event is SparkListenerShutdown. - */ - def postToAll(event: SparkListenerEvent) { + override def onPostEvent(listener: SparkListener, event: SparkListenerEvent): Unit = { event match { case stageSubmitted: SparkListenerStageSubmitted => - foreachListener(_.onStageSubmitted(stageSubmitted)) + listener.onStageSubmitted(stageSubmitted) case stageCompleted: SparkListenerStageCompleted => - foreachListener(_.onStageCompleted(stageCompleted)) + listener.onStageCompleted(stageCompleted) case jobStart: SparkListenerJobStart => - foreachListener(_.onJobStart(jobStart)) + listener.onJobStart(jobStart) case jobEnd: SparkListenerJobEnd => - foreachListener(_.onJobEnd(jobEnd)) + listener.onJobEnd(jobEnd) case taskStart: SparkListenerTaskStart => - foreachListener(_.onTaskStart(taskStart)) + listener.onTaskStart(taskStart) case taskGettingResult: SparkListenerTaskGettingResult => - foreachListener(_.onTaskGettingResult(taskGettingResult)) + listener.onTaskGettingResult(taskGettingResult) case taskEnd: SparkListenerTaskEnd => - foreachListener(_.onTaskEnd(taskEnd)) + listener.onTaskEnd(taskEnd) case environmentUpdate: SparkListenerEnvironmentUpdate => - foreachListener(_.onEnvironmentUpdate(environmentUpdate)) + listener.onEnvironmentUpdate(environmentUpdate) case blockManagerAdded: SparkListenerBlockManagerAdded => - foreachListener(_.onBlockManagerAdded(blockManagerAdded)) + listener.onBlockManagerAdded(blockManagerAdded) case blockManagerRemoved: SparkListenerBlockManagerRemoved => - foreachListener(_.onBlockManagerRemoved(blockManagerRemoved)) + listener.onBlockManagerRemoved(blockManagerRemoved) case unpersistRDD: SparkListenerUnpersistRDD => - foreachListener(_.onUnpersistRDD(unpersistRDD)) + listener.onUnpersistRDD(unpersistRDD) case applicationStart: SparkListenerApplicationStart => - foreachListener(_.onApplicationStart(applicationStart)) + listener.onApplicationStart(applicationStart) case applicationEnd: SparkListenerApplicationEnd => - foreachListener(_.onApplicationEnd(applicationEnd)) + listener.onApplicationEnd(applicationEnd) case metricsUpdate: SparkListenerExecutorMetricsUpdate => - foreachListener(_.onExecutorMetricsUpdate(metricsUpdate)) + listener.onExecutorMetricsUpdate(metricsUpdate) case executorAdded: SparkListenerExecutorAdded => - foreachListener(_.onExecutorAdded(executorAdded)) + listener.onExecutorAdded(executorAdded) case executorRemoved: SparkListenerExecutorRemoved => - foreachListener(_.onExecutorRemoved(executorRemoved)) - case SparkListenerShutdown => - } - } - - /** - * Apply the given function to all attached listeners, catching and logging any exception. - */ - private def foreachListener(f: SparkListener => Unit): Unit = { - sparkListeners.foreach { listener => - try { - f(listener) - } catch { - case e: Exception => - logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e) - } + listener.onExecutorRemoved(executorRemoved) + case logStart: SparkListenerLogStart => // ignore event log metadata } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index cc13f57a49b8..5d0ddb8377c3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -21,7 +21,6 @@ import scala.collection.mutable.HashSet import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.CallSite /** @@ -47,29 +46,23 @@ import org.apache.spark.util.CallSite * be updated for each attempt. * */ -private[spark] class Stage( +private[spark] abstract class Stage( val id: Int, val rdd: RDD[_], val numTasks: Int, - val shuffleDep: Option[ShuffleDependency[_, _, _]], // Output shuffle if stage is a map stage val parents: List[Stage], val jobId: Int, val callSite: CallSite) extends Logging { - val isShuffleMap = shuffleDep.isDefined val numPartitions = rdd.partitions.size - val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) - var numAvailableOutputs = 0 /** Set of jobs that this stage belongs to. */ val jobIds = new HashSet[Int] - /** For stages that are the final (consists of only ResultTasks), link to the ActiveJob. */ - var resultOfJob: Option[ActiveJob] = None var pendingTasks = new HashSet[Task[_]] - private var nextAttemptId = 0 + private var nextAttemptId: Int = 0 val name = callSite.shortForm val details = callSite.longForm @@ -77,53 +70,6 @@ private[spark] class Stage( /** Pointer to the latest [StageInfo] object, set by DAGScheduler. */ var latestInfo: StageInfo = StageInfo.fromStage(this) - def isAvailable: Boolean = { - if (!isShuffleMap) { - true - } else { - numAvailableOutputs == numPartitions - } - } - - def addOutputLoc(partition: Int, status: MapStatus) { - val prevList = outputLocs(partition) - outputLocs(partition) = status :: prevList - if (prevList == Nil) { - numAvailableOutputs += 1 - } - } - - def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location == bmAddress) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - numAvailableOutputs -= 1 - } - } - - /** - * Removes all shuffle outputs associated with this executor. Note that this will also remove - * outputs which are served by an external shuffle server (if one exists), as they are still - * registered with this execId. - */ - def removeOutputsOnExecutor(execId: String) { - var becameUnavailable = false - for (partition <- 0 until numPartitions) { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location.executorId == execId) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - becameUnavailable = true - numAvailableOutputs -= 1 - } - } - if (becameUnavailable) { - logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( - this, execId, numAvailableOutputs, numPartitions, isAvailable)) - } - } - /** Return a new attempt id, starting with 0. */ def newAttemptId(): Int = { val id = nextAttemptId @@ -133,11 +79,8 @@ private[spark] class Stage( def attemptId: Int = nextAttemptId - override def toString = "Stage " + id - - override def hashCode(): Int = id - - override def equals(other: Any): Boolean = other match { + override final def hashCode(): Int = id + override final def equals(other: Any): Boolean = other match { case stage: Stage => stage != null && stage.id == id case _ => false } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 847a4912eec1..b09b19e2ac9e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap -import org.apache.spark.{TaskContextHelper, TaskContextImpl, TaskContext} +import org.apache.spark.{TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.ByteBufferInputStream @@ -45,7 +45,7 @@ import org.apache.spark.util.Utils private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { /** - * Called by Executor to run this task. + * Called by [[Executor]] to run this task. * * @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext. * @param attemptNumber how many times this task has been attempted (0 for the first attempt) @@ -54,7 +54,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex final def run(taskAttemptId: Long, attemptNumber: Int): T = { context = new TaskContextImpl(stageId = stageId, partitionId = partitionId, taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false) - TaskContextHelper.setTaskContext(context) + TaskContext.setTaskContext(context) context.taskMetrics.setHostname(Utils.localHostName()) taskThread = Thread.currentThread() if (_killed) { @@ -64,7 +64,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex runTask(context) } finally { context.markTaskCompleted() - TaskContextHelper.unset() + TaskContext.unset() } } @@ -87,11 +87,18 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex // initialized when kill() is invoked. @volatile @transient private var _killed = false + protected var _executorDeserializeTime: Long = 0 + /** * Whether the task has been killed. */ def killed: Boolean = _killed + /** + * Returns the amount of time spent deserializing the RDD and function to be run. + */ + def executorDeserializeTime: Long = _executorDeserializeTime + /** * Kills a task by setting the interrupted flag to true. This relies on the upper level Spark * code and user code to properly handle the flag. This function should be idempotent so it can diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 6fa1f2c880f7..132a9ced7770 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -81,9 +81,11 @@ class TaskInfo( def status: String = { if (running) { - "RUNNING" - } else if (gettingResult) { - "GET RESULT" + if (gettingResult) { + "GET RESULT" + } else { + "RUNNING" + } } else if (failed) { "FAILED" } else if (successful) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala index 10c685f29d3a..da07ce2c6ea4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala @@ -29,23 +29,22 @@ private[spark] sealed trait TaskLocation { /** * A location that includes both a host and an executor id on that host. */ -private [spark] case class ExecutorCacheTaskLocation(override val host: String, - val executorId: String) extends TaskLocation { -} +private [spark] +case class ExecutorCacheTaskLocation(override val host: String, executorId: String) + extends TaskLocation /** * A location on a host. */ private [spark] case class HostTaskLocation(override val host: String) extends TaskLocation { - override def toString = host + override def toString: String = host } /** * A location on a host that is cached by HDFS. */ -private [spark] case class HDFSCacheTaskLocation(override val host: String) - extends TaskLocation { - override def toString = TaskLocation.inMemoryLocationTag + host +private [spark] case class HDFSCacheTaskLocation(override val host: String) extends TaskLocation { + override def toString: String = TaskLocation.inMemoryLocationTag + host } private[spark] object TaskLocation { @@ -54,14 +53,16 @@ private[spark] object TaskLocation { // confusion. See RFC 952 and RFC 1123 for information about the format of hostnames. val inMemoryLocationTag = "hdfs_cache_" - def apply(host: String, executorId: String) = new ExecutorCacheTaskLocation(host, executorId) + def apply(host: String, executorId: String): TaskLocation = { + new ExecutorCacheTaskLocation(host, executorId) + } /** * Create a TaskLocation from a string returned by getPreferredLocations. * These strings have the form [hostname] or hdfs_cache_[hostname], depending on whether the * location is cached. */ - def apply(str: String) = { + def apply(str: String): TaskLocation = { val hstr = str.stripPrefix(inMemoryLocationTag) if (hstr.equals(str)) { new HostTaskLocation(str) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 774f3d8cdb27..391827c1d215 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer +import java.util.concurrent.RejectedExecutionException import scala.language.existentials import scala.util.control.NonFatal @@ -25,7 +26,7 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * Runs a thread pool that deserializes and remotely fetches (if necessary) task results. @@ -34,7 +35,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul extends Logging { private val THREADS = sparkEnv.conf.getInt("spark.resultGetter.threads", 4) - private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool( + private val getTaskResultExecutor = ThreadUtils.newDaemonFixedThreadPool( THREADS, "task-result-getter") protected val serializer = new ThreadLocal[SerializerInstance] { @@ -95,25 +96,30 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState, serializedData: ByteBuffer) { var reason : TaskEndReason = UnknownReason - getTaskResultExecutor.execute(new Runnable { - override def run(): Unit = Utils.logUncaughtExceptions { - try { - if (serializedData != null && serializedData.limit() > 0) { - reason = serializer.get().deserialize[TaskEndReason]( - serializedData, Utils.getSparkClassLoader) + try { + getTaskResultExecutor.execute(new Runnable { + override def run(): Unit = Utils.logUncaughtExceptions { + try { + if (serializedData != null && serializedData.limit() > 0) { + reason = serializer.get().deserialize[TaskEndReason]( + serializedData, Utils.getSparkClassLoader) + } + } catch { + case cnd: ClassNotFoundException => + // Log an error but keep going here -- the task failed, so not catastrophic + // if we can't deserialize the reason. + val loader = Utils.getContextOrSparkClassLoader + logError( + "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) + case ex: Exception => {} } - } catch { - case cnd: ClassNotFoundException => - // Log an error but keep going here -- the task failed, so not catastrophic if we can't - // deserialize the reason. - val loader = Utils.getContextOrSparkClassLoader - logError( - "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) - case ex: Exception => {} + scheduler.handleFailedTask(taskSetManager, tid, taskState, reason) } - scheduler.handleFailedTask(taskSetManager, tid, taskState, reason) - } - }) + }) + } catch { + case e: RejectedExecutionException if sparkEnv.isStopped => + // ignore it + } } def stop() { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index f095915352b1..ed3418676e07 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -73,5 +73,9 @@ private[spark] trait TaskScheduler { * @return An application ID */ def applicationId(): String = appId - + + /** + * Process a lost executor + */ + def executorLost(executorId: String, reason: ExecutorLossReason): Unit } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 33a7aae5d3fc..13a52d836f32 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -62,10 +62,10 @@ private[spark] class TaskSchedulerImpl( val conf = sc.conf // How often to check for speculative tasks - val SPECULATION_INTERVAL = conf.getLong("spark.speculation.interval", 100) + val SPECULATION_INTERVAL_MS = conf.getTimeAsMs("spark.speculation.interval", "100ms") // Threshold above which we warn user initial TaskSet may be starved - val STARVATION_TIMEOUT = conf.getLong("spark.starvation.timeout", 15000) + val STARVATION_TIMEOUT_MS = conf.getTimeAsMs("spark.starvation.timeout", "15s") // CPUs to request per task val CPUS_PER_TASK = conf.getInt("spark.task.cpus", 1) @@ -142,11 +142,10 @@ private[spark] class TaskSchedulerImpl( if (!isLocal && conf.getBoolean("spark.speculation", false)) { logInfo("Starting speculative execution thread") - import sc.env.actorSystem.dispatcher - sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds, - SPECULATION_INTERVAL milliseconds) { - Utils.tryOrExit { checkSpeculatableTasks() } - } + sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL_MS milliseconds, + SPECULATION_INTERVAL_MS milliseconds) { + Utils.tryOrStopSparkContext(sc) { checkSpeculatableTasks() } + }(sc.env.actorSystem.dispatcher) } } @@ -158,7 +157,7 @@ private[spark] class TaskSchedulerImpl( val tasks = taskSet.tasks logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { - val manager = new TaskSetManager(this, taskSet, maxTaskFailures) + val manager = createTaskSetManager(taskSet, maxTaskFailures) activeTaskSets(taskSet.id) = manager schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) @@ -173,13 +172,20 @@ private[spark] class TaskSchedulerImpl( this.cancel() } } - }, STARVATION_TIMEOUT, STARVATION_TIMEOUT) + }, STARVATION_TIMEOUT_MS, STARVATION_TIMEOUT_MS) } hasReceivedTask = true } backend.reviveOffers() } + // Label as private[scheduler] to allow tests to swap in different task set managers if necessary + private[scheduler] def createTaskSetManager( + taskSet: TaskSet, + maxTaskFailures: Int): TaskSetManager = { + new TaskSetManager(this, taskSet, maxTaskFailures) + } + override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized { logInfo("Cancelling stage " + stageId) activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => @@ -361,22 +367,22 @@ private[spark] class TaskSchedulerImpl( dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId) } - def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long) { + def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long): Unit = synchronized { taskSetManager.handleTaskGettingResult(tid) } def handleSuccessfulTask( - taskSetManager: TaskSetManager, - tid: Long, - taskResult: DirectTaskResult[_]) = synchronized { + taskSetManager: TaskSetManager, + tid: Long, + taskResult: DirectTaskResult[_]): Unit = synchronized { taskSetManager.handleSuccessfulTask(tid, taskResult) } def handleFailedTask( - taskSetManager: TaskSetManager, - tid: Long, - taskState: TaskState, - reason: TaskEndReason) = synchronized { + taskSetManager: TaskSetManager, + tid: Long, + taskState: TaskState, + reason: TaskEndReason): Unit = synchronized { taskSetManager.handleFailedTask(tid, taskState, reason) if (!taskSetManager.isZombie && taskState != TaskState.KILLED) { // Need to revive offers again now that the task set manager state has been updated to @@ -387,7 +393,7 @@ private[spark] class TaskSchedulerImpl( def error(message: String) { synchronized { - if (activeTaskSets.size > 0) { + if (activeTaskSets.nonEmpty) { // Have each task set throw a SparkException with the error for ((taskSetId, manager) <- activeTaskSets) { try { @@ -400,8 +406,7 @@ private[spark] class TaskSchedulerImpl( // No task sets are active but we still got an error. Just exit since this // must mean the error is during registration. // It might be good to do something smarter here in the future. - logError("Exiting due to error from cluster scheduler: " + message) - System.exit(1) + throw new SparkException(s"Exiting due to error from cluster scheduler: $message") } } } @@ -416,7 +421,7 @@ private[spark] class TaskSchedulerImpl( starvationTimer.cancel() } - override def defaultParallelism() = backend.defaultParallelism() + override def defaultParallelism(): Int = backend.defaultParallelism() // Check for speculatable tasks in all our active jobs. def checkSpeculatableTasks() { @@ -429,7 +434,7 @@ private[spark] class TaskSchedulerImpl( } } - def executorLost(executorId: String, reason: ExecutorLossReason) { + override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = { var failedExecutor: Option[String] = None synchronized { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 5c94c6bbcb37..7dc325283d96 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.nio.ByteBuffer import java.util.Arrays +import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap @@ -29,6 +30,7 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler.SchedulingMode._ import org.apache.spark.TaskState.TaskState import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -51,7 +53,7 @@ private[spark] class TaskSetManager( sched: TaskSchedulerImpl, val taskSet: TaskSet, val maxTaskFailures: Int, - clock: Clock = SystemClock) + clock: Clock = new SystemClock()) extends Schedulable with Logging { val conf = sched.sc.conf @@ -97,7 +99,8 @@ private[spark] class TaskSetManager( var calculatedTasks = 0 val runningTasksSet = new HashSet[Long] - override def runningTasks = runningTasksSet.size + + override def runningTasks: Int = runningTasksSet.size // True once no more tasks should be launched for this task set manager. TaskSetManagers enter // the zombie state once at least one attempt of each task has completed successfully, or if the @@ -166,11 +169,11 @@ private[spark] class TaskSetManager( // last launched a task at that level, and move up a level when localityWaits[curLevel] expires. // We then move down if we manage to launch a "more local" task. var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels - var lastLaunchTime = clock.getTime() // Time we last launched a task at this level + var lastLaunchTime = clock.getTimeMillis() // Time we last launched a task at this level - override def schedulableQueue = null + override def schedulableQueue: ConcurrentLinkedQueue[Schedulable] = null - override def schedulingMode = SchedulingMode.NONE + override def schedulingMode: SchedulingMode = SchedulingMode.NONE var emittedTaskSizeWarning = false @@ -281,7 +284,7 @@ private[spark] class TaskSetManager( val failed = failedExecutors.get(taskId).get return failed.contains(execId) && - clock.getTime() - failed.get(execId).get < EXECUTOR_TASK_BLACKLIST_TIMEOUT + clock.getTimeMillis() - failed.get(execId).get < EXECUTOR_TASK_BLACKLIST_TIMEOUT } false @@ -292,7 +295,8 @@ private[spark] class TaskSetManager( * an attempt running on this host, in case the host is slow. In addition, the task should meet * the given locality constraint. */ - private def dequeueSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value) + // Labeled as protected to allow tests to override providing speculative tasks if necessary + protected def dequeueSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value) : Option[(Int, TaskLocality.Value)] = { speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set @@ -427,7 +431,7 @@ private[spark] class TaskSetManager( : Option[TaskDescription] = { if (!isZombie) { - val curTime = clock.getTime() + val curTime = clock.getTimeMillis() var allowedLocality = maxLocality @@ -458,7 +462,7 @@ private[spark] class TaskSetManager( lastLaunchTime = curTime } // Serialize and return the task - val startTime = clock.getTime() + val startTime = clock.getTimeMillis() val serializedTask: ByteBuffer = try { Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser) } catch { @@ -506,13 +510,64 @@ private[spark] class TaskSetManager( * Get the level we can launch tasks according to delay scheduling, based on current wait time. */ private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = { - while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) && - currentLocalityIndex < myLocalityLevels.length - 1) - { - // Jump to the next locality level, and remove our waiting time for the current one since - // we don't want to count it again on the next one - lastLaunchTime += localityWaits(currentLocalityIndex) - currentLocalityIndex += 1 + // Remove the scheduled or finished tasks lazily + def tasksNeedToBeScheduledFrom(pendingTaskIds: ArrayBuffer[Int]): Boolean = { + var indexOffset = pendingTaskIds.size + while (indexOffset > 0) { + indexOffset -= 1 + val index = pendingTaskIds(indexOffset) + if (copiesRunning(index) == 0 && !successful(index)) { + return true + } else { + pendingTaskIds.remove(indexOffset) + } + } + false + } + // Walk through the list of tasks that can be scheduled at each location and returns true + // if there are any tasks that still need to be scheduled. Lazily cleans up tasks that have + // already been scheduled. + def moreTasksToRunIn(pendingTasks: HashMap[String, ArrayBuffer[Int]]): Boolean = { + val emptyKeys = new ArrayBuffer[String] + val hasTasks = pendingTasks.exists { + case (id: String, tasks: ArrayBuffer[Int]) => + if (tasksNeedToBeScheduledFrom(tasks)) { + true + } else { + emptyKeys += id + false + } + } + // The key could be executorId, host or rackId + emptyKeys.foreach(id => pendingTasks.remove(id)) + hasTasks + } + + while (currentLocalityIndex < myLocalityLevels.length - 1) { + val moreTasks = myLocalityLevels(currentLocalityIndex) match { + case TaskLocality.PROCESS_LOCAL => moreTasksToRunIn(pendingTasksForExecutor) + case TaskLocality.NODE_LOCAL => moreTasksToRunIn(pendingTasksForHost) + case TaskLocality.NO_PREF => pendingTasksWithNoPrefs.nonEmpty + case TaskLocality.RACK_LOCAL => moreTasksToRunIn(pendingTasksForRack) + } + if (!moreTasks) { + // This is a performance optimization: if there are no more tasks that can + // be scheduled at a particular locality level, there is no point in waiting + // for the locality wait timeout (SPARK-4939). + lastLaunchTime = curTime + logDebug(s"No tasks for locality level ${myLocalityLevels(currentLocalityIndex)}, " + + s"so moving to locality level ${myLocalityLevels(currentLocalityIndex + 1)}") + currentLocalityIndex += 1 + } else if (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex)) { + // Jump to the next locality level, and reset lastLaunchTime so that the next locality + // wait timer doesn't immediately expire + lastLaunchTime += localityWaits(currentLocalityIndex) + currentLocalityIndex += 1 + logDebug(s"Moving to ${myLocalityLevels(currentLocalityIndex)} after waiting for " + + s"${localityWaits(currentLocalityIndex)}ms") + } else { + return myLocalityLevels(currentLocalityIndex) + } } myLocalityLevels(currentLocalityIndex) } @@ -533,7 +588,7 @@ private[spark] class TaskSetManager( /** * Marks the task as getting result and notifies the DAG Scheduler */ - def handleTaskGettingResult(tid: Long) = { + def handleTaskGettingResult(tid: Long): Unit = { val info = taskInfos(tid) info.markGettingResult() sched.dagScheduler.taskGettingResult(info) @@ -542,7 +597,7 @@ private[spark] class TaskSetManager( /** * Check whether has enough quota to fetch the result with `size` bytes */ - def canFetchMoreResults(size: Long): Boolean = synchronized { + def canFetchMoreResults(size: Long): Boolean = sched.synchronized { totalResultSize += size calculatedTasks += 1 if (maxResultSize > 0 && totalResultSize > maxResultSize) { @@ -560,7 +615,7 @@ private[spark] class TaskSetManager( /** * Marks the task as successful and notifies the DAGScheduler that a task has ended. */ - def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = { + def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]): Unit = { val info = taskInfos(tid) val index = info.index info.markSuccessful() @@ -622,7 +677,7 @@ private[spark] class TaskSetManager( return } val key = ef.description - val now = clock.getTime() + val now = clock.getTimeMillis() val (printFull, dupCount) = { if (recentExceptions.contains(key)) { val (dupCount, printTime) = recentExceptions(key) @@ -654,10 +709,13 @@ private[spark] class TaskSetManager( } // always add to failed executors failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()). - put(info.executorId, clock.getTime()) + put(info.executorId, clock.getTimeMillis()) sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, taskMetrics) addPendingTask(index) - if (!isZombie && state != TaskState.KILLED) { + if (!isZombie && state != TaskState.KILLED && !reason.isInstanceOf[TaskCommitDenied]) { + // If a task failed because its attempt to commit was denied, do not count this failure + // towards failing the stage. This is intended to prevent spurious stage failures in cases + // where many speculative tasks are launched and denied to commit. assert (null != failureReason) numFailures(index) += 1 if (numFailures(index) >= maxTaskFailures) { @@ -671,7 +729,7 @@ private[spark] class TaskSetManager( maybeFinishTaskSet() } - def abort(message: String) { + def abort(message: String): Unit = sched.synchronized { // TODO: Kill running tasks if we were not terminated due to a Mesos error sched.dagScheduler.taskSetFailed(taskSet, message) isZombie = true @@ -766,7 +824,7 @@ private[spark] class TaskSetManager( val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) { - val time = clock.getTime() + val time = clock.getTimeMillis() val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray Arrays.sort(durations) val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1)) @@ -790,15 +848,18 @@ private[spark] class TaskSetManager( } private def getLocalityWait(level: TaskLocality.TaskLocality): Long = { - val defaultWait = conf.get("spark.locality.wait", "3000") - level match { - case TaskLocality.PROCESS_LOCAL => - conf.get("spark.locality.wait.process", defaultWait).toLong - case TaskLocality.NODE_LOCAL => - conf.get("spark.locality.wait.node", defaultWait).toLong - case TaskLocality.RACK_LOCAL => - conf.get("spark.locality.wait.rack", defaultWait).toLong - case _ => 0L + val defaultWait = conf.get("spark.locality.wait", "3s") + val localityWaitKey = level match { + case TaskLocality.PROCESS_LOCAL => "spark.locality.wait.process" + case TaskLocality.NODE_LOCAL => "spark.locality.wait.node" + case TaskLocality.RACK_LOCAL => "spark.locality.wait.rack" + case _ => null + } + + if (localityWaitKey != null) { + conf.getTimeAsMs(localityWaitKey, defaultWait) + } else { + 0L } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 1da6fe976da5..70364cea62a8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler.cluster import java.nio.ByteBuffer import org.apache.spark.TaskState.TaskState +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.{SerializableBuffer, Utils} private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable @@ -39,7 +40,12 @@ private[spark] object CoarseGrainedClusterMessages { case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage // Executors to driver - case class RegisterExecutor(executorId: String, hostPort: String, cores: Int) + case class RegisterExecutor( + executorId: String, + executorRef: RpcEndpointRef, + hostPort: String, + cores: Int, + logUrls: Map[String, String]) extends CoarseGrainedClusterMessage { Utils.checkHostPort(hostPort, "Expected host port") } @@ -66,6 +72,8 @@ private[spark] object CoarseGrainedClusterMessages { case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage + case class SetupDriver(driver: RpcEndpointRef) extends CoarseGrainedClusterMessage + // Exchanged between the driver and the AM in Yarn client mode case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase: String) extends CoarseGrainedClusterMessage @@ -73,7 +81,7 @@ private[spark] object CoarseGrainedClusterMessages { // Messages exchanged between the driver and the cluster manager for executor allocation // In Yarn mode, these are exchanged between the driver and the AM - case object RegisterClusterManager extends CoarseGrainedClusterMessage + case class RegisterClusterManager(am: RpcEndpointRef) extends CoarseGrainedClusterMessage // Request executors by specifying the new total number of executors desired // This includes executors already pending or running diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 5786d367464f..9656fb76858e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -17,20 +17,16 @@ package org.apache.spark.scheduler.cluster +import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.concurrent.Await -import scala.concurrent.duration._ - -import akka.actor._ -import akka.pattern.ask -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import org.apache.spark.rpc._ import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils} +import org.apache.spark.util.{ThreadUtils, SerializableBuffer, AkkaUtils, Utils} /** * A scheduler backend that waits for coarse grained executors to connect to it through Akka. @@ -41,7 +37,7 @@ import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Ut * (spark.deploy.*). */ private[spark] -class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSystem: ActorSystem) +class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: RpcEnv) extends ExecutorAllocationClient with SchedulerBackend with Logging { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed @@ -49,7 +45,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Total number of executors that are currently registered var totalRegisteredExecutors = new AtomicInteger(0) val conf = scheduler.sc.conf - private val timeout = AkkaUtils.askTimeout(conf) private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) // Submit tasks only after (registered resources / total expected resources) // is equal to at least this value, that is double between 0 and 1. @@ -57,8 +52,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste math.min(1, conf.getDouble("spark.scheduler.minRegisteredResourcesRatio", 0)) // Submit tasks after maxRegisteredWaitingTime milliseconds // if minRegisteredRatio has not yet been reached - val maxRegisteredWaitingTime = - conf.getInt("spark.scheduler.maxRegisteredResourcesWaitingTime", 30000) + val maxRegisteredWaitingTimeMs = + conf.getTimeAsMs("spark.scheduler.maxRegisteredResourcesWaitingTime", "30s") val createTime = System.currentTimeMillis() private val executorDataMap = new HashMap[String, ExecutorData] @@ -71,47 +66,27 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Executors we have requested the cluster manager to kill that have not died yet private val executorsPendingToRemove = new HashSet[String] - class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor with ActorLogReceive { + class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) + extends ThreadSafeRpcEndpoint with Logging { override protected def log = CoarseGrainedSchedulerBackend.this.log - private val addressToExecutorId = new HashMap[Address, String] - - override def preStart() { - // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - // Periodically revive offers to allow delay scheduling to work - val reviveInterval = conf.getLong("spark.scheduler.revive.interval", 1000) - import context.dispatcher - context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers) - } + private val addressToExecutorId = new HashMap[RpcAddress, String] - def receiveWithLogging = { - case RegisterExecutor(executorId, hostPort, cores) => - Utils.checkHostPort(hostPort, "Host port expected " + hostPort) - if (executorDataMap.contains(executorId)) { - sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) - } else { - logInfo("Registered executor: " + sender + " with ID " + executorId) - sender ! RegisteredExecutor + private val reviveThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-revive-thread") - addressToExecutorId(sender.path.address) = executorId - totalCoreCount.addAndGet(cores) - totalRegisteredExecutors.addAndGet(1) - val (host, _) = Utils.parseHostPort(hostPort) - val data = new ExecutorData(sender, sender.path.address, host, cores, cores) - // This must be synchronized because variables mutated - // in this block are read when requesting executors - CoarseGrainedSchedulerBackend.this.synchronized { - executorDataMap.put(executorId, data) - if (numPendingExecutors > 0) { - numPendingExecutors -= 1 - logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") - } - } - listenerBus.post(SparkListenerExecutorAdded(executorId, data)) - makeOffers() + override def onStart() { + // Periodically revive offers to allow delay scheduling to work + val reviveIntervalMs = conf.getTimeAsMs("spark.scheduler.revive.interval", "1s") + + reviveThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + Option(self).foreach(_.send(ReviveOffers)) } + }, 0, reviveIntervalMs, TimeUnit.MILLISECONDS) + } + override def receive: PartialFunction[Any, Unit] = { case StatusUpdate(executorId, taskId, state, data) => scheduler.statusUpdate(taskId, state, data.value) if (TaskState.isFinished(state)) { @@ -132,33 +107,58 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste case KillTask(taskId, executorId, interruptThread) => executorDataMap.get(executorId) match { case Some(executorInfo) => - executorInfo.executorActor ! KillTask(taskId, executorId, interruptThread) + executorInfo.executorEndpoint.send(KillTask(taskId, executorId, interruptThread)) case None => // Ignoring the task kill since the executor is not registered. logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.") } + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterExecutor(executorId, executorRef, hostPort, cores, logUrls) => + Utils.checkHostPort(hostPort, "Host port expected " + hostPort) + if (executorDataMap.contains(executorId)) { + context.reply(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) + } else { + logInfo("Registered executor: " + executorRef + " with ID " + executorId) + context.reply(RegisteredExecutor) + + addressToExecutorId(executorRef.address) = executorId + totalCoreCount.addAndGet(cores) + totalRegisteredExecutors.addAndGet(1) + val (host, _) = Utils.parseHostPort(hostPort) + val data = new ExecutorData(executorRef, executorRef.address, host, cores, cores, logUrls) + // This must be synchronized because variables mutated + // in this block are read when requesting executors + CoarseGrainedSchedulerBackend.this.synchronized { + executorDataMap.put(executorId, data) + if (numPendingExecutors > 0) { + numPendingExecutors -= 1 + logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") + } + } + listenerBus.post( + SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data)) + makeOffers() + } case StopDriver => - sender ! true - context.stop(self) + context.reply(true) + stop() case StopExecutors => logInfo("Asking each executor to shut down") for ((_, executorData) <- executorDataMap) { - executorData.executorActor ! StopExecutor + executorData.executorEndpoint.send(StopExecutor) } - sender ! true + context.reply(true) case RemoveExecutor(executorId, reason) => removeExecutor(executorId, reason) - sender ! true - - case DisassociatedEvent(_, address, _) => - addressToExecutorId.get(address).foreach(removeExecutor(_, - "remote Akka client disassociated")) + context.reply(true) case RetrieveSparkProps => - sender ! sparkProperties + context.reply(sparkProperties) } // Make fake resource offers on all executors @@ -168,6 +168,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste }.toSeq)) } + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + addressToExecutorId.get(remoteAddress).foreach(removeExecutor(_, + "remote Rpc client disassociated")) + } + // Make fake resource offers on just one executor def makeOffers(executorId: String) { val executorData = executorDataMap(executorId) @@ -198,7 +203,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste else { val executorData = executorDataMap(task.executorId) executorData.freeCores -= scheduler.CPUS_PER_TASK - executorData.executorActor ! LaunchTask(new SerializableBuffer(serializedTask)) + executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask))) } } } @@ -210,19 +215,25 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // This must be synchronized because variables mutated // in this block are read when requesting executors CoarseGrainedSchedulerBackend.this.synchronized { + addressToExecutorId -= executorInfo.executorAddress executorDataMap -= executorId executorsPendingToRemove -= executorId } totalCoreCount.addAndGet(-executorInfo.totalCores) totalRegisteredExecutors.addAndGet(-1) scheduler.executorLost(executorId, SlaveLost(reason)) - listenerBus.post(SparkListenerExecutorRemoved(executorId)) + listenerBus.post( + SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason)) case None => logError(s"Asked to remove non-existent executor $executorId") } } + + override def onStop() { + reviveThread.shutdownNow() + } } - var driverActor: ActorRef = null + var driverEndpoint: RpcEndpointRef = null val taskIdsOnSlave = new HashMap[String, HashSet[String]] override def start() { @@ -233,16 +244,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste } } // TODO (prashant) send conf instead of properties - driverActor = actorSystem.actorOf( - Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME) + driverEndpoint = rpcEnv.setupEndpoint( + CoarseGrainedSchedulerBackend.ENDPOINT_NAME, new DriverEndpoint(rpcEnv, properties)) } def stopExecutors() { try { - if (driverActor != null) { + if (driverEndpoint != null) { logInfo("Shutting down all executors") - val future = driverActor.ask(StopExecutors)(timeout) - Await.ready(future, timeout) + driverEndpoint.askWithReply[Boolean](StopExecutors) } } catch { case e: Exception => @@ -253,22 +263,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste override def stop() { stopExecutors() try { - if (driverActor != null) { - val future = driverActor.ask(StopDriver)(timeout) - Await.ready(future, timeout) + if (driverEndpoint != null) { + driverEndpoint.askWithReply[Boolean](StopDriver) } } catch { case e: Exception => - throw new SparkException("Error stopping standalone scheduler's driver actor", e) + throw new SparkException("Error stopping standalone scheduler's driver endpoint", e) } } override def reviveOffers() { - driverActor ! ReviveOffers + driverEndpoint.send(ReviveOffers) } override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - driverActor ! KillTask(taskId, executorId, interruptThread) + driverEndpoint.send(KillTask(taskId, executorId, interruptThread)) } override def defaultParallelism(): Int = { @@ -278,11 +287,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Called by subclasses when notified of a lost worker def removeExecutor(executorId: String, reason: String) { try { - val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout) - Await.ready(future, timeout) + driverEndpoint.askWithReply[Boolean](RemoveExecutor(executorId, reason)) } catch { case e: Exception => - throw new SparkException("Error notifying standalone scheduler's driver actor", e) + throw new SparkException("Error notifying standalone scheduler's driver endpoint", e) } } @@ -294,9 +302,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste s"reached minRegisteredResourcesRatio: $minRegisteredRatio") return true } - if ((System.currentTimeMillis() - createTime) >= maxRegisteredWaitingTime) { + if ((System.currentTimeMillis() - createTime) >= maxRegisteredWaitingTimeMs) { logInfo("SchedulerBackend is ready for scheduling beginning after waiting " + - s"maxRegisteredResourcesWaitingTime: $maxRegisteredWaitingTime(ms)") + s"maxRegisteredResourcesWaitingTime: $maxRegisteredWaitingTimeMs(ms)") return true } false @@ -309,9 +317,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste /** * Request an additional number of executors from the cluster manager. - * Return whether the request is acknowledged. + * @return whether the request is acknowledged. */ final override def requestExecutors(numAdditionalExecutors: Int): Boolean = synchronized { + if (numAdditionalExecutors < 0) { + throw new IllegalArgumentException( + "Attempted to request a negative number of additional executor(s) " + + s"$numAdditionalExecutors from the cluster manager. Please specify a positive number!") + } logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager") logDebug(s"Number of pending executors is now $numPendingExecutors") numPendingExecutors += numAdditionalExecutors @@ -320,6 +333,22 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste doRequestTotalExecutors(newTotal) } + /** + * Express a preference to the cluster manager for a given total number of executors. This can + * result in canceling pending requests or filing additional requests. + * @return whether the request is acknowledged. + */ + final override def requestTotalExecutors(numExecutors: Int): Boolean = synchronized { + if (numExecutors < 0) { + throw new IllegalArgumentException( + "Attempted to request a negative number of executor(s) " + + s"$numExecutors from the cluster manager. Please specify a positive number!") + } + numPendingExecutors = + math.max(numExecutors - numExistingExecutors + executorsPendingToRemove.size, 0) + doRequestTotalExecutors(numExecutors) + } + /** * Request executors from the cluster manager by specifying the total number desired, * including existing pending and running executors. @@ -330,7 +359,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste * insufficient resources to satisfy the first request. We make the assumption here that the * cluster manager will eventually fulfill all requests when resources free up. * - * Return whether the request is acknowledged. + * @return whether the request is acknowledged. */ protected def doRequestTotalExecutors(requestedTotal: Int): Boolean = false @@ -348,6 +377,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste logWarning(s"Executor to kill $id does not exist!") } } + // Killing executors means effectively that we want less executors than before, so also update + // the target number of executors to avoid having the backend allocate new ones. + val newTotal = (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size + - filteredExecutorIds.size) + doRequestTotalExecutors(newTotal) + executorsPendingToRemove ++= filteredExecutorIds doKillExecutors(filteredExecutorIds) } @@ -361,5 +396,5 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste } private[spark] object CoarseGrainedSchedulerBackend { - val ACTOR_NAME = "CoarseGrainedScheduler" + val ENDPOINT_NAME = "CoarseGrainedScheduler" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index eb52ddfb1eab..26e72c0bff38 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -17,21 +17,22 @@ package org.apache.spark.scheduler.cluster -import akka.actor.{Address, ActorRef} +import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress} /** * Grouping of data for an executor used by CoarseGrainedSchedulerBackend. * - * @param executorActor The ActorRef representing this executor + * @param executorEndpoint The ActorRef representing this executor * @param executorAddress The network address of this executor * @param executorHost The hostname that this executor is running on * @param freeCores The current number of cores available for work on the executor * @param totalCores The total number of cores available to the executor */ private[cluster] class ExecutorData( - val executorActor: ActorRef, - val executorAddress: Address, + val executorEndpoint: RpcEndpointRef, + val executorAddress: RpcAddress, override val executorHost: String, var freeCores: Int, - override val totalCores: Int -) extends ExecutorInfo(executorHost, totalCores) + override val totalCores: Int, + override val logUrlMap: Map[String, String] +) extends ExecutorInfo(executorHost, totalCores, logUrlMap) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala index b4738e64c939..7f218566146a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala @@ -25,8 +25,8 @@ import org.apache.spark.annotation.DeveloperApi @DeveloperApi class ExecutorInfo( val executorHost: String, - val totalCores: Int -) { + val totalCores: Int, + val logUrlMap: Map[String, String]) { def canEqual(other: Any): Boolean = other.isInstanceOf[ExecutorInfo] @@ -34,12 +34,13 @@ class ExecutorInfo( case that: ExecutorInfo => (that canEqual this) && executorHost == that.executorHost && - totalCores == that.totalCores + totalCores == that.totalCores && + logUrlMap == that.logUrlMap case _ => false } override def hashCode(): Int = { - val state = Seq(executorHost, totalCores) + val state = Seq(executorHost, totalCores, logUrlMap) state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index ee10aa061f4e..0324c9dab910 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.spark.rpc.RpcAddress import org.apache.spark.{Logging, SparkContext, SparkEnv} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.TaskSchedulerImpl @@ -27,7 +28,7 @@ private[spark] class SimrSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, driverFilePath: String) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with Logging { val tmpPath = new Path(driverFilePath + "_tmp") @@ -38,11 +39,9 @@ private[spark] class SimrSchedulerBackend( override def start() { super.start() - val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( - SparkEnv.driverActorSystemName, - sc.conf.get("spark.driver.host"), - sc.conf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName, + RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val conf = SparkHadoopUtil.get.newConfiguration(sc.conf) val fs = FileSystem.get(conf) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 7eb87a564d6f..ccf1dc5af612 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -17,6 +17,9 @@ package org.apache.spark.scheduler.cluster +import java.util.concurrent.Semaphore + +import org.apache.spark.rpc.RpcAddress import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} @@ -27,32 +30,35 @@ private[spark] class SparkDeploySchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, masters: Array[String]) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with AppClientListener with Logging { - var client: AppClient = null - var stopping = false - var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _ - @volatile var appId: String = _ + private var client: AppClient = null + private var stopping = false + + @volatile var shutdownCallback: SparkDeploySchedulerBackend => Unit = _ + @volatile private var appId: String = _ - val registrationLock = new Object() - var registrationDone = false + private val registrationBarrier = new Semaphore(0) - val maxCores = conf.getOption("spark.cores.max").map(_.toInt) - val totalExpectedCores = maxCores.getOrElse(0) + private val maxCores = conf.getOption("spark.cores.max").map(_.toInt) + private val totalExpectedCores = maxCores.getOrElse(0) override def start() { super.start() // The endpoint for executors to talk to us - val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( - SparkEnv.driverActorSystemName, - conf.get("spark.driver.host"), - conf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) - val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{APP_ID}}", - "{{WORKER_URL}}") + val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName, + RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + val args = Seq( + "--driver-url", driverUrl, + "--executor-id", "{{EXECUTOR_ID}}", + "--hostname", "{{HOSTNAME}}", + "--cores", "{{CORES}}", + "--app-id", "{{APP_ID}}", + "--worker-url", "{{WORKER_URL}}") val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions") .map(Utils.splitCommandString).getOrElse(Seq.empty) val classPathEntries = sc.conf.getOption("spark.executor.extraClassPath") @@ -76,12 +82,11 @@ private[spark] class SparkDeploySchedulerBackend( val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs, classPathEntries ++ testingClassPath, libraryPathEntries, javaOpts) val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("") - val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, - appUIAddress, sc.eventLogDir) - + val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt) + val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, + command, appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor) client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf) client.start() - waitForRegistration() } @@ -89,8 +94,10 @@ private[spark] class SparkDeploySchedulerBackend( stopping = true super.stop() client.stop() - if (shutdownCallback != null) { - shutdownCallback(this) + + val callback = shutdownCallback + if (callback != null) { + callback(this) } } @@ -111,9 +118,12 @@ private[spark] class SparkDeploySchedulerBackend( notifyContext() if (!stopping) { logError("Application has been killed. Reason: " + reason) - scheduler.error(reason) - // Ensure the application terminates, as we can no longer run jobs. - sc.stop() + try { + scheduler.error(reason) + } finally { + // Ensure the application terminates, as we can no longer run jobs. + sc.stop() + } } } @@ -143,18 +153,11 @@ private[spark] class SparkDeploySchedulerBackend( } private def waitForRegistration() = { - registrationLock.synchronized { - while (!registrationDone) { - registrationLock.wait() - } - } + registrationBarrier.acquire() } private def notifyContext() = { - registrationLock.synchronized { - registrationDone = true - registrationLock.notifyAll() - } + registrationBarrier.release() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index f14aaeea0a25..d987c7d56357 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -19,14 +19,12 @@ package org.apache.spark.scheduler.cluster import scala.concurrent.{Future, ExecutionContext} -import akka.actor.{Actor, ActorRef, Props} -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} - -import org.apache.spark.SparkContext +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.rpc._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.ui.JettyUtils -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{ThreadUtils, RpcUtils} import scala.util.control.NonFatal @@ -37,7 +35,7 @@ import scala.util.control.NonFatal private[spark] abstract class YarnSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) { + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) { if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { minRegisteredRatio = 0.8 @@ -45,28 +43,24 @@ private[spark] abstract class YarnSchedulerBackend( protected var totalExpectedExecutors = 0 - private val yarnSchedulerActor: ActorRef = - actorSystem.actorOf( - Props(new YarnSchedulerActor), - name = YarnSchedulerBackend.ACTOR_NAME) + private val yarnSchedulerEndpoint = rpcEnv.setupEndpoint( + YarnSchedulerBackend.ENDPOINT_NAME, new YarnSchedulerEndpoint(rpcEnv)) - private implicit val askTimeout = AkkaUtils.askTimeout(sc.conf) + private implicit val askTimeout = RpcUtils.askTimeout(sc.conf) /** * Request executors from the ApplicationMaster by specifying the total number desired. * This includes executors already pending or running. */ override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - AkkaUtils.askWithReply[Boolean]( - RequestExecutors(requestedTotal), yarnSchedulerActor, askTimeout) + yarnSchedulerEndpoint.askWithReply[Boolean](RequestExecutors(requestedTotal)) } /** * Request that the ApplicationMaster kill the specified executors. */ override def doKillExecutors(executorIds: Seq[String]): Boolean = { - AkkaUtils.askWithReply[Boolean]( - KillExecutors(executorIds), yarnSchedulerActor, askTimeout) + yarnSchedulerEndpoint.askWithReply[Boolean](KillExecutors(executorIds)) } override def sufficientResourcesRegistered(): Boolean = { @@ -96,64 +90,71 @@ private[spark] abstract class YarnSchedulerBackend( } /** - * An actor that communicates with the ApplicationMaster. + * An [[RpcEndpoint]] that communicates with the ApplicationMaster. */ - private class YarnSchedulerActor extends Actor { - private var amActor: Option[ActorRef] = None + private class YarnSchedulerEndpoint(override val rpcEnv: RpcEnv) + extends ThreadSafeRpcEndpoint with Logging { + private var amEndpoint: Option[RpcEndpointRef] = None - implicit val askAmActorExecutor = ExecutionContext.fromExecutor( - Utils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-executor")) + private val askAmThreadPool = + ThreadUtils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-thread-pool") + implicit val askAmExecutor = ExecutionContext.fromExecutor(askAmThreadPool) - override def preStart(): Unit = { - // Listen for disassociation events - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - } + override def receive: PartialFunction[Any, Unit] = { + case RegisterClusterManager(am) => + logInfo(s"ApplicationMaster registered as $am") + amEndpoint = Some(am) + + case AddWebUIFilter(filterName, filterParams, proxyBase) => + addWebUIFilter(filterName, filterParams, proxyBase) - override def receive = { - case RegisterClusterManager => - logInfo(s"ApplicationMaster registered as $sender") - amActor = Some(sender) + } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case r: RequestExecutors => - amActor match { - case Some(actor) => - val driverActor = sender + amEndpoint match { + case Some(am) => Future { - driverActor ! AkkaUtils.askWithReply[Boolean](r, actor, askTimeout) + context.reply(am.askWithReply[Boolean](r)) } onFailure { - case NonFatal(e) => logError(s"Sending $r to AM was unsuccessful", e) + case NonFatal(e) => + logError(s"Sending $r to AM was unsuccessful", e) + context.sendFailure(e) } case None => logWarning("Attempted to request executors before the AM has registered!") - sender ! false + context.reply(false) } case k: KillExecutors => - amActor match { - case Some(actor) => - val driverActor = sender + amEndpoint match { + case Some(am) => Future { - driverActor ! AkkaUtils.askWithReply[Boolean](k, actor, askTimeout) + context.reply(am.askWithReply[Boolean](k)) } onFailure { - case NonFatal(e) => logError(s"Sending $k to AM was unsuccessful", e) + case NonFatal(e) => + logError(s"Sending $k to AM was unsuccessful", e) + context.sendFailure(e) } case None => logWarning("Attempted to kill executors before the AM has registered!") - sender ! false + context.reply(false) } - case AddWebUIFilter(filterName, filterParams, proxyBase) => - addWebUIFilter(filterName, filterParams, proxyBase) - sender ! true + } - case d: DisassociatedEvent => - if (amActor.isDefined && sender == amActor.get) { - logWarning(s"ApplicationMaster has disassociated: $d") - } + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (amEndpoint.exists(_.address == remoteAddress)) { + logWarning(s"ApplicationMaster has disassociated: $remoteAddress") + } + } + + override def onStop(): Unit ={ + askAmThreadPool.shutdownNow() } } } private[spark] object YarnSchedulerBackend { - val ACTOR_NAME = "YarnScheduler" + val ENDPOINT_NAME = "YarnScheduler" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 5289661eb896..82f652dae037 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -28,10 +28,10 @@ import org.apache.mesos.{Scheduler => MScheduler} import org.apache.mesos._ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} -import org.apache.spark.{Logging, SparkContext, SparkEnv, SparkException} +import org.apache.spark.{Logging, SparkContext, SparkEnv, SparkException, TaskState} import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.Utils +import org.apache.spark.util.{Utils, AkkaUtils} /** * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds @@ -47,7 +47,7 @@ private[spark] class CoarseMesosSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, master: String) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with MScheduler with Logging { @@ -143,28 +143,36 @@ private[spark] class CoarseMesosSchedulerBackend( } val command = CommandInfo.newBuilder() .setEnvironment(environment) - val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( + val driverUrl = AkkaUtils.address( + AkkaUtils.protocol(sc.env.actorSystem), SparkEnv.driverActorSystemName, conf.get("spark.driver.host"), conf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val uri = conf.get("spark.executor.uri", null) if (uri == null) { val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath command.setValue( - "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s".format( - prefixEnv, runScript, driverUrl, offer.getSlaveId.getValue, - offer.getHostname, numCores, appId)) + "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend" + .format(prefixEnv, runScript) + + s" --driver-url $driverUrl" + + s" --executor-id ${offer.getSlaveId.getValue}" + + s" --hostname ${offer.getHostname}" + + s" --cores $numCores" + + s" --app-id $appId") } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.split('/').last.split('.').head command.setValue( - ("cd %s*; %s " + - "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s") - .format(basename, prefixEnv, driverUrl, offer.getSlaveId.getValue, - offer.getHostname, numCores, appId)) + s"cd $basename*; $prefixEnv " + + "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" + + s" --driver-url $driverUrl" + + s" --executor-id ${offer.getSlaveId.getValue}" + + s" --hostname ${offer.getHostname}" + + s" --cores $numCores" + + s" --app-id $appId") command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) } command.build() @@ -199,7 +207,7 @@ private[spark] class CoarseMesosSchedulerBackend( */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { synchronized { - val filters = Filters.newBuilder().setRefuseSeconds(-1).build() + val filters = Filters.newBuilder().setRefuseSeconds(5).build() for (offer <- offers) { val slaveId = offer.getSlaveId.toString @@ -254,20 +262,12 @@ private[spark] class CoarseMesosSchedulerBackend( .build() } - /** Check whether a Mesos task state represents a finished task */ - private def isFinished(state: MesosTaskState) = { - state == MesosTaskState.TASK_FINISHED || - state == MesosTaskState.TASK_FAILED || - state == MesosTaskState.TASK_KILLED || - state == MesosTaskState.TASK_LOST - } - override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { val taskId = status.getTaskId.getValue.toInt val state = status.getState logInfo("Mesos task " + taskId + " is now " + state) synchronized { - if (isFinished(state)) { + if (TaskState.isFinished(TaskState.fromMesos(state))) { val slaveId = taskIdToSlaveId(taskId) slaveIdsWithExecutors -= slaveId taskIdToSlaveId -= taskId @@ -277,7 +277,7 @@ private[spark] class CoarseMesosSchedulerBackend( coresByTaskId -= taskId } // If it was a failure, mark the slave as failed for blacklisting purposes - if (state == MesosTaskState.TASK_FAILED || state == MesosTaskState.TASK_LOST) { + if (TaskState.isFailed(TaskState.fromMesos(state))) { failuresBySlaveId(slaveId) = failuresBySlaveId.getOrElse(slaveId, 0) + 1 if (failuresBySlaveId(slaveId) >= MAX_SLAVE_FAILURES) { logInfo("Blacklisting Mesos slave " + slaveId + " due to too many failures; " + diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala index 5101ec8352e7..8df4f3b554c4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala @@ -21,15 +21,11 @@ import org.apache.spark.SparkContext private[spark] object MemoryUtils { // These defaults copied from YARN - val OVERHEAD_FRACTION = 1.07 + val OVERHEAD_FRACTION = 0.10 val OVERHEAD_MINIMUM = 384 - def calculateTotalMemory(sc: SparkContext) = { - math.max( - sc.conf.getOption("spark.mesos.executor.memoryOverhead") - .getOrElse(OVERHEAD_MINIMUM.toString) - .toInt + sc.executorMemory, - OVERHEAD_FRACTION * sc.executorMemory - ) + def calculateTotalMemory(sc: SparkContext): Int = { + sc.conf.getInt("spark.mesos.executor.memoryOverhead", + math.max(OVERHEAD_FRACTION * sc.executorMemory, OVERHEAD_MINIMUM).toInt) + sc.executorMemory } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 79c9051e8869..d9d62b0e287e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -67,6 +67,8 @@ private[spark] class MesosSchedulerBackend( // The listener bus to publish executor added/removed events. val listenerBus = sc.listenerBus + + private[mesos] val mesosExecutorCores = sc.conf.getDouble("spark.mesos.mesosExecutor.cores", 1) @volatile var appId: String = _ @@ -139,7 +141,7 @@ private[spark] class MesosSchedulerBackend( .setName("cpus") .setType(Value.Type.SCALAR) .setScalar(Value.Scalar.newBuilder() - .setValue(scheduler.CPUS_PER_TASK).build()) + .setValue(mesosExecutorCores).build()) .build() val memory = Resource.newBuilder() .setName("mem") @@ -220,10 +222,9 @@ private[spark] class MesosSchedulerBackend( val mem = getResource(o.getResourcesList, "mem") val cpus = getResource(o.getResourcesList, "cpus") val slaveId = o.getSlaveId.getValue - // TODO(pwendell): Should below be 1 + scheduler.CPUS_PER_TASK? (mem >= MemoryUtils.calculateTotalMemory(sc) && // need at least 1 for executor, 1 for task - cpus >= 2 * scheduler.CPUS_PER_TASK) || + cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK)) || (slaveIdsWithExecutors.contains(slaveId) && cpus >= scheduler.CPUS_PER_TASK) } @@ -232,10 +233,9 @@ private[spark] class MesosSchedulerBackend( val cpus = if (slaveIdsWithExecutors.contains(o.getSlaveId.getValue)) { getResource(o.getResourcesList, "cpus").toInt } else { - // If the executor doesn't exist yet, subtract CPU for executor - // TODO(pwendell): Should below just subtract "1"? - getResource(o.getResourcesList, "cpus").toInt - - scheduler.CPUS_PER_TASK + // If the Mesos executor has not been started on this slave yet, set aside a few + // cores for the Mesos executor by offering fewer cores to the Spark executor + (getResource(o.getResourcesList, "cpus") - mesosExecutorCores).toInt } new WorkerOffer( o.getSlaveId.getValue, @@ -269,8 +269,9 @@ private[spark] class MesosSchedulerBackend( mesosTasks.foreach { case (slaveId, tasks) => slaveIdToWorkerOffer.get(slaveId).foreach(o => - listenerBus.post(SparkListenerExecutorAdded(slaveId, - new ExecutorInfo(o.host, o.cores))) + listenerBus.post(SparkListenerExecutorAdded(System.currentTimeMillis(), slaveId, + // TODO: Add support for log urls for Mesos + new ExecutorInfo(o.host, o.cores, Map.empty))) ) d.launchTasks(Collections.singleton(slaveIdToOffer(slaveId).getId), tasks, filters) } @@ -312,24 +313,17 @@ private[spark] class MesosSchedulerBackend( .build() } - /** Check whether a Mesos task state represents a finished task */ - def isFinished(state: MesosTaskState) = { - state == MesosTaskState.TASK_FINISHED || - state == MesosTaskState.TASK_FAILED || - state == MesosTaskState.TASK_KILLED || - state == MesosTaskState.TASK_LOST - } - override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { inClassLoader() { val tid = status.getTaskId.getValue.toLong val state = TaskState.fromMesos(status.getState) synchronized { - if (status.getState == MesosTaskState.TASK_LOST && taskIdToSlaveId.contains(tid)) { + if (TaskState.isFailed(TaskState.fromMesos(status.getState)) + && taskIdToSlaveId.contains(tid)) { // We lost the executor on this slave, so remember that it's gone - removeExecutor(taskIdToSlaveId(tid)) + removeExecutor(taskIdToSlaveId(tid), "Lost executor") } - if (isFinished(status.getState)) { + if (TaskState.isFinished(state)) { taskIdToSlaveId.remove(tid) } } @@ -359,9 +353,9 @@ private[spark] class MesosSchedulerBackend( /** * Remove executor associated with slaveId in a thread safe manner. */ - private def removeExecutor(slaveId: String) = { + private def removeExecutor(slaveId: String, reason: String) = { synchronized { - listenerBus.post(SparkListenerExecutorRemoved(slaveId)) + listenerBus.post(SparkListenerExecutorRemoved(System.currentTimeMillis(), slaveId, reason)) slaveIdsWithExecutors -= slaveId } } @@ -369,7 +363,7 @@ private[spark] class MesosSchedulerBackend( private def recordSlaveLost(d: SchedulerDriver, slaveId: SlaveID, reason: ExecutorLossReason) { inClassLoader() { logInfo("Mesos slave lost: " + slaveId.getValue) - removeExecutor(slaveId.getValue) + removeExecutor(slaveId.getValue, reason.toString) scheduler.executorLost(slaveId.getValue, reason) } } @@ -393,7 +387,7 @@ private[spark] class MesosSchedulerBackend( } // TODO: query Mesos for number of cores - override def defaultParallelism() = sc.conf.getInt("spark.default.parallelism", 8) + override def defaultParallelism(): Int = sc.conf.getInt("spark.default.parallelism", 8) override def applicationId(): String = Option(appId).getOrElse { diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 05b6fa54564b..ac5b52451781 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -18,14 +18,14 @@ package org.apache.spark.scheduler.local import java.nio.ByteBuffer +import java.util.concurrent.TimeUnit -import akka.actor.{Actor, ActorRef, Props} - -import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState} +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} +import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcCallContext, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} -import org.apache.spark.util.ActorLogReceive +import org.apache.spark.util.{ThreadUtils, Utils} private case class ReviveOffers() @@ -36,15 +36,19 @@ private case class KillTask(taskId: Long, interruptThread: Boolean) private case class StopExecutor() /** - * Calls to LocalBackend are all serialized through LocalActor. Using an actor makes the calls on - * LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend + * Calls to LocalBackend are all serialized through LocalEndpoint. Using an RpcEndpoint makes the + * calls on LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend * and the TaskSchedulerImpl. */ -private[spark] class LocalActor( +private[spark] class LocalEndpoint( + override val rpcEnv: RpcEnv, scheduler: TaskSchedulerImpl, executorBackend: LocalBackend, private val totalCores: Int) - extends Actor with ActorLogReceive with Logging { + extends ThreadSafeRpcEndpoint with Logging { + + private val reviveThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("local-revive-thread") private var freeCores = totalCores @@ -54,7 +58,7 @@ private[spark] class LocalActor( private val executor = new Executor( localExecutorId, localExecutorHostname, SparkEnv.get, isLocal = true) - override def receiveWithLogging = { + override def receive: PartialFunction[Any, Unit] = { case ReviveOffers => reviveOffers() @@ -67,18 +71,35 @@ private[spark] class LocalActor( case KillTask(taskId, interruptThread) => executor.killTask(taskId, interruptThread) + } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case StopExecutor => executor.stop() + context.reply(true) } + def reviveOffers() { val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) - for (task <- scheduler.resourceOffers(offers).flatten) { + val tasks = scheduler.resourceOffers(offers).flatten + for (task <- tasks) { freeCores -= scheduler.CPUS_PER_TASK executor.launchTask(executorBackend, taskId = task.taskId, attemptNumber = task.attemptNumber, task.name, task.serializedTask) } + if (tasks.isEmpty && scheduler.activeTaskSets.nonEmpty) { + // Try to reviveOffer after 1 second, because scheduler may wait for locality timeout + reviveThread.schedule(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + Option(self).foreach(_.send(ReviveOffers)) + } + }, 1000, TimeUnit.MILLISECONDS) + } + } + + override def onStop(): Unit = { + reviveThread.shutdownNow() } } @@ -87,35 +108,37 @@ private[spark] class LocalActor( * master all run in the same JVM. It sits behind a TaskSchedulerImpl and handles launching tasks * on a single Executor (created by the LocalBackend) running locally. */ -private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: Int) - extends SchedulerBackend with ExecutorBackend { +private[spark] class LocalBackend( + conf: SparkConf, + scheduler: TaskSchedulerImpl, + val totalCores: Int) + extends SchedulerBackend with ExecutorBackend with Logging { private val appId = "local-" + System.currentTimeMillis - var localActor: ActorRef = null + var localEndpoint: RpcEndpointRef = null override def start() { - localActor = SparkEnv.get.actorSystem.actorOf( - Props(new LocalActor(scheduler, this, totalCores)), - "LocalBackendActor") + localEndpoint = SparkEnv.get.rpcEnv.setupEndpoint( + "LocalBackendEndpoint", new LocalEndpoint(SparkEnv.get.rpcEnv, scheduler, this, totalCores)) } override def stop() { - localActor ! StopExecutor + localEndpoint.sendWithReply(StopExecutor) } override def reviveOffers() { - localActor ! ReviveOffers + localEndpoint.send(ReviveOffers) } - override def defaultParallelism() = + override def defaultParallelism(): Int = scheduler.conf.getInt("spark.default.parallelism", totalCores) override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - localActor ! KillTask(taskId, interruptThread) + localEndpoint.send(KillTask(taskId, interruptThread)) } override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { - localActor ! StatusUpdate(taskId, state, serializedData) + localEndpoint.send(StatusUpdate(taskId, state, serializedData)) } override def applicationId(): String = appId diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index fa8a337ad63a..dfbde7c8a1b0 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -27,7 +27,8 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.ByteBufferInputStream import org.apache.spark.util.Utils -private[spark] class JavaSerializationStream(out: OutputStream, counterReset: Int) +private[spark] class JavaSerializationStream( + out: OutputStream, counterReset: Int, extraDebugInfo: Boolean) extends SerializationStream { private val objOut = new ObjectOutputStream(out) private var counter = 0 @@ -39,7 +40,12 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In * the stream 'resets' object class descriptions have to be re-written) */ def writeObject[T: ClassTag](t: T): SerializationStream = { - objOut.writeObject(t) + try { + objOut.writeObject(t) + } catch { + case e: NotSerializableException if extraDebugInfo => + throw SerializationDebugger.improveException(t, e) + } counter += 1 if (counterReset > 0 && counter >= counterReset) { objOut.reset() @@ -53,9 +59,10 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In } private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoader) -extends DeserializationStream { + extends DeserializationStream { + private val objIn = new ObjectInputStream(in) { - override def resolveClass(desc: ObjectStreamClass) = + override def resolveClass(desc: ObjectStreamClass): Class[_] = Class.forName(desc.getName, false, loader) } @@ -64,7 +71,8 @@ extends DeserializationStream { } -private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoader: ClassLoader) +private[spark] class JavaSerializerInstance( + counterReset: Int, extraDebugInfo: Boolean, defaultClassLoader: ClassLoader) extends SerializerInstance { override def serialize[T: ClassTag](t: T): ByteBuffer = { @@ -88,7 +96,7 @@ private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoade } override def serializeStream(s: OutputStream): SerializationStream = { - new JavaSerializationStream(s, counterReset) + new JavaSerializationStream(s, counterReset, extraDebugInfo) } override def deserializeStream(s: InputStream): DeserializationStream = { @@ -111,17 +119,20 @@ private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoade @DeveloperApi class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable { private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100) + private var extraDebugInfo = conf.getBoolean("spark.serializer.extraDebugInfo", true) override def newInstance(): SerializerInstance = { val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader) - new JavaSerializerInstance(counterReset, classLoader) + new JavaSerializerInstance(counterReset, extraDebugInfo, classLoader) } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { out.writeInt(counterReset) + out.writeBoolean(extraDebugInfo) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { counterReset = in.readInt() + extraDebugInfo = in.readBoolean() } } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index d56e23ce4478..579fb6624e69 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -20,22 +20,23 @@ package org.apache.spark.serializer import java.io.{EOFException, InputStream, OutputStream} import java.nio.ByteBuffer +import scala.reflect.ClassTag + import com.esotericsoftware.kryo.{Kryo, KryoException} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} +import org.roaringbitmap.{ArrayContainer, BitmapContainer, RoaringArray, RoaringBitmap} import org.apache.spark._ import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.HttpBroadcast -import org.apache.spark.network.nio.{PutBlock, GotBlock, GetBlock} +import org.apache.spark.network.nio.{GetBlock, GotBlock, PutBlock} import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ import org.apache.spark.util.BoundedPriorityQueue import org.apache.spark.util.collection.CompactBuffer -import scala.reflect.ClassTag - /** * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. * @@ -48,26 +49,28 @@ class KryoSerializer(conf: SparkConf) with Logging with Serializable { - private val bufferSize = - (conf.getDouble("spark.kryoserializer.buffer.mb", 0.064) * 1024 * 1024).toInt + private val bufferSizeMb = conf.getDouble("spark.kryoserializer.buffer.mb", 0.064) + if (bufferSizeMb >= 2048) { + throw new IllegalArgumentException("spark.kryoserializer.buffer.mb must be less than " + + s"2048 mb, got: + $bufferSizeMb mb.") + } + private val bufferSize = (bufferSizeMb * 1024 * 1024).toInt + + val maxBufferSizeMb = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) + if (maxBufferSizeMb >= 2048) { + throw new IllegalArgumentException("spark.kryoserializer.buffer.max.mb must be less than " + + s"2048 mb, got: + $maxBufferSizeMb mb.") + } + private val maxBufferSize = maxBufferSizeMb * 1024 * 1024 - private val maxBufferSize = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) * 1024 * 1024 private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true) private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false) private val userRegistrator = conf.getOption("spark.kryo.registrator") private val classesToRegister = conf.get("spark.kryo.classesToRegister", "") .split(',') .filter(!_.isEmpty) - .map { className => - try { - Class.forName(className) - } catch { - case e: Exception => - throw new SparkException("Failed to load class to register with Kryo", e) - } - } - def newKryoOutput() = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) + def newKryoOutput(): KryoOutput = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) def newKryo(): Kryo = { val instantiator = new EmptyScalaKryoInstantiator @@ -97,7 +100,8 @@ class KryoSerializer(conf: SparkConf) // Use the default classloader when calling the user registrator. Thread.currentThread.setContextClassLoader(classLoader) // Register classes given through spark.kryo.classesToRegister. - classesToRegister.foreach { clazz => kryo.register(clazz) } + classesToRegister + .foreach { className => kryo.register(Class.forName(className, true, classLoader)) } // Allow the user to register their own classes by setting spark.kryo.registrator. userRegistrator .map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]) @@ -164,7 +168,13 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ override def serialize[T: ClassTag](t: T): ByteBuffer = { output.clear() - kryo.writeClassAndObject(output, t) + try { + kryo.writeClassAndObject(output, t) + } catch { + case e: KryoException if e.getMessage.startsWith("Buffer overflow") => + throw new SparkException(s"Kryo serialization failed: ${e.getMessage}. To avoid this, " + + "increase spark.kryoserializer.buffer.max.mb value.") + } ByteBuffer.wrap(output.toBytes) } @@ -209,9 +219,17 @@ private[serializer] object KryoSerializer { classOf[GetBlock], classOf[CompressedMapStatus], classOf[HighlyCompressedMapStatus], + classOf[RoaringBitmap], + classOf[RoaringArray], + classOf[RoaringArray.Element], + classOf[Array[RoaringArray.Element]], + classOf[ArrayContainer], + classOf[BitmapContainer], classOf[CompactBuffer[_]], classOf[BlockManagerId], classOf[Array[Byte]], + classOf[Array[Short]], + classOf[Array[Long]], classOf[BoundedPriorityQueue[_]], classOf[SparkConf] ) diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala new file mode 100644 index 000000000000..cecb99257965 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io.{NotSerializableException, ObjectOutput, ObjectStreamClass, ObjectStreamField} +import java.lang.reflect.{Field, Method} +import java.security.AccessController + +import scala.annotation.tailrec +import scala.collection.mutable + +import org.apache.spark.Logging + +private[serializer] object SerializationDebugger extends Logging { + + /** + * Improve the given NotSerializableException with the serialization path leading from the given + * object to the problematic object. This is turned off automatically if + * `sun.io.serialization.extendedDebugInfo` flag is turned on for the JVM. + */ + def improveException(obj: Any, e: NotSerializableException): NotSerializableException = { + if (enableDebugging && reflect != null) { + new NotSerializableException( + e.getMessage + "\nSerialization stack:\n" + find(obj).map("\t- " + _).mkString("\n")) + } else { + e + } + } + + /** + * Find the path leading to a not serializable object. This method is modeled after OpenJDK's + * serialization mechanism, and handles the following cases: + * - primitives + * - arrays of primitives + * - arrays of non-primitive objects + * - Serializable objects + * - Externalizable objects + * - writeReplace + * + * It does not yet handle writeObject override, but that shouldn't be too hard to do either. + */ + def find(obj: Any): List[String] = { + new SerializationDebugger().visit(obj, List.empty) + } + + private[serializer] var enableDebugging: Boolean = { + !AccessController.doPrivileged(new sun.security.action.GetBooleanAction( + "sun.io.serialization.extendedDebugInfo")).booleanValue() + } + + private class SerializationDebugger { + + /** A set to track the list of objects we have visited, to avoid cycles in the graph. */ + private val visited = new mutable.HashSet[Any] + + /** + * Visit the object and its fields and stop when we find an object that is not serializable. + * Return the path as a list. If everything can be serialized, return an empty list. + */ + def visit(o: Any, stack: List[String]): List[String] = { + if (o == null) { + List.empty + } else if (visited.contains(o)) { + List.empty + } else { + visited += o + o match { + // Primitive value, string, and primitive arrays are always serializable + case _ if o.getClass.isPrimitive => List.empty + case _: String => List.empty + case _ if o.getClass.isArray && o.getClass.getComponentType.isPrimitive => List.empty + + // Traverse non primitive array. + case a: Array[_] if o.getClass.isArray && !o.getClass.getComponentType.isPrimitive => + val elem = s"array (class ${a.getClass.getName}, size ${a.length})" + visitArray(o.asInstanceOf[Array[_]], elem :: stack) + + case e: java.io.Externalizable => + val elem = s"externalizable object (class ${e.getClass.getName}, $e)" + visitExternalizable(e, elem :: stack) + + case s: Object with java.io.Serializable => + val elem = s"object (class ${s.getClass.getName}, $s)" + visitSerializable(s, elem :: stack) + + case _ => + // Found an object that is not serializable! + s"object not serializable (class: ${o.getClass.getName}, value: $o)" :: stack + } + } + } + + private def visitArray(o: Array[_], stack: List[String]): List[String] = { + var i = 0 + while (i < o.length) { + val childStack = visit(o(i), s"element of array (index: $i)" :: stack) + if (childStack.nonEmpty) { + return childStack + } + i += 1 + } + return List.empty + } + + private def visitExternalizable(o: java.io.Externalizable, stack: List[String]): List[String] = + { + val fieldList = new ListObjectOutput + o.writeExternal(fieldList) + val childObjects = fieldList.outputArray + var i = 0 + while (i < childObjects.length) { + val childStack = visit(childObjects(i), "writeExternal data" :: stack) + if (childStack.nonEmpty) { + return childStack + } + i += 1 + } + return List.empty + } + + private def visitSerializable(o: Object, stack: List[String]): List[String] = { + // An object contains multiple slots in serialization. + // Get the slots and visit fields in all of them. + val (finalObj, desc) = findObjectAndDescriptor(o) + val slotDescs = desc.getSlotDescs + var i = 0 + while (i < slotDescs.length) { + val slotDesc = slotDescs(i) + if (slotDesc.hasWriteObjectMethod) { + // TODO: Handle classes that specify writeObject method. + } else { + val fields: Array[ObjectStreamField] = slotDesc.getFields + val objFieldValues: Array[Object] = new Array[Object](slotDesc.getNumObjFields) + val numPrims = fields.length - objFieldValues.length + desc.getObjFieldValues(finalObj, objFieldValues) + + var j = 0 + while (j < objFieldValues.length) { + val fieldDesc = fields(numPrims + j) + val elem = s"field (class: ${slotDesc.getName}" + + s", name: ${fieldDesc.getName}" + + s", type: ${fieldDesc.getType})" + val childStack = visit(objFieldValues(j), elem :: stack) + if (childStack.nonEmpty) { + return childStack + } + j += 1 + } + + } + i += 1 + } + return List.empty + } + } + + /** + * Find the object to serialize and the associated [[ObjectStreamClass]]. This method handles + * writeReplace in Serializable. It starts with the object itself, and keeps calling the + * writeReplace method until there is no more + */ + @tailrec + private def findObjectAndDescriptor(o: Object): (Object, ObjectStreamClass) = { + val cl = o.getClass + val desc = ObjectStreamClass.lookupAny(cl) + if (!desc.hasWriteReplaceMethod) { + (o, desc) + } else { + // write place + findObjectAndDescriptor(desc.invokeWriteReplace(o)) + } + } + + /** + * A dummy [[ObjectOutput]] that simply saves the list of objects written by a writeExternal + * call, and returns them through `outputArray`. + */ + private class ListObjectOutput extends ObjectOutput { + private val output = new mutable.ArrayBuffer[Any] + def outputArray: Array[Any] = output.toArray + override def writeObject(o: Any): Unit = output += o + override def flush(): Unit = {} + override def write(i: Int): Unit = {} + override def write(bytes: Array[Byte]): Unit = {} + override def write(bytes: Array[Byte], i: Int, i1: Int): Unit = {} + override def close(): Unit = {} + override def writeFloat(v: Float): Unit = {} + override def writeChars(s: String): Unit = {} + override def writeDouble(v: Double): Unit = {} + override def writeUTF(s: String): Unit = {} + override def writeShort(i: Int): Unit = {} + override def writeInt(i: Int): Unit = {} + override def writeBoolean(b: Boolean): Unit = {} + override def writeBytes(s: String): Unit = {} + override def writeChar(i: Int): Unit = {} + override def writeLong(l: Long): Unit = {} + override def writeByte(i: Int): Unit = {} + } + + /** An implicit class that allows us to call private methods of ObjectStreamClass. */ + implicit class ObjectStreamClassMethods(val desc: ObjectStreamClass) extends AnyVal { + def getSlotDescs: Array[ObjectStreamClass] = { + reflect.GetClassDataLayout.invoke(desc).asInstanceOf[Array[Object]].map { + classDataSlot => reflect.DescField.get(classDataSlot).asInstanceOf[ObjectStreamClass] + } + } + + def hasWriteObjectMethod: Boolean = { + reflect.HasWriteObjectMethod.invoke(desc).asInstanceOf[Boolean] + } + + def hasWriteReplaceMethod: Boolean = { + reflect.HasWriteReplaceMethod.invoke(desc).asInstanceOf[Boolean] + } + + def invokeWriteReplace(obj: Object): Object = { + reflect.InvokeWriteReplace.invoke(desc, obj) + } + + def getNumObjFields: Int = { + reflect.GetNumObjFields.invoke(desc).asInstanceOf[Int] + } + + def getObjFieldValues(obj: Object, out: Array[Object]): Unit = { + reflect.GetObjFieldValues.invoke(desc, obj, out) + } + } + + /** + * Object to hold all the reflection objects. If we run on a JVM that we cannot understand, + * this field will be null and this the debug helper should be disabled. + */ + private val reflect: ObjectStreamClassReflection = try { + new ObjectStreamClassReflection + } catch { + case e: Exception => + logWarning("Cannot find private methods using reflection", e) + null + } + + private class ObjectStreamClassReflection { + /** ObjectStreamClass.getClassDataLayout */ + val GetClassDataLayout: Method = { + val f = classOf[ObjectStreamClass].getDeclaredMethod("getClassDataLayout") + f.setAccessible(true) + f + } + + /** ObjectStreamClass.hasWriteObjectMethod */ + val HasWriteObjectMethod: Method = { + val f = classOf[ObjectStreamClass].getDeclaredMethod("hasWriteObjectMethod") + f.setAccessible(true) + f + } + + /** ObjectStreamClass.hasWriteReplaceMethod */ + val HasWriteReplaceMethod: Method = { + val f = classOf[ObjectStreamClass].getDeclaredMethod("hasWriteReplaceMethod") + f.setAccessible(true) + f + } + + /** ObjectStreamClass.invokeWriteReplace */ + val InvokeWriteReplace: Method = { + val f = classOf[ObjectStreamClass].getDeclaredMethod("invokeWriteReplace", classOf[Object]) + f.setAccessible(true) + f + } + + /** ObjectStreamClass.getNumObjFields */ + val GetNumObjFields: Method = { + val f = classOf[ObjectStreamClass].getDeclaredMethod("getNumObjFields") + f.setAccessible(true) + f + } + + /** ObjectStreamClass.getObjFieldValues */ + val GetObjFieldValues: Method = { + val f = classOf[ObjectStreamClass].getDeclaredMethod( + "getObjFieldValues", classOf[Object], classOf[Array[Object]]) + f.setAccessible(true) + f + } + + /** ObjectStreamClass$ClassDataSlot.desc field */ + val DescField: Field = { + val f = Class.forName("java.io.ObjectStreamClass$ClassDataSlot").getDeclaredField("desc") + f.setAccessible(true) + f + } + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index 7de2f9cbb286..538e150ead05 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -67,7 +67,7 @@ private[spark] trait ShuffleWriterGroup { // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getHashBasedShuffleBlockData(). private[spark] class FileShuffleBlockManager(conf: SparkConf) - extends ShuffleBlockManager with Logging { + extends ShuffleBlockResolver with Logging { private val transportConf = SparkTransportConf.fromSparkConf(conf) @@ -106,17 +106,19 @@ class FileShuffleBlockManager(conf: SparkConf) * when the writers are closed successfully */ def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer, - writeMetrics: ShuffleWriteMetrics) = { + writeMetrics: ShuffleWriteMetrics): ShuffleWriterGroup = { new ShuffleWriterGroup { shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) private val shuffleState = shuffleStates(shuffleId) private var fileGroup: ShuffleFileGroup = null + val openStartTime = System.nanoTime + val serializerInstance = serializer.newInstance() val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) { fileGroup = getUnusedFileGroup() Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) - blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize, + blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializerInstance, bufferSize, writeMetrics) } } else { @@ -132,9 +134,13 @@ class FileShuffleBlockManager(conf: SparkConf) logWarning(s"Failed to remove existing shuffle file $blockFile") } } - blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize, writeMetrics) + blockManager.getDiskWriter(blockId, blockFile, serializerInstance, bufferSize, + writeMetrics) } } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, so should be included in the shuffle write time. + writeMetrics.incShuffleWriteTime(System.nanoTime - openStartTime) override def releaseWriters(success: Boolean) { if (consolidateShuffleFiles) { @@ -171,11 +177,6 @@ class FileShuffleBlockManager(conf: SparkConf) } } - override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = { - val segment = getBlockData(blockId) - Some(segment.nioByteBuffer()) - } - override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { if (consolidateShuffleFiles) { // Search all file groups associated with this shuffle. @@ -268,7 +269,7 @@ object FileShuffleBlockManager { new PrimitiveVector[Long]() } - def apply(bucketId: Int) = files(bucketId) + def apply(bucketId: Int): File = files(bucketId) def recordMapOutput(mapId: Int, offsets: Array[Long], lengths: Array[Long]) { assert(offsets.length == lengths.length) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala index b292587d3702..a1741e2875c1 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala @@ -26,6 +26,9 @@ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.storage._ +import org.apache.spark.util.Utils + +import IndexShuffleBlockManager.NOOP_REDUCE_ID /** * Create and maintain the shuffle blocks' mapping between logic block and physical file location. @@ -39,25 +42,18 @@ import org.apache.spark.storage._ // Note: Changes to the format in this file should be kept in sync with // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getSortBasedShuffleBlockData(). private[spark] -class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { +class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockResolver { private lazy val blockManager = SparkEnv.get.blockManager private val transportConf = SparkTransportConf.fromSparkConf(conf) - /** - * Mapping to a single shuffleBlockId with reduce ID 0. - * */ - def consolidateId(shuffleId: Int, mapId: Int): ShuffleBlockId = { - ShuffleBlockId(shuffleId, mapId, 0) - } - def getDataFile(shuffleId: Int, mapId: Int): File = { - blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, 0)) + blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) } private def getIndexFile(shuffleId: Int, mapId: Int): File = { - blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, 0)) + blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) } /** @@ -80,27 +76,22 @@ class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { * end of the output file. This will be used by getBlockLocation to figure out where each block * begins and ends. * */ - def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]) = { + def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]): Unit = { val indexFile = getIndexFile(shuffleId, mapId) val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile))) - try { + Utils.tryWithSafeFinally { // We take in lengths of each block, need to convert it to offsets. var offset = 0L out.writeLong(offset) - for (length <- lengths) { offset += length out.writeLong(offset) } - } finally { + } { out.close() } } - override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = { - Some(getBlockData(blockId).nioByteBuffer()) - } - override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { // The block is actually going to be a range of a single map output file for this map, so // find out the consolidated file, then the offset within that from our index @@ -121,5 +112,13 @@ class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { } } - override def stop() = {} + override def stop(): Unit = {} +} + +private[spark] object IndexShuffleBlockManager { + // No-op reduce ID used in interactions with disk store and BlockObjectWriter. + // The disk store currently expects puts to relate to a (map, reduce) pair, but in the sort + // shuffle outputs for several reduces are glommed into a single file. + // TODO: Avoid this entirely by having the DiskBlockObjectWriter not require a BlockId. + val NOOP_REDUCE_ID = 0 } diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala similarity index 68% rename from core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala rename to core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala index b521f0c7fc77..4342b0d598b1 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala @@ -22,15 +22,19 @@ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.storage.ShuffleBlockId private[spark] -trait ShuffleBlockManager { +/** + * Implementers of this trait understand how to retrieve block data for a logical shuffle block + * identifier (i.e. map, reduce, and shuffle). Implementations may use files or file segments to + * encapsulate shuffle data. This is used by the BlockStore to abstract over different shuffle + * implementations when shuffle data is retrieved. + */ +trait ShuffleBlockResolver { type ShuffleId = Int /** - * Get shuffle block data managed by the local ShuffleBlockManager. - * @return Some(ByteBuffer) if block found, otherwise None. + * Retrieve the data for the specified block. If the data for that block is not available, + * throws an unspecified exception. */ - def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] - def getBlockData(blockId: ShuffleBlockId): ManagedBuffer def stop(): Unit diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index a44a8e124925..978366d1a1d1 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -55,7 +55,10 @@ private[spark] trait ShuffleManager { */ def unregisterShuffle(shuffleId: Int): Boolean - def shuffleBlockManager: ShuffleBlockManager + /** + * Return a resolver capable of retrieving shuffle block data based on block coordinates. + */ + def shuffleBlockResolver: ShuffleBlockResolver /** Shut down this ShuffleManager. */ def stop(): Unit diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala index b934480cfb9b..f6e6fe5defe0 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala @@ -23,7 +23,7 @@ import org.apache.spark.scheduler.MapStatus * Obtained inside a map task to write out records to the shuffle system. */ private[spark] trait ShuffleWriter[K, V] { - /** Write a bunch of records to this task's output */ + /** Write a sequence of records to this task's output */ def write(records: Iterator[_ <: Product2[K, V]]): Unit /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index e3e7434df45b..7a2c5ae32d98 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -86,6 +86,12 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { context.taskMetrics.updateShuffleReadMetrics() }) - new InterruptibleIterator[T](context, completionIter) + new InterruptibleIterator[T](context, completionIter) { + val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + override def next(): T = { + readMetrics.incRecordsRead(1) + delegate.next() + } + } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index 62e0629b3440..2a7df8dd5bd8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -53,20 +53,20 @@ private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext) : ShuffleWriter[K, V] = { new HashShuffleWriter( - shuffleBlockManager, handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context) + shuffleBlockResolver, handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context) } /** Remove a shuffle's metadata from the ShuffleManager. */ override def unregisterShuffle(shuffleId: Int): Boolean = { - shuffleBlockManager.removeShuffle(shuffleId) + shuffleBlockResolver.removeShuffle(shuffleId) } - override def shuffleBlockManager: FileShuffleBlockManager = { + override def shuffleBlockResolver: FileShuffleBlockManager = { fileShuffleBlockManager } /** Shut down this ShuffleManager. */ override def stop(): Unit = { - shuffleBlockManager.stop() + shuffleBlockResolver.stop() } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index bda30a56d808..049703619215 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -58,7 +58,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, V, _]] shuffleMapNumber.putIfAbsent(baseShuffleHandle.shuffleId, baseShuffleHandle.numMaps) new SortShuffleWriter( - shuffleBlockManager, baseShuffleHandle, mapId, context) + shuffleBlockResolver, baseShuffleHandle, mapId, context) } /** Remove a shuffle's metadata from the ShuffleManager. */ @@ -66,18 +66,19 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager if (shuffleMapNumber.containsKey(shuffleId)) { val numMaps = shuffleMapNumber.remove(shuffleId) (0 until numMaps).map{ mapId => - shuffleBlockManager.removeDataByMap(shuffleId, mapId) + shuffleBlockResolver.removeDataByMap(shuffleId, mapId) } } true } - override def shuffleBlockManager: IndexShuffleBlockManager = { + override def shuffleBlockResolver: IndexShuffleBlockManager = { indexShuffleBlockManager } /** Shut down this ShuffleManager. */ override def stop(): Unit = { - shuffleBlockManager.stop() + shuffleBlockResolver.stop() } } + diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 27496c5a289c..a066435df6fb 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -58,13 +58,15 @@ private[spark] class SortShuffleWriter[K, V, C]( // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't // care whether the keys get sorted in each partition; that will be done on the reduce side // if the operation being run is sortByKey. - sorter = new ExternalSorter[K, V, V]( - None, Some(dep.partitioner), None, dep.serializer) + sorter = new ExternalSorter[K, V, V](None, Some(dep.partitioner), None, dep.serializer) sorter.insertAll(records) } + // Don't bother including the time to open the merged output file in the shuffle write time, + // because it just opens a single file, so is typically too fast to measure accurately + // (see SPARK-3570). val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId) - val blockId = shuffleBlockManager.consolidateId(dep.shuffleId, mapId) + val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID) val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile) shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths) @@ -88,9 +90,13 @@ private[spark] class SortShuffleWriter[K, V, C]( } finally { // Clean up our sorter, which may have its own intermediate files if (sorter != null) { + val startTime = System.nanoTime() sorter.stop() + context.taskMetrics.shuffleWriteMetrics.foreach( + _.incShuffleWriteTime(System.nanoTime - startTime)) sorter = null } } } } + diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 1f012941c85a..c186fd360fef 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -35,13 +35,13 @@ sealed abstract class BlockId { def name: String // convenience methods - def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None - def isRDD = isInstanceOf[RDDBlockId] - def isShuffle = isInstanceOf[ShuffleBlockId] - def isBroadcast = isInstanceOf[BroadcastBlockId] + def asRDDId: Option[RDDBlockId] = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None + def isRDD: Boolean = isInstanceOf[RDDBlockId] + def isShuffle: Boolean = isInstanceOf[ShuffleBlockId] + def isBroadcast: Boolean = isInstanceOf[BroadcastBlockId] - override def toString = name - override def hashCode = name.hashCode + override def toString: String = name + override def hashCode: Int = name.hashCode override def equals(other: Any): Boolean = other match { case o: BlockId => getClass == o.getClass && name.equals(o.name) case _ => false @@ -50,54 +50,54 @@ sealed abstract class BlockId { @DeveloperApi case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { - def name = "rdd_" + rddId + "_" + splitIndex + override def name: String = "rdd_" + rddId + "_" + splitIndex } // Format of the shuffle block ids (including data and index) should be kept in sync with // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getBlockData(). @DeveloperApi case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { - def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId } @DeveloperApi case class ShuffleDataBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { - def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data" + override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data" } @DeveloperApi case class ShuffleIndexBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { - def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index" + override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index" } @DeveloperApi case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId { - def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) + override def name: String = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) } @DeveloperApi case class TaskResultBlockId(taskId: Long) extends BlockId { - def name = "taskresult_" + taskId + override def name: String = "taskresult_" + taskId } @DeveloperApi case class StreamBlockId(streamId: Int, uniqueId: Long) extends BlockId { - def name = "input-" + streamId + "-" + uniqueId + override def name: String = "input-" + streamId + "-" + uniqueId } /** Id associated with temporary local data managed as blocks. Not serializable. */ private[spark] case class TempLocalBlockId(id: UUID) extends BlockId { - def name = "temp_local_" + id + override def name: String = "temp_local_" + id } /** Id associated with temporary shuffle data managed as blocks. Not serializable. */ private[spark] case class TempShuffleBlockId(id: UUID) extends BlockId { - def name = "temp_shuffle_" + id + override def name: String = "temp_shuffle_" + id } // Intended only for testing purposes private[spark] case class TestBlockId(id: String) extends BlockId { - def name = "test_" + id + override def name: String = "test_" + id } @DeveloperApi @@ -112,7 +112,7 @@ object BlockId { val TEST = "test_(.*)".r /** Converts a BlockId "name" String back into a BlockId. */ - def apply(id: String) = id match { + def apply(id: String): BlockId = id match { case RDD(rddId, splitIndex) => RDDBlockId(rddId.toInt, splitIndex.toInt) case SHUFFLE(shuffleId, mapId, reduceId) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 8bc5a1cd18b6..55718e584c19 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -26,18 +26,18 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.util.Random -import akka.actor.{ActorSystem, Props} import sun.nio.ch.DirectBuffer import org.apache.spark._ -import org.apache.spark.executor._ +import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo -import org.apache.spark.serializer.Serializer +import org.apache.spark.rpc.RpcEnv +import org.apache.spark.serializer.{SerializerInstance, Serializer} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.util._ @@ -50,11 +50,8 @@ private[spark] case class ArrayValues(buffer: Array[Any]) extends BlockValues /* Class for returning a fetched block and associated metrics. */ private[spark] class BlockResult( val data: Iterator[Any], - readMethod: DataReadMethod.Value, - bytes: Long) { - val inputMetrics = new InputMetrics(readMethod) - inputMetrics.addBytesRead(bytes) -} + val readMethod: DataReadMethod.Value, + val bytes: Long) /** * Manager running on every node (driver and executors) which provides interfaces for putting and @@ -64,7 +61,7 @@ private[spark] class BlockResult( */ private[spark] class BlockManager( executorId: String, - actorSystem: ActorSystem, + rpcEnv: RpcEnv, val master: BlockManagerMaster, defaultSerializer: Serializer, maxMemory: Long, @@ -136,9 +133,9 @@ private[spark] class BlockManager( // Whether to compress shuffle output temporarily spilled to disk private val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) - private val slaveActor = actorSystem.actorOf( - Props(new BlockManagerSlaveActor(this, mapOutputTracker)), - name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) + private val slaveEndpoint = rpcEnv.setupEndpoint( + "BlockManagerEndpoint" + BlockManager.ID_GENERATOR.next, + new BlockManagerSlaveEndpoint(rpcEnv, this, mapOutputTracker)) // Pending re-registration action being executed asynchronously or null if none is pending. // Accesses should synchronize on asyncReregisterLock. @@ -167,7 +164,7 @@ private[spark] class BlockManager( */ def this( execId: String, - actorSystem: ActorSystem, + rpcEnv: RpcEnv, master: BlockManagerMaster, serializer: Serializer, conf: SparkConf, @@ -176,7 +173,7 @@ private[spark] class BlockManager( blockTransferService: BlockTransferService, securityManager: SecurityManager, numUsableCores: Int) = { - this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), + this(execId, rpcEnv, master, serializer, BlockManager.getMaxMemory(conf), conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) } @@ -186,7 +183,7 @@ private[spark] class BlockManager( * where it is only learned after registration with the TaskScheduler). * * This method initializes the BlockTransferService and ShuffleClient, registers with the - * BlockManagerMaster, starts the BlockManagerWorker actor, and registers with a local shuffle + * BlockManagerMaster, starts the BlockManagerWorker endpoint, and registers with a local shuffle * service if configured. */ def initialize(appId: String): Unit = { @@ -202,7 +199,7 @@ private[spark] class BlockManager( blockManagerId } - master.registerBlockManager(blockManagerId, maxMemory, slaveActor) + master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint) // Register Executors' configuration with the local shuffle service, if one should exist. if (externalShuffleServiceEnabled && !blockManagerId.isDriver) { @@ -265,7 +262,7 @@ private[spark] class BlockManager( def reregister(): Unit = { // TODO: We might need to rate limit re-registering. logInfo("BlockManager re-registering with master") - master.registerBlockManager(blockManagerId, maxMemory, slaveActor) + master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint) reportAllBlocks() } @@ -301,7 +298,7 @@ private[spark] class BlockManager( */ override def getBlockData(blockId: BlockId): ManagedBuffer = { if (blockId.isShuffle) { - shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) + shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) } else { val blockBytesOpt = doGetLocal(blockId, asBlockResult = false) .asInstanceOf[Option[ByteBuffer]] @@ -439,14 +436,10 @@ private[spark] class BlockManager( // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work if (blockId.isShuffle) { - val shuffleBlockManager = shuffleManager.shuffleBlockManager - shuffleBlockManager.getBytes(blockId.asInstanceOf[ShuffleBlockId]) match { - case Some(bytes) => - Some(bytes) - case None => - throw new BlockException( - blockId, s"Block $blockId not found on disk, though it should be") - } + val shuffleBlockManager = shuffleManager.shuffleBlockResolver + // TODO: This should gracefully handle case where local block is not available. Currently + // downstream code will throw an exception. + Option(shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer()) } else { doGetLocal(blockId, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] } @@ -535,9 +528,14 @@ private[spark] class BlockManager( /* We'll store the bytes in memory if the block's storage level includes * "memory serialized", or if it should be cached as objects in memory * but we only requested its serialized bytes. */ - val copyForMemory = ByteBuffer.allocate(bytes.limit) - copyForMemory.put(bytes) - memoryStore.putBytes(blockId, copyForMemory, level) + memoryStore.putBytes(blockId, bytes.limit, () => { + // https://issues.apache.org/jira/browse/SPARK-6076 + // If the file size is bigger than the free memory, OOM will happen. So if we cannot + // put it into MemoryStore, copyForMemory should not be created. That's why this + // action is put into a `() => ByteBuffer` and created lazily. + val copyForMemory = ByteBuffer.allocate(bytes.limit) + copyForMemory.put(bytes) + }) bytes.rewind() } if (!asBlockResult) { @@ -645,13 +643,13 @@ private[spark] class BlockManager( def getDiskWriter( blockId: BlockId, file: File, - serializer: Serializer, + serializerInstance: SerializerInstance, bufferSize: Int, writeMetrics: ShuffleWriteMetrics): BlockObjectWriter = { val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) - new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites, - writeMetrics) + new DiskBlockObjectWriter(blockId, file, serializerInstance, bufferSize, compressStream, + syncWrites, writeMetrics) } /** @@ -991,15 +989,23 @@ private[spark] class BlockManager( putIterator(blockId, Iterator(value), level, tellMaster) } + def dropFromMemory( + blockId: BlockId, + data: Either[Array[Any], ByteBuffer]): Option[BlockStatus] = { + dropFromMemory(blockId, () => data) + } + /** * Drop a block from memory, possibly putting it on disk if applicable. Called when the memory * store reaches its limit and needs to free up space. * + * If `data` is not put on disk, it won't be created. + * * Return the block status if the given block has been updated, else None. */ def dropFromMemory( blockId: BlockId, - data: Either[Array[Any], ByteBuffer]): Option[BlockStatus] = { + data: () => Either[Array[Any], ByteBuffer]): Option[BlockStatus] = { logInfo(s"Dropping block $blockId from memory") val info = blockInfo.get(blockId).orNull @@ -1023,7 +1029,7 @@ private[spark] class BlockManager( // Drop to disk, if storage level requires if (level.useDisk && !diskStore.contains(blockId)) { logInfo(s"Writing block $blockId to disk") - data match { + data() match { case Left(elements) => diskStore.putArray(blockId, elements, level, returnValues = false) case Right(bytes) => @@ -1074,7 +1080,7 @@ private[spark] class BlockManager( * Remove all blocks belonging to the given broadcast. */ def removeBroadcast(broadcastId: Long, tellMaster: Boolean): Int = { - logInfo(s"Removing broadcast $broadcastId") + logDebug(s"Removing broadcast $broadcastId") val blocksToRemove = blockInfo.keys.collect { case bid @ BroadcastBlockId(`broadcastId`, _) => bid } @@ -1086,7 +1092,7 @@ private[spark] class BlockManager( * Remove a block from both memory and disk. */ def removeBlock(blockId: BlockId, tellMaster: Boolean = true): Unit = { - logInfo(s"Removing block $blockId") + logDebug(s"Removing block $blockId") val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { @@ -1206,7 +1212,7 @@ private[spark] class BlockManager( shuffleClient.close() } diskBlockManager.stop() - actorSystem.stop(slaveActor) + rpcEnv.stop(slaveEndpoint) blockInfo.clear() memoryStore.clear() diskStore.clear() @@ -1245,10 +1251,10 @@ private[spark] object BlockManager extends Logging { } } - def blockIdsToBlockManagers( + def blockIdsToHosts( blockIds: Array[BlockId], env: SparkEnv, - blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[BlockManagerId]] = { + blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[String]] = { // blockManagerMaster != null is used in tests assert(env != null || blockManagerMaster != null) @@ -1258,24 +1264,10 @@ private[spark] object BlockManager extends Logging { blockManagerMaster.getLocations(blockIds) } - val blockManagers = new HashMap[BlockId, Seq[BlockManagerId]] + val blockManagers = new HashMap[BlockId, Seq[String]] for (i <- 0 until blockIds.length) { - blockManagers(blockIds(i)) = blockLocations(i) + blockManagers(blockIds(i)) = blockLocations(i).map(_.host) } blockManagers.toMap } - - def blockIdsToExecutorIds( - blockIds: Array[BlockId], - env: SparkEnv, - blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[String]] = { - blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.executorId)) - } - - def blockIdsToHosts( - blockIds: Array[BlockId], - env: SparkEnv, - blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[String]] = { - blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.host)) - } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index b177a59c721d..69ac37511e73 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -60,7 +60,10 @@ class BlockManagerId private ( def port: Int = port_ - def isDriver: Boolean = { executorId == SparkContext.DRIVER_IDENTIFIER } + def isDriver: Boolean = { + executorId == SparkContext.DRIVER_IDENTIFIER || + executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER + } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { out.writeUTF(executorId_) @@ -77,11 +80,11 @@ class BlockManagerId private ( @throws(classOf[IOException]) private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this) - override def toString = s"BlockManagerId($executorId, $host, $port)" + override def toString: String = s"BlockManagerId($executorId, $host, $port)" override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port - override def equals(that: Any) = that match { + override def equals(that: Any): Boolean = that match { case id: BlockManagerId => executorId == id.executorId && port == id.port && host == id.host case _ => @@ -100,10 +103,10 @@ private[spark] object BlockManagerId { * @param port Port of the block manager. * @return A new [[org.apache.spark.storage.BlockManagerId]]. */ - def apply(execId: String, host: String, port: Int) = + def apply(execId: String, host: String, port: Int): BlockManagerId = getCachedBlockManagerId(new BlockManagerId(execId, host, port)) - def apply(in: ObjectInput) = { + def apply(in: ObjectInput): BlockManagerId = { val obj = new BlockManagerId() obj.readExternal(in) getCachedBlockManagerId(obj) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index b63c7f191155..c798843bd5d8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -20,35 +20,31 @@ package org.apache.spark.storage import scala.concurrent.{Await, Future} import scala.concurrent.ExecutionContext.Implicits.global -import akka.actor._ - +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.AkkaUtils +import org.apache.spark.util.RpcUtils private[spark] class BlockManagerMaster( - var driverActor: ActorRef, + var driverEndpoint: RpcEndpointRef, conf: SparkConf, isDriver: Boolean) extends Logging { - private val AKKA_RETRY_ATTEMPTS: Int = AkkaUtils.numRetries(conf) - private val AKKA_RETRY_INTERVAL_MS: Int = AkkaUtils.retryWaitMs(conf) - - val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster" - val timeout = AkkaUtils.askTimeout(conf) + val timeout = RpcUtils.askTimeout(conf) - /** Remove a dead executor from the driver actor. This is only called on the driver side. */ + /** Remove a dead executor from the driver endpoint. This is only called on the driver side. */ def removeExecutor(execId: String) { tell(RemoveExecutor(execId)) logInfo("Removed " + execId + " successfully in removeExecutor") } /** Register the BlockManager's id with the driver. */ - def registerBlockManager(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + def registerBlockManager( + blockManagerId: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef): Unit = { logInfo("Trying to register BlockManager") - tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor)) + tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint)) logInfo("Registered BlockManager") } @@ -59,37 +55,37 @@ class BlockManagerMaster( memSize: Long, diskSize: Long, tachyonSize: Long): Boolean = { - val res = askDriverWithReply[Boolean]( + val res = driverEndpoint.askWithReply[Boolean]( UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize, tachyonSize)) - logInfo("Updated info of block " + blockId) + logDebug(s"Updated info of block $blockId") res } /** Get locations of the blockId from the driver */ def getLocations(blockId: BlockId): Seq[BlockManagerId] = { - askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId)) + driverEndpoint.askWithReply[Seq[BlockManagerId]](GetLocations(blockId)) } /** Get locations of multiple blockIds from the driver */ def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { - askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) + driverEndpoint.askWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) } /** * Check if block manager master has a block. Note that this can be used to check for only * those blocks that are reported to block manager master. */ - def contains(blockId: BlockId) = { + def contains(blockId: BlockId): Boolean = { !getLocations(blockId).isEmpty } /** Get ids of other nodes in the cluster from the driver */ def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = { - askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId)) + driverEndpoint.askWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId)) } - def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { - askDriverWithReply[Option[(String, Int)]](GetActorSystemHostPortForExecutor(executorId)) + def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { + driverEndpoint.askWithReply[Option[(String, Int)]](GetRpcHostPortForExecutor(executorId)) } /** @@ -97,12 +93,12 @@ class BlockManagerMaster( * blocks that the driver knows about. */ def removeBlock(blockId: BlockId) { - askDriverWithReply(RemoveBlock(blockId)) + driverEndpoint.askWithReply[Boolean](RemoveBlock(blockId)) } /** Remove all blocks belonging to the given RDD. */ def removeRdd(rddId: Int, blocking: Boolean) { - val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) + val future = driverEndpoint.askWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) future.onFailure { case e: Exception => logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}") @@ -114,7 +110,7 @@ class BlockManagerMaster( /** Remove all blocks belonging to the given shuffle. */ def removeShuffle(shuffleId: Int, blocking: Boolean) { - val future = askDriverWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) + val future = driverEndpoint.askWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) future.onFailure { case e: Exception => logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}") @@ -126,7 +122,7 @@ class BlockManagerMaster( /** Remove all blocks belonging to the given broadcast. */ def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) { - val future = askDriverWithReply[Future[Seq[Int]]]( + val future = driverEndpoint.askWithReply[Future[Seq[Int]]]( RemoveBroadcast(broadcastId, removeFromMaster)) future.onFailure { case e: Exception => @@ -145,11 +141,11 @@ class BlockManagerMaster( * amount of remaining memory. */ def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { - askDriverWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) + driverEndpoint.askWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) } def getStorageStatus: Array[StorageStatus] = { - askDriverWithReply[Array[StorageStatus]](GetStorageStatus) + driverEndpoint.askWithReply[Array[StorageStatus]](GetStorageStatus) } /** @@ -165,11 +161,12 @@ class BlockManagerMaster( askSlaves: Boolean = true): Map[BlockManagerId, BlockStatus] = { val msg = GetBlockStatus(blockId, askSlaves) /* - * To avoid potential deadlocks, the use of Futures is necessary, because the master actor + * To avoid potential deadlocks, the use of Futures is necessary, because the master endpoint * should not block on waiting for a block manager, which can in turn be waiting for the - * master actor for a response to a prior message. + * master endpoint for a response to a prior message. */ - val response = askDriverWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) + val response = driverEndpoint. + askWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) val (blockManagerIds, futures) = response.unzip val result = Await.result(Future.sequence(futures), timeout) if (result == null) { @@ -193,33 +190,28 @@ class BlockManagerMaster( filter: BlockId => Boolean, askSlaves: Boolean): Seq[BlockId] = { val msg = GetMatchingBlockIds(filter, askSlaves) - val future = askDriverWithReply[Future[Seq[BlockId]]](msg) + val future = driverEndpoint.askWithReply[Future[Seq[BlockId]]](msg) Await.result(future, timeout) } - /** Stop the driver actor, called only on the Spark driver node */ + /** Stop the driver endpoint, called only on the Spark driver node */ def stop() { - if (driverActor != null && isDriver) { + if (driverEndpoint != null && isDriver) { tell(StopBlockManagerMaster) - driverActor = null + driverEndpoint = null logInfo("BlockManagerMaster stopped") } } - /** Send a one-way message to the master actor, to which we expect it to reply with true. */ + /** Send a one-way message to the master endpoint, to which we expect it to reply with true. */ private def tell(message: Any) { - if (!askDriverWithReply[Boolean](message)) { - throw new SparkException("BlockManagerMasterActor returned false, expected true.") + if (!driverEndpoint.askWithReply[Boolean](message)) { + throw new SparkException("BlockManagerMasterEndpoint returned false, expected true.") } } - /** - * Send a message to the driver actor and get its result within a default timeout, or - * throw a SparkException if this fails. - */ - private def askDriverWithReply[T](message: Any): T = { - AkkaUtils.askWithReply(message, driverActor, AKKA_RETRY_ATTEMPTS, AKKA_RETRY_INTERVAL_MS, - timeout) - } +} +private[spark] object BlockManagerMaster { + val DRIVER_ENDPOINT_NAME = "BlockManagerMaster" } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala similarity index 78% rename from core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala rename to core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 64133464d8da..4682167912ff 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -21,25 +21,26 @@ import java.util.{HashMap => JHashMap} import scala.collection.mutable import scala.collection.JavaConversions._ -import scala.concurrent.Future -import scala.concurrent.duration._ +import scala.concurrent.{ExecutionContext, Future} -import akka.actor.{Actor, ActorRef, Cancellable} -import akka.pattern.ask - -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, ThreadSafeRpcEndpoint} +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils} +import org.apache.spark.util.{ThreadUtils, Utils} /** - * BlockManagerMasterActor is an actor on the master node to track statuses of - * all slaves' block managers. + * BlockManagerMasterEndpoint is an [[ThreadSafeRpcEndpoint]] on the master node to track statuses + * of all slaves' block managers. */ private[spark] -class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus: LiveListenerBus) - extends Actor with ActorLogReceive with Logging { +class BlockManagerMasterEndpoint( + override val rpcEnv: RpcEnv, + val isLocal: Boolean, + conf: SparkConf, + listenerBus: LiveListenerBus) + extends ThreadSafeRpcEndpoint with Logging { // Mapping from block manager id to the block manager's information. private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo] @@ -50,87 +51,67 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Mapping from block id to the set of block managers that have the block. private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] - private val akkaTimeout = AkkaUtils.askTimeout(conf) - - val slaveTimeout = conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", 120 * 1000) + private val askThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-ask-thread-pool") + private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool) - val checkTimeoutInterval = conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000) - - var timeoutCheckingTask: Cancellable = null - - override def preStart() { - import context.dispatcher - timeoutCheckingTask = context.system.scheduler.schedule(0.seconds, - checkTimeoutInterval.milliseconds, self, ExpireDeadHosts) - super.preStart() - } - - override def receiveWithLogging = { - case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => - register(blockManagerId, maxMemSize, slaveActor) - sender ! true + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint) => + register(blockManagerId, maxMemSize, slaveEndpoint) + context.reply(true) case UpdateBlockInfo( blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) => - sender ! updateBlockInfo( - blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) + context.reply(updateBlockInfo( + blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize)) case GetLocations(blockId) => - sender ! getLocations(blockId) + context.reply(getLocations(blockId)) case GetLocationsMultipleBlockIds(blockIds) => - sender ! getLocationsMultipleBlockIds(blockIds) + context.reply(getLocationsMultipleBlockIds(blockIds)) case GetPeers(blockManagerId) => - sender ! getPeers(blockManagerId) + context.reply(getPeers(blockManagerId)) - case GetActorSystemHostPortForExecutor(executorId) => - sender ! getActorSystemHostPortForExecutor(executorId) + case GetRpcHostPortForExecutor(executorId) => + context.reply(getRpcHostPortForExecutor(executorId)) case GetMemoryStatus => - sender ! memoryStatus + context.reply(memoryStatus) case GetStorageStatus => - sender ! storageStatus + context.reply(storageStatus) case GetBlockStatus(blockId, askSlaves) => - sender ! blockStatus(blockId, askSlaves) + context.reply(blockStatus(blockId, askSlaves)) case GetMatchingBlockIds(filter, askSlaves) => - sender ! getMatchingBlockIds(filter, askSlaves) + context.reply(getMatchingBlockIds(filter, askSlaves)) case RemoveRdd(rddId) => - sender ! removeRdd(rddId) + context.reply(removeRdd(rddId)) case RemoveShuffle(shuffleId) => - sender ! removeShuffle(shuffleId) + context.reply(removeShuffle(shuffleId)) case RemoveBroadcast(broadcastId, removeFromDriver) => - sender ! removeBroadcast(broadcastId, removeFromDriver) + context.reply(removeBroadcast(broadcastId, removeFromDriver)) case RemoveBlock(blockId) => removeBlockFromWorkers(blockId) - sender ! true + context.reply(true) case RemoveExecutor(execId) => removeExecutor(execId) - sender ! true + context.reply(true) case StopBlockManagerMaster => - sender ! true - if (timeoutCheckingTask != null) { - timeoutCheckingTask.cancel() - } - context.stop(self) - - case ExpireDeadHosts => - expireDeadHosts() + context.reply(true) + stop() case BlockManagerHeartbeat(blockManagerId) => - sender ! heartbeatReceived(blockManagerId) + context.reply(heartbeatReceived(blockManagerId)) - case other => - logWarning("Got unknown message: " + other) } private def removeRdd(rddId: Int): Future[Seq[Int]] = { @@ -148,22 +129,20 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Ask the slaves to remove the RDD, and put the result in a sequence of Futures. // The dispatcher is used as an implicit argument into the Future sequence construction. - import context.dispatcher val removeMsg = RemoveRdd(rddId) Future.sequence( blockManagerInfo.values.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] + bm.slaveEndpoint.sendWithReply[Int](removeMsg) }.toSeq ) } private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = { - // Nothing to do in the BlockManagerMasterActor data structures - import context.dispatcher + // Nothing to do in the BlockManagerMasterEndpoint data structures val removeMsg = RemoveShuffle(shuffleId) Future.sequence( blockManagerInfo.values.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Boolean] + bm.slaveEndpoint.sendWithReply[Boolean](removeMsg) }.toSeq ) } @@ -174,14 +153,13 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus * from the executors, but not from the driver. */ private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = { - import context.dispatcher val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver) val requiredBlockManagers = blockManagerInfo.values.filter { info => removeFromDriver || !info.blockManagerId.isDriver } Future.sequence( requiredBlockManagers.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] + bm.slaveEndpoint.sendWithReply[Int](removeMsg) }.toSeq ) } @@ -207,21 +185,6 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus logInfo(s"Removing block manager $blockManagerId") } - private def expireDeadHosts() { - logTrace("Checking for hosts with no recent heart beats in BlockManagerMaster.") - val now = System.currentTimeMillis() - val minSeenTime = now - slaveTimeout - val toRemove = new mutable.HashSet[BlockManagerId] - for (info <- blockManagerInfo.values) { - if (info.lastSeenMs < minSeenTime && !info.blockManagerId.isDriver) { - logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats: " - + (now - info.lastSeenMs) + "ms exceeds " + slaveTimeout + "ms") - toRemove += info.blockManagerId - } - } - toRemove.foreach(removeBlockManager) - } - private def removeExecutor(execId: String) { logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.") blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) @@ -251,7 +214,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Remove the block from the slave's BlockManager. // Doesn't actually wait for a confirmation and the message might get lost. // If message loss becomes frequent, we should add retry logic here. - blockManager.get.slaveActor.ask(RemoveBlock(blockId))(akkaTimeout) + blockManager.get.slaveEndpoint.sendWithReply[Boolean](RemoveBlock(blockId)) } } } @@ -281,17 +244,16 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private def blockStatus( blockId: BlockId, askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = { - import context.dispatcher val getBlockStatus = GetBlockStatus(blockId) /* - * Rather than blocking on the block status query, master actor should simply return + * Rather than blocking on the block status query, master endpoint should simply return * Futures to avoid potential deadlocks. This can arise if there exists a block manager - * that is also waiting for this master actor's response to a previous message. + * that is also waiting for this master endpoint's response to a previous message. */ blockManagerInfo.values.map { info => val blockStatusFuture = if (askSlaves) { - info.slaveActor.ask(getBlockStatus)(akkaTimeout).mapTo[Option[BlockStatus]] + info.slaveEndpoint.sendWithReply[Option[BlockStatus]](getBlockStatus) } else { Future { info.getStatus(blockId) } } @@ -310,13 +272,12 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private def getMatchingBlockIds( filter: BlockId => Boolean, askSlaves: Boolean): Future[Seq[BlockId]] = { - import context.dispatcher val getMatchingBlockIds = GetMatchingBlockIds(filter) Future.sequence( blockManagerInfo.values.map { info => val future = if (askSlaves) { - info.slaveActor.ask(getMatchingBlockIds)(akkaTimeout).mapTo[Seq[BlockId]] + info.slaveEndpoint.sendWithReply[Seq[BlockId]](getMatchingBlockIds) } else { Future { info.blocks.keys.filter(filter).toSeq } } @@ -325,7 +286,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus ).map(_.flatten.toSeq) } - private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + private def register(id: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef) { val time = System.currentTimeMillis() if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { @@ -342,7 +303,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus blockManagerIdByExecutor(id.executorId) = id blockManagerInfo(id) = new BlockManagerInfo( - id, System.currentTimeMillis(), maxMemSize, slaveActor) + id, System.currentTimeMillis(), maxMemSize, slaveEndpoint) } listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize)) } @@ -413,19 +374,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus } /** - * Returns the hostname and port of an executor's actor system, based on the Akka address of its - * BlockManagerSlaveActor. + * Returns the hostname and port of an executor, based on the [[RpcEnv]] address of its + * [[BlockManagerSlaveEndpoint]]. */ - private def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { + private def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { for ( blockManagerId <- blockManagerIdByExecutor.get(executorId); - info <- blockManagerInfo.get(blockManagerId); - host <- info.slaveActor.path.address.host; - port <- info.slaveActor.path.address.port + info <- blockManagerInfo.get(blockManagerId) ) yield { - (host, port) + (info.slaveEndpoint.address.host, info.slaveEndpoint.address.port) } } + + override def onStop(): Unit = { + askThreadPool.shutdownNow() + } } @DeveloperApi @@ -446,7 +409,7 @@ private[spark] class BlockManagerInfo( val blockManagerId: BlockManagerId, timeMs: Long, val maxMem: Long, - val slaveActor: ActorRef) + val slaveEndpoint: RpcEndpointRef) extends Logging { private var _lastSeenMs: Long = timeMs @@ -455,7 +418,7 @@ private[spark] class BlockManagerInfo( // Mapping from block id to its status. private val _blocks = new JHashMap[BlockId, BlockStatus] - def getStatus(blockId: BlockId) = Option(_blocks.get(blockId)) + def getStatus(blockId: BlockId): Option[BlockStatus] = Option(_blocks.get(blockId)) def updateLastSeenMs() { _lastSeenMs = System.currentTimeMillis() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 3f32099d08cc..f89d8d7493f7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -19,8 +19,7 @@ package org.apache.spark.storage import java.io.{Externalizable, ObjectInput, ObjectOutput} -import akka.actor.ActorRef - +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] object BlockManagerMessages { @@ -52,7 +51,7 @@ private[spark] object BlockManagerMessages { case class RegisterBlockManager( blockManagerId: BlockManagerId, maxMemSize: Long, - sender: ActorRef) + sender: RpcEndpointRef) extends ToBlockManagerMaster case class UpdateBlockInfo( @@ -92,7 +91,7 @@ private[spark] object BlockManagerMessages { case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster - case class GetActorSystemHostPortForExecutor(executorId: String) extends ToBlockManagerMaster + case class GetRpcHostPortForExecutor(executorId: String) extends ToBlockManagerMaster case class RemoveExecutor(execId: String) extends ToBlockManagerMaster @@ -109,6 +108,4 @@ private[spark] object BlockManagerMessages { extends ToBlockManagerMaster case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster - - case object ExpireDeadHosts extends ToBlockManagerMaster } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala similarity index 61% rename from core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala rename to core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index 8462871e798a..543df4e1350d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -17,41 +17,43 @@ package org.apache.spark.storage -import scala.concurrent.Future - -import akka.actor.{ActorRef, Actor} +import scala.concurrent.{ExecutionContext, Future} +import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} +import org.apache.spark.util.ThreadUtils import org.apache.spark.{Logging, MapOutputTracker, SparkEnv} import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.ActorLogReceive /** - * An actor to take commands from the master to execute options. For example, + * An RpcEndpoint to take commands from the master to execute options. For example, * this is used to remove blocks from the slave's BlockManager. */ private[storage] -class BlockManagerSlaveActor( +class BlockManagerSlaveEndpoint( + override val rpcEnv: RpcEnv, blockManager: BlockManager, mapOutputTracker: MapOutputTracker) - extends Actor with ActorLogReceive with Logging { + extends RpcEndpoint with Logging { - import context.dispatcher + private val asyncThreadPool = + ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool") + private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) // Operations that involve removing blocks may be slow and should be done asynchronously - override def receiveWithLogging = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RemoveBlock(blockId) => - doAsync[Boolean]("removing block " + blockId, sender) { + doAsync[Boolean]("removing block " + blockId, context) { blockManager.removeBlock(blockId) true } case RemoveRdd(rddId) => - doAsync[Int]("removing RDD " + rddId, sender) { + doAsync[Int]("removing RDD " + rddId, context) { blockManager.removeRdd(rddId) } case RemoveShuffle(shuffleId) => - doAsync[Boolean]("removing shuffle " + shuffleId, sender) { + doAsync[Boolean]("removing shuffle " + shuffleId, context) { if (mapOutputTracker != null) { mapOutputTracker.unregisterShuffle(shuffleId) } @@ -59,30 +61,34 @@ class BlockManagerSlaveActor( } case RemoveBroadcast(broadcastId, _) => - doAsync[Int]("removing broadcast " + broadcastId, sender) { + doAsync[Int]("removing broadcast " + broadcastId, context) { blockManager.removeBroadcast(broadcastId, tellMaster = true) } case GetBlockStatus(blockId, _) => - sender ! blockManager.getStatus(blockId) + context.reply(blockManager.getStatus(blockId)) case GetMatchingBlockIds(filter, _) => - sender ! blockManager.getMatchingBlockIds(filter) + context.reply(blockManager.getMatchingBlockIds(filter)) } - private def doAsync[T](actionMessage: String, responseActor: ActorRef)(body: => T) { + private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) { val future = Future { logDebug(actionMessage) body } future.onSuccess { case response => logDebug("Done " + actionMessage + ", response is " + response) - responseActor ! response - logDebug("Sent response: " + response + " to " + responseActor) + context.reply(response) + logDebug("Sent response: " + response + " to " + context.sender) } future.onFailure { case t: Throwable => logError("Error in " + actionMessage, t) - responseActor ! null.asInstanceOf[T] + context.sendFailure(t) } } + + override def onStop(): Unit = { + asyncThreadPool.shutdownNow() + } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 3198d766fca3..14833791f7a4 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -21,15 +21,17 @@ import java.io.{BufferedOutputStream, FileOutputStream, File, OutputStream} import java.nio.channels.FileChannel import org.apache.spark.Logging -import org.apache.spark.serializer.{SerializationStream, Serializer} +import org.apache.spark.serializer.{SerializerInstance, SerializationStream} import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.util.Utils /** * An interface for writing JVM objects to some underlying storage. This interface allows * appending data to an existing block, and can guarantee atomicity in the case of faults * as it allows the caller to revert partial writes. * - * This interface does not support concurrent writes. + * This interface does not support concurrent writes. Also, once the writer has + * been opened, it cannot be reopened again. */ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) { @@ -69,7 +71,7 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) { private[spark] class DiskBlockObjectWriter( blockId: BlockId, file: File, - serializer: Serializer, + serializerInstance: SerializerInstance, bufferSize: Int, compressStream: OutputStream => OutputStream, syncWrites: Boolean, @@ -81,11 +83,13 @@ private[spark] class DiskBlockObjectWriter( { /** Intercepts write calls and tracks total time spent writing. Not thread safe. */ private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream { - def write(i: Int): Unit = callWithTiming(out.write(i)) - override def write(b: Array[Byte]) = callWithTiming(out.write(b)) - override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len)) - override def close() = out.close() - override def flush() = out.flush() + override def write(i: Int): Unit = callWithTiming(out.write(i)) + override def write(b: Array[Byte]): Unit = callWithTiming(out.write(b)) + override def write(b: Array[Byte], off: Int, len: Int): Unit = { + callWithTiming(out.write(b, off, len)) + } + override def close(): Unit = out.close() + override def flush(): Unit = out.flush() } /** The file channel, used for repositioning / truncating the file. */ @@ -95,6 +99,7 @@ private[spark] class DiskBlockObjectWriter( private var ts: TimeTrackingOutputStream = null private var objOut: SerializationStream = null private var initialized = false + private var hasBeenClosed = false /** * Cursors used to represent positions in the file. @@ -115,29 +120,38 @@ private[spark] class DiskBlockObjectWriter( private var finalPosition: Long = -1 private var reportedPosition = initialPosition - /** Calling channel.position() to update the write metrics can be a little bit expensive, so we - * only call it every N writes */ - private var writesSinceMetricsUpdate = 0 + /** + * Keep track of number of records written and also use this to periodically + * output bytes written since the latter is expensive to do for each record. + */ + private var numRecordsWritten = 0 override def open(): BlockObjectWriter = { + if (hasBeenClosed) { + throw new IllegalStateException("Writer already closed. Cannot be reopened.") + } fos = new FileOutputStream(file, true) ts = new TimeTrackingOutputStream(fos) channel = fos.getChannel() bs = compressStream(new BufferedOutputStream(ts, bufferSize)) - objOut = serializer.newInstance().serializeStream(bs) + objOut = serializerInstance.serializeStream(bs) initialized = true this } override def close() { if (initialized) { - if (syncWrites) { - // Force outstanding writes to disk and track how long it takes - objOut.flush() - def sync = fos.getFD.sync() - callWithTiming(sync) + Utils.tryWithSafeFinally { + if (syncWrites) { + // Force outstanding writes to disk and track how long it takes + objOut.flush() + callWithTiming { + fos.getFD.sync() + } + } + } { + objOut.close() } - objOut.close() channel = null bs = null @@ -145,6 +159,7 @@ private[spark] class DiskBlockObjectWriter( ts = null objOut = null initialized = false + hasBeenClosed = true } } @@ -168,6 +183,7 @@ private[spark] class DiskBlockObjectWriter( override def revertPartialWritesAndClose() { try { writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition) + writeMetrics.decShuffleRecordsWritten(numRecordsWritten) if (initialized) { objOut.flush() @@ -193,12 +209,11 @@ private[spark] class DiskBlockObjectWriter( } objOut.writeObject(value) + numRecordsWritten += 1 + writeMetrics.incShuffleRecordsWritten(1) - if (writesSinceMetricsUpdate == 32) { - writesSinceMetricsUpdate = 0 + if (numRecordsWritten % 32 == 0) { updateBytesWritten() - } else { - writesSinceMetricsUpdate += 1 } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index af05eb3ca69c..7ea5e54f9e1f 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -17,9 +17,8 @@ package org.apache.spark.storage +import java.util.UUID import java.io.{IOException, File} -import java.text.SimpleDateFormat -import java.util.{Date, Random, UUID} import org.apache.spark.{SparkConf, Logging} import org.apache.spark.executor.ExecutorExitCode @@ -37,7 +36,6 @@ import org.apache.spark.util.Utils private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkConf) extends Logging { - private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 private[spark] val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64) @@ -49,9 +47,11 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon logError("Failed to create any local dir.") System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) } + // The content of subDirs is immutable but the content of subDirs(i) is mutable. And the content + // of subDirs(i) is protected by the lock of subDirs(i) private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) - addShutdownHook() + private val shutdownHook = addShutdownHook() /** Looks up a file by hashing it into one of our local subdirectories. */ // This method should be kept in sync with @@ -63,20 +63,17 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon val subDirId = (hash / localDirs.length) % subDirsPerLocalDir // Create the subdirectory if it doesn't already exist - var subDir = subDirs(dirId)(subDirId) - if (subDir == null) { - subDir = subDirs(dirId).synchronized { - val old = subDirs(dirId)(subDirId) - if (old != null) { - old - } else { - val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) - if (!newDir.exists() && !newDir.mkdir()) { - throw new IOException(s"Failed to create local dir in $newDir.") - } - subDirs(dirId)(subDirId) = newDir - newDir + val subDir = subDirs(dirId).synchronized { + val old = subDirs(dirId)(subDirId) + if (old != null) { + old + } else { + val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) + if (!newDir.exists() && !newDir.mkdir()) { + throw new IOException(s"Failed to create local dir in $newDir.") } + subDirs(dirId)(subDirId) = newDir + newDir } } @@ -93,7 +90,12 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon /** List all the files currently stored on disk by the disk manager. */ def getAllFiles(): Seq[File] = { // Get all the files inside the array of array of directories - subDirs.flatten.filter(_ != null).flatMap { dir => + subDirs.flatMap { dir => + dir.synchronized { + // Copy the content of dir because it may be modified in other threads + dir.clone() + } + }.filter(_ != null).flatMap { dir => val files = dir.listFiles() if (files != null) files else Seq.empty } @@ -123,48 +125,34 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon } private def createLocalDirs(conf: SparkConf): Array[File] = { - val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") Utils.getOrCreateLocalRootDirs(conf).flatMap { rootDir => - var foundLocalDir = false - var localDir: File = null - var localDirId: String = null - var tries = 0 - val rand = new Random() - while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) { - tries += 1 - try { - localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) - localDir = new File(rootDir, s"spark-local-$localDirId") - if (!localDir.exists) { - foundLocalDir = localDir.mkdirs() - } - } catch { - case e: Exception => - logWarning(s"Attempt $tries to create local dir $localDir failed", e) - } - } - if (!foundLocalDir) { - logError(s"Failed $MAX_DIR_CREATION_ATTEMPTS attempts to create local dir in $rootDir." + - " Ignoring this directory.") - None - } else { + try { + val localDir = Utils.createDirectory(rootDir, "blockmgr") logInfo(s"Created local directory at $localDir") Some(localDir) + } catch { + case e: IOException => + logError(s"Failed to create local dir in $rootDir. Ignoring this directory.", e) + None } } } - private def addShutdownHook() { - Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { - override def run(): Unit = Utils.logUncaughtExceptions { - logDebug("Shutdown hook called") - DiskBlockManager.this.stop() - } - }) + private def addShutdownHook(): AnyRef = { + Utils.addShutdownHook { () => + logDebug("Shutdown hook called") + DiskBlockManager.this.doStop() + } } /** Cleanup local dirs and stop shuffle sender. */ private[spark] def stop() { + // Remove the shutdown hook. It causes memory leaks if we leave it around. + Utils.removeShutdownHook(shutdownHook) + doStop() + } + + private def doStop(): Unit = { // Only perform cleanup if an external service is not serving our shuffle files. if (!blockManager.externalShuffleServiceEnabled || blockManager.blockManagerId.isDriver) { localDirs.foreach { localDir => diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 61ef5ff16879..4b232ae7d318 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -46,10 +46,13 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc val startTime = System.currentTimeMillis val file = diskManager.getFile(blockId) val channel = new FileOutputStream(file).getChannel - while (bytes.remaining > 0) { - channel.write(bytes) + Utils.tryWithSafeFinally { + while (bytes.remaining > 0) { + channel.write(bytes) + } + } { + channel.close() } - channel.close() val finishTime = System.currentTimeMillis logDebug("Block %s stored as %s file on disk in %d ms".format( file.getName, Utils.bytesToString(bytes.limit), finishTime - startTime)) @@ -75,9 +78,9 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc val file = diskManager.getFile(blockId) val outputStream = new FileOutputStream(file) try { - try { + Utils.tryWithSafeFinally { blockManager.dataSerializeStream(blockId, outputStream, values) - } finally { + } { // Close outputStream here because it should be closed before file is deleted. outputStream.close() } @@ -106,8 +109,7 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc private def getBytes(file: File, offset: Long, length: Long): Option[ByteBuffer] = { val channel = new RandomAccessFile(file, "r").getChannel - - try { + Utils.tryWithSafeFinally { // For small files, directly read rather than memory map if (length < minMemoryMapBytes) { val buf = ByteBuffer.allocate(length.toInt) @@ -123,7 +125,7 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc } else { Some(channel.map(MapMode.READ_ONLY, offset, length)) } - } finally { + } { channel.close() } } diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala index 132502b75f8c..95e2d688d9b1 100644 --- a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala +++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala @@ -24,5 +24,7 @@ import java.io.File * based off an offset and a length. */ private[spark] class FileSegment(val file: File, val offset: Long, val length: Long) { - override def toString = "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length) + override def toString: String = { + "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length) + } } diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 71305a46bf57..ed609772e697 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -46,6 +46,14 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // A mapping from thread ID to amount of memory used for unrolling a block (in bytes) // All accesses of this map are assumed to have manually synchronized on `accountingLock` private val unrollMemoryMap = mutable.HashMap[Long, Long]() + // Same as `unrollMemoryMap`, but for pending unroll memory as defined below. + // Pending unroll memory refers to the intermediate memory occupied by a thread + // after the unroll but before the actual putting of the block in the cache. + // This chunk of memory is expected to be released *as soon as* we finish + // caching the corresponding block as opposed to until after the task finishes. + // This is only used if a block is successfully unrolled in its entirety in + // memory (SPARK-4777). + private val pendingUnrollMemoryMap = mutable.HashMap[Long, Long]() /** * The amount of space ensured for unrolling values in memory, shared across all cores. @@ -90,6 +98,26 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } + /** + * Use `size` to test if there is enough space in MemoryStore. If so, create the ByteBuffer and + * put it into MemoryStore. Otherwise, the ByteBuffer won't be created. + * + * The caller should guarantee that `size` is correct. + */ + def putBytes(blockId: BlockId, size: Long, _bytes: () => ByteBuffer): PutResult = { + // Work on a duplicate - since the original input might be used elsewhere. + lazy val bytes = _bytes().duplicate().rewind().asInstanceOf[ByteBuffer] + val putAttempt = tryToPut(blockId, () => bytes, size, deserialized = false) + val data = + if (putAttempt.success) { + assert(bytes.limit == size) + Right(bytes.duplicate()) + } else { + null + } + PutResult(size, data, putAttempt.droppedBlocks) + } + override def putArray( blockId: BlockId, values: Array[Any], @@ -184,7 +212,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) val entry = entries.remove(blockId) if (entry != null) { currentMemory -= entry.size - logInfo(s"Block $blockId of size ${entry.size} dropped from memory (free $freeMemory)") + logDebug(s"Block $blockId of size ${entry.size} dropped from memory (free $freeMemory)") true } else { false @@ -283,12 +311,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } finally { - // If we return an array, the values returned do not depend on the underlying vector and - // we can immediately free up space for other threads. Otherwise, if we return an iterator, - // we release the memory claimed by this thread later on when the task finishes. + // If we return an array, the values returned will later be cached in `tryToPut`. + // In this case, we should release the memory after we cache the block there. + // Otherwise, if we return an iterator, we release the memory reserved here + // later when the task finishes. if (keepUnrolling) { - val amountToRelease = currentUnrollMemoryForThisThread - previousMemoryReserved - releaseUnrollMemoryForThisThread(amountToRelease) + accountingLock.synchronized { + val amountToRelease = currentUnrollMemoryForThisThread - previousMemoryReserved + releaseUnrollMemoryForThisThread(amountToRelease) + reservePendingUnrollMemoryForThisThread(amountToRelease) + } } } } @@ -300,11 +332,22 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) blockId.asRDDId.map(_.rddId) } + private def tryToPut( + blockId: BlockId, + value: Any, + size: Long, + deserialized: Boolean): ResultWithDroppedBlocks = { + tryToPut(blockId, () => value, size, deserialized) + } + /** * Try to put in a set of values, if we can free up enough space. The value should either be * an Array if deserialized is true or a ByteBuffer otherwise. Its (possibly estimated) size * must also be passed by the caller. * + * `value` will be lazily created. If it cannot be put into MemoryStore or disk, `value` won't be + * created to avoid OOM since it may be a big ByteBuffer. + * * Synchronize on `accountingLock` to ensure that all the put requests and its associated block * dropping is done by only on thread at a time. Otherwise while one thread is dropping * blocks to free memory for one block, another thread may use up the freed space for @@ -314,7 +357,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) */ private def tryToPut( blockId: BlockId, - value: Any, + value: () => Any, size: Long, deserialized: Boolean): ResultWithDroppedBlocks = { @@ -333,7 +376,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) droppedBlocks ++= freeSpaceResult.droppedBlocks if (enoughFreeSpace) { - val entry = new MemoryEntry(value, size, deserialized) + val entry = new MemoryEntry(value(), size, deserialized) entries.synchronized { entries.put(blockId, entry) currentMemory += size @@ -345,14 +388,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } else { // Tell the block manager that we couldn't put it in memory so that it can drop it to // disk if the block allows disk storage. - val data = if (deserialized) { - Left(value.asInstanceOf[Array[Any]]) + lazy val data = if (deserialized) { + Left(value().asInstanceOf[Array[Any]]) } else { - Right(value.asInstanceOf[ByteBuffer].duplicate()) + Right(value().asInstanceOf[ByteBuffer].duplicate()) } - val droppedBlockStatus = blockManager.dropFromMemory(blockId, data) + val droppedBlockStatus = blockManager.dropFromMemory(blockId, () => data) droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) } } + // Release the unroll memory used because we no longer need the underlying Array + releasePendingUnrollMemoryForThisThread() } ResultWithDroppedBlocks(putSuccess, droppedBlocks) } @@ -381,7 +426,10 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } // Take into account the amount of memory currently occupied by unrolling blocks - val actualFreeMemory = freeMemory - currentUnrollMemory + // and minus the pending unroll memory for that block on current thread. + val threadId = Thread.currentThread().getId + val actualFreeMemory = freeMemory - currentUnrollMemory + + pendingUnrollMemoryMap.getOrElse(threadId, 0L) if (actualFreeMemory < space) { val rddToAdd = getRddId(blockIdToAdd) @@ -468,11 +516,32 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } + /** + * Reserve the unroll memory of current unroll successful block used by this thread + * until actually put the block into memory entry. + */ + def reservePendingUnrollMemoryForThisThread(memory: Long): Unit = { + val threadId = Thread.currentThread().getId + accountingLock.synchronized { + pendingUnrollMemoryMap(threadId) = pendingUnrollMemoryMap.getOrElse(threadId, 0L) + memory + } + } + + /** + * Release pending unroll memory of current unroll successful block used by this thread + */ + def releasePendingUnrollMemoryForThisThread(): Unit = { + val threadId = Thread.currentThread().getId + accountingLock.synchronized { + pendingUnrollMemoryMap.remove(threadId) + } + } + /** * Return the amount of memory currently occupied for unrolling blocks across all threads. */ def currentUnrollMemory: Long = accountingLock.synchronized { - unrollMemoryMap.values.sum + unrollMemoryMap.values.sum + pendingUnrollMemoryMap.values.sum } /** diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 120c327a7e58..034525b56f59 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -36,7 +36,7 @@ class RDDInfo( def isCached: Boolean = (memSize + diskSize + tachyonSize > 0) && numCachedPartitions > 0 - override def toString = { + override def toString: String = { import Utils.bytesToString ("RDD \"%s\" (%d) StorageLevel: %s; CachedPartitions: %d; TotalPartitions: %d; " + "MemorySize: %s; TachyonSize: %s; DiskSize: %s").format( @@ -44,7 +44,7 @@ class RDDInfo( bytesToString(memSize), bytesToString(tachyonSize), bytesToString(diskSize)) } - override def compare(that: RDDInfo) = { + override def compare(that: RDDInfo): Int = { this.id - that.id } } @@ -52,6 +52,6 @@ class RDDInfo( private[spark] object RDDInfo { def fromRdd(rdd: RDD[_]): RDDInfo = { val rddName = Option(rdd.name).getOrElse(rdd.id.toString) - new RDDInfo(rdd.id, rddName, rdd.partitions.size, rdd.getStorageLevel) + new RDDInfo(rdd.id, rddName, rdd.partitions.length, rdd.getStorageLevel) } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index ab9ee4f0096b..f3379521d55e 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -27,7 +27,7 @@ import org.apache.spark.{Logging, TaskContext} import org.apache.spark.network.BlockTransferService import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.serializer.Serializer +import org.apache.spark.serializer.{SerializerInstance, Serializer} import org.apache.spark.util.{CompletionIterator, Utils} /** @@ -106,6 +106,8 @@ final class ShuffleBlockFetcherIterator( private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + private[this] val serializerInstance: SerializerInstance = serializer.newInstance() + /** * Whether the iterator is still active. If isZombie is true, the callback interface will no * longer place fetched blocks into [[results]]. @@ -234,6 +236,7 @@ final class ShuffleBlockFetcherIterator( try { val buf = blockManager.getBlockData(blockId) shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() results.put(new SuccessFetchResult(blockId, 0, buf)) } catch { @@ -298,7 +301,7 @@ final class ShuffleBlockFetcherIterator( // the scheduler gets a FetchFailedException. Try(buf.createInputStream()).map { is0 => val is = blockManager.wrapForCompression(blockId, is0) - val iter = serializer.newInstance().deserializeStream(is).asIterator + val iter = serializerInstance.deserializeStream(is).asIterator CompletionIterator[Any, Iterator[Any]](iter, { // Once the iterator is exhausted, release the buffer and set currentResult to null // so we don't release it again in cleanup. diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index e5e1cf5a69a1..134abea86621 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -50,11 +50,11 @@ class StorageLevel private( def this() = this(false, true, false, false) // For deserialization - def useDisk = _useDisk - def useMemory = _useMemory - def useOffHeap = _useOffHeap - def deserialized = _deserialized - def replication = _replication + def useDisk: Boolean = _useDisk + def useMemory: Boolean = _useMemory + def useOffHeap: Boolean = _useOffHeap + def deserialized: Boolean = _deserialized + def replication: Int = _replication assert(replication < 40, "Replication restricted to be less than 40 for calculating hash codes") @@ -80,7 +80,7 @@ class StorageLevel private( false } - def isValid = (useMemory || useDisk || useOffHeap) && (replication > 0) + def isValid: Boolean = (useMemory || useDisk || useOffHeap) && (replication > 0) def toInt: Int = { var ret = 0 @@ -183,7 +183,7 @@ object StorageLevel { useMemory: Boolean, useOffHeap: Boolean, deserialized: Boolean, - replication: Int) = { + replication: Int): StorageLevel = { getCachedStorageLevel( new StorageLevel(useDisk, useMemory, useOffHeap, deserialized, replication)) } @@ -197,7 +197,7 @@ object StorageLevel { useDisk: Boolean, useMemory: Boolean, deserialized: Boolean, - replication: Int = 1) = { + replication: Int = 1): StorageLevel = { getCachedStorageLevel(new StorageLevel(useDisk, useMemory, false, deserialized, replication)) } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala index def49e80a360..7d75929b96f7 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala @@ -19,7 +19,6 @@ package org.apache.spark.storage import scala.collection.mutable -import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ @@ -32,7 +31,7 @@ class StorageStatusListener extends SparkListener { // This maintains only blocks that are cached (i.e. storage level is not StorageLevel.NONE) private[storage] val executorIdToStorageStatus = mutable.Map[String, StorageStatus]() - def storageStatusList = executorIdToStorageStatus.values.toSeq + def storageStatusList: Seq[StorageStatus] = executorIdToStorageStatus.values.toSeq /** Update storage status list to reflect updated block statuses */ private def updateStorageStatus(execId: String, updatedBlocks: Seq[(BlockId, BlockStatus)]) { @@ -56,7 +55,7 @@ class StorageStatusListener extends SparkListener { } } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { val info = taskEnd.taskInfo val metrics = taskEnd.taskMetrics if (info != null && metrics != null) { @@ -67,7 +66,7 @@ class StorageStatusListener extends SparkListener { } } - override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD) = synchronized { + override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit = synchronized { updateStorageStatus(unpersistRDD.rddId) } diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index af873034215a..583f1fdf0475 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -20,8 +20,8 @@ package org.apache.spark.storage import java.text.SimpleDateFormat import java.util.{Date, Random} -import tachyon.client.TachyonFS -import tachyon.client.TachyonFile +import tachyon.TachyonURI +import tachyon.client.{TachyonFile, TachyonFS} import org.apache.spark.Logging import org.apache.spark.executor.ExecutorExitCode @@ -40,7 +40,7 @@ private[spark] class TachyonBlockManager( val master: String) extends Logging { - val client = if (master != null && master != "") TachyonFS.get(master) else null + val client = if (master != null && master != "") TachyonFS.get(new TachyonURI(master)) else null if (client == null) { logError("Failed to connect to the Tachyon as the master address is not configured") @@ -60,11 +60,11 @@ private[spark] class TachyonBlockManager( addShutdownHook() def removeFile(file: TachyonFile): Boolean = { - client.delete(file.getPath(), false) + client.delete(new TachyonURI(file.getPath()), false) } def fileExists(file: TachyonFile): Boolean = { - client.exist(file.getPath()) + client.exist(new TachyonURI(file.getPath())) } def getFile(filename: String): TachyonFile = { @@ -81,7 +81,7 @@ private[spark] class TachyonBlockManager( if (old != null) { old } else { - val path = tachyonDirs(dirId) + "/" + "%02x".format(subDirId) + val path = new TachyonURI(s"${tachyonDirs(dirId)}/${"%02x".format(subDirId)}") client.mkdir(path) val newDir = client.getFile(path) subDirs(dirId)(subDirId) = newDir @@ -89,7 +89,7 @@ private[spark] class TachyonBlockManager( } } } - val filePath = subDir + "/" + filename + val filePath = new TachyonURI(s"$subDir/$filename") if(!client.exist(filePath)) { client.createFile(filePath) } @@ -113,7 +113,7 @@ private[spark] class TachyonBlockManager( tries += 1 try { tachyonDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) - val path = rootDir + "/" + "spark-tachyon-" + tachyonDirId + val path = new TachyonURI(s"$rootDir/spark-tachyon-$tachyonDirId") if (!client.exist(path)) { foundLocalDir = client.mkdir(path) tachyonDir = client.getFile(path) @@ -135,21 +135,19 @@ private[spark] class TachyonBlockManager( private def addShutdownHook() { tachyonDirs.foreach(tachyonDir => Utils.registerShutdownDeleteDir(tachyonDir)) - Runtime.getRuntime.addShutdownHook(new Thread("delete Spark tachyon dirs") { - override def run(): Unit = Utils.logUncaughtExceptions { - logDebug("Shutdown hook called") - tachyonDirs.foreach { tachyonDir => - try { - if (!Utils.hasRootAsShutdownDeleteDir(tachyonDir)) { - Utils.deleteRecursively(tachyonDir, client) - } - } catch { - case e: Exception => - logError("Exception while deleting tachyon spark dir: " + tachyonDir, e) + Utils.addShutdownHook { () => + logDebug("Shutdown hook called") + tachyonDirs.foreach { tachyonDir => + try { + if (!Utils.hasRootAsShutdownDeleteDir(tachyonDir)) { + Utils.deleteRecursively(tachyonDir, client) } + } catch { + case e: Exception => + logError("Exception while deleting tachyon spark dir: " + tachyonDir, e) } - client.close() } - }) + client.close() + } } } diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala b/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala index b86abbda1d3e..65fa81704c36 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala @@ -24,5 +24,7 @@ import tachyon.client.TachyonFile * a length. */ private[spark] class TachyonFileSegment(val file: TachyonFile, val offset: Long, val length: Long) { - override def toString = "(name=%s, offset=%d, length=%d)".format(file.getPath(), offset, length) + override def toString: String = { + "(name=%s, offset=%d, length=%d)".format(file.getPath(), offset, length) + } } diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala index 27ba9e18237b..77c0bc8b5360 100644 --- a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala +++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala @@ -28,7 +28,6 @@ import org.apache.spark._ * of them will be combined together, showed in one line. */ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { - // Carrige return val CR = '\r' // Update period of progress bar, in milliseconds @@ -66,7 +65,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { val stageIds = sc.statusTracker.getActiveStageIds() val stages = stageIds.map(sc.statusTracker.getStageInfo).flatten.filter(_.numTasks() > 1) .filter(now - _.submissionTime() > FIRST_DELAY).sortBy(_.stageId()) - if (stages.size > 0) { + if (stages.length > 0) { show(now, stages.take(3)) // display at most 3 stages in same time } } @@ -82,7 +81,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { val total = s.numTasks() val header = s"[Stage ${s.stageId()}:" val tailer = s"(${s.numCompletedTasks()} + ${s.numActiveTasks()}) / $total]" - val w = width - header.size - tailer.size + val w = width - header.length - tailer.length val bar = if (w > 0) { val percent = w * s.numCompletedTasks() / total (0 until w).map { i => @@ -121,4 +120,10 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { clear() lastFinishTime = System.currentTimeMillis() } + + /** + * Tear down the timer thread. The timer thread is a GC root, and it retains the entire + * SparkContext if it's not terminated. + */ + def stop(): Unit = timer.cancel() } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 88fed833f922..a091ca650c60 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -62,19 +62,28 @@ private[spark] object JettyUtils extends Logging { securityMgr: SecurityManager): HttpServlet = { new HttpServlet { override def doGet(request: HttpServletRequest, response: HttpServletResponse) { - if (securityMgr.checkUIViewPermissions(request.getRemoteUser)) { - response.setContentType("%s;charset=utf-8".format(servletParams.contentType)) - response.setStatus(HttpServletResponse.SC_OK) - val result = servletParams.responder(request) - response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") - response.getWriter.println(servletParams.extractFn(result)) - } else { - response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) - response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") - response.sendError(HttpServletResponse.SC_UNAUTHORIZED, - "User is not authorized to access this page.") + try { + if (securityMgr.checkUIViewPermissions(request.getRemoteUser)) { + response.setContentType("%s;charset=utf-8".format(servletParams.contentType)) + response.setStatus(HttpServletResponse.SC_OK) + val result = servletParams.responder(request) + response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + response.getWriter.println(servletParams.extractFn(result)) + } else { + response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) + response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + response.sendError(HttpServletResponse.SC_UNAUTHORIZED, + "User is not authorized to access this page.") + } + } catch { + case e: IllegalArgumentException => + response.sendError(HttpServletResponse.SC_BAD_REQUEST, e.getMessage) } } + // SPARK-5983 ensure TRACE is not supported + protected override def doTrace(req: HttpServletRequest, res: HttpServletResponse): Unit = { + res.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) + } } } @@ -105,15 +114,32 @@ private[spark] object JettyUtils extends Logging { srcPath: String, destPath: String, beforeRedirect: HttpServletRequest => Unit = x => (), - basePath: String = ""): ServletContextHandler = { + basePath: String = "", + httpMethod: String = "GET"): ServletContextHandler = { val prefixedDestPath = attachPrefix(basePath, destPath) val servlet = new HttpServlet { - override def doGet(request: HttpServletRequest, response: HttpServletResponse) { + override def doGet(request: HttpServletRequest, response: HttpServletResponse): Unit = { + httpMethod match { + case "GET" => doRequest(request, response) + case _ => response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) + } + } + override def doPost(request: HttpServletRequest, response: HttpServletResponse): Unit = { + httpMethod match { + case "POST" => doRequest(request, response) + case _ => response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) + } + } + private def doRequest(request: HttpServletRequest, response: HttpServletResponse): Unit = { beforeRedirect(request) // Make sure we don't end up with "//" in the middle val newUrl = new URL(new URL(request.getRequestURL.toString), prefixedDestPath).toString response.sendRedirect(newUrl) } + // SPARK-5983 ensure TRACE is not supported + protected override def doTrace(req: HttpServletRequest, res: HttpServletResponse): Unit = { + res.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) + } } createServletHandler(srcPath, servlet, basePath) } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 0c24ad2760e0..580ab8b1325f 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -55,12 +55,12 @@ private[spark] class SparkUI private ( attachTab(new ExecutorsTab(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) attachHandler(createRedirectHandler("/", "/jobs", basePath = basePath)) - attachHandler( - createRedirectHandler("/stages/stage/kill", "/stages", stagesTab.handleKillRequest)) + attachHandler(createRedirectHandler( + "/stages/stage/kill", "/stages", stagesTab.handleKillRequest, httpMethod = "POST")) } initialize() - def getAppName = appName + def getAppName: String = appName /** Set the app name for this UI. */ def setAppName(name: String) { diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index 6f446c5a95a0..24f323645624 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -25,17 +25,26 @@ private[spark] object ToolTips { of task results.""" val TASK_DESERIALIZATION_TIME = - """Time spent deserializating the task closure on the executor.""" + """Time spent deserializing the task closure on the executor, including the time to read the + broadcasted task.""" - val INPUT = "Bytes read from Hadoop or from Spark storage." + val SHUFFLE_READ_BLOCKED_TIME = + "Time that the task spent blocked waiting for shuffle data to be read from remote machines." - val OUTPUT = "Bytes written to Hadoop." + val INPUT = "Bytes and records read from Hadoop or from Spark storage." - val SHUFFLE_WRITE = "Bytes written to disk in order to be read by a shuffle in a future stage." + val OUTPUT = "Bytes and records written to Hadoop." + + val SHUFFLE_WRITE = + "Bytes and records written to disk in order to be read by a shuffle in a future stage." val SHUFFLE_READ = - """Bytes read from remote executors. Typically less than shuffle write bytes - because this does not include shuffle data read locally.""" + """Total shuffle bytes and records read (includes both data read locally and data read from + remote executors). """ + + val SHUFFLE_READ_REMOTE_SIZE = + """Total shuffle bytes read from remote executors. This is a subset of the shuffle + read bytes; the remaining shuffle data is read locally. """ val GETTING_RESULT_TIME = """Time that the driver spends fetching task results from workers. If this is large, consider diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index b5022fe853c4..f07864141a21 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -149,9 +149,11 @@ private[spark] object UIUtils extends Logging { } } - def prependBaseUri(basePath: String = "", resource: String = "") = uiRoot + basePath + resource + def prependBaseUri(basePath: String = "", resource: String = ""): String = { + uiRoot + basePath + resource + } - def commonHeaderNodes = { + def commonHeaderNodes: Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala index fc1844600f1c..5fbcd6bb8ad9 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala @@ -17,6 +17,8 @@ package org.apache.spark.ui +import java.util.concurrent.Semaphore + import scala.util.Random import org.apache.spark.{SparkConf, SparkContext} @@ -51,7 +53,7 @@ private[spark] object UIWorkloadGenerator { val nJobSet = args(2).toInt val sc = new SparkContext(conf) - def setProperties(s: String) = { + def setProperties(s: String): Unit = { if(schedulingMode == SchedulingMode.FAIR) { sc.setLocalProperty("spark.scheduler.pool", s) } @@ -59,7 +61,7 @@ private[spark] object UIWorkloadGenerator { } val baseData = sc.makeRDD(1 to NUM_PARTITIONS * 10, NUM_PARTITIONS) - def nextFloat() = new Random().nextFloat() + def nextFloat(): Float = new Random().nextFloat() val jobs = Seq[(String, () => Long)]( ("Count", baseData.count), @@ -88,6 +90,8 @@ private[spark] object UIWorkloadGenerator { ("Job with delays", baseData.map(x => Thread.sleep(100)).count) ) + val barrier = new Semaphore(-nJobSet * jobs.size + 1) + (1 to nJobSet).foreach { _ => for ((desc, job) <- jobs) { new Thread { @@ -99,12 +103,17 @@ private[spark] object UIWorkloadGenerator { } catch { case e: Exception => println("Job Failed: " + desc) + } finally { + barrier.release() } } }.start Thread.sleep(INTER_JOB_WAIT_MS) } } + + // Waiting for threads. + barrier.acquire() sc.stop() } } diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 9be65a4a39a0..f9860d1a5ce7 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -20,14 +20,15 @@ package org.apache.spark.ui import javax.servlet.http.HttpServletRequest import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap import scala.xml.Node import org.eclipse.jetty.servlet.ServletContextHandler import org.json4s.JsonAST.{JNothing, JValue} -import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SecurityManager, SparkConf} /** * The top level component of the UI hierarchy that contains the server. @@ -45,9 +46,10 @@ private[spark] abstract class WebUI( protected val tabs = ArrayBuffer[WebUITab]() protected val handlers = ArrayBuffer[ServletContextHandler]() + protected val pageToHandlers = new HashMap[WebUIPage, ArrayBuffer[ServletContextHandler]] protected var serverInfo: Option[ServerInfo] = None - protected val localHostName = Utils.localHostName() - protected val publicHostName = Option(System.getenv("SPARK_PUBLIC_DNS")).getOrElse(localHostName) + protected val localHostName = Utils.localHostNameForURI() + protected val publicHostName = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(localHostName) private val className = Utils.getFormattedClassName(this) def getBasePath: String = basePath @@ -60,14 +62,30 @@ private[spark] abstract class WebUI( tab.pages.foreach(attachPage) tabs += tab } + + def detachTab(tab: WebUITab) { + tab.pages.foreach(detachPage) + tabs -= tab + } + + def detachPage(page: WebUIPage) { + pageToHandlers.remove(page).foreach(_.foreach(detachHandler)) + } /** Attach a page to this UI. */ def attachPage(page: WebUIPage) { val pagePath = "/" + page.prefix - attachHandler(createServletHandler(pagePath, - (request: HttpServletRequest) => page.render(request), securityManager, basePath)) - attachHandler(createServletHandler(pagePath.stripSuffix("/") + "/json", - (request: HttpServletRequest) => page.renderJson(request), securityManager, basePath)) + val renderHandler = createServletHandler(pagePath, + (request: HttpServletRequest) => page.render(request), securityManager, basePath) + val renderJsonHandler = createServletHandler(pagePath.stripSuffix("/") + "/json", + (request: HttpServletRequest) => page.renderJson(request), securityManager, basePath) + attachHandler(renderHandler) + attachHandler(renderJsonHandler) + pageToHandlers.getOrElseUpdate(page, ArrayBuffer[ServletContextHandler]()) + .append(renderHandler) + pageToHandlers.getOrElseUpdate(page, ArrayBuffer[ServletContextHandler]()) + .append(renderJsonHandler) + } /** Attach a handler to this UI. */ diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index c82730f524eb..f0ae95bb8c81 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -43,7 +43,7 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage } id }.getOrElse { - return Text(s"Missing executorId parameter") + throw new IllegalArgumentException(s"Missing executorId parameter") } val time = System.currentTimeMillis() val maybeThreadDump = sc.get.getExecutorThreadDump(executorId) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index 363cb96de799..956608d7c0cb 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -26,7 +26,8 @@ import org.apache.spark.ui.{ToolTips, UIUtils, WebUIPage} import org.apache.spark.util.Utils /** Summary information about an executor to display in the UI. */ -private case class ExecutorSummaryInfo( +// Needs to be private[ui] because of a false positive MiMa failure. +private[ui] case class ExecutorSummaryInfo( id: String, hostPort: String, rddBlocks: Int, @@ -40,7 +41,8 @@ private case class ExecutorSummaryInfo( totalInputBytes: Long, totalShuffleRead: Long, totalShuffleWrite: Long, - maxMemory: Long) + maxMemory: Long, + executorLogs: Map[String, String]) private[ui] class ExecutorsPage( parent: ExecutorsTab, @@ -55,6 +57,7 @@ private[ui] class ExecutorsPage( val diskUsed = storageStatusList.map(_.diskUsed).sum val execInfo = for (statusId <- 0 until storageStatusList.size) yield getExecInfo(statusId) val execInfoSorted = execInfo.sortBy(_.id) + val logsExist = execInfo.filter(_.executorLogs.nonEmpty).nonEmpty val execTable = @@ -79,10 +82,11 @@ private[ui] class ExecutorsPage( Shuffle Write + {if (logsExist) else Seq.empty} {if (threadDumpEnabled) else Seq.empty} - {execInfoSorted.map(execRow)} + {execInfoSorted.map(execRow(_, logsExist))}
      LogsThread Dump
      @@ -107,7 +111,7 @@ private[ui] class ExecutorsPage( } /** Render an HTML row representing an executor */ - private def execRow(info: ExecutorSummaryInfo): Seq[Node] = { + private def execRow(info: ExecutorSummaryInfo, logsExist: Boolean): Seq[Node] = { val maximumMemory = info.maxMemory val memoryUsed = info.memoryUsed val diskUsed = info.diskUsed @@ -138,6 +142,21 @@ private[ui] class ExecutorsPage( {Utils.bytesToString(info.totalShuffleWrite)} + { + if (logsExist) { + + { + info.executorLogs.map { case (logName, logUrl) => + + } + } + + } + } { if (threadDumpEnabled) { val encodedId = URLEncoder.encode(info.id, "UTF-8") @@ -168,6 +187,7 @@ private[ui] class ExecutorsPage( val totalInputBytes = listener.executorToInputBytes.getOrElse(execId, 0L) val totalShuffleRead = listener.executorToShuffleRead.getOrElse(execId, 0L) val totalShuffleWrite = listener.executorToShuffleWrite.getOrElse(execId, 0L) + val executorLogs = listener.executorToLogUrls.getOrElse(execId, Map.empty) new ExecutorSummaryInfo( execId, @@ -183,7 +203,8 @@ private[ui] class ExecutorsPage( totalInputBytes, totalShuffleRead, totalShuffleWrite, - maxMem + maxMem, + executorLogs ) } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index dd1c2b78c409..69053fe44d7e 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.HashMap import org.apache.spark.ExceptionFailure import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ -import org.apache.spark.storage.StorageStatusListener +import org.apache.spark.storage.{StorageStatus, StorageStatusListener} import org.apache.spark.ui.{SparkUI, SparkUITab} private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "executors") { @@ -48,18 +48,26 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp val executorToTasksFailed = HashMap[String, Int]() val executorToDuration = HashMap[String, Long]() val executorToInputBytes = HashMap[String, Long]() + val executorToInputRecords = HashMap[String, Long]() val executorToOutputBytes = HashMap[String, Long]() + val executorToOutputRecords = HashMap[String, Long]() val executorToShuffleRead = HashMap[String, Long]() val executorToShuffleWrite = HashMap[String, Long]() + val executorToLogUrls = HashMap[String, Map[String, String]]() - def storageStatusList = storageStatusListener.storageStatusList + def storageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList - override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { + override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = synchronized { + val eid = executorAdded.executorId + executorToLogUrls(eid) = executorAdded.executorInfo.logUrlMap + } + + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { val eid = taskStart.taskInfo.executorId executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 0) + 1 } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { val info = taskEnd.taskInfo if (info != null) { val eid = info.executorId @@ -78,10 +86,14 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp metrics.inputMetrics.foreach { inputMetrics => executorToInputBytes(eid) = executorToInputBytes.getOrElse(eid, 0L) + inputMetrics.bytesRead + executorToInputRecords(eid) = + executorToInputRecords.getOrElse(eid, 0L) + inputMetrics.recordsRead } metrics.outputMetrics.foreach { outputMetrics => executorToOutputBytes(eid) = executorToOutputBytes.getOrElse(eid, 0L) + outputMetrics.bytesWritten + executorToOutputRecords(eid) = + executorToOutputRecords.getOrElse(eid, 0L) + outputMetrics.recordsWritten } metrics.shuffleReadMetrics.foreach { shuffleRead => executorToShuffleRead(eid) = diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 045c69da06fe..bd923d78a86c 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -42,7 +42,9 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { } def makeRow(job: JobUIData): Seq[Node] = { - val lastStageInfo = listener.stageIdToInfo.get(job.stageIds.max) + val lastStageInfo = Option(job.stageIds) + .filter(_.nonEmpty) + .flatMap { ids => listener.stageIdToInfo.get(ids.max) } val lastStageData = lastStageInfo.flatMap { s => listener.stageIdToData.get((s.stageId, s.attemptId)) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index 479f967fb154..527f960af2df 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -128,7 +128,7 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { activeStagesTable.toNodeSeq } if (shouldShowPendingStages) { - content ++=

      Pending Stages ({pendingStages.size}

      ++ + content ++=

      Pending Stages ({pendingStages.size})

      ++ pendingStagesTable.toNodeSeq } if (shouldShowCompletedStages) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 9836d11a6d85..1f8536d1b719 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -36,6 +36,20 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage /** Special table which merges two header cells. */ private def executorTable[T](): Seq[Node] = { + val stageData = listener.stageIdToData.get((stageId, stageAttemptId)) + var hasInput = false + var hasOutput = false + var hasShuffleWrite = false + var hasShuffleRead = false + var hasBytesSpilled = false + stageData.foreach(data => { + hasInput = data.hasInput + hasOutput = data.hasOutput + hasShuffleRead = data.hasShuffleRead + hasShuffleWrite = data.hasShuffleWrite + hasBytesSpilled = data.hasBytesSpilled + }) + @@ -44,12 +58,32 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage - - - - - - + {if (hasInput) { + + }} + {if (hasOutput) { + + }} + {if (hasShuffleRead) { + + }} + {if (hasShuffleWrite) { + + }} + {if (hasBytesSpilled) { + + + }} {createExecutorTable()} @@ -76,18 +110,34 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage - - - - - - + {if (stageData.hasInput) { + + }} + {if (stageData.hasOutput) { + + }} + {if (stageData.hasShuffleRead) { + + }} + {if (stageData.hasShuffleWrite) { + + }} + {if (stageData.hasBytesSpilled) { + + + }} } case None => diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 77d36209c604..7541d3e9c72e 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -32,7 +32,10 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { def render(request: HttpServletRequest): Seq[Node] = { listener.synchronized { - val jobId = request.getParameter("id").toInt + val parameterId = request.getParameter("id") + require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") + + val jobId = parameterId.toInt val jobDataOption = listener.jobIdToData.get(jobId) if (jobDataOption.isEmpty) { val content = diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 4d200eeda86b..625596885faa 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -44,6 +44,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { // These type aliases are public because they're used in the types of public fields: type JobId = Int + type JobGroupId = String type StageId = Int type StageAttemptId = Int type PoolName = String @@ -54,6 +55,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val completedJobs = ListBuffer[JobUIData]() val failedJobs = ListBuffer[JobUIData]() val jobIdToData = new HashMap[JobId, JobUIData] + val jobGroupToJobIds = new HashMap[JobGroupId, HashSet[JobId]] // Stages: val pendingStages = new HashMap[StageId, StageInfo] @@ -73,7 +75,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { // Misc: val executorIdToBlockManagerId = HashMap[ExecutorId, BlockManagerId]() - def blockManagerIds = executorIdToBlockManagerId.values.toSeq + def blockManagerIds: Seq[BlockManagerId] = executorIdToBlockManagerId.values.toSeq var schedulingMode: Option[SchedulingMode] = None @@ -119,7 +121,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { Map( "jobIdToData" -> jobIdToData.size, "stageIdToData" -> stageIdToData.size, - "stageIdToStageInfo" -> stageIdToInfo.size + "stageIdToStageInfo" -> stageIdToInfo.size, + "jobGroupToJobIds" -> jobGroupToJobIds.values.map(_.size).sum, + // Since jobGroupToJobIds is map of sets, check that we don't leak keys with empty values: + "jobGroupToJobIds keySet" -> jobGroupToJobIds.keys.size ) } @@ -140,13 +145,25 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { if (jobs.size > retainedJobs) { val toRemove = math.max(retainedJobs / 10, 1) jobs.take(toRemove).foreach { job => - jobIdToData.remove(job.jobId) + // Remove the job's UI data, if it exists + jobIdToData.remove(job.jobId).foreach { removedJob => + // A null jobGroupId is used for jobs that are run without a job group + val jobGroupId = removedJob.jobGroup.orNull + // Remove the job group -> job mapping entry, if it exists + jobGroupToJobIds.get(jobGroupId).foreach { jobsInGroup => + jobsInGroup.remove(job.jobId) + // If this was the last job in this job group, remove the map entry for the job group + if (jobsInGroup.isEmpty) { + jobGroupToJobIds.remove(jobGroupId) + } + } + } } jobs.trimStart(toRemove) } } - override def onJobStart(jobStart: SparkListenerJobStart) = synchronized { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized { val jobGroup = for ( props <- Option(jobStart.properties); group <- Option(props.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) @@ -158,6 +175,8 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageIds = jobStart.stageIds, jobGroup = jobGroup, status = JobExecutionStatus.RUNNING) + // A null jobGroupId is used for jobs that are run without a job group + jobGroupToJobIds.getOrElseUpdate(jobGroup.orNull, new HashSet[JobId]).add(jobStart.jobId) jobStart.stageInfos.foreach(x => pendingStages(x.stageId) = x) // Compute (a potential underestimate of) the number of tasks that will be run by this job. // This may be an underestimate because the job start event references all of the result @@ -182,7 +201,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { } } - override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized { val jobData = activeJobs.remove(jobEnd.jobId).getOrElse { logWarning(s"Job completed for unknown job ${jobEnd.jobId}") new JobUIData(jobId = jobEnd.jobId) @@ -203,6 +222,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { for (stageId <- jobData.stageIds) { stageIdToActiveJobIds.get(stageId).foreach { jobsUsingStage => jobsUsingStage.remove(jobEnd.jobId) + if (jobsUsingStage.isEmpty) { + stageIdToActiveJobIds.remove(stageId) + } stageIdToInfo.get(stageId).foreach { stageInfo => if (stageInfo.submissionTime.isEmpty) { // if this stage is pending, it won't complete, so mark it as "skipped": @@ -216,7 +238,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { } } - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = synchronized { val stage = stageCompleted.stageInfo stageIdToInfo(stage.stageId) = stage val stageData = stageIdToData.getOrElseUpdate((stage.stageId, stage.attemptId), { @@ -257,7 +279,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { } /** For FIFO, all stages are contained by "default" pool but "default" pool here is meaningless */ - override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = synchronized { + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized { val stage = stageSubmitted.stageInfo activeStages(stage.stageId) = stage pendingStages.remove(stage.stageId) @@ -285,7 +307,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { } } - override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { val taskInfo = taskStart.taskInfo if (taskInfo != null) { val stageData = stageIdToData.getOrElseUpdate((taskStart.stageId, taskStart.stageAttemptId), { @@ -309,7 +331,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { // stageToTaskInfos already has the updated status. } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { val info = taskEnd.taskInfo // If stage attempt id is -1, it means the DAGScheduler had no idea which attempt this task // completion event is for. Let's just drop it here. This means we might have some speculation @@ -394,24 +416,48 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageData.shuffleWriteBytes += shuffleWriteDelta execSummary.shuffleWrite += shuffleWriteDelta + val shuffleWriteRecordsDelta = + (taskMetrics.shuffleWriteMetrics.map(_.shuffleRecordsWritten).getOrElse(0L) + - oldMetrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleRecordsWritten).getOrElse(0L)) + stageData.shuffleWriteRecords += shuffleWriteRecordsDelta + execSummary.shuffleWriteRecords += shuffleWriteRecordsDelta + val shuffleReadDelta = - (taskMetrics.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L) - - oldMetrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead).getOrElse(0L)) - stageData.shuffleReadBytes += shuffleReadDelta + (taskMetrics.shuffleReadMetrics.map(_.totalBytesRead).getOrElse(0L) + - oldMetrics.flatMap(_.shuffleReadMetrics).map(_.totalBytesRead).getOrElse(0L)) + stageData.shuffleReadTotalBytes += shuffleReadDelta execSummary.shuffleRead += shuffleReadDelta + val shuffleReadRecordsDelta = + (taskMetrics.shuffleReadMetrics.map(_.recordsRead).getOrElse(0L) + - oldMetrics.flatMap(_.shuffleReadMetrics).map(_.recordsRead).getOrElse(0L)) + stageData.shuffleReadRecords += shuffleReadRecordsDelta + execSummary.shuffleReadRecords += shuffleReadRecordsDelta + val inputBytesDelta = (taskMetrics.inputMetrics.map(_.bytesRead).getOrElse(0L) - oldMetrics.flatMap(_.inputMetrics).map(_.bytesRead).getOrElse(0L)) stageData.inputBytes += inputBytesDelta execSummary.inputBytes += inputBytesDelta + val inputRecordsDelta = + (taskMetrics.inputMetrics.map(_.recordsRead).getOrElse(0L) + - oldMetrics.flatMap(_.inputMetrics).map(_.recordsRead).getOrElse(0L)) + stageData.inputRecords += inputRecordsDelta + execSummary.inputRecords += inputRecordsDelta + val outputBytesDelta = (taskMetrics.outputMetrics.map(_.bytesWritten).getOrElse(0L) - oldMetrics.flatMap(_.outputMetrics).map(_.bytesWritten).getOrElse(0L)) stageData.outputBytes += outputBytesDelta execSummary.outputBytes += outputBytesDelta + val outputRecordsDelta = + (taskMetrics.outputMetrics.map(_.recordsWritten).getOrElse(0L) + - oldMetrics.flatMap(_.outputMetrics).map(_.recordsWritten).getOrElse(0L)) + stageData.outputRecords += outputRecordsDelta + execSummary.outputRecords += outputRecordsDelta + val diskSpillDelta = taskMetrics.diskBytesSpilled - oldMetrics.map(_.diskBytesSpilled).getOrElse(0L) stageData.diskBytesSpilled += diskSpillDelta diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala index b2bbfdee5694..7ffcf291b5cc 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala @@ -24,7 +24,7 @@ import org.apache.spark.ui.{SparkUI, SparkUITab} private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") { val sc = parent.sc val killEnabled = parent.killEnabled - def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) + def isFairScheduler: Boolean = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) val listener = parent.jobProgressListener attachPage(new AllJobsPage(this)) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index 5fc6cc753315..f47cdc935e53 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -32,6 +32,8 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { def render(request: HttpServletRequest): Seq[Node] = { listener.synchronized { val poolName = request.getParameter("poolname") + require(poolName != null && poolName.nonEmpty, "Missing poolname parameter") + val poolToActiveStages = listener.poolToActiveStages val activeStages = poolToActiveStages.get(poolName) match { case Some(s) => s.values.toSeq diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 09a936c2234c..797c9404bc44 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -20,7 +20,7 @@ package org.apache.spark.ui.jobs import java.util.Date import javax.servlet.http.HttpServletRequest -import scala.xml.{Node, Unparsed} +import scala.xml.{Elem, Node, Unparsed} import org.apache.commons.lang3.StringEscapeUtils @@ -36,8 +36,14 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { def render(request: HttpServletRequest): Seq[Node] = { listener.synchronized { - val stageId = request.getParameter("id").toInt - val stageAttemptId = request.getParameter("attempt").toInt + val parameterId = request.getParameter("id") + require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") + + val parameterAttempt = request.getParameter("attempt") + require(parameterAttempt != null && parameterAttempt.nonEmpty, "Missing attempt parameter") + + val stageId = parameterId.toInt + val stageAttemptId = parameterAttempt.toInt val stageDataOption = listener.stageIdToData.get((stageId, stageAttemptId)) if (stageDataOption.isEmpty || stageDataOption.get.taskData.isEmpty) { @@ -56,11 +62,6 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val numCompleted = tasks.count(_.taskInfo.finished) val accumulables = listener.stageIdToData((stageId, stageAttemptId)).accumulables val hasAccumulators = accumulables.size > 0 - val hasInput = stageData.inputBytes > 0 - val hasOutput = stageData.outputBytes > 0 - val hasShuffleRead = stageData.shuffleReadBytes > 0 - val hasShuffleWrite = stageData.shuffleWriteBytes > 0 - val hasBytesSpilled = stageData.memoryBytesSpilled > 0 && stageData.diskBytesSpilled > 0 val summary =
      @@ -69,31 +70,33 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { Total task time across all tasks: {UIUtils.formatDuration(stageData.executorRunTime)} - {if (hasInput) { + {if (stageData.hasInput) {
    • - Input: - {Utils.bytesToString(stageData.inputBytes)} + Input Size / Records: + {s"${Utils.bytesToString(stageData.inputBytes)} / ${stageData.inputRecords}"}
    • }} - {if (hasOutput) { + {if (stageData.hasOutput) {
    • Output: - {Utils.bytesToString(stageData.outputBytes)} + {s"${Utils.bytesToString(stageData.outputBytes)} / ${stageData.outputRecords}"}
    • }} - {if (hasShuffleRead) { + {if (stageData.hasShuffleRead) {
    • Shuffle read: - {Utils.bytesToString(stageData.shuffleReadBytes)} + {s"${Utils.bytesToString(stageData.shuffleReadTotalBytes)} / " + + s"${stageData.shuffleReadRecords}"}
    • }} - {if (hasShuffleWrite) { + {if (stageData.hasShuffleWrite) {
    • Shuffle write: - {Utils.bytesToString(stageData.shuffleWriteBytes)} + {s"${Utils.bytesToString(stageData.shuffleWriteBytes)} / " + + s"${stageData.shuffleWriteRecords}"}
    • }} - {if (hasBytesSpilled) { + {if (stageData.hasBytesSpilled) {
    • Shuffle spill (memory): {Utils.bytesToString(stageData.memoryBytesSpilled)} @@ -132,6 +135,22 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { Task Deserialization Time
    • + {if (stageData.hasShuffleRead) { +
    • + + + Shuffle Read Blocked Time + +
    • +
    • + + + Shuffle Remote Reads + +
    • + }}
    • @@ -151,7 +170,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
    • val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value") - def accumulableRow(acc: AccumulableInfo) = + def accumulableRow(acc: AccumulableInfo): Elem = + val accumulableTable = UIUtils.listingTable(accumulableHeaders, accumulableRow, accumulables.values.toSeq) @@ -165,20 +185,33 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++ {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ - {if (hasInput) Seq(("Input", "")) else Nil} ++ - {if (hasOutput) Seq(("Output", "")) else Nil} ++ - {if (hasShuffleRead) Seq(("Shuffle Read", "")) else Nil} ++ - {if (hasShuffleWrite) Seq(("Write Time", ""), ("Shuffle Write", "")) else Nil} ++ - {if (hasBytesSpilled) Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) - else Nil} ++ + {if (stageData.hasInput) Seq(("Input Size / Records", "")) else Nil} ++ + {if (stageData.hasOutput) Seq(("Output Size / Records", "")) else Nil} ++ + {if (stageData.hasShuffleRead) { + Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), + ("Shuffle Read Size / Records", ""), + ("Shuffle Remote Reads", TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE)) + } else { + Nil + }} ++ + {if (stageData.hasShuffleWrite) { + Seq(("Write Time", ""), ("Shuffle Write Size / Records", "")) + } else { + Nil + }} ++ + {if (stageData.hasBytesSpilled) { + Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) + } else { + Nil + }} ++ Seq(("Errors", "")) val unzipped = taskHeadersAndCssClasses.unzip val taskTable = UIUtils.listingTable( unzipped._1, - taskRow(hasAccumulators, hasInput, hasOutput, hasShuffleRead, hasShuffleWrite, - hasBytesSpilled), + taskRow(hasAccumulators, stageData.hasInput, stageData.hasOutput, + stageData.hasShuffleRead, stageData.hasShuffleWrite, stageData.hasBytesSpilled), tasks, headerClasses = unzipped._2) // Excludes tasks which failed and have incomplete metrics @@ -189,8 +222,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { None } else { + def getDistributionQuantiles(data: Seq[Double]): IndexedSeq[Double] = + Distribution(data).get.getQuantiles() + def getFormattedTimeQuantiles(times: Seq[Double]): Seq[Node] = { - Distribution(times).get.getQuantiles().map { millis => + getDistributionQuantiles(times).map { millis => } } @@ -233,11 +269,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { +: getFormattedTimeQuantiles(serializationTimes) val gettingResultTimes = validTasks.map { case TaskUIData(info, _, _) => - if (info.gettingResultTime > 0) { - (info.finishTime - info.gettingResultTime).toDouble - } else { - 0.0 - } + getGettingResultTime(info).toDouble } val gettingResultQuantiles = ) + def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] = + getDistributionQuantiles(data).map(d => ) + + def getFormattedSizeQuantilesWithRecords(data: Seq[Double], records: Seq[Double]) + : Seq[Elem] = { + val recordDist = getDistributionQuantiles(records).iterator + getDistributionQuantiles(data).map(d => + + ) + } val inputSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.inputMetrics.map(_.bytesRead).getOrElse(0L).toDouble } - val inputQuantiles = +: getFormattedSizeQuantiles(inputSizes) + + val inputRecords = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.inputMetrics.map(_.recordsRead).getOrElse(0L).toDouble + } + + val inputQuantiles = +: + getFormattedSizeQuantilesWithRecords(inputSizes, inputRecords) val outputSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.outputMetrics.map(_.bytesWritten).getOrElse(0L).toDouble } - val outputQuantiles = +: getFormattedSizeQuantiles(outputSizes) - val shuffleReadSizes = validTasks.map { case TaskUIData(_, metrics, _) => + val outputRecords = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.outputMetrics.map(_.recordsWritten).getOrElse(0L).toDouble + } + + val outputQuantiles = +: + getFormattedSizeQuantilesWithRecords(outputSizes, outputRecords) + + val shuffleReadBlockedTimes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.shuffleReadMetrics.map(_.fetchWaitTime).getOrElse(0L).toDouble + } + val shuffleReadBlockedQuantiles = + +: + getFormattedTimeQuantiles(shuffleReadBlockedTimes) + + val shuffleReadTotalSizes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.shuffleReadMetrics.map(_.totalBytesRead).getOrElse(0L).toDouble + } + val shuffleReadTotalRecords = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.shuffleReadMetrics.map(_.recordsRead).getOrElse(0L).toDouble + } + val shuffleReadTotalQuantiles = + +: + getFormattedSizeQuantilesWithRecords(shuffleReadTotalSizes, shuffleReadTotalRecords) + + val shuffleReadRemoteSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble } - val shuffleReadQuantiles = +: - getFormattedSizeQuantiles(shuffleReadSizes) + val shuffleReadRemoteQuantiles = + +: + getFormattedSizeQuantiles(shuffleReadRemoteSizes) val shuffleWriteSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L).toDouble } - val shuffleWriteQuantiles = +: - getFormattedSizeQuantiles(shuffleWriteSizes) + + val shuffleWriteRecords = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.shuffleWriteMetrics.map(_.shuffleRecordsWritten).getOrElse(0L).toDouble + } + + val shuffleWriteQuantiles = +: + getFormattedSizeQuantilesWithRecords(shuffleWriteSizes, shuffleWriteRecords) val memoryBytesSpilledSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.memoryBytesSpilled.toDouble @@ -306,12 +396,22 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {serializationQuantiles} , {gettingResultQuantiles}, - if (hasInput) {inputQuantiles} else Nil, - if (hasOutput) {outputQuantiles} else Nil, - if (hasShuffleRead) {shuffleReadQuantiles} else Nil, - if (hasShuffleWrite) {shuffleWriteQuantiles} else Nil, - if (hasBytesSpilled) {memoryBytesSpilledQuantiles} else Nil, - if (hasBytesSpilled) {diskBytesSpilledQuantiles} else Nil) + if (stageData.hasInput) {inputQuantiles} else Nil, + if (stageData.hasOutput) {outputQuantiles} else Nil, + if (stageData.hasShuffleRead) { + + {shuffleReadBlockedQuantiles} + + {shuffleReadTotalQuantiles} + + {shuffleReadRemoteQuantiles} + + } else { + Nil + }, + if (stageData.hasShuffleWrite) {shuffleWriteQuantiles} else Nil, + if (stageData.hasBytesSpilled) {memoryBytesSpilledQuantiles} else Nil, + if (stageData.hasBytesSpilled) {diskBytesSpilledQuantiles} else Nil) val quantileHeaders = Seq("Metric", "Min", "25th percentile", "Median", "75th percentile", "Max") @@ -360,7 +460,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L) val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) - val gettingResultTime = info.gettingResultTime + val gettingResultTime = getGettingResultTime(info) val maybeAccumulators = info.accumulables val accumulatorsReadable = maybeAccumulators.map{acc => s"${acc.name}: ${acc.update.get}"} @@ -370,21 +470,36 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val inputReadable = maybeInput .map(m => s"${Utils.bytesToString(m.bytesRead)} (${m.readMethod.toString.toLowerCase()})") .getOrElse("") + val inputRecords = maybeInput.map(_.recordsRead.toString).getOrElse("") val maybeOutput = metrics.flatMap(_.outputMetrics) val outputSortable = maybeOutput.map(_.bytesWritten.toString).getOrElse("") val outputReadable = maybeOutput .map(m => s"${Utils.bytesToString(m.bytesWritten)}") .getOrElse("") - - val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead) - val shuffleReadSortable = maybeShuffleRead.map(_.toString).getOrElse("") - val shuffleReadReadable = maybeShuffleRead.map(Utils.bytesToString).getOrElse("") - - val maybeShuffleWrite = - metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleBytesWritten) - val shuffleWriteSortable = maybeShuffleWrite.map(_.toString).getOrElse("") - val shuffleWriteReadable = maybeShuffleWrite.map(Utils.bytesToString).getOrElse("") + val outputRecords = maybeOutput.map(_.recordsWritten.toString).getOrElse("") + + val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics) + val shuffleReadBlockedTimeSortable = maybeShuffleRead + .map(_.fetchWaitTime.toString).getOrElse("") + val shuffleReadBlockedTimeReadable = + maybeShuffleRead.map(ms => UIUtils.formatDuration(ms.fetchWaitTime)).getOrElse("") + + val totalShuffleBytes = maybeShuffleRead.map(_.totalBytesRead) + val shuffleReadSortable = totalShuffleBytes.map(_.toString).getOrElse("") + val shuffleReadReadable = totalShuffleBytes.map(Utils.bytesToString).getOrElse("") + val shuffleReadRecords = maybeShuffleRead.map(_.recordsRead.toString).getOrElse("") + + val remoteShuffleBytes = maybeShuffleRead.map(_.remoteBytesRead) + val shuffleReadRemoteSortable = remoteShuffleBytes.map(_.toString).getOrElse("") + val shuffleReadRemoteReadable = remoteShuffleBytes.map(Utils.bytesToString).getOrElse("") + + val maybeShuffleWrite = metrics.flatMap(_.shuffleWriteMetrics) + val shuffleWriteSortable = maybeShuffleWrite.map(_.shuffleBytesWritten.toString).getOrElse("") + val shuffleWriteReadable = maybeShuffleWrite + .map(m => s"${Utils.bytesToString(m.shuffleBytesWritten)}").getOrElse("") + val shuffleWriteRecords = maybeShuffleWrite + .map(_.shuffleRecordsWritten.toString).getOrElse("") val maybeWriteTime = metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleWriteTime) val writeTimeSortable = maybeWriteTime.map(_.toString).getOrElse("") @@ -440,17 +555,25 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { }} {if (hasInput) { }} {if (hasOutput) { }} {if (hasShuffleRead) { + + }} {if (hasShuffleWrite) { @@ -458,7 +581,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {writeTimeReadable} }} {if (hasBytesSpilled) { @@ -500,16 +623,32 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { } - private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics): Long = { - val totalExecutionTime = { - if (info.gettingResultTime > 0) { - (info.gettingResultTime - info.launchTime) + private def getGettingResultTime(info: TaskInfo): Long = { + if (info.gettingResultTime > 0) { + if (info.finishTime > 0) { + info.finishTime - info.gettingResultTime } else { - (info.finishTime - info.launchTime) + // The task is still fetching the result. + System.currentTimeMillis - info.gettingResultTime } + } else { + 0L } + } + + private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics): Long = { + val totalExecutionTime = + if (info.gettingResult) { + info.gettingResultTime - info.launchTime + } else if (info.finished) { + info.finishTime - info.launchTime + } else { + 0 + } val executorOverhead = (metrics.executorDeserializeTime + metrics.resultSerializationTime) - totalExecutionTime - metrics.executorRunTime - executorOverhead + math.max( + 0, + totalExecutionTime - metrics.executorRunTime - executorOverhead - getGettingResultTime(info)) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 703d43f9c640..cb72890a0fd2 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -73,20 +73,21 @@ private[ui] class StageTableBase( } private def makeDescription(s: StageInfo): Seq[Node] = { - // scalastyle:off + val basePathUri = UIUtils.prependBaseUri(basePath) + val killLink = if (killEnabled) { - val killLinkUri = "%s/stages/stage/kill?id=%s&terminate=true" - .format(UIUtils.prependBaseUri(basePath), s.stageId) - val confirm = "return window.confirm('Are you sure you want to kill stage %s ?');" - .format(s.stageId) - - (kill) - + val killLinkUri = s"$basePathUri/stages/stage/kill/" + val confirm = + s"if (window.confirm('Are you sure you want to kill stage ${s.stageId} ?')) " + + "{ this.parentNode.submit(); return true; } else { return false; }" + + + + (kill) + } - // scalastyle:on - val nameLinkUri ="%s/stages/stage?id=%s&attempt=%s" - .format(UIUtils.prependBaseUri(basePath), s.stageId, s.attemptId) + val nameLinkUri = s"$basePathUri/stages/stage?id=${s.stageId}&attempt=${s.attemptId}" val nameLink = {s.name} val cachedRddInfos = s.rddInfos.filter(_.numCachedPartitions > 0) @@ -98,11 +99,9 @@ private[ui] class StageTableBase( @@ -138,7 +137,7 @@ private[ui] class StageTableBase( val inputReadWithUnit = if (inputRead > 0) Utils.bytesToString(inputRead) else "" val outputWrite = stageData.outputBytes val outputWriteWithUnit = if (outputWrite > 0) Utils.bytesToString(outputWrite) else "" - val shuffleRead = stageData.shuffleReadBytes + val shuffleRead = stageData.shuffleReadTotalBytes val shuffleReadWithUnit = if (shuffleRead > 0) Utils.bytesToString(shuffleRead) else "" val shuffleWrite = stageData.shuffleWriteBytes val shuffleWriteWithUnit = if (shuffleWrite > 0) Utils.bytesToString(shuffleWrite) else "" diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index 937261de00e3..1bd2d87e0079 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -32,10 +32,10 @@ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages" attachPage(new StagePage(this)) attachPage(new PoolPage(this)) - def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) + def isFairScheduler: Boolean = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) - def handleKillRequest(request: HttpServletRequest) = { - if ((killEnabled) && (parent.securityManager.checkModifyPermissions(request.getRemoteUser))) { + def handleKillRequest(request: HttpServletRequest): Unit = { + if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) { val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean val stageId = Option(request.getParameter("id")).getOrElse("-1").toInt if (stageId >= 0 && killFlag && listener.activeStages.contains(stageId)) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala index 2d13bb6ddde4..9bf67db8acde 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala @@ -27,6 +27,8 @@ package org.apache.spark.ui.jobs private[spark] object TaskDetailsClassNames { val SCHEDULER_DELAY = "scheduler_delay" val TASK_DESERIALIZATION_TIME = "deserialization_time" + val SHUFFLE_READ_BLOCKED_TIME = "fetch_wait_time" + val SHUFFLE_READ_REMOTE_SIZE = "shuffle_read_remote" val RESULT_SERIALIZATION_TIME = "serialization_time" val GETTING_RESULT_TIME = "getting_result_time" } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 01f7e23212c3..711a3697bda1 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -31,9 +31,13 @@ private[jobs] object UIData { var failedTasks : Int = 0 var succeededTasks : Int = 0 var inputBytes : Long = 0 + var inputRecords : Long = 0 var outputBytes : Long = 0 + var outputRecords : Long = 0 var shuffleRead : Long = 0 + var shuffleReadRecords : Long = 0 var shuffleWrite : Long = 0 + var shuffleWriteRecords : Long = 0 var memoryBytesSpilled : Long = 0 var diskBytesSpilled : Long = 0 } @@ -73,9 +77,13 @@ private[jobs] object UIData { var executorRunTime: Long = _ var inputBytes: Long = _ + var inputRecords: Long = _ var outputBytes: Long = _ - var shuffleReadBytes: Long = _ + var outputRecords: Long = _ + var shuffleReadTotalBytes: Long = _ + var shuffleReadRecords : Long = _ var shuffleWriteBytes: Long = _ + var shuffleWriteRecords: Long = _ var memoryBytesSpilled: Long = _ var diskBytesSpilled: Long = _ @@ -85,6 +93,12 @@ private[jobs] object UIData { var accumulables = new HashMap[Long, AccumulableInfo] var taskData = new HashMap[Long, TaskUIData] var executorSummary = new HashMap[String, ExecutorSummary] + + def hasInput: Boolean = inputBytes > 0 + def hasOutput: Boolean = outputBytes > 0 + def hasShuffleRead: Boolean = shuffleReadTotalBytes > 0 + def hasShuffleWrite: Boolean = shuffleWriteBytes > 0 + def hasBytesSpilled: Boolean = memoryBytesSpilled > 0 && diskBytesSpilled > 0 } /** diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 12d23a92878c..199f731b92bc 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -30,7 +30,10 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { - val rddId = request.getParameter("id").toInt + val parameterId = request.getParameter("id") + require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") + + val rddId = parameterId.toInt val storageStatusList = listener.storageStatusList val rddInfo = listener.rddInfoList.find(_.id == rddId).getOrElse { // Rather than crashing, render an "RDD Not Found" page diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index a81291d50558..045bd784990d 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -40,10 +40,10 @@ private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storag class StorageListener(storageStatusListener: StorageStatusListener) extends SparkListener { private[ui] val _rddInfoMap = mutable.Map[Int, RDDInfo]() // exposed for testing - def storageStatusList = storageStatusListener.storageStatusList + def storageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList /** Filter RDD info to include only those with cached partitions */ - def rddInfoList = _rddInfoMap.values.filter(_.numCachedPartitions > 0).toSeq + def rddInfoList: Seq[RDDInfo] = _rddInfoMap.values.filter(_.numCachedPartitions > 0).toSeq /** Update the storage info of the RDDs whose blocks are among the given updated blocks */ private def updateRDDInfo(updatedBlocks: Seq[(BlockId, BlockStatus)]): Unit = { @@ -56,19 +56,19 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Spar * Assumes the storage status list is fully up-to-date. This implies the corresponding * StorageStatusSparkListener must process the SparkListenerTaskEnd event before this listener. */ - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { val metrics = taskEnd.taskMetrics if (metrics != null && metrics.updatedBlocks.isDefined) { updateRDDInfo(metrics.updatedBlocks.get) } } - override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = synchronized { + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized { val rddInfos = stageSubmitted.stageInfo.rddInfos rddInfos.foreach { info => _rddInfoMap.getOrElseUpdate(info.id, info) } } - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = synchronized { // Remove all partitions that are no longer cached in current completed stage val completedRddIds = stageCompleted.stageInfo.rddInfos.map(r => r.id).toSet _rddInfoMap.retain { case (id, info) => @@ -76,7 +76,7 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Spar } } - override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD) = synchronized { + override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit = synchronized { _rddInfoMap.remove(unpersistRDD.rddId) } } diff --git a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala index 332d0cbb2dc0..81a7cbde01ce 100644 --- a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala +++ b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala @@ -43,7 +43,13 @@ private[spark] trait ActorLogReceive { private val _receiveWithLogging = receiveWithLogging - override def isDefinedAt(o: Any): Boolean = _receiveWithLogging.isDefinedAt(o) + override def isDefinedAt(o: Any): Boolean = { + val handled = _receiveWithLogging.isDefinedAt(o) + if (!handled) { + log.debug(s"Received unexpected actor system event: $o") + } + handled + } override def apply(o: Any): Unit = { if (log.isDebugEnabled) { diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 4c9b1e3c46f0..b725df3b4459 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.util import scala.collection.JavaConversions.mapAsJavaMap import scala.concurrent.Await -import scala.concurrent.duration.{Duration, FiniteDuration} +import scala.concurrent.duration.FiniteDuration import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem} import akka.pattern.ask @@ -65,7 +65,8 @@ private[spark] object AkkaUtils extends Logging { val akkaThreads = conf.getInt("spark.akka.threads", 4) val akkaBatchSize = conf.getInt("spark.akka.batchSize", 15) - val akkaTimeout = conf.getInt("spark.akka.timeout", conf.getInt("spark.network.timeout", 120)) + val akkaTimeoutS = conf.getTimeAsSeconds("spark.akka.timeout", + conf.get("spark.network.timeout", "120s")) val akkaFrameSize = maxFrameSizeBytes(conf) val akkaLogLifecycleEvents = conf.getBoolean("spark.akka.logLifecycleEvents", false) val lifecycleEvents = if (akkaLogLifecycleEvents) "on" else "off" @@ -77,10 +78,8 @@ private[spark] object AkkaUtils extends Logging { val logAkkaConfig = if (conf.getBoolean("spark.akka.logAkkaConfig", false)) "on" else "off" - val akkaHeartBeatPauses = conf.getInt("spark.akka.heartbeat.pauses", 6000) - val akkaFailureDetector = - conf.getDouble("spark.akka.failure-detector.threshold", 300.0) - val akkaHeartBeatInterval = conf.getInt("spark.akka.heartbeat.interval", 1000) + val akkaHeartBeatPausesS = conf.getTimeAsSeconds("spark.akka.heartbeat.pauses", "6000s") + val akkaHeartBeatIntervalS = conf.getTimeAsSeconds("spark.akka.heartbeat.interval", "1000s") val secretKey = securityManager.getSecretKey() val isAuthOn = securityManager.isAuthenticationEnabled() @@ -91,8 +90,11 @@ private[spark] object AkkaUtils extends Logging { val secureCookie = if (isAuthOn) secretKey else "" logDebug(s"In createActorSystem, requireCookie is: $requireCookie") - val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap[String, String]).withFallback( - ConfigFactory.parseString( + val akkaSslConfig = securityManager.akkaSSLOptions.createAkkaConfig + .getOrElse(ConfigFactory.empty()) + + val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap[String, String]) + .withFallback(akkaSslConfig).withFallback(ConfigFactory.parseString( s""" |akka.daemonic = on |akka.loggers = [""akka.event.slf4j.Slf4jLogger""] @@ -100,15 +102,14 @@ private[spark] object AkkaUtils extends Logging { |akka.jvm-exit-on-fatal-error = off |akka.remote.require-cookie = "$requireCookie" |akka.remote.secure-cookie = "$secureCookie" - |akka.remote.transport-failure-detector.heartbeat-interval = $akkaHeartBeatInterval s - |akka.remote.transport-failure-detector.acceptable-heartbeat-pause = $akkaHeartBeatPauses s - |akka.remote.transport-failure-detector.threshold = $akkaFailureDetector + |akka.remote.transport-failure-detector.heartbeat-interval = $akkaHeartBeatIntervalS s + |akka.remote.transport-failure-detector.acceptable-heartbeat-pause = $akkaHeartBeatPausesS s |akka.actor.provider = "akka.remote.RemoteActorRefProvider" |akka.remote.netty.tcp.transport-class = "akka.remote.transport.netty.NettyTransport" |akka.remote.netty.tcp.hostname = "$host" |akka.remote.netty.tcp.port = $port |akka.remote.netty.tcp.tcp-nodelay = on - |akka.remote.netty.tcp.connection-timeout = $akkaTimeout s + |akka.remote.netty.tcp.connection-timeout = $akkaTimeoutS s |akka.remote.netty.tcp.maximum-frame-size = ${akkaFrameSize}B |akka.remote.netty.tcp.execution-pool-size = $akkaThreads |akka.actor.default-dispatcher.throughput = $akkaBatchSize @@ -124,16 +125,6 @@ private[spark] object AkkaUtils extends Logging { (actorSystem, boundPort) } - /** Returns the default Spark timeout to use for Akka ask operations. */ - def askTimeout(conf: SparkConf): FiniteDuration = { - Duration.create(conf.getLong("spark.akka.askTimeout", 30), "seconds") - } - - /** Returns the default Spark timeout to use for Akka remote actor lookup. */ - def lookupTimeout(conf: SparkConf): FiniteDuration = { - Duration.create(conf.getLong("spark.akka.lookupTimeout", 30), "seconds") - } - private val AKKA_MAX_FRAME_SIZE_IN_MB = Int.MaxValue / 1024 / 1024 /** Returns the configured max frame size for Akka messages in bytes. */ @@ -149,16 +140,6 @@ private[spark] object AkkaUtils extends Logging { /** Space reserved for extra data in an Akka message besides serialized task or task result. */ val reservedSizeBytes = 200 * 1024 - /** Returns the configured number of times to retry connecting */ - def numRetries(conf: SparkConf): Int = { - conf.getInt("spark.akka.num.retries", 3) - } - - /** Returns the configured number of milliseconds to wait on each retry */ - def retryWaitMs(conf: SparkConf): Int = { - conf.getInt("spark.akka.retry.wait", 3000) - } - /** * Send a message to the given actor and get its result within a default timeout, or * throw a SparkException if this fails. @@ -178,7 +159,7 @@ private[spark] object AkkaUtils extends Logging { message: Any, actor: ActorRef, maxAttempts: Int, - retryInterval: Int, + retryInterval: Long, timeout: FiniteDuration): T = { // TODO: Consider removing multiple attempts if (actor == null) { @@ -214,8 +195,8 @@ private[spark] object AkkaUtils extends Logging { val driverHost: String = conf.get("spark.driver.host", "localhost") val driverPort: Int = conf.getInt("spark.driver.port", 7077) Utils.checkHost(driverHost, "Expected hostname") - val url = s"akka.tcp://$driverActorSystemName@$driverHost:$driverPort/user/$name" - val timeout = AkkaUtils.lookupTimeout(conf) + val url = address(protocol(actorSystem), driverActorSystemName, driverHost, driverPort, name) + val timeout = RpcUtils.lookupTimeout(conf) logInfo(s"Connecting to $name: $url") Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) } @@ -228,9 +209,33 @@ private[spark] object AkkaUtils extends Logging { actorSystem: ActorSystem): ActorRef = { val executorActorSystemName = SparkEnv.executorActorSystemName Utils.checkHost(host, "Expected hostname") - val url = s"akka.tcp://$executorActorSystemName@$host:$port/user/$name" - val timeout = AkkaUtils.lookupTimeout(conf) + val url = address(protocol(actorSystem), executorActorSystemName, host, port, name) + val timeout = RpcUtils.lookupTimeout(conf) logInfo(s"Connecting to $name: $url") Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) } + + def protocol(actorSystem: ActorSystem): String = { + val akkaConf = actorSystem.settings.config + val sslProp = "akka.remote.netty.tcp.enable-ssl" + protocol(akkaConf.hasPath(sslProp) && akkaConf.getBoolean(sslProp)) + } + + def protocol(ssl: Boolean = false): String = { + if (ssl) { + "akka.ssl.tcp" + } else { + "akka.tcp" + } + } + + def address( + protocol: String, + systemName: String, + host: String, + port: Any, + actorName: String): String = { + s"$protocol://$systemName@$host:$port/user/$actorName" + } + } diff --git a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala new file mode 100644 index 000000000000..ce7887b76ff9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.concurrent._ +import java.util.concurrent.atomic.AtomicBoolean + +import com.google.common.annotations.VisibleForTesting +import org.apache.spark.SparkContext + +/** + * Asynchronously passes events to registered listeners. + * + * Until `start()` is called, all posted events are only buffered. Only after this listener bus + * has started will events be actually propagated to all attached listeners. This listener bus + * is stopped when `stop()` is called, and it will drop further events after stopping. + * + * @param name name of the listener bus, will be the name of the listener thread. + * @tparam L type of listener + * @tparam E type of event + */ +private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: String) + extends ListenerBus[L, E] { + + self => + + private var sparkContext: SparkContext = null + + /* Cap the capacity of the event queue so we get an explicit error (rather than + * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */ + private val EVENT_QUEUE_CAPACITY = 10000 + private val eventQueue = new LinkedBlockingQueue[E](EVENT_QUEUE_CAPACITY) + + // Indicate if `start()` is called + private val started = new AtomicBoolean(false) + // Indicate if `stop()` is called + private val stopped = new AtomicBoolean(false) + + // Indicate if we are processing some event + // Guarded by `self` + private var processingEvent = false + + // A counter that represents the number of events produced and consumed in the queue + private val eventLock = new Semaphore(0) + + private val listenerThread = new Thread(name) { + setDaemon(true) + override def run(): Unit = Utils.tryOrStopSparkContext(sparkContext) { + while (true) { + eventLock.acquire() + self.synchronized { + processingEvent = true + } + try { + val event = eventQueue.poll + if (event == null) { + // Get out of the while loop and shutdown the daemon thread + if (!stopped.get) { + throw new IllegalStateException("Polling `null` from eventQueue means" + + " the listener bus has been stopped. So `stopped` must be true") + } + return + } + postToAll(event) + } finally { + self.synchronized { + processingEvent = false + } + } + } + } + } + + /** + * Start sending events to attached listeners. + * + * This first sends out all buffered events posted before this listener bus has started, then + * listens for any additional events asynchronously while the listener bus is still running. + * This should only be called once. + * + * @param sc Used to stop the SparkContext in case the listener thread dies. + */ + def start(sc: SparkContext) { + if (started.compareAndSet(false, true)) { + sparkContext = sc + listenerThread.start() + } else { + throw new IllegalStateException(s"$name already started!") + } + } + + def post(event: E) { + if (stopped.get) { + // Drop further events to make `listenerThread` exit ASAP + logError(s"$name has already stopped! Dropping event $event") + return + } + val eventAdded = eventQueue.offer(event) + if (eventAdded) { + eventLock.release() + } else { + onDropEvent(event) + } + } + + /** + * For testing only. Wait until there are no more events in the queue, or until the specified + * time has elapsed. Return true if the queue has emptied and false is the specified time + * elapsed before the queue emptied. + */ + @VisibleForTesting + def waitUntilEmpty(timeoutMillis: Int): Boolean = { + val finishTime = System.currentTimeMillis + timeoutMillis + while (!queueIsEmpty) { + if (System.currentTimeMillis > finishTime) { + return false + } + /* Sleep rather than using wait/notify, because this is used only for testing and + * wait/notify add overhead in the general case. */ + Thread.sleep(10) + } + true + } + + /** + * For testing only. Return whether the listener daemon thread is still alive. + */ + @VisibleForTesting + def listenerThreadIsAlive: Boolean = listenerThread.isAlive + + /** + * Return whether the event queue is empty. + * + * The use of synchronized here guarantees that all events that once belonged to this queue + * have already been processed by all attached listeners, if this returns true. + */ + private def queueIsEmpty: Boolean = synchronized { eventQueue.isEmpty && !processingEvent } + + /** + * Stop the listener bus. It will wait until the queued events have been processed, but drop the + * new events after stopping. + */ + def stop() { + if (!started.get()) { + throw new IllegalStateException(s"Attempted to stop $name that has not yet started!") + } + if (stopped.compareAndSet(false, true)) { + // Call eventLock.release() so that listenerThread will poll `null` from `eventQueue` and know + // `stop` is called. + eventLock.release() + listenerThread.join() + } else { + // Keep quiet + } + } + + /** + * If the event queue exceeds its capacity, the new events will be dropped. The subclasses will be + * notified with the dropped events. + * + * Note: `onDropEvent` can be called in any thread. + */ + def onDropEvent(event: E): Unit +} diff --git a/core/src/main/scala/org/apache/spark/util/Clock.scala b/core/src/main/scala/org/apache/spark/util/Clock.scala index 97c2b45aabf2..e92ed11bd165 100644 --- a/core/src/main/scala/org/apache/spark/util/Clock.scala +++ b/core/src/main/scala/org/apache/spark/util/Clock.scala @@ -21,9 +21,47 @@ package org.apache.spark.util * An interface to represent clocks, so that they can be mocked out in unit tests. */ private[spark] trait Clock { - def getTime(): Long + def getTimeMillis(): Long + def waitTillTime(targetTime: Long): Long } -private[spark] object SystemClock extends Clock { - def getTime(): Long = System.currentTimeMillis() +/** + * A clock backed by the actual time from the OS as reported by the `System` API. + */ +private[spark] class SystemClock extends Clock { + + val minPollTime = 25L + + /** + * @return the same time (milliseconds since the epoch) + * as is reported by `System.currentTimeMillis()` + */ + def getTimeMillis(): Long = System.currentTimeMillis() + + /** + * @param targetTime block until the current time is at least this value + * @return current system time when wait has completed + */ + def waitTillTime(targetTime: Long): Long = { + var currentTime = 0L + currentTime = System.currentTimeMillis() + + var waitTime = targetTime - currentTime + if (waitTime <= 0) { + return currentTime + } + + val pollTime = math.max(waitTime / 10.0, minPollTime).toLong + + while (true) { + currentTime = System.currentTimeMillis() + waitTime = targetTime - currentTime + if (waitTime <= 0) { + return currentTime + } + val sleepTime = math.min(waitTime, pollTime) + Thread.sleep(sleepTime) + } + -1 + } } diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala index 390310243ee0..9044aaeef2d4 100644 --- a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala @@ -27,8 +27,8 @@ abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterat // scalastyle:on private[this] var completed = false - def next() = sub.next() - def hasNext = { + def next(): A = sub.next() + def hasNext: Boolean = { val r = sub.hasNext if (!r && !completed) { completed = true @@ -37,13 +37,13 @@ abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterat r } - def completion() + def completion(): Unit } private[spark] object CompletionIterator { - def apply[A, I <: Iterator[A]](sub: I, completionFunction: => Unit) : CompletionIterator[A,I] = { + def apply[A, I <: Iterator[A]](sub: I, completionFunction: => Unit) : CompletionIterator[A, I] = { new CompletionIterator[A,I](sub) { - def completion() = completionFunction + def completion(): Unit = completionFunction } } } diff --git a/core/src/main/scala/org/apache/spark/util/Distribution.scala b/core/src/main/scala/org/apache/spark/util/Distribution.scala index a465298c8c5a..9aea8efa38c7 100644 --- a/core/src/main/scala/org/apache/spark/util/Distribution.scala +++ b/core/src/main/scala/org/apache/spark/util/Distribution.scala @@ -57,7 +57,7 @@ private[spark] class Distribution(val data: Array[Double], val startIdx: Int, va out.println } - def statCounter = StatCounter(data.slice(startIdx, endIdx)) + def statCounter: StatCounter = StatCounter(data.slice(startIdx, endIdx)) /** * print a summary of this distribution to the given PrintStream. diff --git a/core/src/main/scala/org/apache/spark/util/EventLoop.scala b/core/src/main/scala/org/apache/spark/util/EventLoop.scala index b0ed908b8442..e9b2b8d24b47 100644 --- a/core/src/main/scala/org/apache/spark/util/EventLoop.scala +++ b/core/src/main/scala/org/apache/spark/util/EventLoop.scala @@ -76,9 +76,21 @@ private[spark] abstract class EventLoop[E](name: String) extends Logging { def stop(): Unit = { if (stopped.compareAndSet(false, true)) { eventThread.interrupt() - eventThread.join() - // Call onStop after the event thread exits to make sure onReceive happens before onStop - onStop() + var onStopCalled = false + try { + eventThread.join() + // Call onStop after the event thread exits to make sure onReceive happens before onStop + onStopCalled = true + onStop() + } catch { + case ie: InterruptedException => + Thread.currentThread().interrupt() + if (!onStopCalled) { + // ie is thrown from `eventThread.join()`. Otherwise, we should not call `onStop` since + // it's already called. + onStop() + } + } } else { // Keep quiet to allow calling `stop` multiple times. } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index f896b5072e4f..474f79fb756f 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -32,7 +32,6 @@ import org.apache.spark.executor._ import org.apache.spark.scheduler._ import org.apache.spark.storage._ import org.apache.spark._ -import org.apache.hadoop.hdfs.web.JsonUtil /** * Serializes SparkListener events to/from JSON. This protocol provides strong backwards- @@ -90,8 +89,9 @@ private[spark] object JsonProtocol { executorAddedToJson(executorAdded) case executorRemoved: SparkListenerExecutorRemoved => executorRemovedToJson(executorRemoved) + case logStart: SparkListenerLogStart => + logStartToJson(logStart) // These aren't used, but keeps compiler happy - case SparkListenerShutdown => JNothing case SparkListenerExecutorMetricsUpdate(_, _) => JNothing } } @@ -204,13 +204,21 @@ private[spark] object JsonProtocol { def executorAddedToJson(executorAdded: SparkListenerExecutorAdded): JValue = { ("Event" -> Utils.getFormattedClassName(executorAdded)) ~ + ("Timestamp" -> executorAdded.time) ~ ("Executor ID" -> executorAdded.executorId) ~ ("Executor Info" -> executorInfoToJson(executorAdded.executorInfo)) } def executorRemovedToJson(executorRemoved: SparkListenerExecutorRemoved): JValue = { ("Event" -> Utils.getFormattedClassName(executorRemoved)) ~ - ("Executor ID" -> executorRemoved.executorId) + ("Timestamp" -> executorRemoved.time) ~ + ("Executor ID" -> executorRemoved.executorId) ~ + ("Removed Reason" -> executorRemoved.reason) + } + + def logStartToJson(logStart: SparkListenerLogStart): JValue = { + ("Event" -> Utils.getFormattedClassName(logStart)) ~ + ("Spark Version" -> SPARK_VERSION) } /** ------------------------------------------------------------------- * @@ -292,22 +300,27 @@ private[spark] object JsonProtocol { ("Remote Blocks Fetched" -> shuffleReadMetrics.remoteBlocksFetched) ~ ("Local Blocks Fetched" -> shuffleReadMetrics.localBlocksFetched) ~ ("Fetch Wait Time" -> shuffleReadMetrics.fetchWaitTime) ~ - ("Remote Bytes Read" -> shuffleReadMetrics.remoteBytesRead) + ("Remote Bytes Read" -> shuffleReadMetrics.remoteBytesRead) ~ + ("Local Bytes Read" -> shuffleReadMetrics.localBytesRead) ~ + ("Total Records Read" -> shuffleReadMetrics.recordsRead) } def shuffleWriteMetricsToJson(shuffleWriteMetrics: ShuffleWriteMetrics): JValue = { ("Shuffle Bytes Written" -> shuffleWriteMetrics.shuffleBytesWritten) ~ - ("Shuffle Write Time" -> shuffleWriteMetrics.shuffleWriteTime) + ("Shuffle Write Time" -> shuffleWriteMetrics.shuffleWriteTime) ~ + ("Shuffle Records Written" -> shuffleWriteMetrics.shuffleRecordsWritten) } def inputMetricsToJson(inputMetrics: InputMetrics): JValue = { ("Data Read Method" -> inputMetrics.readMethod.toString) ~ - ("Bytes Read" -> inputMetrics.bytesRead) + ("Bytes Read" -> inputMetrics.bytesRead) ~ + ("Records Read" -> inputMetrics.recordsRead) } def outputMetricsToJson(outputMetrics: OutputMetrics): JValue = { ("Data Write Method" -> outputMetrics.writeMethod.toString) ~ - ("Bytes Written" -> outputMetrics.bytesWritten) + ("Bytes Written" -> outputMetrics.bytesWritten) ~ + ("Records Written" -> outputMetrics.recordsWritten) } def taskEndReasonToJson(taskEndReason: TaskEndReason): JValue = { @@ -382,7 +395,8 @@ private[spark] object JsonProtocol { def executorInfoToJson(executorInfo: ExecutorInfo): JValue = { ("Host" -> executorInfo.executorHost) ~ - ("Total Cores" -> executorInfo.totalCores) + ("Total Cores" -> executorInfo.totalCores) ~ + ("Log Urls" -> mapToJson(executorInfo.logUrlMap)) } /** ------------------------------ * @@ -440,6 +454,7 @@ private[spark] object JsonProtocol { val applicationEnd = Utils.getFormattedClassName(SparkListenerApplicationEnd) val executorAdded = Utils.getFormattedClassName(SparkListenerExecutorAdded) val executorRemoved = Utils.getFormattedClassName(SparkListenerExecutorRemoved) + val logStart = Utils.getFormattedClassName(SparkListenerLogStart) (json \ "Event").extract[String] match { case `stageSubmitted` => stageSubmittedFromJson(json) @@ -457,6 +472,7 @@ private[spark] object JsonProtocol { case `applicationEnd` => applicationEndFromJson(json) case `executorAdded` => executorAddedFromJson(json) case `executorRemoved` => executorRemovedFromJson(json) + case `logStart` => logStartFromJson(json) } } @@ -554,14 +570,22 @@ private[spark] object JsonProtocol { } def executorAddedFromJson(json: JValue): SparkListenerExecutorAdded = { + val time = (json \ "Timestamp").extract[Long] val executorId = (json \ "Executor ID").extract[String] val executorInfo = executorInfoFromJson(json \ "Executor Info") - SparkListenerExecutorAdded(executorId, executorInfo) + SparkListenerExecutorAdded(time, executorId, executorInfo) } def executorRemovedFromJson(json: JValue): SparkListenerExecutorRemoved = { + val time = (json \ "Timestamp").extract[Long] val executorId = (json \ "Executor ID").extract[String] - SparkListenerExecutorRemoved(executorId) + val reason = (json \ "Removed Reason").extract[String] + SparkListenerExecutorRemoved(time, executorId, reason) + } + + def logStartFromJson(json: JValue): SparkListenerLogStart = { + val sparkVersion = (json \ "Spark Version").extract[String] + SparkListenerLogStart(sparkVersion) } /** --------------------------------------------------------------------- * @@ -665,6 +689,8 @@ private[spark] object JsonProtocol { metrics.incLocalBlocksFetched((json \ "Local Blocks Fetched").extract[Int]) metrics.incFetchWaitTime((json \ "Fetch Wait Time").extract[Long]) metrics.incRemoteBytesRead((json \ "Remote Bytes Read").extract[Long]) + metrics.incLocalBytesRead((json \ "Local Bytes Read").extractOpt[Long].getOrElse(0)) + metrics.incRecordsRead((json \ "Total Records Read").extractOpt[Long].getOrElse(0)) metrics } @@ -672,13 +698,16 @@ private[spark] object JsonProtocol { val metrics = new ShuffleWriteMetrics metrics.incShuffleBytesWritten((json \ "Shuffle Bytes Written").extract[Long]) metrics.incShuffleWriteTime((json \ "Shuffle Write Time").extract[Long]) + metrics.setShuffleRecordsWritten((json \ "Shuffle Records Written") + .extractOpt[Long].getOrElse(0)) metrics } def inputMetricsFromJson(json: JValue): InputMetrics = { val metrics = new InputMetrics( DataReadMethod.withName((json \ "Data Read Method").extract[String])) - metrics.addBytesRead((json \ "Bytes Read").extract[Long]) + metrics.incBytesRead((json \ "Bytes Read").extract[Long]) + metrics.incRecordsRead((json \ "Records Read").extractOpt[Long].getOrElse(0)) metrics } @@ -686,6 +715,7 @@ private[spark] object JsonProtocol { val metrics = new OutputMetrics( DataWriteMethod.withName((json \ "Data Write Method").extract[String])) metrics.setBytesWritten((json \ "Bytes Written").extract[Long]) + metrics.setRecordsWritten((json \ "Records Written").extractOpt[Long].getOrElse(0)) metrics } @@ -788,7 +818,8 @@ private[spark] object JsonProtocol { def executorInfoFromJson(json: JValue): ExecutorInfo = { val executorHost = (json \ "Host").extract[String] val totalCores = (json \ "Total Cores").extract[Int] - new ExecutorInfo(executorHost, totalCores) + val logUrls = mapFromJson(json \ "Log Urls").toMap + new ExecutorInfo(executorHost, totalCores, logUrls) } /** -------------------------------- * diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala new file mode 100644 index 000000000000..a725767d08cc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.concurrent.CopyOnWriteArrayList + +import scala.collection.JavaConversions._ +import scala.reflect.ClassTag +import scala.util.control.NonFatal + +import org.apache.spark.Logging +import org.apache.spark.scheduler.SparkListener + +/** + * An event bus which posts events to its listeners. + */ +private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { + + // Marked `private[spark]` for access in tests. + private[spark] val listeners = new CopyOnWriteArrayList[L] + + /** + * Add a listener to listen events. This method is thread-safe and can be called in any thread. + */ + final def addListener(listener: L) { + listeners.add(listener) + } + + /** + * Post the event to all registered listeners. The `postToAll` caller should guarantee calling + * `postToAll` in the same thread for all events. + */ + final def postToAll(event: E): Unit = { + // JavaConversions will create a JIterableWrapper if we use some Scala collection functions. + // However, this method will be called frequently. To avoid the wrapper cost, here ewe use + // Java Iterator directly. + val iter = listeners.iterator + while (iter.hasNext) { + val listener = iter.next() + try { + onPostEvent(listener, event) + } catch { + case NonFatal(e) => + logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e) + } + } + } + + /** + * Post an event to the specified listener. `onPostEvent` is guaranteed to be called in the same + * thread. + */ + def onPostEvent(listener: L, event: E): Unit + + private[spark] def findListenersByClass[T <: L : ClassTag](): Seq[T] = { + val c = implicitly[ClassTag[T]].runtimeClass + listeners.filter(_.getClass == c).map(_.asInstanceOf[T]).toSeq + } + +} diff --git a/core/src/main/scala/org/apache/spark/util/ManualClock.scala b/core/src/main/scala/org/apache/spark/util/ManualClock.scala new file mode 100644 index 000000000000..171855406198 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ManualClock.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * A `Clock` whose time can be manually set and modified. Its reported time does not change + * as time elapses, but only as its time is modified by callers. This is mainly useful for + * testing. + * + * @param time initial time (in milliseconds since the epoch) + */ +private[spark] class ManualClock(private var time: Long) extends Clock { + + /** + * @return `ManualClock` with initial time 0 + */ + def this() = this(0L) + + def getTimeMillis(): Long = + synchronized { + time + } + + /** + * @param timeToSet new time (in milliseconds) that the clock should represent + */ + def setTime(timeToSet: Long): Unit = synchronized { + time = timeToSet + notifyAll() + } + + /** + * @param timeToAdd time (in milliseconds) to add to the clock's time + */ + def advance(timeToAdd: Long): Unit = synchronized { + time += timeToAdd + notifyAll() + } + + /** + * @param targetTime block until the clock time is set or advanced to at least this time + * @return current time reported by the clock when waiting finishes + */ + def waitTillTime(targetTime: Long): Long = synchronized { + while (time < targetTime) { + wait(100) + } + getTimeMillis() + } +} diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index ac40f19ed679..2bbfc988a99a 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -67,15 +67,16 @@ private[spark] object MetadataCleanerType extends Enumeration { type MetadataCleanerType = Value - def systemProperty(which: MetadataCleanerType.MetadataCleanerType) = - "spark.cleaner.ttl." + which.toString + def systemProperty(which: MetadataCleanerType.MetadataCleanerType): String = { + "spark.cleaner.ttl." + which.toString + } } // TODO: This mutates a Conf to set properties right now, which is kind of ugly when used in the // initialization of StreamingContext. It's okay for users trying to configure stuff themselves. private[spark] object MetadataCleaner { - def getDelaySeconds(conf: SparkConf) = { - conf.getInt("spark.cleaner.ttl", -1) + def getDelaySeconds(conf: SparkConf): Int = { + conf.getTimeAsSeconds("spark.cleaner.ttl", "-1").toInt } def getDelaySeconds( diff --git a/core/src/main/scala/org/apache/spark/util/MutablePair.scala b/core/src/main/scala/org/apache/spark/util/MutablePair.scala index 74fa77b68de0..dad888548ed1 100644 --- a/core/src/main/scala/org/apache/spark/util/MutablePair.scala +++ b/core/src/main/scala/org/apache/spark/util/MutablePair.scala @@ -43,7 +43,7 @@ case class MutablePair[@specialized(Int, Long, Double, Char, Boolean/* , AnyRef this } - override def toString = "(" + _1 + "," + _2 + ")" + override def toString: String = "(" + _1 + "," + _2 + ")" override def canEqual(that: Any): Boolean = that.isInstanceOf[MutablePair[_,_]] } diff --git a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala new file mode 100644 index 000000000000..1e0ba5c28754 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.net.{URLClassLoader, URL} +import java.util.Enumeration +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConversions._ + +/** + * URL class loader that exposes the `addURL` and `getURLs` methods in URLClassLoader. + */ +private[spark] class MutableURLClassLoader(urls: Array[URL], parent: ClassLoader) + extends URLClassLoader(urls, parent) { + + override def addURL(url: URL): Unit = { + super.addURL(url) + } + + override def getURLs(): Array[URL] = { + super.getURLs() + } + +} + +/** + * A mutable class loader that gives preference to its own URLs over the parent class loader + * when loading classes and resources. + */ +private[spark] class ChildFirstURLClassLoader(urls: Array[URL], parent: ClassLoader) + extends MutableURLClassLoader(urls, null) { + + private val parentClassLoader = new ParentClassLoader(parent) + + /** + * Used to implement fine-grained class loading locks similar to what is done by Java 7. This + * prevents deadlock issues when using non-hierarchical class loaders. + * + * Note that due to Java 6 compatibility (and some issues with implementing class loaders in + * Scala), Java 7's `ClassLoader.registerAsParallelCapable` method is not called. + */ + private val locks = new ConcurrentHashMap[String, Object]() + + override def loadClass(name: String, resolve: Boolean): Class[_] = { + var lock = locks.get(name) + if (lock == null) { + val newLock = new Object() + lock = locks.putIfAbsent(name, newLock) + if (lock == null) { + lock = newLock + } + } + + lock.synchronized { + try { + super.loadClass(name, resolve) + } catch { + case e: ClassNotFoundException => + parentClassLoader.loadClass(name, resolve) + } + } + } + + override def getResource(name: String): URL = { + val url = super.findResource(name) + val res = if (url != null) url else parentClassLoader.getResource(name) + res + } + + override def getResources(name: String): Enumeration[URL] = { + val urls = super.findResources(name) + val res = + if (urls != null && urls.hasMoreElements()) { + urls + } else { + parentClassLoader.getResources(name) + } + res + } + + override def addURL(url: URL) { + super.addURL(url) + } + +} diff --git a/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala b/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala index 3abc12681fe9..73d126ff6254 100644 --- a/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala +++ b/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala @@ -18,15 +18,20 @@ package org.apache.spark.util /** - * A class loader which makes findClass accesible to the child + * A class loader which makes some protected methods in ClassLoader accesible. */ private[spark] class ParentClassLoader(parent: ClassLoader) extends ClassLoader(parent) { - override def findClass(name: String) = { + override def findClass(name: String): Class[_] = { super.findClass(name) } override def loadClass(name: String): Class[_] = { super.loadClass(name) } + + override def loadClass(name: String, resolve: Boolean): Class[_] = { + super.loadClass(name, resolve) + } + } diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala new file mode 100644 index 000000000000..f16cc8e7e42c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.apache.spark.{SparkEnv, SparkConf} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv} + +object RpcUtils { + + /** + * Retrieve a [[RpcEndpointRef]] which is located in the driver via its name. + */ + def makeDriverRef(name: String, conf: SparkConf, rpcEnv: RpcEnv): RpcEndpointRef = { + val driverActorSystemName = SparkEnv.driverActorSystemName + val driverHost: String = conf.get("spark.driver.host", "localhost") + val driverPort: Int = conf.getInt("spark.driver.port", 7077) + Utils.checkHost(driverHost, "Expected hostname") + rpcEnv.setupEndpointRef(driverActorSystemName, RpcAddress(driverHost, driverPort), name) + } + + /** Returns the configured number of times to retry connecting */ + def numRetries(conf: SparkConf): Int = { + conf.getInt("spark.rpc.numRetries", 3) + } + + /** Returns the configured number of milliseconds to wait on each retry */ + def retryWaitMs(conf: SparkConf): Long = { + conf.getTimeAsMs("spark.rpc.retry.wait", "3s") + } + + /** Returns the default Spark timeout to use for RPC ask operations. */ + def askTimeout(conf: SparkConf): FiniteDuration = { + conf.getTimeAsSeconds("spark.rpc.askTimeout", + conf.get("spark.network.timeout", "120s")) seconds + } + + /** Returns the default Spark timeout to use for RPC remote endpoint lookup. */ + def lookupTimeout(conf: SparkConf): FiniteDuration = { + conf.getTimeAsSeconds("spark.rpc.lookupTimeout", + conf.get("spark.network.timeout", "120s")) seconds + } +} diff --git a/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala b/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala index 770ff9d5ad6a..a06b6f84ef11 100644 --- a/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala @@ -27,7 +27,7 @@ import java.nio.channels.Channels */ private[spark] class SerializableBuffer(@transient var buffer: ByteBuffer) extends Serializable { - def value = buffer + def value: ByteBuffer = buffer private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { val length = in.readInt() diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index bce3b3afe9ab..26ffbf935038 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -18,18 +18,16 @@ package org.apache.spark.util import java.lang.management.ManagementFactory -import java.lang.reflect.{Array => JArray} -import java.lang.reflect.Field -import java.lang.reflect.Modifier -import java.util.IdentityHashMap -import java.util.Random +import java.lang.reflect.{Field, Modifier} +import java.util.{IdentityHashMap, Random} import java.util.concurrent.ConcurrentHashMap - import scala.collection.mutable.ArrayBuffer +import scala.runtime.ScalaRunTime import org.apache.spark.Logging import org.apache.spark.util.collection.OpenHashSet + /** * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in * memory-aware caches. @@ -184,9 +182,9 @@ private[spark] object SizeEstimator extends Logging { private val ARRAY_SIZE_FOR_SAMPLING = 200 private val ARRAY_SAMPLE_SIZE = 100 // should be lower than ARRAY_SIZE_FOR_SAMPLING - private def visitArray(array: AnyRef, cls: Class[_], state: SearchState) { - val length = JArray.getLength(array) - val elementClass = cls.getComponentType + private def visitArray(array: AnyRef, arrayClass: Class[_], state: SearchState) { + val length = ScalaRunTime.array_length(array) + val elementClass = arrayClass.getComponentType() // Arrays have object header and length field which is an integer var arrSize: Long = alignSize(objectSize + INT_SIZE) @@ -199,22 +197,26 @@ private[spark] object SizeEstimator extends Logging { state.size += arrSize if (length <= ARRAY_SIZE_FOR_SAMPLING) { - for (i <- 0 until length) { - state.enqueue(JArray.get(array, i)) + var arrayIndex = 0 + while (arrayIndex < length) { + state.enqueue(ScalaRunTime.array_apply(array, arrayIndex).asInstanceOf[AnyRef]) + arrayIndex += 1 } } else { // Estimate the size of a large array by sampling elements without replacement. var size = 0.0 val rand = new Random(42) val drawn = new OpenHashSet[Int](ARRAY_SAMPLE_SIZE) - for (i <- 0 until ARRAY_SAMPLE_SIZE) { + var numElementsDrawn = 0 + while (numElementsDrawn < ARRAY_SAMPLE_SIZE) { var index = 0 do { index = rand.nextInt(length) } while (drawn.contains(index)) drawn.add(index) - val elem = JArray.get(array, index) + val elem = ScalaRunTime.array_apply(array, index).asInstanceOf[AnyRef] size += SizeEstimator.estimate(elem, state.visited) + numElementsDrawn += 1 } state.size += ((length / (ARRAY_SAMPLE_SIZE * 1.0)) * size).toLong } diff --git a/core/src/main/scala/org/apache/spark/util/StatCounter.scala b/core/src/main/scala/org/apache/spark/util/StatCounter.scala index d80eed455c42..8586da1996cf 100644 --- a/core/src/main/scala/org/apache/spark/util/StatCounter.scala +++ b/core/src/main/scala/org/apache/spark/util/StatCounter.scala @@ -141,8 +141,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { object StatCounter { /** Build a StatCounter from a list of values. */ - def apply(values: TraversableOnce[Double]) = new StatCounter(values) + def apply(values: TraversableOnce[Double]): StatCounter = new StatCounter(values) /** Build a StatCounter from a list of values passed as variable-length arguments. */ - def apply(values: Double*) = new StatCounter(values) + def apply(values: Double*): StatCounter = new StatCounter(values) } diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala new file mode 100644 index 000000000000..098a4b79496b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.util + +import java.util.concurrent._ + +import com.google.common.util.concurrent.ThreadFactoryBuilder + +private[spark] object ThreadUtils { + + /** + * Create a thread factory that names threads with a prefix and also sets the threads to daemon. + */ + def namedThreadFactory(prefix: String): ThreadFactory = { + new ThreadFactoryBuilder().setDaemon(true).setNameFormat(prefix + "-%d").build() + } + + /** + * Wrapper over newCachedThreadPool. Thread names are formatted as prefix-ID, where ID is a + * unique, sequentially assigned integer. + */ + def newDaemonCachedThreadPool(prefix: String): ThreadPoolExecutor = { + val threadFactory = namedThreadFactory(prefix) + Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor] + } + + /** + * Wrapper over newFixedThreadPool. Thread names are formatted as prefix-ID, where ID is a + * unique, sequentially assigned integer. + */ + def newDaemonFixedThreadPool(nThreads: Int, prefix: String): ThreadPoolExecutor = { + val threadFactory = namedThreadFactory(prefix) + Executors.newFixedThreadPool(nThreads, threadFactory).asInstanceOf[ThreadPoolExecutor] + } + + /** + * Wrapper over newSingleThreadExecutor. + */ + def newDaemonSingleThreadExecutor(threadName: String): ExecutorService = { + val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build() + Executors.newSingleThreadExecutor(threadFactory) + } + + /** + * Wrapper over newSingleThreadScheduledExecutor. + */ + def newDaemonSingleThreadScheduledExecutor(threadName: String): ScheduledExecutorService = { + val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build() + Executors.newSingleThreadScheduledExecutor(threadFactory) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala index f5be5856c210..310c0c109416 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala @@ -82,7 +82,7 @@ private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boo this } - override def update(key: A, value: B) = this += ((key, value)) + override def update(key: A, value: B): Unit = this += ((key, value)) override def apply(key: A): B = internalMap.apply(key) @@ -92,14 +92,14 @@ private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boo override def size: Int = internalMap.size - override def foreach[U](f: ((A, B)) => U) = nonNullReferenceMap.foreach(f) + override def foreach[U](f: ((A, B)) => U): Unit = nonNullReferenceMap.foreach(f) def putIfAbsent(key: A, value: B): Option[B] = internalMap.putIfAbsent(key, value) def toMap: Map[A, B] = iterator.toMap /** Remove old key-value pairs with timestamps earlier than `threshTime`. */ - def clearOldValues(threshTime: Long) = internalMap.clearOldValues(threshTime) + def clearOldValues(threshTime: Long): Unit = internalMap.clearOldValues(threshTime) /** Remove entries with values that are no longer strongly reachable. */ def clearNullValues() { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 2c04e4ddfbcb..667aa168e7ef 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -21,29 +21,34 @@ import java.io._ import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer -import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor} -import java.util.{Locale, Properties, Random, UUID} +import java.util.{PriorityQueue, Properties, Locale, Random, UUID} +import java.util.concurrent._ +import javax.net.ssl.HttpsURLConnection import scala.collection.JavaConversions._ import scala.collection.Map import scala.collection.mutable.ArrayBuffer import scala.io.Source import scala.reflect.ClassTag -import scala.util.Try +import scala.util.{Failure, Success, Try} import scala.util.control.{ControlThrowable, NonFatal} import com.google.common.io.{ByteStreams, Files} -import com.google.common.util.concurrent.ThreadFactoryBuilder +import com.google.common.net.InetAddresses import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} +import org.apache.hadoop.security.UserGroupInformation import org.apache.log4j.PropertyConfigurator import org.eclipse.jetty.util.MultiException import org.json4s._ + +import tachyon.TachyonURI import tachyon.client.{TachyonFS, TachyonFile} import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} /** CallSite represents a place in user code. It can have a short and a long form. */ @@ -60,6 +65,15 @@ private[spark] object CallSite { private[spark] object Utils extends Logging { val random = new Random() + val DEFAULT_SHUTDOWN_PRIORITY = 100 + + private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 + @volatile private var localRootDirs: Array[String] = null + + + private val shutdownHooks = new SparkShutdownHookManager() + shutdownHooks.install() + /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -80,7 +94,7 @@ private[spark] object Utils extends Logging { def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { val bis = new ByteArrayInputStream(bytes) val ois = new ObjectInputStream(bis) { - override def resolveClass(desc: ObjectStreamClass) = + override def resolveClass(desc: ObjectStreamClass): Class[_] = Class.forName(desc.getName, false, loader) } ois.readObject.asInstanceOf[T] @@ -101,11 +115,10 @@ private[spark] object Utils extends Logging { /** Serialize via nested stream using specific serializer */ def serializeViaNestedStream(os: OutputStream, ser: SerializerInstance)( - f: SerializationStream => Unit) = { + f: SerializationStream => Unit): Unit = { val osWrapper = ser.serializeStream(new OutputStream { - def write(b: Int) = os.write(b) - - override def write(b: Array[Byte], off: Int, len: Int) = os.write(b, off, len) + override def write(b: Int): Unit = os.write(b) + override def write(b: Array[Byte], off: Int, len: Int): Unit = os.write(b, off, len) }) try { f(osWrapper) @@ -116,10 +129,9 @@ private[spark] object Utils extends Logging { /** Deserialize via nested stream using specific serializer */ def deserializeViaNestedStream(is: InputStream, ser: SerializerInstance)( - f: DeserializationStream => Unit) = { + f: DeserializationStream => Unit): Unit = { val isWrapper = ser.deserializeStream(new InputStream { - def read(): Int = is.read() - + override def read(): Int = is.read() override def read(b: Array[Byte], off: Int, len: Int): Int = is.read(b, off, len) }) try { @@ -132,7 +144,7 @@ private[spark] object Utils extends Logging { /** * Get the ClassLoader which loaded Spark. */ - def getSparkClassLoader = getClass.getClassLoader + def getSparkClassLoader: ClassLoader = getClass.getClassLoader /** * Get the Context ClassLoader on this thread or, if not present, the ClassLoader that @@ -141,7 +153,7 @@ private[spark] object Utils extends Logging { * This should be used whenever passing a ClassLoader to Class.ForName or finding the currently * active loader when setting up ClassLoader delegation chains. */ - def getContextOrSparkClassLoader = + def getContextOrSparkClassLoader: ClassLoader = Option(Thread.currentThread().getContextClassLoader).getOrElse(getSparkClassLoader) /** Determines whether the provided class is loadable in the current thread. */ @@ -150,12 +162,14 @@ private[spark] object Utils extends Logging { } /** Preferred alternative to Class.forName(className) */ - def classForName(className: String) = Class.forName(className, true, getContextOrSparkClassLoader) + def classForName(className: String): Class[_] = { + Class.forName(className, true, getContextOrSparkClassLoader) + } /** * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]] */ - def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput) = { + def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput): Unit = { if (bb.hasArray) { out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) } else { @@ -169,18 +183,16 @@ private[spark] object Utils extends Logging { private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() // Add a shutdown hook to delete the temp dirs when the JVM exits - Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dirs") { - override def run(): Unit = Utils.logUncaughtExceptions { - logDebug("Shutdown hook called") - shutdownDeletePaths.foreach { dirPath => - try { - Utils.deleteRecursively(new File(dirPath)) - } catch { - case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e) - } + addShutdownHook { () => + logDebug("Shutdown hook called") + shutdownDeletePaths.foreach { dirPath => + try { + Utils.deleteRecursively(new File(dirPath)) + } catch { + case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e) } } - }) + } // Register the path to be deleted via shutdown hook def registerShutdownDeleteDir(file: File) { @@ -209,8 +221,8 @@ private[spark] object Utils extends Logging { // Is the path already registered to be deleted via a shutdown hook ? def hasShutdownDeleteTachyonDir(file: TachyonFile): Boolean = { val absolutePath = file.getPath() - shutdownDeletePaths.synchronized { - shutdownDeletePaths.contains(absolutePath) + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths.contains(absolutePath) } } @@ -246,13 +258,28 @@ private[spark] object Utils extends Logging { retval } + /** + * JDK equivalent of `chmod 700 file`. + * + * @param file the file whose permissions will be modified + * @return true if the permissions were successfully changed, false otherwise. + */ + def chmod700(file: File): Boolean = { + file.setReadable(false, false) && + file.setReadable(true, true) && + file.setWritable(false, false) && + file.setWritable(true, true) && + file.setExecutable(false, false) && + file.setExecutable(true, true) + } + /** * Create a directory inside the given parent directory. The directory is guaranteed to be * newly created, and is not marked for automatic deletion. */ - def createDirectory(root: String): File = { + def createDirectory(root: String, namePrefix: String = "spark"): File = { var attempts = 0 - val maxAttempts = 10 + val maxAttempts = MAX_DIR_CREATION_ATTEMPTS var dir: File = null while (dir == null) { attempts += 1 @@ -261,22 +288,24 @@ private[spark] object Utils extends Logging { maxAttempts + " attempts!") } try { - dir = new File(root, "spark-" + UUID.randomUUID.toString) + dir = new File(root, namePrefix + "-" + UUID.randomUUID.toString) if (dir.exists() || !dir.mkdirs()) { dir = null } } catch { case e: SecurityException => dir = null; } } - dir + dir.getCanonicalFile } /** * Create a temporary directory inside the given parent directory. The directory will be * automatically deleted when the VM shuts down. */ - def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = { - val dir = createDirectory(root) + def createTempDir( + root: String = System.getProperty("java.io.tmpdir"), + namePrefix: String = "spark"): File = { + val dir = createDirectory(root, namePrefix) registerShutdownDeleteDir(dir) dir } @@ -291,7 +320,7 @@ private[spark] object Utils extends Logging { transferToEnabled: Boolean = false): Long = { var count = 0L - try { + tryWithSafeFinally { if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream] && transferToEnabled) { // When both streams are File stream, use transferTo to improve copy performance. @@ -331,7 +360,7 @@ private[spark] object Utils extends Logging { } } count - } finally { + } { if (closeStreams) { try { in.close() @@ -359,8 +388,10 @@ private[spark] object Utils extends Logging { } /** - * Download a file to target directory. Supports fetching the file in a variety of ways, - * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. + * Download a file or directory to target directory. Supports fetching the file in a variety of + * ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based + * on the URL parameter. Fetching directories is only supported from Hadoop-compatible + * filesystems. * * If `useCache` is true, first attempts to fetch the file to a local cache that's shared * across executors running the same application. `useCache` is used mainly for @@ -379,7 +410,8 @@ private[spark] object Utils extends Logging { useCache: Boolean) { val fileName = url.split("/").last val targetFile = new File(targetDir, fileName) - if (useCache) { + val fetchCacheEnabled = conf.getBoolean("spark.files.useFetchCache", defaultValue = true) + if (useCache && fetchCacheEnabled) { val cachedFileName = s"${url.hashCode}${timestamp}_cache" val lockFileName = s"${url.hashCode}${timestamp}_lock" val localDir = new File(getLocalDir(conf)) @@ -410,13 +442,19 @@ private[spark] object Utils extends Logging { // Decompress the file if it's a .tar or .tar.gz if (fileName.endsWith(".tar.gz") || fileName.endsWith(".tgz")) { logInfo("Untarring " + fileName) - Utils.execute(Seq("tar", "-xzf", fileName), targetDir) + executeAndGetOutput(Seq("tar", "-xzf", fileName), targetDir) } else if (fileName.endsWith(".tar")) { logInfo("Untarring " + fileName) - Utils.execute(Seq("tar", "-xf", fileName), targetDir) + executeAndGetOutput(Seq("tar", "-xf", fileName), targetDir) } // Make the file executable - That's necessary for scripts FileUtil.chmod(targetFile.getAbsolutePath, "a+x") + + // Windows does not grant read permission by default to non-admin users + // Add read permission to owner explicitly + if (isWindows) { + FileUtil.chmod(targetFile.getAbsolutePath, "u+r") + } } /** @@ -429,7 +467,6 @@ private[spark] object Utils extends Logging { * * @param url URL that `sourceFile` originated from, for logging purposes. * @param in InputStream to download. - * @param tempFile File path to download `in` to. * @param destFile File path to move `tempFile` to. * @param fileOverwrite Whether to delete/overwrite an existing `destFile` that does not match * `sourceFile` @@ -437,9 +474,11 @@ private[spark] object Utils extends Logging { private def downloadFile( url: String, in: InputStream, - tempFile: File, destFile: File, fileOverwrite: Boolean): Unit = { + val tempFile = File.createTempFile("fetchFileTemp", null, + new File(destFile.getParentFile.getAbsolutePath)) + logInfo(s"Fetching $url to $tempFile") try { val out = new FileOutputStream(tempFile) @@ -478,7 +517,7 @@ private[spark] object Utils extends Logging { removeSourceFile: Boolean = false): Unit = { if (destFile.exists) { - if (!Files.equal(sourceFile, destFile)) { + if (!filesEqualRecursive(sourceFile, destFile)) { if (fileOverwrite) { logInfo( s"File $destFile exists and does not match contents of $url, replacing it with $url" @@ -513,13 +552,44 @@ private[spark] object Utils extends Logging { Files.move(sourceFile, destFile) } else { logInfo(s"Copying ${sourceFile.getAbsolutePath} to ${destFile.getAbsolutePath}") - Files.copy(sourceFile, destFile) + copyRecursive(sourceFile, destFile) + } + } + + private def filesEqualRecursive(file1: File, file2: File): Boolean = { + if (file1.isDirectory && file2.isDirectory) { + val subfiles1 = file1.listFiles() + val subfiles2 = file2.listFiles() + if (subfiles1.size != subfiles2.size) { + return false + } + subfiles1.sortBy(_.getName).zip(subfiles2.sortBy(_.getName)).forall { + case (f1, f2) => filesEqualRecursive(f1, f2) + } + } else if (file1.isFile && file2.isFile) { + Files.equal(file1, file2) + } else { + false + } + } + + private def copyRecursive(source: File, dest: File): Unit = { + if (source.isDirectory) { + if (!dest.mkdir()) { + throw new IOException(s"Failed to create directory ${dest.getPath}") + } + val subfiles = source.listFiles() + subfiles.foreach(f => copyRecursive(f, new File(dest, f.getName))) + } else { + Files.copy(source, dest) } } /** - * Download a file to target directory. Supports fetching the file in a variety of ways, - * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. + * Download a file or directory to target directory. Supports fetching the file in a variety of + * ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based + * on the URL parameter. Fetching directories is only supported from Hadoop-compatible + * filesystems. * * Throws SparkException if the target file already exists and has different contents than * the requested file. @@ -531,14 +601,11 @@ private[spark] object Utils extends Logging { conf: SparkConf, securityMgr: SecurityManager, hadoopConf: Configuration) { - val tempFile = File.createTempFile("fetchFileTemp", null, new File(targetDir.getAbsolutePath)) val targetFile = new File(targetDir, filename) val uri = new URI(url) val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false) Option(uri.getScheme).getOrElse("file") match { case "http" | "https" | "ftp" => - logInfo("Fetching " + url + " to " + tempFile) - var uc: URLConnection = null if (securityMgr.isAuthenticationEnabled()) { logDebug("fetchFile with security enabled") @@ -549,23 +616,56 @@ private[spark] object Utils extends Logging { logDebug("fetchFile not using security") uc = new URL(url).openConnection() } + Utils.setupSecureURLConnection(uc, securityMgr) - val timeout = conf.getInt("spark.files.fetchTimeout", 60) * 1000 - uc.setConnectTimeout(timeout) - uc.setReadTimeout(timeout) + val timeoutMs = + conf.getTimeAsSeconds("spark.files.fetchTimeout", "60s").toInt * 1000 + uc.setConnectTimeout(timeoutMs) + uc.setReadTimeout(timeoutMs) uc.connect() val in = uc.getInputStream() - downloadFile(url, in, tempFile, targetFile, fileOverwrite) + downloadFile(url, in, targetFile, fileOverwrite) case "file" => // In the case of a local file, copy the local file to the target directory. // Note the difference between uri vs url. val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url) copyFile(url, sourceFile, targetFile, fileOverwrite) case _ => - // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others val fs = getHadoopFileSystem(uri, hadoopConf) - val in = fs.open(new Path(uri)) - downloadFile(url, in, tempFile, targetFile, fileOverwrite) + val path = new Path(uri) + fetchHcfsFile(path, targetDir, fs, conf, hadoopConf, fileOverwrite, + filename = Some(filename)) + } + } + + /** + * Fetch a file or directory from a Hadoop-compatible filesystem. + * + * Visible for testing + */ + private[spark] def fetchHcfsFile( + path: Path, + targetDir: File, + fs: FileSystem, + conf: SparkConf, + hadoopConf: Configuration, + fileOverwrite: Boolean, + filename: Option[String] = None): Unit = { + if (!targetDir.exists() && !targetDir.mkdir()) { + throw new IOException(s"Failed to create directory ${targetDir.getPath}") + } + val dest = new File(targetDir, filename.getOrElse(path.getName)) + if (fs.isFile(path)) { + val in = fs.open(path) + try { + downloadFile(path.toString, in, dest, fileOverwrite) + } finally { + in.close() + } + } else { + fs.listStatus(path).foreach { fileStatus => + fetchHcfsFile(fileStatus.getPath(), dest, fs, conf, hadoopConf, fileOverwrite) + } } } @@ -597,28 +697,56 @@ private[spark] object Utils extends Logging { * and returns only the directories that exist / could be created. * * If no directories could be created, this will return an empty list. + * + * This method will cache the local directories for the application when it's first invoked. + * So calling it multiple times with a different configuration will always return the same + * set of directories. */ private[spark] def getOrCreateLocalRootDirs(conf: SparkConf): Array[String] = { - val confValue = if (isRunningInYarnContainer(conf)) { + if (localRootDirs == null) { + this.synchronized { + if (localRootDirs == null) { + localRootDirs = getOrCreateLocalRootDirsImpl(conf) + } + } + } + localRootDirs + } + + private def getOrCreateLocalRootDirsImpl(conf: SparkConf): Array[String] = { + if (isRunningInYarnContainer(conf)) { // If we are in yarn mode, systems can have different disk layouts so we must set it - // to what Yarn on this system said was available. - getYarnLocalDirs(conf) + // to what Yarn on this system said was available. Note this assumes that Yarn has + // created the directories already, and that they are secured so that only the + // user has access to them. + getYarnLocalDirs(conf).split(",") + } else if (conf.getenv("SPARK_EXECUTOR_DIRS") != null) { + conf.getenv("SPARK_EXECUTOR_DIRS").split(File.pathSeparator) } else { - Option(conf.getenv("SPARK_LOCAL_DIRS")).getOrElse( - conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))) - } - val rootDirs = confValue.split(',') - logDebug(s"Getting/creating local root dirs at '$confValue'") - - rootDirs.flatMap { rootDir => - val localDir: File = new File(rootDir) - val foundLocalDir = localDir.exists || localDir.mkdirs() - if (!foundLocalDir) { - logError(s"Failed to create local root dir in $rootDir. Ignoring this directory.") - None - } else { - Some(rootDir) - } + // In non-Yarn mode (or for the driver in yarn-client mode), we cannot trust the user + // configuration to point to a secure directory. So create a subdirectory with restricted + // permissions under each listed directory. + Option(conf.getenv("SPARK_LOCAL_DIRS")) + .getOrElse(conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))) + .split(",") + .flatMap { root => + try { + val rootDir = new File(root) + if (rootDir.exists || rootDir.mkdirs()) { + val dir = createTempDir(root) + chmod700(dir) + Some(dir.getAbsolutePath) + } else { + logError(s"Failed to create dir in $root. Ignoring this directory.") + None + } + } catch { + case e: IOException => + logError(s"Failed to create local root dir in $root. Ignoring this directory.") + None + } + } + .toArray } } @@ -637,6 +765,11 @@ private[spark] object Utils extends Logging { localDirs } + /** Used by unit tests. Do not call from other places. */ + private[spark] def clearLocalRootDirs(): Unit = { + localRootDirs = null + } + /** * Shuffle the elements of a collection into a random order, returning the * result in a new collection. Unlike scala.util.Random.shuffle, this method @@ -664,13 +797,12 @@ private[spark] object Utils extends Logging { * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4). * Note, this is typically not used from within core spark. */ - lazy val localIpAddress: String = findLocalIpAddress() - lazy val localIpAddressHostname: String = getAddressHostName(localIpAddress) + private lazy val localIpAddress: InetAddress = findLocalInetAddress() - private def findLocalIpAddress(): String = { + private def findLocalInetAddress(): InetAddress = { val defaultIpOverride = System.getenv("SPARK_LOCAL_IP") if (defaultIpOverride != null) { - defaultIpOverride + InetAddress.getByName(defaultIpOverride) } else { val address = InetAddress.getLocalHost if (address.isLoopbackAddress) { @@ -681,15 +813,20 @@ private[spark] object Utils extends Logging { // It's more proper to pick ip address following system output order. val activeNetworkIFs = NetworkInterface.getNetworkInterfaces.toList val reOrderedNetworkIFs = if (isWindows) activeNetworkIFs else activeNetworkIFs.reverse + for (ni <- reOrderedNetworkIFs) { - for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && - !addr.isLoopbackAddress && addr.isInstanceOf[Inet4Address]) { + val addresses = ni.getInetAddresses.toList + .filterNot(addr => addr.isLinkLocalAddress || addr.isLoopbackAddress) + if (addresses.nonEmpty) { + val addr = addresses.find(_.isInstanceOf[Inet4Address]).getOrElse(addresses.head) + // because of Inet6Address.toHostName may add interface at the end if it knows about it + val strippedAddress = InetAddress.getByAddress(addr.getAddress) // We've found an address that looks reasonable! logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + - " a loopback address: " + address.getHostAddress + "; using " + addr.getHostAddress + - " instead (on interface " + ni.getName + ")") + " a loopback address: " + address.getHostAddress + "; using " + + strippedAddress.getHostAddress + " instead (on interface " + ni.getName + ")") logWarning("Set SPARK_LOCAL_IP if you need to bind to another address") - return addr.getHostAddress + return strippedAddress } } logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + @@ -697,7 +834,7 @@ private[spark] object Utils extends Logging { " external IP address!") logWarning("Set SPARK_LOCAL_IP if you need to bind to another address") } - address.getHostAddress + address } } @@ -717,11 +854,14 @@ private[spark] object Utils extends Logging { * Get the local machine's hostname. */ def localHostName(): String = { - customHostname.getOrElse(localIpAddressHostname) + customHostname.getOrElse(localIpAddress.getHostAddress) } - def getAddressHostName(address: String): String = { - InetAddress.getByName(address).getHostName + /** + * Get the local machine's URI. + */ + def localHostNameForURI(): String = { + customHostname.getOrElse(InetAddresses.toUriString(localIpAddress)) } def checkHost(host: String, message: String = "") { @@ -758,34 +898,6 @@ private[spark] object Utils extends Logging { hostPortParseResults.get(hostPort) } - private val daemonThreadFactoryBuilder: ThreadFactoryBuilder = - new ThreadFactoryBuilder().setDaemon(true) - - /** - * Create a thread factory that names threads with a prefix and also sets the threads to daemon. - */ - def namedThreadFactory(prefix: String): ThreadFactory = { - daemonThreadFactoryBuilder.setNameFormat(prefix + "-%d").build() - } - - /** - * Wrapper over newCachedThreadPool. Thread names are formatted as prefix-ID, where ID is a - * unique, sequentially assigned integer. - */ - def newDaemonCachedThreadPool(prefix: String): ThreadPoolExecutor = { - val threadFactory = namedThreadFactory(prefix) - Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor] - } - - /** - * Wrapper over newFixedThreadPool. Thread names are formatted as prefix-ID, where ID is a - * unique, sequentially assigned integer. - */ - def newDaemonFixedThreadPool(nThreads: Int, prefix: String): ThreadPoolExecutor = { - val threadFactory = namedThreadFactory(prefix) - Executors.newFixedThreadPool(nThreads, threadFactory).asInstanceOf[ThreadPoolExecutor] - } - /** * Return the string to tell how long has passed in milliseconds. */ @@ -845,7 +957,7 @@ private[spark] object Utils extends Logging { * Delete a file or directory and its contents recursively. */ def deleteRecursively(dir: TachyonFile, client: TachyonFS) { - if (!client.delete(dir.getPath(), true)) { + if (!client.delete(new TachyonURI(dir.getPath()), true)) { throw new IOException("Failed to delete the tachyon dir: " + dir) } } @@ -885,6 +997,22 @@ private[spark] object Utils extends Logging { ) } + /** + * Convert a time parameter such as (50s, 100ms, or 250us) to microseconds for internal use. If + * no suffix is provided, the passed number is assumed to be in ms. + */ + def timeStringAsMs(str: String): Long = { + JavaUtils.timeStringAsMs(str) + } + + /** + * Convert a time parameter such as (50s, 100ms, or 250us) to microseconds for internal use. If + * no suffix is provided, the passed number is assumed to be in seconds. + */ + def timeStringAsSeconds(str: String): Long = { + JavaUtils.timeStringAsSec(str) + } + /** * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes. */ @@ -956,25 +1084,25 @@ private[spark] object Utils extends Logging { } /** - * Execute a command in the given working directory, throwing an exception if it completes - * with an exit code other than 0. + * Execute a command and return the process running the command. */ - def execute(command: Seq[String], workingDir: File) { - val process = new ProcessBuilder(command: _*) - .directory(workingDir) - .redirectErrorStream(true) - .start() - new Thread("read stdout for " + command(0)) { - override def run() { - for (line <- Source.fromInputStream(process.getInputStream).getLines()) { - System.err.println(line) - } - } - }.start() - val exitCode = process.waitFor() - if (exitCode != 0) { - throw new SparkException("Process " + command + " exited with code " + exitCode) + def executeCommand( + command: Seq[String], + workingDir: File = new File("."), + extraEnvironment: Map[String, String] = Map.empty, + redirectStderr: Boolean = true): Process = { + val builder = new ProcessBuilder(command: _*).directory(workingDir) + val environment = builder.environment() + for ((key, value) <- extraEnvironment) { + environment.put(key, value) } + val process = builder.start() + if (redirectStderr) { + val threadName = "redirect stderr for command " + command(0) + def log(s: String): Unit = logInfo(s) + processStreamByLine(threadName, process.getErrorStream, log) + } + process } /** @@ -983,31 +1111,13 @@ private[spark] object Utils extends Logging { def executeAndGetOutput( command: Seq[String], workingDir: File = new File("."), - extraEnvironment: Map[String, String] = Map.empty): String = { - val builder = new ProcessBuilder(command: _*) - .directory(workingDir) - val environment = builder.environment() - for ((key, value) <- extraEnvironment) { - environment.put(key, value) - } - - val process = builder.start() - new Thread("read stderr for " + command(0)) { - override def run() { - for (line <- Source.fromInputStream(process.getErrorStream).getLines()) { - logInfo(line) - } - } - }.start() + extraEnvironment: Map[String, String] = Map.empty, + redirectStderr: Boolean = true): String = { + val process = executeCommand(command, workingDir, extraEnvironment, redirectStderr) val output = new StringBuffer - val stdoutThread = new Thread("read stdout for " + command(0)) { - override def run() { - for (line <- Source.fromInputStream(process.getInputStream).getLines()) { - output.append(line) - } - } - } - stdoutThread.start() + val threadName = "read stdout for " + command(0) + def appendToOutput(s: String): Unit = output.append(s) + val stdoutThread = processStreamByLine(threadName, process.getInputStream, appendToOutput) val exitCode = process.waitFor() stdoutThread.join() // Wait for it to finish reading output if (exitCode != 0) { @@ -1017,9 +1127,30 @@ private[spark] object Utils extends Logging { output.toString } + /** + * Return and start a daemon thread that processes the content of the input stream line by line. + */ + def processStreamByLine( + threadName: String, + inputStream: InputStream, + processLine: String => Unit): Thread = { + val t = new Thread(threadName) { + override def run() { + for (line <- Source.fromInputStream(inputStream).getLines()) { + processLine(line) + } + } + } + t.setDaemon(true) + t.start() + t + } + /** * Execute a block of code that evaluates to Unit, forwarding any uncaught exceptions to the * default UncaughtExceptionHandler + * + * NOTE: This method is to be called by the spark-started JVM process. */ def tryOrExit(block: => Unit) { try { @@ -1030,6 +1161,32 @@ private[spark] object Utils extends Logging { } } + /** + * Execute a block of code that evaluates to Unit, stop SparkContext is there is any uncaught + * exception + * + * NOTE: This method is to be called by the driver-side components to avoid stopping the + * user-started JVM process completely; in contrast, tryOrExit is to be called in the + * spark-started JVM process . + */ + def tryOrStopSparkContext(sc: SparkContext)(block: => Unit) { + try { + block + } catch { + case e: ControlThrowable => throw e + case t: Throwable => + val currentThreadName = Thread.currentThread().getName + if (sc != null) { + logError(s"uncaught error in thread $currentThreadName, stopping SparkContext", t) + sc.stop() + } + if (!NonFatal(t)) { + logError(s"throw uncaught fatal error in thread $currentThreadName", t) + throw t + } + } + } + /** * Execute a block of code that evaluates to Unit, re-throwing any non-fatal uncaught * exceptions as IOException. This is used when implementing Externalizable and Serializable's @@ -1060,15 +1217,63 @@ private[spark] object Utils extends Logging { } } + /** Executes the given block. Log non-fatal errors if any, and only throw fatal errors */ + def tryLogNonFatalError(block: => Unit) { + try { + block + } catch { + case NonFatal(t) => + logError(s"Uncaught exception in thread ${Thread.currentThread().getName}", t) + } + } + + /** + * Execute a block of code, then a finally block, but if exceptions happen in + * the finally block, do not suppress the original exception. + * + * This is primarily an issue with `finally { out.close() }` blocks, where + * close needs to be called to clean up `out`, but if an exception happened + * in `out.write`, it's likely `out` may be corrupted and `out.close` will + * fail as well. This would then suppress the original/likely more meaningful + * exception from the original `out.write` call. + */ + def tryWithSafeFinally[T](block: => T)(finallyBlock: => Unit): T = { + // It would be nice to find a method on Try that did this + var originalThrowable: Throwable = null + try { + block + } catch { + case t: Throwable => + // Purposefully not using NonFatal, because even fatal exceptions + // we don't want to have our finallyBlock suppress + originalThrowable = t + throw originalThrowable + } finally { + try { + finallyBlock + } catch { + case t: Throwable => + if (originalThrowable != null) { + // We could do originalThrowable.addSuppressed(t), but it's + // not available in JDK 1.6. + logWarning(s"Suppressing exception in finally: " + t.getMessage, t) + throw originalThrowable + } else { + throw t + } + } + } + } + /** Default filtering function for finding call sites using `getCallSite`. */ private def coreExclusionFunction(className: String): Boolean = { // A regular expression to match classes of the "core" Spark API that we want to skip when // finding the call site of a method. val SPARK_CORE_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?(\.broadcast)?\.[A-Z]""".r - val SCALA_CLASS_REGEX = """^scala""".r + val SCALA_CORE_CLASS_PREFIX = "scala" val isSparkCoreClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined - val isScalaClass = SCALA_CLASS_REGEX.findFirstIn(className).isDefined + val isScalaClass = className.startsWith(SCALA_CORE_CLASS_PREFIX) // If the class is a Spark internal class or a Scala class, then exclude. isSparkCoreClass || isScalaClass } @@ -1311,9 +1516,14 @@ private[spark] object Utils extends Logging { hashAbs } - /** Returns a copy of the system properties that is thread-safe to iterator over. */ - def getSystemProperties(): Map[String, String] = { - System.getProperties.clone().asInstanceOf[java.util.Properties].toMap[String, String] + /** Returns the system properties map that is thread-safe to iterator over. It gets the + * properties which have been set explicitly, as well as those for which only a default value + * has been defined. */ + def getSystemProperties: Map[String, String] = { + val sysProps = for (key <- System.getProperties.stringPropertyNames()) yield + (key, System.getProperty(key)) + + sysProps.toMap } /** @@ -1398,7 +1608,7 @@ private[spark] object Utils extends Logging { /** Return the class name of the given object, removing all dollar signs */ - def getFormattedClassName(obj: AnyRef) = { + def getFormattedClassName(obj: AnyRef): String = { obj.getClass.getSimpleName.replace("$", "") } @@ -1411,7 +1621,7 @@ private[spark] object Utils extends Logging { } /** Return an empty JSON object */ - def emptyJson = JObject(List[JField]()) + def emptyJson: JsonAST.JObject = JObject(List[JField]()) /** * Return a Hadoop FileSystem with the scheme encoded in the given path. @@ -1459,7 +1669,7 @@ private[spark] object Utils extends Logging { /** * Indicates whether Spark is currently running unit tests. */ - def isTesting = { + def isTesting: Boolean = { sys.env.contains("SPARK_TESTING") || sys.props.contains("spark.testing") } @@ -1717,6 +1927,10 @@ private[spark] object Utils extends Logging { startService: Int => (T, Int), conf: SparkConf, serviceName: String = ""): (T, Int) = { + + require(startPort == 0 || (1024 <= startPort && startPort < 65536), + "startPort should be between 1024 and 65535 (inclusive), or 0 for a random free port.") + val serviceString = if (serviceName.isEmpty) "" else s" '$serviceName'" val maxRetries = portMaxRetries(conf) for (offset <- 0 to maxRetries) { @@ -1779,6 +1993,20 @@ private[spark] object Utils extends Logging { PropertyConfigurator.configure(pro) } + /** + * If the given URL connection is HttpsURLConnection, it sets the SSL socket factory and + * the host verifier from the given security manager. + */ + def setupSecureURLConnection(urlConnection: URLConnection, sm: SecurityManager): URLConnection = { + urlConnection match { + case https: HttpsURLConnection => + sm.sslSocketFactory.foreach(https.setSSLSocketFactory) + sm.hostnameVerifier.foreach(https.setHostnameVerifier) + https + case connection => connection + } + } + def invoke( clazz: Class[_], obj: AnyRef, @@ -1871,6 +2099,112 @@ private[spark] object Utils extends Logging { throw new SparkException("Invalid master URL: " + sparkUrl, e) } } + + /** + * Returns the current user name. This is the currently logged in user, unless that's been + * overridden by the `SPARK_USER` environment variable. + */ + def getCurrentUserName(): String = { + Option(System.getenv("SPARK_USER")) + .getOrElse(UserGroupInformation.getCurrentUser().getShortUserName()) + } + + /** + * Adds a shutdown hook with default priority. + * + * @param hook The code to run during shutdown. + * @return A handle that can be used to unregister the shutdown hook. + */ + def addShutdownHook(hook: () => Unit): AnyRef = { + addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY, hook) + } + + /** + * Adds a shutdown hook with the given priority. Hooks with lower priority values run + * first. + * + * @param hook The code to run during shutdown. + * @return A handle that can be used to unregister the shutdown hook. + */ + def addShutdownHook(priority: Int, hook: () => Unit): AnyRef = { + shutdownHooks.add(priority, hook) + } + + /** + * Remove a previously installed shutdown hook. + * + * @param ref A handle returned by `addShutdownHook`. + * @return Whether the hook was removed. + */ + def removeShutdownHook(ref: AnyRef): Boolean = { + shutdownHooks.remove(ref) + } + +} + +private [util] class SparkShutdownHookManager { + + private val hooks = new PriorityQueue[SparkShutdownHook]() + private var shuttingDown = false + + /** + * Install a hook to run at shutdown and run all registered hooks in order. Hadoop 1.x does not + * have `ShutdownHookManager`, so in that case we just use the JVM's `Runtime` object and hope for + * the best. + */ + def install(): Unit = { + val hookTask = new Runnable() { + override def run(): Unit = runAll() + } + Try(Class.forName("org.apache.hadoop.util.ShutdownHookManager")) match { + case Success(shmClass) => + val fsPriority = classOf[FileSystem].getField("SHUTDOWN_HOOK_PRIORITY").get() + .asInstanceOf[Int] + val shm = shmClass.getMethod("get").invoke(null) + shm.getClass().getMethod("addShutdownHook", classOf[Runnable], classOf[Int]) + .invoke(shm, hookTask, Integer.valueOf(fsPriority + 30)) + + case Failure(_) => + Runtime.getRuntime.addShutdownHook(new Thread(hookTask, "Spark Shutdown Hook")); + } + } + + def runAll(): Unit = synchronized { + shuttingDown = true + while (!hooks.isEmpty()) { + Utils.logUncaughtExceptions(hooks.poll().run()) + } + } + + def add(priority: Int, hook: () => Unit): AnyRef = synchronized { + checkState() + val hookRef = new SparkShutdownHook(priority, hook) + hooks.add(hookRef) + hookRef + } + + def remove(ref: AnyRef): Boolean = synchronized { + checkState() + hooks.remove(ref) + } + + private def checkState(): Unit = { + if (shuttingDown) { + throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.") + } + } + +} + +private class SparkShutdownHook(private val priority: Int, hook: () => Unit) + extends Comparable[SparkShutdownHook] { + + override def compareTo(other: SparkShutdownHook): Int = { + other.priority - priority + } + + def run(): Unit = hook() + } /** @@ -1887,7 +2221,7 @@ private[spark] class RedirectThread( override def run() { scala.util.control.Exception.ignoring(classOf[IOException]) { // FIXME: We copy the stream on the level of bytes to avoid encoding problems. - try { + Utils.tryWithSafeFinally { val buf = new Array[Byte](1024) var len = in.read(buf) while (len != -1) { @@ -1895,7 +2229,7 @@ private[spark] class RedirectThread( out.flush() len = in.read(buf) } - } finally { + } { if (propagateEof) { out.close() } diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index af1f64649f35..41cb8cfe2afa 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -39,7 +39,7 @@ class BitSet(numBits: Int) extends Serializable { val wordIndex = bitIndex >> 6 // divide by 64 var i = 0 while(i < wordIndex) { words(i) = -1; i += 1 } - if(wordIndex < words.size) { + if(wordIndex < words.length) { // Set the remaining bits (note that the mask could still be zero) val mask = ~(-1L << (bitIndex & 0x3f)) words(wordIndex) |= mask @@ -156,10 +156,10 @@ class BitSet(numBits: Int) extends Serializable { /** * Get an iterator over the set bits. */ - def iterator = new Iterator[Int] { + def iterator: Iterator[Int] = new Iterator[Int] { var ind = nextSetBit(0) override def hasNext: Boolean = ind >= 0 - override def next() = { + override def next(): Int = { val tmp = ind ind = nextSetBit(ind + 1) tmp diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 8a0f5a602de1..30dd7f22e494 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -151,15 +151,14 @@ class ExternalAppendOnlyMap[K, V, C]( override protected[this] def spill(collection: SizeTracker): Unit = { val (blockId, file) = diskBlockManager.createTempLocalBlock() curWriteMetrics = new ShuffleWriteMetrics() - var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize, - curWriteMetrics) + var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) var objectsWritten = 0 // List of batch sizes (bytes) in the order they are written to disk val batchSizes = new ArrayBuffer[Long] // Flush the disk writer's contents to disk, and update relevant variables - def flush() = { + def flush(): Unit = { val w = writer writer = null w.commitAndClose() @@ -179,8 +178,7 @@ class ExternalAppendOnlyMap[K, V, C]( if (objectsWritten == serializerBatchSize) { flush() curWriteMetrics = new ShuffleWriteMetrics() - writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize, - curWriteMetrics) + writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) } } if (objectsWritten > 0) { @@ -355,7 +353,7 @@ class ExternalAppendOnlyMap[K, V, C]( val pairs: ArrayBuffer[(K, C)]) extends Comparable[StreamBuffer] { - def isEmpty = pairs.length == 0 + def isEmpty: Boolean = pairs.length == 0 // Invalid if there are no more pairs in this stream def minKeyHash: Int = { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 6ba03841f746..79a695fb6208 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -53,7 +53,18 @@ import org.apache.spark.storage.{BlockObjectWriter, BlockId} * probably want to pass None as the ordering to avoid extra sorting. On the other hand, if you do * want to do combining, having an Ordering is more efficient than not having it. * - * At a high level, this class works as follows: + * Users interact with this class in the following way: + * + * 1. Instantiate an ExternalSorter. + * + * 2. Call insertAll() with a set of records. + * + * 3. Request an iterator() back to traverse sorted/aggregated records. + * - or - + * Invoke writePartitionedFile() to create a file containing sorted/aggregated outputs + * that can be used in Spark's sort shuffle. + * + * At a high level, this class works internally as follows: * * - We repeatedly fill up buffers of in-memory data, using either a SizeTrackingAppendOnlyMap if * we want to combine by key, or an simple SizeTrackingBuffer if we don't. Inside these buffers, @@ -65,11 +76,11 @@ import org.apache.spark.storage.{BlockObjectWriter, BlockId} * aggregation. For each file, we track how many objects were in each partition in memory, so we * don't have to write out the partition ID for every element. * - * - When the user requests an iterator, the spilled files are merged, along with any remaining - * in-memory data, using the same sort order defined above (unless both sorting and aggregation - * are disabled). If we need to aggregate by key, we either use a total ordering from the - * ordering parameter, or read the keys with the same hash code and compare them with each other - * for equality to merge values. + * - When the user requests an iterator or file output, the spilled files are merged, along with + * any remaining in-memory data, using the same sort order defined above (unless both sorting + * and aggregation are disabled). If we need to aggregate by key, we either use a total ordering + * from the ordering parameter, or read the keys with the same hash code and compare them with + * each other for equality to merge values. * * - Users are expected to call stop() at the end to delete all the intermediate files. * @@ -259,8 +270,8 @@ private[spark] class ExternalSorter[K, V, C]( * Spill our in-memory collection to a sorted file that we can merge later (normal code path). * We add this file into spilledFiles to find it later. * - * Alternatively, if bypassMergeSort is true, we spill to separate files for each partition. - * See spillToPartitionedFiles() for that code path. + * This should not be invoked if bypassMergeSort is true. In that case, spillToPartitionedFiles() + * is used to write files for each partition. * * @param collection whichever collection we're using (map or buffer) */ @@ -272,7 +283,8 @@ private[spark] class ExternalSorter[K, V, C]( // createTempShuffleBlock here; see SPARK-3426 for more context. val (blockId, file) = diskBlockManager.createTempShuffleBlock() curWriteMetrics = new ShuffleWriteMetrics() - var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) + var writer = blockManager.getDiskWriter( + blockId, file, serInstance, fileBufferSize, curWriteMetrics) var objectsWritten = 0 // Objects written since the last flush // List of batch sizes (bytes) in the order they are written to disk @@ -283,7 +295,7 @@ private[spark] class ExternalSorter[K, V, C]( // Flush the disk writer's contents to disk, and update relevant variables. // The writer is closed at the end of this process, and cannot be reused. - def flush() = { + def flush(): Unit = { val w = writer writer = null w.commitAndClose() @@ -308,7 +320,8 @@ private[spark] class ExternalSorter[K, V, C]( if (objectsWritten == serializerBatchSize) { flush() curWriteMetrics = new ShuffleWriteMetrics() - writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) + writer = blockManager.getDiskWriter( + blockId, file, serInstance, fileBufferSize, curWriteMetrics) } } if (objectsWritten > 0) { @@ -352,13 +365,20 @@ private[spark] class ExternalSorter[K, V, C]( // Create our file writers if we haven't done so yet if (partitionWriters == null) { curWriteMetrics = new ShuffleWriteMetrics() + val openStartTime = System.nanoTime partitionWriters = Array.fill(numPartitions) { // Because these files may be read during shuffle, their compression must be controlled by // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use // createTempShuffleBlock here; see SPARK-3426 for more context. val (blockId, file) = diskBlockManager.createTempShuffleBlock() - blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics).open() + val writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, + curWriteMetrics) + writer.open() } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, and can take a long time in aggregate when we open many files, so should be + // included in the shuffle write time. + curWriteMetrics.incShuffleWriteTime(System.nanoTime - openStartTime) } // No need to sort stuff, just write each element out @@ -659,6 +679,8 @@ private[spark] class ExternalSorter[K, V, C]( } /** + * Exposed for testing purposes. + * * Return an iterator over all the data written to this object, grouped by partition and * aggregated by the requested aggregator. For each partition we then have an iterator over its * contents, and these are expected to be accessed in order (you can't "skip ahead" to one @@ -668,7 +690,7 @@ private[spark] class ExternalSorter[K, V, C]( * For now, we just merge all the spilled files in once pass, but this can be modified to * support hierarchical merging. */ - def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { + def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { val usingMap = aggregator.isDefined val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer if (spills.isEmpty && partitionWriters == null) { @@ -721,32 +743,29 @@ private[spark] class ExternalSorter[K, V, C]( // this simple we spill out the current in-memory collection so that everything is in files. spillToPartitionFiles(if (aggregator.isDefined) map else buffer) partitionWriters.foreach(_.commitAndClose()) - var out: FileOutputStream = null - var in: FileInputStream = null - try { - out = new FileOutputStream(outputFile, true) + val out = new FileOutputStream(outputFile, true) + val writeStartTime = System.nanoTime + util.Utils.tryWithSafeFinally { for (i <- 0 until numPartitions) { - in = new FileInputStream(partitionWriters(i).fileSegment().file) - val size = org.apache.spark.util.Utils.copyStream(in, out, false, transferToEnabled) - in.close() - in = null - lengths(i) = size - } - } finally { - if (out != null) { - out.close() - } - if (in != null) { - in.close() + val in = new FileInputStream(partitionWriters(i).fileSegment().file) + util.Utils.tryWithSafeFinally { + lengths(i) = org.apache.spark.util.Utils.copyStream(in, out, false, transferToEnabled) + } { + in.close() + } } + } { + out.close() + context.taskMetrics.shuffleWriteMetrics.foreach( + _.incShuffleWriteTime(System.nanoTime - writeStartTime)) } } else { // Either we're not bypassing merge-sort or we have only in-memory data; get an iterator by // partition and just write everything directly. for ((id, elements) <- this.partitionedIterator) { if (elements.hasNext) { - val writer = blockManager.getDiskWriter( - blockId, outputFile, ser, fileBufferSize, context.taskMetrics.shuffleWriteMetrics.get) + val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize, + context.taskMetrics.shuffleWriteMetrics.get) for (elem <- elements) { writer.write(elem) } @@ -763,6 +782,7 @@ private[spark] class ExternalSorter[K, V, C]( if (curWriteMetrics != null) { m.incShuffleBytesWritten(curWriteMetrics.shuffleBytesWritten) m.incShuffleWriteTime(curWriteMetrics.shuffleWriteTime) + m.incShuffleRecordsWritten(curWriteMetrics.shuffleRecordsWritten) } } @@ -772,7 +792,7 @@ private[spark] class ExternalSorter[K, V, C]( /** * Read a partition file back as an iterator (used in our iterator method) */ - def readPartitionFile(writer: BlockObjectWriter): Iterator[Product2[K, C]] = { + private def readPartitionFile(writer: BlockObjectWriter): Iterator[Product2[K, C]] = { if (writer.isOpen) { writer.commitAndClose() } diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala index b8de4ff9aa49..efc2482c74dd 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala @@ -53,6 +53,15 @@ class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag]( override def size: Int = if (haveNullValue) _keySet.size + 1 else _keySet.size + /** Tests whether this map contains a binding for a key. */ + def contains(k: K): Boolean = { + if (k == null) { + haveNullValue + } else { + _keySet.getPos(k) != OpenHashSet.INVALID_POS + } + } + /** Get the value for a given key */ def apply(k: K): V = { if (k == null) { @@ -109,7 +118,7 @@ class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag]( } } - override def iterator = new Iterator[(K, V)] { + override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] { var pos = -1 var nextPair: (K, V) = computeNextPair() @@ -132,9 +141,9 @@ class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag]( } } - def hasNext = nextPair != null + def hasNext: Boolean = nextPair != null - def next() = { + def next(): (K, V) = { val pair = nextPair nextPair = computeNextPair() pair diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 4e363b74f4be..1501111a0665 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -85,7 +85,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( protected var _bitset = new BitSet(_capacity) - def getBitSet = _bitset + def getBitSet: BitSet = _bitset // Init of the array in constructor (instead of in declaration) to work around a Scala compiler // specialization bug that would generate two arrays (one for Object and one for specialized T). @@ -122,7 +122,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( */ def addWithoutResize(k: T): Int = { var pos = hashcode(hasher.hash(k)) & _mask - var i = 1 + var delta = 1 while (true) { if (!_bitset.get(pos)) { // This is a new key. @@ -134,14 +134,12 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( // Found an existing key. return pos } else { - val delta = i + // quadratic probing with values increase by 1, 2, 3, ... pos = (pos + delta) & _mask - i += 1 + delta += 1 } } - // Never reached here - assert(INVALID_POS != INVALID_POS) - INVALID_POS + throw new RuntimeException("Should never reach here.") } /** @@ -163,27 +161,25 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( */ def getPos(k: T): Int = { var pos = hashcode(hasher.hash(k)) & _mask - var i = 1 - val maxProbe = _data.size - while (i < maxProbe) { + var delta = 1 + while (true) { if (!_bitset.get(pos)) { return INVALID_POS } else if (k == _data(pos)) { return pos } else { - val delta = i + // quadratic probing with values increase by 1, 2, 3, ... pos = (pos + delta) & _mask - i += 1 + delta += 1 } } - // Never reached here - INVALID_POS + throw new RuntimeException("Should never reach here.") } /** Return the value at the specified position. */ def getValue(pos: Int): T = _data(pos) - def iterator = new Iterator[T] { + def iterator: Iterator[T] = new Iterator[T] { var pos = nextPos(0) override def hasNext: Boolean = pos != INVALID_POS override def next(): T = { diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala index 2e1ef06cbc4e..b4ec4ea52125 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala @@ -46,7 +46,12 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, private var _oldValues: Array[V] = null - override def size = _keySet.size + override def size: Int = _keySet.size + + /** Tests whether this map contains a binding for a key. */ + def contains(k: K): Boolean = { + _keySet.getPos(k) != OpenHashSet.INVALID_POS + } /** Get the value for a given key */ def apply(k: K): V = { @@ -87,7 +92,7 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, } } - override def iterator = new Iterator[(K, V)] { + override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] { var pos = 0 var nextPair: (K, V) = computeNextPair() @@ -103,9 +108,9 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, } } - def hasNext = nextPair != null + def hasNext: Boolean = nextPair != null - def next() = { + def next(): (K, V) = { val pair = nextPair nextPair = computeNextPair() pair diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala index 7e76d060d600..b6c380a8eea9 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala @@ -71,12 +71,21 @@ class PrimitiveVector[@specialized(Long, Int, Double) V: ClassTag](initialSize: /** Resizes the array, dropping elements if the total length decreases. */ def resize(newLength: Int): PrimitiveVector[V] = { - val newArray = new Array[V](newLength) - _array.copyToArray(newArray) - _array = newArray + _array = copyArrayWithLength(newLength) if (newLength < _numElements) { _numElements = newLength } this } + + /** Return a trimmed version of the underlying array. */ + def toArray: Array[V] = { + copyArrayWithLength(size) + } + + private def copyArrayWithLength(length: Int): Array[V] = { + val copy = new Array[V](length) + _array.copyToArray(copy) + copy + } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingVector.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingVector.scala index 65a7b4e0d497..dfcfb66af861 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingVector.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingVector.scala @@ -36,11 +36,4 @@ private[spark] class SizeTrackingVector[T: ClassTag] resetSamples() this } - - /** - * Return a trimmed version of the underlying array. - */ - def toArray: Array[T] = { - super.iterator.toArray - } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 9f5431207485..747ecf075a39 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -42,9 +42,6 @@ private[spark] trait Spillable[C] extends Logging { // Memory manager that can be used to acquire/release memory private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager - // Threshold for `elementsRead` before we start tracking this collection's memory usage - private[this] val trackMemoryThreshold = 1000 - // Initial threshold for the size of a collection before we start tracking its memory usage // Exposed for testing private[this] val initialMemoryThreshold: Long = @@ -72,8 +69,7 @@ private[spark] trait Spillable[C] extends Logging { * @return true if `collection` was spilled to disk; false otherwise */ protected def maybeSpill(collection: C, currentMemory: Long): Boolean = { - if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 && - currentMemory >= myMemoryThreshold) { + if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) { // Claim up to double our current memory from the shuffle memory pool val amountToRequest = 2 * currentMemory - myMemoryThreshold val granted = shuffleMemoryManager.tryToAcquire(amountToRequest) diff --git a/core/src/main/scala/org/apache/spark/util/collection/Utils.scala b/core/src/main/scala/org/apache/spark/util/collection/Utils.scala index c5268c0fae0e..bdbca00a0062 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Utils.scala @@ -32,7 +32,7 @@ private[spark] object Utils { */ def takeOrdered[T](input: Iterator[T], num: Int)(implicit ord: Ordering[T]): Iterator[T] = { val ordering = new GuavaOrdering[T] { - override def compare(l: T, r: T) = ord.compare(l, r) + override def compare(l: T, r: T): Int = ord.compare(l, r) } collectionAsScalaIterable(ordering.leastOf(asJavaIterator(input), num)).iterator } diff --git a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala index 1d5467060623..14b6ba4af489 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala @@ -121,7 +121,7 @@ private[spark] object FileAppender extends Logging { val rollingSizeBytes = conf.get(SIZE_PROPERTY, STRATEGY_DEFAULT) val rollingInterval = conf.get(INTERVAL_PROPERTY, INTERVAL_DEFAULT) - def createTimeBasedAppender() = { + def createTimeBasedAppender(): FileAppender = { val validatedParams: Option[(Long, String)] = rollingInterval match { case "daily" => logInfo(s"Rolling executor logs enabled for $file with daily rolling") @@ -149,7 +149,7 @@ private[spark] object FileAppender extends Logging { } } - def createSizeBasedAppender() = { + def createSizeBasedAppender(): FileAppender = { rollingSizeBytes match { case IntParam(bytes) => logInfo(s"Rolling executor logs enabled for $file with rolling every $bytes bytes") diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index 76e7a2760bcd..786b97ad7b9e 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -105,7 +105,7 @@ class BernoulliCellSampler[T](lb: Double, ub: Double, complement: Boolean = fals private val rng: Random = new XORShiftRandom - override def setSeed(seed: Long) = rng.setSeed(seed) + override def setSeed(seed: Long): Unit = rng.setSeed(seed) override def sample(items: Iterator[T]): Iterator[T] = { if (ub - lb <= 0.0) { @@ -131,7 +131,7 @@ class BernoulliCellSampler[T](lb: Double, ub: Double, complement: Boolean = fals def cloneComplement(): BernoulliCellSampler[T] = new BernoulliCellSampler[T](lb, ub, !complement) - override def clone = new BernoulliCellSampler[T](lb, ub, complement) + override def clone: BernoulliCellSampler[T] = new BernoulliCellSampler[T](lb, ub, complement) } @@ -153,7 +153,7 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T private val rng: Random = RandomSampler.newDefaultRNG - override def setSeed(seed: Long) = rng.setSeed(seed) + override def setSeed(seed: Long): Unit = rng.setSeed(seed) override def sample(items: Iterator[T]): Iterator[T] = { if (fraction <= 0.0) { @@ -167,7 +167,7 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T } } - override def clone = new BernoulliSampler[T](fraction) + override def clone: BernoulliSampler[T] = new BernoulliSampler[T](fraction) } @@ -209,7 +209,7 @@ class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] } } - override def clone = new PoissonSampler[T](fraction) + override def clone: PoissonSampler[T] = new PoissonSampler[T](fraction) } @@ -228,15 +228,18 @@ class GapSamplingIterator[T: ClassTag]( val arrayClass = Array.empty[T].iterator.getClass val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass data.getClass match { - case `arrayClass` => ((n: Int) => { data = data.drop(n) }) - case `arrayBufferClass` => ((n: Int) => { data = data.drop(n) }) - case _ => ((n: Int) => { + case `arrayClass` => + (n: Int) => { data = data.drop(n) } + case `arrayBufferClass` => + (n: Int) => { data = data.drop(n) } + case _ => + (n: Int) => { var j = 0 while (j < n && data.hasNext) { data.next() j += 1 } - }) + } } } @@ -244,21 +247,21 @@ class GapSamplingIterator[T: ClassTag]( override def next(): T = { val r = data.next() - advance + advance() r } private val lnq = math.log1p(-f) /** skip elements that won't be sampled, according to geometric dist P(k) = (f)(1-f)^k. */ - private def advance: Unit = { + private def advance(): Unit = { val u = math.max(rng.nextDouble(), epsilon) val k = (math.log(u) / lnq).toInt iterDrop(k) } /** advance to first sample as part of object construction. */ - advance + advance() // Attempting to invoke this closer to the top with other object initialization // was causing it to break in strange ways, so I'm invoking it last, which seems to // work reliably. @@ -279,15 +282,18 @@ class GapSamplingReplacementIterator[T: ClassTag]( val arrayClass = Array.empty[T].iterator.getClass val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass data.getClass match { - case `arrayClass` => ((n: Int) => { data = data.drop(n) }) - case `arrayBufferClass` => ((n: Int) => { data = data.drop(n) }) - case _ => ((n: Int) => { + case `arrayClass` => + (n: Int) => { data = data.drop(n) } + case `arrayBufferClass` => + (n: Int) => { data = data.drop(n) } + case _ => + (n: Int) => { var j = 0 while (j < n && data.hasNext) { data.next() j += 1 } - }) + } } } @@ -300,7 +306,7 @@ class GapSamplingReplacementIterator[T: ClassTag]( override def next(): T = { val r = v rep -= 1 - if (rep <= 0) advance + if (rep <= 0) advance() r } @@ -309,7 +315,7 @@ class GapSamplingReplacementIterator[T: ClassTag]( * Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is * q is the probabililty of Poisson(0; f) */ - private def advance: Unit = { + private def advance(): Unit = { val u = math.max(rng.nextDouble(), epsilon) val k = (math.log(u) / (-f)).toInt iterDrop(k) @@ -343,7 +349,7 @@ class GapSamplingReplacementIterator[T: ClassTag]( } /** advance to first sample as part of object construction. */ - advance + advance() // Attempting to invoke this closer to the top with other object initialization // was causing it to break in strange ways, so I'm invoking it last, which seems to // work reliably. diff --git a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala index 2ae308dacf1a..9e29bf9d61f1 100644 --- a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala @@ -311,7 +311,7 @@ private[random] class AcceptanceResult(var numItems: Long = 0L, var numAccepted: var acceptBound: Double = Double.NaN // upper bound for accepting item instantly var waitListBound: Double = Double.NaN // upper bound for adding item to waitlist - def areBoundsEmpty = acceptBound.isNaN || waitListBound.isNaN + def areBoundsEmpty: Boolean = acceptBound.isNaN || waitListBound.isNaN def merge(other: Option[AcceptanceResult]): Unit = { if (other.isDefined) { diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala index 467b890fb4bb..c4a7b4441c85 100644 --- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala +++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala @@ -83,7 +83,7 @@ private[spark] object XORShiftRandom { * @return Map of execution times for {@link java.util.Random java.util.Random} * and XORShift */ - def benchmark(numIters: Int) = { + def benchmark(numIters: Int): Map[String, Long] = { val seed = 1L val million = 1e6.toInt diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 004de05c10ee..8a4f2a08fe70 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -24,11 +24,12 @@ import java.util.*; import java.util.concurrent.*; -import org.apache.spark.input.PortableDataStream; +import scala.collection.JavaConversions; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.common.collect.Iterators; import com.google.common.collect.Lists; @@ -51,8 +52,11 @@ import org.apache.spark.api.java.*; import org.apache.spark.api.java.function.*; import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.input.PortableDataStream; import org.apache.spark.partial.BoundedDouble; import org.apache.spark.partial.PartialResult; +import org.apache.spark.rdd.RDD; +import org.apache.spark.serializer.KryoSerializer; import org.apache.spark.storage.StorageLevel; import org.apache.spark.util.StatCounter; @@ -267,6 +271,22 @@ public void call(String s) throws IOException { Assert.assertEquals(2, accum.value().intValue()); } + @Test + public void foreachPartition() { + final Accumulator accum = sc.accumulator(0); + JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); + rdd.foreachPartition(new VoidFunction>() { + @Override + public void call(Iterator iter) throws IOException { + while (iter.hasNext()) { + iter.next(); + accum.add(1); + } + } + }); + Assert.assertEquals(2, accum.value().intValue()); + } + @Test public void toLocalIterator() { List correct = Arrays.asList(1, 2, 3, 4); @@ -492,6 +512,36 @@ public Integer call(Integer a, Integer b) { Assert.assertEquals(33, sum); } + @Test + public void treeReduce() { + JavaRDD rdd = sc.parallelize(Arrays.asList(-5, -4, -3, -2, -1, 1, 2, 3, 4), 10); + Function2 add = new Function2() { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + }; + for (int depth = 1; depth <= 10; depth++) { + int sum = rdd.treeReduce(add, depth); + Assert.assertEquals(-5, sum); + } + } + + @Test + public void treeAggregate() { + JavaRDD rdd = sc.parallelize(Arrays.asList(-5, -4, -3, -2, -1, 1, 2, 3, 4), 10); + Function2 add = new Function2() { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + }; + for (int depth = 1; depth <= 10; depth++) { + int sum = rdd.treeAggregate(0, add, add, depth); + Assert.assertEquals(-5, sum); + } + } + @SuppressWarnings("unchecked") @Test public void aggregateByKey() { @@ -627,6 +677,13 @@ public Boolean call(Integer i) { }).isEmpty()); } + @Test + public void toArray() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3)); + List list = rdd.toArray(); + Assert.assertEquals(Arrays.asList(1, 2, 3), list); + } + @Test public void cartesian() { JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); @@ -673,11 +730,103 @@ public void javaDoubleRDDHistoGram() { Tuple2 results = rdd.histogram(2); double[] expected_buckets = {1.0, 2.5, 4.0}; long[] expected_counts = {2, 2}; - Assert.assertArrayEquals(expected_buckets, results._1, 0.1); - Assert.assertArrayEquals(expected_counts, results._2); + Assert.assertArrayEquals(expected_buckets, results._1(), 0.1); + Assert.assertArrayEquals(expected_counts, results._2()); // Test with provided buckets long[] histogram = rdd.histogram(expected_buckets); Assert.assertArrayEquals(expected_counts, histogram); + // SPARK-5744 + Assert.assertArrayEquals( + new long[] {0}, + sc.parallelizeDoubles(new ArrayList(0), 1).histogram(new double[]{0.0, 1.0})); + } + + private static class DoubleComparator implements Comparator, Serializable { + public int compare(Double o1, Double o2) { + return o1.compareTo(o2); + } + } + + @Test + public void max() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + double max = rdd.max(new DoubleComparator()); + Assert.assertEquals(4.0, max, 0.001); + } + + @Test + public void min() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + double max = rdd.min(new DoubleComparator()); + Assert.assertEquals(1.0, max, 0.001); + } + + @Test + public void naturalMax() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + double max = rdd.max(); + Assert.assertTrue(4.0 == max); + } + + @Test + public void naturalMin() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + double max = rdd.min(); + Assert.assertTrue(1.0 == max); + } + + @Test + public void takeOrdered() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + Assert.assertEquals(Arrays.asList(1.0, 2.0), rdd.takeOrdered(2, new DoubleComparator())); + Assert.assertEquals(Arrays.asList(1.0, 2.0), rdd.takeOrdered(2)); + } + + @Test + public void top() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + List top2 = rdd.top(2); + Assert.assertEquals(Arrays.asList(4, 3), top2); + } + + private static class AddInts implements Function2 { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + } + + @Test + public void reduce() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + int sum = rdd.reduce(new AddInts()); + Assert.assertEquals(10, sum); + } + + @Test + public void reduceOnJavaDoubleRDD() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + double sum = rdd.reduce(new Function2() { + @Override + public Double call(Double v1, Double v2) throws Exception { + return v1 + v2; + } + }); + Assert.assertEquals(10.0, sum, 0.001); + } + + @Test + public void fold() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + int sum = rdd.fold(0, new AddInts()); + Assert.assertEquals(10, sum); + } + + @Test + public void aggregate() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + int sum = rdd.aggregate(0, new AddInts(), new AddInts()); + Assert.assertEquals(10, sum); } @Test @@ -796,6 +945,25 @@ public Iterable call(Iterator iter) { Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); } + + @Test + public void mapPartitionsWithIndex() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); + JavaRDD partitionSums = rdd.mapPartitionsWithIndex( + new Function2, Iterator>() { + @Override + public Iterator call(Integer index, Iterator iter) throws Exception { + int sum = 0; + while (iter.hasNext()) { + sum += iter.next(); + } + return Collections.singletonList(sum).iterator(); + } + }, false); + Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); + } + + @Test public void repartition() { // Shrinking number of partitions @@ -1274,6 +1442,49 @@ public void checkpointAndRestore() { Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect()); } + @Test + public void combineByKey() { + JavaRDD originalRDD = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6)); + Function keyFunction = new Function() { + @Override + public Integer call(Integer v1) throws Exception { + return v1 % 3; + } + }; + Function createCombinerFunction = new Function() { + @Override + public Integer call(Integer v1) throws Exception { + return v1; + } + }; + + Function2 mergeValueFunction = new Function2() { + @Override + public Integer call(Integer v1, Integer v2) throws Exception { + return v1 + v2; + } + }; + + JavaPairRDD combinedRDD = originalRDD.keyBy(keyFunction) + .combineByKey(createCombinerFunction, mergeValueFunction, mergeValueFunction); + Map results = combinedRDD.collectAsMap(); + ImmutableMap expected = ImmutableMap.of(0, 9, 1, 5, 2, 7); + Assert.assertEquals(expected, results); + + Partitioner defaultPartitioner = Partitioner.defaultPartitioner( + combinedRDD.rdd(), JavaConversions.asScalaBuffer(Lists.>newArrayList())); + combinedRDD = originalRDD.keyBy(keyFunction) + .combineByKey( + createCombinerFunction, + mergeValueFunction, + mergeValueFunction, + defaultPartitioner, + false, + new KryoSerializer(new SparkConf())); + results = combinedRDD.collectAsMap(); + Assert.assertEquals(expected, results); + } + @SuppressWarnings("unchecked") @Test public void mapOnPairRDD() { @@ -1482,6 +1693,19 @@ public void collectAsync() throws Exception { Assert.assertEquals(1, future.jobIds().size()); } + @Test + public void takeAsync() throws Exception { + List data = Arrays.asList(1, 2, 3, 4, 5); + JavaRDD rdd = sc.parallelize(data, 1); + JavaFutureAction> future = rdd.takeAsync(1); + List result = future.get(); + Assert.assertEquals(1, result.size()); + Assert.assertEquals((Integer) 1, result.get(0)); + Assert.assertFalse(future.isCancelled()); + Assert.assertTrue(future.isDone()); + Assert.assertEquals(1, future.jobIds().size()); + } + @Test public void foreachAsync() throws Exception { List data = Arrays.asList(1, 2, 3, 4, 5); diff --git a/core/src/test/java/org/apache/spark/util/collection/TestTimSort.java b/core/src/test/java/org/apache/spark/util/collection/TestTimSort.java new file mode 100644 index 000000000000..45772b6d3c20 --- /dev/null +++ b/core/src/test/java/org/apache/spark/util/collection/TestTimSort.java @@ -0,0 +1,134 @@ +/** + * Copyright 2015 Stijn de Gouw + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +package org.apache.spark.util.collection; + +import java.util.*; + +/** + * This codes generates a int array which fails the standard TimSort. + * + * The blog that reported the bug + * http://www.envisage-project.eu/timsort-specification-and-verification/ + * + * This codes was originally wrote by Stijn de Gouw, modified by Evan Yu to adapt to + * our test suite. + * + * https://github.com/abstools/java-timsort-bug + * https://github.com/abstools/java-timsort-bug/blob/master/LICENSE + */ +public class TestTimSort { + + private static final int MIN_MERGE = 32; + + /** + * Returns an array of integers that demonstrate the bug in TimSort + */ + public static int[] getTimSortBugTestSet(int length) { + int minRun = minRunLength(length); + List runs = runsJDKWorstCase(minRun, length); + return createArray(runs, length); + } + + private static int minRunLength(int n) { + int r = 0; // Becomes 1 if any 1 bits are shifted off + while (n >= MIN_MERGE) { + r |= (n & 1); + n >>= 1; + } + return n + r; + } + + private static int[] createArray(List runs, int length) { + int[] a = new int[length]; + Arrays.fill(a, 0); + int endRun = -1; + for (long len : runs) { + a[endRun += len] = 1; + } + a[length - 1] = 0; + return a; + } + + /** + * Fills runs with a sequence of run lengths of the form
      + * Y_n x_{n,1} x_{n,2} ... x_{n,l_n}
      + * Y_{n-1} x_{n-1,1} x_{n-1,2} ... x_{n-1,l_{n-1}}
      + * ...
      + * Y_1 x_{1,1} x_{1,2} ... x_{1,l_1}
      + * The Y_i's are chosen to satisfy the invariant throughout execution, + * but the x_{i,j}'s are merged (by TimSort.mergeCollapse) + * into an X_i that violates the invariant. + * + * @param length The sum of all run lengths that will be added to runs. + */ + private static List runsJDKWorstCase(int minRun, int length) { + List runs = new ArrayList(); + + long runningTotal = 0, Y = minRun + 4, X = minRun; + + while (runningTotal + Y + X <= length) { + runningTotal += X + Y; + generateJDKWrongElem(runs, minRun, X); + runs.add(0, Y); + // X_{i+1} = Y_i + x_{i,1} + 1, since runs.get(1) = x_{i,1} + X = Y + runs.get(1) + 1; + // Y_{i+1} = X_{i+1} + Y_i + 1 + Y += X + 1; + } + + if (runningTotal + X <= length) { + runningTotal += X; + generateJDKWrongElem(runs, minRun, X); + } + + runs.add(length - runningTotal); + return runs; + } + + /** + * Adds a sequence x_1, ..., x_n of run lengths to runs such that:
      + * 1. X = x_1 + ... + x_n
      + * 2. x_j >= minRun for all j
      + * 3. x_1 + ... + x_{j-2} < x_j < x_1 + ... + x_{j-1} for all j
      + * These conditions guarantee that TimSort merges all x_j's one by one + * (resulting in X) using only merges on the second-to-last element. + * + * @param X The sum of the sequence that should be added to runs. + */ + private static void generateJDKWrongElem(List runs, int minRun, long X) { + for (long newTotal; X >= 2 * minRun + 1; X = newTotal) { + //Default strategy + newTotal = X / 2 + 1; + //Specialized strategies + if (3 * minRun + 3 <= X && X <= 4 * minRun + 1) { + // add x_1=MIN+1, x_2=MIN, x_3=X-newTotal to runs + newTotal = 2 * minRun + 1; + } else if (5 * minRun + 5 <= X && X <= 6 * minRun + 5) { + // add x_1=MIN+1, x_2=MIN, x_3=MIN+2, x_4=X-newTotal to runs + newTotal = 3 * minRun + 3; + } else if (8 * minRun + 9 <= X && X <= 10 * minRun + 9) { + // add x_1=MIN+1, x_2=MIN, x_3=MIN+2, x_4=2MIN+2, x_5=X-newTotal to runs + newTotal = 5 * minRun + 5; + } else if (13 * minRun + 15 <= X && X <= 16 * minRun + 17) { + // add x_1=MIN+1, x_2=MIN, x_3=MIN+2, x_4=2MIN+2, x_5=3MIN+4, x_6=X-newTotal to runs + newTotal = 8 * minRun + 9; + } + runs.add(0, X - newTotal); + } + runs.add(0, X); + } +} diff --git a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java b/core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java similarity index 93% rename from core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java rename to core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java index e9ec700e32e1..e38bc38949d7 100644 --- a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java +++ b/core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java @@ -15,9 +15,10 @@ * limitations under the License. */ -package org.apache.spark.util; +package test.org.apache.spark; import org.apache.spark.TaskContext; +import org.apache.spark.util.TaskCompletionListener; /** diff --git a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java new file mode 100644 index 000000000000..4a918f725dc9 --- /dev/null +++ b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark; + +import org.apache.spark.TaskContext; + +/** + * Something to make sure that TaskContext can be used in Java. + */ +public class JavaTaskContextCompileCheck { + + public static void test() { + TaskContext tc = TaskContext.get(); + + tc.isCompleted(); + tc.isInterrupted(); + tc.isRunningLocally(); + + tc.addTaskCompletionListener(new JavaTaskCompletionListenerImpl()); + + tc.attemptNumber(); + tc.partitionId(); + tc.stageId(); + tc.taskAttemptId(); + } +} diff --git a/core/src/test/resources/keystore b/core/src/test/resources/keystore new file mode 100644 index 000000000000..f8310e39ba1e Binary files /dev/null and b/core/src/test/resources/keystore differ diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties index 287c8e356350..eb3b1999eb99 100644 --- a/core/src/test/resources/log4j.properties +++ b/core/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/core/src/test/resources/test_metrics_system.properties b/core/src/test/resources/test_metrics_system.properties index 35d0bd3b8d0b..4e8b8465696e 100644 --- a/core/src/test/resources/test_metrics_system.properties +++ b/core/src/test/resources/test_metrics_system.properties @@ -18,7 +18,5 @@ *.sink.console.period = 10 *.sink.console.unit = seconds test.sink.console.class = org.apache.spark.metrics.sink.ConsoleSink -test.sink.dummy.class = org.apache.spark.metrics.sink.DummySink -test.source.dummy.class = org.apache.spark.metrics.source.DummySource test.sink.console.period = 20 test.sink.console.unit = minutes diff --git a/core/src/test/resources/truststore b/core/src/test/resources/truststore new file mode 100644 index 000000000000..a6b1d46e1f39 Binary files /dev/null and b/core/src/test/resources/truststore differ diff --git a/core/src/test/resources/untrusted-keystore b/core/src/test/resources/untrusted-keystore new file mode 100644 index 000000000000..6015b02caa12 Binary files /dev/null and b/core/src/test/resources/untrusted-keystore differ diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index f087fc550dde..75399461f2a5 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import scala.collection.mutable +import scala.ref.WeakReference import org.scalatest.FunSuite import org.scalatest.Matchers @@ -26,19 +27,20 @@ import org.scalatest.Matchers class AccumulatorSuite extends FunSuite with Matchers with LocalSparkContext { - implicit def setAccum[A] = new AccumulableParam[mutable.Set[A], A] { - def addInPlace(t1: mutable.Set[A], t2: mutable.Set[A]) : mutable.Set[A] = { - t1 ++= t2 - t1 - } - def addAccumulator(t1: mutable.Set[A], t2: A) : mutable.Set[A] = { - t1 += t2 - t1 - } - def zero(t: mutable.Set[A]) : mutable.Set[A] = { - new mutable.HashSet[A]() + implicit def setAccum[A]: AccumulableParam[mutable.Set[A], A] = + new AccumulableParam[mutable.Set[A], A] { + def addInPlace(t1: mutable.Set[A], t2: mutable.Set[A]) : mutable.Set[A] = { + t1 ++= t2 + t1 + } + def addAccumulator(t1: mutable.Set[A], t2: A) : mutable.Set[A] = { + t1 += t2 + t1 + } + def zero(t: mutable.Set[A]) : mutable.Set[A] = { + new mutable.HashSet[A]() + } } - } test ("basic accumulation"){ sc = new SparkContext("local", "test") @@ -48,11 +50,10 @@ class AccumulatorSuite extends FunSuite with Matchers with LocalSparkContext { d.foreach{x => acc += x} acc.value should be (210) - - val longAcc = sc.accumulator(0l) + val longAcc = sc.accumulator(0L) val maxInt = Integer.MAX_VALUE.toLong d.foreach{x => longAcc += maxInt + x} - longAcc.value should be (210l + maxInt * 20) + longAcc.value should be (210L + maxInt * 20) } test ("value not assignable from tasks") { @@ -136,4 +137,23 @@ class AccumulatorSuite extends FunSuite with Matchers with LocalSparkContext { } } + test ("garbage collection") { + // Create an accumulator and let it go out of scope to test that it's properly garbage collected + sc = new SparkContext("local", "test") + var acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) + val accId = acc.id + val ref = WeakReference(acc) + + // Ensure the accumulator is present + assert(ref.get.isDefined) + + // Remove the explicit reference to it and allow weak reference to get garbage collected + acc = null + System.gc() + assert(ref.get.isEmpty) + + Accumulators.remove(accId) + assert(!Accumulators.originals.get(accId).isDefined) + } + } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index d7d9dc7b50f3..70529d921659 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -17,16 +17,18 @@ package org.apache.spark +import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfter, FunSuite} -import org.scalatest.mock.EasyMockSugar +import org.scalatest.mock.MockitoSugar -import org.apache.spark.executor.{DataReadMethod, TaskMetrics} +import org.apache.spark.executor.DataReadMethod import org.apache.spark.rdd.RDD import org.apache.spark.storage._ // TODO: Test the CacheManager's thread-safety aspects -class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar { - var sc : SparkContext = _ +class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAfter + with MockitoSugar { + var blockManager: BlockManager = _ var cacheManager: CacheManager = _ var split: Partition = _ @@ -43,24 +45,21 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar rdd = new RDD[Int](sc, Nil) { override def getPartitions: Array[Partition] = Array(split) override val getDependencies = List[Dependency[_]]() - override def compute(split: Partition, context: TaskContext) = Array(1, 2, 3, 4).iterator + override def compute(split: Partition, context: TaskContext): Iterator[Int] = + Array(1, 2, 3, 4).iterator } rdd2 = new RDD[Int](sc, List(new OneToOneDependency(rdd))) { override def getPartitions: Array[Partition] = firstParent[Int].partitions - override def compute(split: Partition, context: TaskContext) = + override def compute(split: Partition, context: TaskContext): Iterator[Int] = firstParent[Int].iterator(split, context) }.cache() rdd3 = new RDD[Int](sc, List(new OneToOneDependency(rdd2))) { override def getPartitions: Array[Partition] = firstParent[Int].partitions - override def compute(split: Partition, context: TaskContext) = + override def compute(split: Partition, context: TaskContext): Iterator[Int] = firstParent[Int].iterator(split, context) }.cache() } - after { - sc.stop() - } - test("get uncached rdd") { // Do not mock this test, because attempting to match Array[Any], which is not covariant, // in blockManager.put is a losing battle. You have been warned. @@ -75,29 +74,21 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } test("get cached rdd") { - expecting { - val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12) - blockManager.get(RDDBlockId(0, 0)).andReturn(Some(result)) - } + val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12) + when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result)) - whenExecuting(blockManager) { - val context = new TaskContextImpl(0, 0, 0, 0) - val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) - assert(value.toList === List(5, 6, 7)) - } + val context = new TaskContextImpl(0, 0, 0, 0) + val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) + assert(value.toList === List(5, 6, 7)) } test("get uncached local rdd") { - expecting { - // Local computation should not persist the resulting value, so don't expect a put(). - blockManager.get(RDDBlockId(0, 0)).andReturn(None) - } + // Local computation should not persist the resulting value, so don't expect a put(). + when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None) - whenExecuting(blockManager) { - val context = new TaskContextImpl(0, 0, 0, 0, true) - val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) - assert(value.toList === List(1, 2, 3, 4)) - } + val context = new TaskContextImpl(0, 0, 0, 0, true) + val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) + assert(value.toList === List(1, 2, 3, 4)) } test("verify task metrics updated correctly") { diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 3b10b3a04231..e1faddeabec7 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -33,8 +33,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { override def beforeEach() { super.beforeEach() - checkpointDir = File.createTempFile("temp", "") - checkpointDir.deleteOnExit() + checkpointDir = File.createTempFile("temp", "", Utils.createTempDir()) checkpointDir.delete() sc = new SparkContext("local", "test") sc.setCheckpointDir(checkpointDir.toString) @@ -76,7 +75,8 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result) assert(parCollection.dependencies != Nil) assert(parCollection.partitions.length === numPartitions) - assert(parCollection.partitions.toList === parCollection.checkpointData.get.getPartitions.toList) + assert(parCollection.partitions.toList === + parCollection.checkpointData.get.getPartitions.toList) assert(parCollection.collect() === result) } @@ -103,13 +103,13 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { } test("UnionRDD") { - def otherRDD = sc.makeRDD(1 to 10, 1) + def otherRDD: RDD[Int] = sc.makeRDD(1 to 10, 1) testRDD(_.union(otherRDD)) testRDDPartitions(_.union(otherRDD)) } test("CartesianRDD") { - def otherRDD = sc.makeRDD(1 to 10, 1) + def otherRDD: RDD[Int] = sc.makeRDD(1 to 10, 1) testRDD(new CartesianRDD(sc, _, otherRDD)) testRDDPartitions(new CartesianRDD(sc, _, otherRDD)) @@ -224,7 +224,8 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { val partitionAfterCheckpoint = serializeDeserialize( unionRDD.partitions.head.asInstanceOf[PartitionerAwareUnionRDDPartition]) assert( - partitionBeforeCheckpoint.parents.head.getClass != partitionAfterCheckpoint.parents.head.getClass, + partitionBeforeCheckpoint.parents.head.getClass != + partitionAfterCheckpoint.parents.head.getClass, "PartitionerAwareUnionRDDPartition.parents not updated after parent RDD is checkpointed" ) } @@ -359,7 +360,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { * Generate an pair RDD (with partitioner) such that both the RDD and its partitions * have large size. */ - def generateFatPairRDD() = { + def generateFatPairRDD(): RDD[(Int, Int)] = { new FatPairRDD(sc.makeRDD(1 to 100, 4), partitioner).mapValues(x => x) } @@ -446,7 +447,8 @@ class FatPairRDD(parent: RDD[Int], _partitioner: Partitioner) extends RDD[(Int, object CheckpointSuite { // This is a custom cogroup function that does not use mapValues like // the PairRDDFunctions.cogroup() - def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = { + def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) + : RDD[(K, Array[Iterable[V]])] = { new CoGroupedRDD[K]( Seq(first.asInstanceOf[RDD[(K, _)]], second.asInstanceOf[RDD[(K, _)]]), part diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index ae2ae7ed0d3a..c7868ddcf770 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -28,7 +28,8 @@ import org.scalatest.concurrent.{PatienceConfiguration, Eventually} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.{RDDCheckpointData, RDD} import org.apache.spark.storage._ import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.shuffle.sort.SortShuffleManager @@ -64,7 +65,7 @@ abstract class ContextCleanerSuiteBase(val shuffleManager: Class[_] = classOf[Ha } } - //------ Helper functions ------ + // ------ Helper functions ------ protected def newRDD() = sc.makeRDD(1 to 10) protected def newPairRDD() = newRDD().map(_ -> 1) @@ -205,6 +206,52 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { postGCTester.assertCleanup() } + test("automatically cleanup checkpoint") { + val checkpointDir = java.io.File.createTempFile("temp", "") + checkpointDir.deleteOnExit() + checkpointDir.delete() + var rdd = newPairRDD + sc.setCheckpointDir(checkpointDir.toString) + rdd.checkpoint() + rdd.cache() + rdd.collect() + var rddId = rdd.id + + // Confirm the checkpoint directory exists + assert(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).isDefined) + val path = RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get + val fs = path.getFileSystem(sc.hadoopConfiguration) + assert(fs.exists(path)) + + // the checkpoint is not cleaned by default (without the configuration set) + var postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil, Nil) + rdd = null // Make RDD out of scope + runGC() + postGCTester.assertCleanup() + assert(fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get)) + + sc.stop() + val conf = new SparkConf().setMaster("local[2]").setAppName("cleanupCheckpoint"). + set("spark.cleaner.referenceTracking.cleanCheckpoints", "true") + sc = new SparkContext(conf) + rdd = newPairRDD + sc.setCheckpointDir(checkpointDir.toString) + rdd.checkpoint() + rdd.cache() + rdd.collect() + rddId = rdd.id + + // Confirm the checkpoint directory exists + assert(fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get)) + + // Test that GC causes checkpoint data cleanup after dereferencing the RDD + postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil, Seq(rddId)) + rdd = null // Make RDD out of scope + runGC() + postGCTester.assertCleanup() + assert(!fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get)) + } + test("automatically cleanup RDD + shuffle + broadcast") { val numRdds = 100 val numBroadcasts = 4 // Broadcasts are more costly @@ -359,18 +406,20 @@ class CleanerTester( sc: SparkContext, rddIds: Seq[Int] = Seq.empty, shuffleIds: Seq[Int] = Seq.empty, - broadcastIds: Seq[Long] = Seq.empty) + broadcastIds: Seq[Long] = Seq.empty, + checkpointIds: Seq[Long] = Seq.empty) extends Logging { val toBeCleanedRDDIds = new HashSet[Int] with SynchronizedSet[Int] ++= rddIds val toBeCleanedShuffleIds = new HashSet[Int] with SynchronizedSet[Int] ++= shuffleIds val toBeCleanedBroadcstIds = new HashSet[Long] with SynchronizedSet[Long] ++= broadcastIds + val toBeCheckpointIds = new HashSet[Long] with SynchronizedSet[Long] ++= checkpointIds val isDistributed = !sc.isLocal val cleanerListener = new CleanerListener { def rddCleaned(rddId: Int): Unit = { toBeCleanedRDDIds -= rddId - logInfo("RDD "+ rddId + " cleaned") + logInfo("RDD " + rddId + " cleaned") } def shuffleCleaned(shuffleId: Int): Unit = { @@ -380,7 +429,16 @@ class CleanerTester( def broadcastCleaned(broadcastId: Long): Unit = { toBeCleanedBroadcstIds -= broadcastId - logInfo("Broadcast" + broadcastId + " cleaned") + logInfo("Broadcast " + broadcastId + " cleaned") + } + + def accumCleaned(accId: Long): Unit = { + logInfo("Cleaned accId " + accId + " cleaned") + } + + def checkpointCleaned(rddId: Long): Unit = { + toBeCheckpointIds -= rddId + logInfo("checkpoint " + rddId + " cleaned") } } @@ -405,7 +463,8 @@ class CleanerTester( /** Verify that RDDs, shuffles, etc. occupy resources */ private def preCleanupValidate() { - assert(rddIds.nonEmpty || shuffleIds.nonEmpty || broadcastIds.nonEmpty, "Nothing to cleanup") + assert(rddIds.nonEmpty || shuffleIds.nonEmpty || broadcastIds.nonEmpty || + checkpointIds.nonEmpty, "Nothing to cleanup") // Verify the RDDs have been persisted and blocks are present rddIds.foreach { rddId => @@ -496,7 +555,8 @@ class CleanerTester( private def isAllCleanedUp = toBeCleanedRDDIds.isEmpty && toBeCleanedShuffleIds.isEmpty && - toBeCleanedBroadcstIds.isEmpty + toBeCleanedBroadcstIds.isEmpty && + toBeCheckpointIds.isEmpty private def getRDDBlocks(rddId: Int): Seq[BlockId] = { blockManager.master.getMatchingBlockIds( _ match { diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index 8a54360e8179..9bd5dfec8703 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -28,31 +28,29 @@ import org.apache.spark.util.Utils class DriverSuite extends FunSuite with Timeouts { - test("driver should exit after finishing") { + test("driver should exit after finishing without cleanup (SPARK-530)") { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing" - val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]")) + val masters = Table("master", "local", "local-cluster[2,1,512]") forAll(masters) { (master: String) => - failAfter(60 seconds) { - Utils.executeAndGetOutput( - Seq(s"$sparkHome/bin/spark-class", "org.apache.spark.DriverWithoutCleanup", master), - new File(sparkHome), - Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) - } + val process = Utils.executeCommand( + Seq(s"$sparkHome/bin/spark-class", "org.apache.spark.DriverWithoutCleanup", master), + new File(sparkHome), + Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) + failAfter(60 seconds) { process.waitFor() } + // Ensure we still kill the process in case it timed out + process.destroy() } } } /** - * Program that creates a Spark driver but doesn't call SparkContext.stop() or - * Sys.exit() after finishing. + * Program that creates a Spark driver but doesn't call SparkContext#stop() or + * sys.exit() after finishing. */ object DriverWithoutCleanup { def main(args: Array[String]) { Utils.configTestLog4j("INFO") - // Bind the web UI to an ephemeral port in order to avoid conflicts with other tests running on - // the same machine (we shouldn't just disable the UI here, since that might mask bugs): - val conf = new SparkConf().set("spark.ui.port", "0") + val conf = new SparkConf val sc = new SparkContext(args(0), "DriverWithoutCleanup", conf) sc.parallelize(1 to 100, 4).count() } diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 0e4df17c1bf8..22acc270b983 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -19,45 +19,50 @@ package org.apache.spark import scala.collection.mutable -import org.scalatest.{FunSuite, PrivateMethodTester} +import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.util.ManualClock /** * Test add and remove behavior of ExecutorAllocationManager. */ -class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { +class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAfter { import ExecutorAllocationManager._ import ExecutorAllocationManagerSuite._ + private val contexts = new mutable.ListBuffer[SparkContext]() + + before { + contexts.clear() + } + + after { + contexts.foreach(_.stop()) + } + test("verify min/max executors") { - // No min or max val conf = new SparkConf() .setMaster("local") .setAppName("test-executor-allocation-manager") .set("spark.dynamicAllocation.enabled", "true") .set("spark.dynamicAllocation.testing", "true") - intercept[SparkException] { new SparkContext(conf) } - SparkEnv.get.stop() // cleanup the created environment - SparkContext.clearActiveContext() - - // Only min - val conf1 = conf.clone().set("spark.dynamicAllocation.minExecutors", "1") - intercept[SparkException] { new SparkContext(conf1) } - SparkEnv.get.stop() - SparkContext.clearActiveContext() - - // Only max - val conf2 = conf.clone().set("spark.dynamicAllocation.maxExecutors", "2") - intercept[SparkException] { new SparkContext(conf2) } - SparkEnv.get.stop() - SparkContext.clearActiveContext() + val sc0 = new SparkContext(conf) + contexts += sc0 + assert(sc0.executorAllocationManager.isDefined) + sc0.stop() + + // Min < 0 + val conf1 = conf.clone().set("spark.dynamicAllocation.minExecutors", "-1") + intercept[SparkException] { contexts += new SparkContext(conf1) } + + // Max < 0 + val conf2 = conf.clone().set("spark.dynamicAllocation.maxExecutors", "-1") + intercept[SparkException] { contexts += new SparkContext(conf2) } // Both min and max, but min > max intercept[SparkException] { createSparkContext(2, 1) } - SparkEnv.get.stop() - SparkContext.clearActiveContext() // Both min and max, and min == max val sc1 = createSparkContext(1, 1) @@ -145,8 +150,8 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { // Verify that running a task reduces the cap sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 3))) - sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( - 0L, BlockManagerId("executor-1", "host1", 1), 100L)) + sc.listenerBus.postToAll(SparkListenerExecutorAdded( + 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty))) sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-1"))) assert(numExecutorsPending(manager) === 4) assert(addExecutors(manager) === 1) @@ -176,6 +181,33 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(numExecutorsPending(manager) === 9) } + test("cancel pending executors when no longer needed") { + sc = createSparkContext(1, 10) + val manager = sc.executorAllocationManager.get + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(2, 5))) + + assert(numExecutorsPending(manager) === 0) + assert(numExecutorsToAdd(manager) === 1) + assert(addExecutors(manager) === 1) + assert(numExecutorsPending(manager) === 1) + assert(numExecutorsToAdd(manager) === 2) + assert(addExecutors(manager) === 2) + assert(numExecutorsPending(manager) === 3) + + val task1Info = createTaskInfo(0, 0, "executor-1") + sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, task1Info)) + + assert(numExecutorsToAdd(manager) === 4) + assert(addExecutors(manager) === 2) + + val task2Info = createTaskInfo(1, 0, "executor-1") + sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, task2Info)) + sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, null, task1Info, null)) + sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, null, task2Info, null)) + + assert(adjustRequestedExecutors(manager) === -1) + } + test("remove executors") { sc = createSparkContext(5, 10) val manager = sc.executorAllocationManager.get @@ -271,15 +303,15 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(removeExecutor(manager, "5")) assert(removeExecutor(manager, "6")) assert(executorIds(manager).size === 10) - assert(addExecutors(manager) === 0) // still at upper limit + assert(addExecutors(manager) === 1) onExecutorRemoved(manager, "3") onExecutorRemoved(manager, "4") assert(executorIds(manager).size === 8) // Add succeeds again, now that we are no longer at the upper limit // Number of executors added restarts at 1 - assert(addExecutors(manager) === 1) - assert(addExecutors(manager) === 1) // upper limit reached again + assert(addExecutors(manager) === 2) + assert(addExecutors(manager) === 1) // upper limit reached assert(addExecutors(manager) === 0) assert(executorIds(manager).size === 8) onExecutorRemoved(manager, "5") @@ -287,9 +319,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { onExecutorAdded(manager, "13") onExecutorAdded(manager, "14") assert(executorIds(manager).size === 8) - assert(addExecutors(manager) === 1) - assert(addExecutors(manager) === 1) // upper limit reached again - assert(addExecutors(manager) === 0) + assert(addExecutors(manager) === 0) // still at upper limit onExecutorAdded(manager, "15") onExecutorAdded(manager, "16") assert(executorIds(manager).size === 10) @@ -297,7 +327,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { test("starting/canceling add timer") { sc = createSparkContext(2, 10) - val clock = new TestClock(8888L) + val clock = new ManualClock(8888L) val manager = sc.executorAllocationManager.get manager.setClock(clock) @@ -306,21 +336,21 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { onSchedulerBacklogged(manager) val firstAddTime = addTime(manager) assert(firstAddTime === clock.getTimeMillis + schedulerBacklogTimeout * 1000) - clock.tick(100L) + clock.advance(100L) onSchedulerBacklogged(manager) assert(addTime(manager) === firstAddTime) // timer is already started - clock.tick(200L) + clock.advance(200L) onSchedulerBacklogged(manager) assert(addTime(manager) === firstAddTime) onSchedulerQueueEmpty(manager) // Restart add timer - clock.tick(1000L) + clock.advance(1000L) assert(addTime(manager) === NOT_SET) onSchedulerBacklogged(manager) val secondAddTime = addTime(manager) assert(secondAddTime === clock.getTimeMillis + schedulerBacklogTimeout * 1000) - clock.tick(100L) + clock.advance(100L) onSchedulerBacklogged(manager) assert(addTime(manager) === secondAddTime) // timer is already started assert(addTime(manager) !== firstAddTime) @@ -329,7 +359,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { test("starting/canceling remove timers") { sc = createSparkContext(2, 10) - val clock = new TestClock(14444L) + val clock = new ManualClock(14444L) val manager = sc.executorAllocationManager.get manager.setClock(clock) @@ -342,17 +372,17 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(removeTimes(manager).contains("1")) val firstRemoveTime = removeTimes(manager)("1") assert(firstRemoveTime === clock.getTimeMillis + executorIdleTimeout * 1000) - clock.tick(100L) + clock.advance(100L) onExecutorIdle(manager, "1") assert(removeTimes(manager)("1") === firstRemoveTime) // timer is already started - clock.tick(200L) + clock.advance(200L) onExecutorIdle(manager, "1") assert(removeTimes(manager)("1") === firstRemoveTime) - clock.tick(300L) + clock.advance(300L) onExecutorIdle(manager, "2") assert(removeTimes(manager)("2") !== firstRemoveTime) // different executor assert(removeTimes(manager)("2") === clock.getTimeMillis + executorIdleTimeout * 1000) - clock.tick(400L) + clock.advance(400L) onExecutorIdle(manager, "3") assert(removeTimes(manager)("3") !== firstRemoveTime) assert(removeTimes(manager)("3") === clock.getTimeMillis + executorIdleTimeout * 1000) @@ -361,7 +391,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(removeTimes(manager).contains("3")) // Restart remove timer - clock.tick(1000L) + clock.advance(1000L) onExecutorBusy(manager, "1") assert(removeTimes(manager).size === 2) onExecutorIdle(manager, "1") @@ -377,7 +407,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { test("mock polling loop with no events") { sc = createSparkContext(1, 20) val manager = sc.executorAllocationManager.get - val clock = new TestClock(2020L) + val clock = new ManualClock(2020L) manager.setClock(clock) // No events - we should not be adding or removing @@ -386,15 +416,15 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { schedule(manager) assert(numExecutorsPending(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) - clock.tick(100L) + clock.advance(100L) schedule(manager) assert(numExecutorsPending(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) - clock.tick(1000L) + clock.advance(1000L) schedule(manager) assert(numExecutorsPending(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) - clock.tick(10000L) + clock.advance(10000L) schedule(manager) assert(numExecutorsPending(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) @@ -402,57 +432,57 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { test("mock polling loop add behavior") { sc = createSparkContext(1, 20) - val clock = new TestClock(2020L) + val clock = new ManualClock(2020L) val manager = sc.executorAllocationManager.get manager.setClock(clock) sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) // Scheduler queue backlogged onSchedulerBacklogged(manager) - clock.tick(schedulerBacklogTimeout * 1000 / 2) + clock.advance(schedulerBacklogTimeout * 1000 / 2) schedule(manager) assert(numExecutorsPending(manager) === 0) // timer not exceeded yet - clock.tick(schedulerBacklogTimeout * 1000) + clock.advance(schedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 1) // first timer exceeded - clock.tick(sustainedSchedulerBacklogTimeout * 1000 / 2) + clock.advance(sustainedSchedulerBacklogTimeout * 1000 / 2) schedule(manager) assert(numExecutorsPending(manager) === 1) // second timer not exceeded yet - clock.tick(sustainedSchedulerBacklogTimeout * 1000) + clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 1 + 2) // second timer exceeded - clock.tick(sustainedSchedulerBacklogTimeout * 1000) + clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 1 + 2 + 4) // third timer exceeded // Scheduler queue drained onSchedulerQueueEmpty(manager) - clock.tick(sustainedSchedulerBacklogTimeout * 1000) + clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 7) // timer is canceled - clock.tick(sustainedSchedulerBacklogTimeout * 1000) + clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 7) // Scheduler queue backlogged again onSchedulerBacklogged(manager) - clock.tick(schedulerBacklogTimeout * 1000) + clock.advance(schedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 7 + 1) // timer restarted - clock.tick(sustainedSchedulerBacklogTimeout * 1000) + clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 7 + 1 + 2) - clock.tick(sustainedSchedulerBacklogTimeout * 1000) + clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 7 + 1 + 2 + 4) - clock.tick(sustainedSchedulerBacklogTimeout * 1000) + clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) assert(numExecutorsPending(manager) === 20) // limit reached } test("mock polling loop remove behavior") { sc = createSparkContext(1, 20) - val clock = new TestClock(2020L) + val clock = new ManualClock(2020L) val manager = sc.executorAllocationManager.get manager.setClock(clock) @@ -462,11 +492,11 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { onExecutorAdded(manager, "executor-3") assert(removeTimes(manager).size === 3) assert(executorsPendingToRemove(manager).isEmpty) - clock.tick(executorIdleTimeout * 1000 / 2) + clock.advance(executorIdleTimeout * 1000 / 2) schedule(manager) assert(removeTimes(manager).size === 3) // idle threshold not reached yet assert(executorsPendingToRemove(manager).isEmpty) - clock.tick(executorIdleTimeout * 1000) + clock.advance(executorIdleTimeout * 1000) schedule(manager) assert(removeTimes(manager).isEmpty) // idle threshold exceeded assert(executorsPendingToRemove(manager).size === 2) // limit reached (1 executor remaining) @@ -487,7 +517,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(!removeTimes(manager).contains("executor-5")) assert(!removeTimes(manager).contains("executor-6")) assert(executorsPendingToRemove(manager).size === 2) - clock.tick(executorIdleTimeout * 1000) + clock.advance(executorIdleTimeout * 1000) schedule(manager) assert(removeTimes(manager).isEmpty) // idle executors are removed assert(executorsPendingToRemove(manager).size === 4) @@ -505,7 +535,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(removeTimes(manager).contains("executor-5")) assert(removeTimes(manager).contains("executor-6")) assert(executorsPendingToRemove(manager).size === 4) - clock.tick(executorIdleTimeout * 1000) + clock.advance(executorIdleTimeout * 1000) schedule(manager) assert(removeTimes(manager).isEmpty) assert(executorsPendingToRemove(manager).size === 6) // limit reached (1 executor remaining) @@ -579,30 +609,28 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(removeTimes(manager).isEmpty) // New executors have registered - sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( - 0L, BlockManagerId("executor-1", "host1", 1), 100L)) + sc.listenerBus.postToAll(SparkListenerExecutorAdded( + 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty))) assert(executorIds(manager).size === 1) assert(executorIds(manager).contains("executor-1")) assert(removeTimes(manager).size === 1) assert(removeTimes(manager).contains("executor-1")) - sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( - 0L, BlockManagerId("executor-2", "host2", 1), 100L)) + sc.listenerBus.postToAll(SparkListenerExecutorAdded( + 0L, "executor-2", new ExecutorInfo("host2", 1, Map.empty))) assert(executorIds(manager).size === 2) assert(executorIds(manager).contains("executor-2")) assert(removeTimes(manager).size === 2) assert(removeTimes(manager).contains("executor-2")) // Existing executors have disconnected - sc.listenerBus.postToAll(SparkListenerBlockManagerRemoved( - 0L, BlockManagerId("executor-1", "host1", 1))) + sc.listenerBus.postToAll(SparkListenerExecutorRemoved(0L, "executor-1", "")) assert(executorIds(manager).size === 1) assert(!executorIds(manager).contains("executor-1")) assert(removeTimes(manager).size === 1) assert(!removeTimes(manager).contains("executor-1")) // Unknown executor has disconnected - sc.listenerBus.postToAll(SparkListenerBlockManagerRemoved( - 0L, BlockManagerId("executor-3", "host3", 1))) + sc.listenerBus.postToAll(SparkListenerExecutorRemoved(0L, "executor-3", "")) assert(executorIds(manager).size === 1) assert(removeTimes(manager).size === 1) } @@ -614,8 +642,8 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(removeTimes(manager).isEmpty) sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) - sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( - 0L, BlockManagerId("executor-1", "host1", 1), 100L)) + sc.listenerBus.postToAll(SparkListenerExecutorAdded( + 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty))) assert(executorIds(manager).size === 1) assert(executorIds(manager).contains("executor-1")) assert(removeTimes(manager).size === 0) @@ -626,32 +654,22 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { val manager = sc.executorAllocationManager.get assert(executorIds(manager).isEmpty) assert(removeTimes(manager).isEmpty) - sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( - 0L, BlockManagerId("executor-1", "host1", 1), 100L)) + sc.listenerBus.postToAll(SparkListenerExecutorAdded( + 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty))) sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) assert(executorIds(manager).size === 1) assert(executorIds(manager).contains("executor-1")) assert(removeTimes(manager).size === 0) - sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( - 0L, BlockManagerId("executor-2", "host1", 1), 100L)) + sc.listenerBus.postToAll(SparkListenerExecutorAdded( + 0L, "executor-2", new ExecutorInfo("host1", 1, Map.empty))) assert(executorIds(manager).size === 2) assert(executorIds(manager).contains("executor-2")) assert(removeTimes(manager).size === 1) assert(removeTimes(manager).contains("executor-2")) assert(!removeTimes(manager).contains("executor-1")) } -} - -/** - * Helper methods for testing ExecutorAllocationManager. - * This includes methods to access private methods and fields in ExecutorAllocationManager. - */ -private object ExecutorAllocationManagerSuite extends PrivateMethodTester { - private val schedulerBacklogTimeout = 1L - private val sustainedSchedulerBacklogTimeout = 2L - private val executorIdleTimeout = 3L private def createSparkContext(minExecutors: Int = 1, maxExecutors: Int = 5): SparkContext = { val conf = new SparkConf() @@ -660,14 +678,28 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { .set("spark.dynamicAllocation.enabled", "true") .set("spark.dynamicAllocation.minExecutors", minExecutors.toString) .set("spark.dynamicAllocation.maxExecutors", maxExecutors.toString) - .set("spark.dynamicAllocation.schedulerBacklogTimeout", schedulerBacklogTimeout.toString) + .set("spark.dynamicAllocation.schedulerBacklogTimeout", + s"${schedulerBacklogTimeout.toString}s") .set("spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", - sustainedSchedulerBacklogTimeout.toString) - .set("spark.dynamicAllocation.executorIdleTimeout", executorIdleTimeout.toString) + s"${sustainedSchedulerBacklogTimeout.toString}s") + .set("spark.dynamicAllocation.executorIdleTimeout", s"${executorIdleTimeout.toString}s") .set("spark.dynamicAllocation.testing", "true") - new SparkContext(conf) + val sc = new SparkContext(conf) + contexts += sc + sc } +} + +/** + * Helper methods for testing ExecutorAllocationManager. + * This includes methods to access private methods and fields in ExecutorAllocationManager. + */ +private object ExecutorAllocationManagerSuite extends PrivateMethodTester { + private val schedulerBacklogTimeout = 1L + private val sustainedSchedulerBacklogTimeout = 2L + private val executorIdleTimeout = 3L + private def createStageInfo(stageId: Int, numTasks: Int): StageInfo = { new StageInfo(stageId, 0, "name", numTasks, Seq.empty, "no details") } @@ -682,6 +714,7 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { private val _numExecutorsToAdd = PrivateMethod[Int]('numExecutorsToAdd) private val _numExecutorsPending = PrivateMethod[Int]('numExecutorsPending) + private val _maxNumExecutorsNeeded = PrivateMethod[Int]('maxNumExecutorsNeeded) private val _executorsPendingToRemove = PrivateMethod[collection.Set[String]]('executorsPendingToRemove) private val _executorIds = PrivateMethod[collection.Set[String]]('executorIds) @@ -689,6 +722,7 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { private val _removeTimes = PrivateMethod[collection.Map[String, Long]]('removeTimes) private val _schedule = PrivateMethod[Unit]('schedule) private val _addExecutors = PrivateMethod[Int]('addExecutors) + private val _addOrCancelExecutorRequests = PrivateMethod[Int]('addOrCancelExecutorRequests) private val _removeExecutor = PrivateMethod[Boolean]('removeExecutor) private val _onExecutorAdded = PrivateMethod[Unit]('onExecutorAdded) private val _onExecutorRemoved = PrivateMethod[Unit]('onExecutorRemoved) @@ -727,7 +761,12 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { } private def addExecutors(manager: ExecutorAllocationManager): Int = { - manager invokePrivate _addExecutors() + val maxNumExecutorsNeeded = manager invokePrivate _maxNumExecutorsNeeded() + manager invokePrivate _addExecutors(maxNumExecutorsNeeded) + } + + private def adjustRequestedExecutors(manager: ExecutorAllocationManager): Int = { + manager invokePrivate _addOrCancelExecutorRequests(0L) } private def removeExecutor(manager: ExecutorAllocationManager, id: String): Boolean = { diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 0f49ce4754fb..a69e9b761f9a 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -18,13 +18,19 @@ package org.apache.spark import java.io._ +import java.net.URI import java.util.jar.{JarEntry, JarOutputStream} +import javax.net.ssl.SSLException import com.google.common.io.ByteStreams +import org.apache.commons.io.{FileUtils, IOUtils} +import org.apache.commons.lang3.RandomUtils import org.scalatest.FunSuite import org.apache.spark.util.Utils +import SSLSampleConfigs._ + class FileServerSuite extends FunSuite with LocalSparkContext { @transient var tmpDir: File = _ @@ -168,4 +174,88 @@ class FileServerSuite extends FunSuite with LocalSparkContext { } } + test ("HttpFileServer should work with SSL") { + val sparkConf = sparkSSLConfig() + val sm = new SecurityManager(sparkConf) + val server = new HttpFileServer(sparkConf, sm, 0) + try { + server.initialize() + + fileTransferTest(server, sm) + } finally { + server.stop() + } + } + + test ("HttpFileServer should work with SSL and good credentials") { + val sparkConf = sparkSSLConfig() + sparkConf.set("spark.authenticate", "true") + sparkConf.set("spark.authenticate.secret", "good") + + val sm = new SecurityManager(sparkConf) + val server = new HttpFileServer(sparkConf, sm, 0) + try { + server.initialize() + + fileTransferTest(server, sm) + } finally { + server.stop() + } + } + + test ("HttpFileServer should not work with valid SSL and bad credentials") { + val sparkConf = sparkSSLConfig() + sparkConf.set("spark.authenticate", "true") + sparkConf.set("spark.authenticate.secret", "bad") + + val sm = new SecurityManager(sparkConf) + val server = new HttpFileServer(sparkConf, sm, 0) + try { + server.initialize() + + intercept[IOException] { + fileTransferTest(server) + } + } finally { + server.stop() + } + } + + test ("HttpFileServer should not work with SSL when the server is untrusted") { + val sparkConf = sparkSSLConfigUntrusted() + val sm = new SecurityManager(sparkConf) + val server = new HttpFileServer(sparkConf, sm, 0) + try { + server.initialize() + + intercept[SSLException] { + fileTransferTest(server) + } + } finally { + server.stop() + } + } + + def fileTransferTest(server: HttpFileServer, sm: SecurityManager = null): Unit = { + val randomContent = RandomUtils.nextBytes(100) + val file = File.createTempFile("FileServerSuite", "sslTests", tmpDir) + FileUtils.writeByteArrayToFile(file, randomContent) + server.addFile(file) + + val uri = new URI(server.serverUri + "/files/" + file.getName) + + val connection = if (sm != null && sm.isAuthenticationEnabled()) { + Utils.constructURIForAuthentication(uri, sm).toURL.openConnection() + } else { + uri.toURL.openConnection() + } + + if (sm != null) { + Utils.setupSecureURLConnection(connection, sm) + } + + val buf = IOUtils.toByteArray(connection.getInputStream) + assert(buf === randomContent) + } + } diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 5e24196101fb..c8f08eed47c7 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -32,7 +32,6 @@ import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInp import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} import org.scalatest.FunSuite -import org.apache.spark.SparkContext._ import org.apache.spark.rdd.{NewHadoopRDD, HadoopRDD} import org.apache.spark.util.Utils @@ -223,7 +222,7 @@ class FileSuite extends FunSuite with LocalSparkContext { val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x))) nums.saveAsSequenceFile(outputDir) val output = - sc.newAPIHadoopFile[IntWritable, Text, SequenceFileInputFormat[IntWritable, Text]](outputDir) + sc.newAPIHadoopFile[IntWritable, Text, SequenceFileInputFormat[IntWritable, Text]](outputDir) assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) } @@ -452,7 +451,8 @@ class FileSuite extends FunSuite with LocalSparkContext { test ("prevent user from overwriting the empty directory (new Hadoop API)") { sc = new SparkContext("local", "test") - val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) + val randomRDD = sc.parallelize( + Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) intercept[FileAlreadyExistsException] { randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempDir.getPath) } @@ -460,8 +460,10 @@ class FileSuite extends FunSuite with LocalSparkContext { test ("prevent user from overwriting the non-empty directory (new Hadoop API)") { sc = new SparkContext("local", "test") - val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) - randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempDir.getPath + "/output") + val randomRDD = sc.parallelize( + Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) + randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]]( + tempDir.getPath + "/output") assert(new File(tempDir.getPath + "/output/part-r-00000").exists() === true) intercept[FileAlreadyExistsException] { randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempDir.getPath) @@ -472,16 +474,20 @@ class FileSuite extends FunSuite with LocalSparkContext { val sf = new SparkConf() sf.setAppName("test").setMaster("local").set("spark.hadoop.validateOutputSpecs", "false") sc = new SparkContext(sf) - val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) - randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempDir.getPath + "/output") + val randomRDD = sc.parallelize( + Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) + randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]]( + tempDir.getPath + "/output") assert(new File(tempDir.getPath + "/output/part-r-00000").exists() === true) - randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempDir.getPath + "/output") + randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]]( + tempDir.getPath + "/output") assert(new File(tempDir.getPath + "/output/part-r-00000").exists() === true) } test ("save Hadoop Dataset through old Hadoop API") { sc = new SparkContext("local", "test") - val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) + val randomRDD = sc.parallelize( + Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) val job = new JobConf() job.setOutputKeyClass(classOf[String]) job.setOutputValueClass(classOf[String]) @@ -493,7 +499,8 @@ class FileSuite extends FunSuite with LocalSparkContext { test ("save Hadoop Dataset through new Hadoop API") { sc = new SparkContext("local", "test") - val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) + val randomRDD = sc.parallelize( + Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) val job = new Job(sc.hadoopConfiguration) job.setOutputKeyClass(classOf[String]) job.setOutputValueClass(classOf[String]) diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala new file mode 100644 index 000000000000..0fd570e5297d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.storage.BlockManagerId +import org.scalatest.FunSuite +import org.mockito.Mockito.{mock, spy, verify, when} +import org.mockito.Matchers +import org.mockito.Matchers._ + +import org.apache.spark.scheduler.TaskScheduler +import org.apache.spark.util.RpcUtils +import org.scalatest.concurrent.Eventually._ + +class HeartbeatReceiverSuite extends FunSuite with LocalSparkContext { + + test("HeartbeatReceiver") { + sc = spy(new SparkContext("local[2]", "test")) + val scheduler = mock(classOf[TaskScheduler]) + when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true) + when(sc.taskScheduler).thenReturn(scheduler) + + val heartbeatReceiver = new HeartbeatReceiver(sc) + sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(heartbeatReceiver.scheduler != null) + } + val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv) + + val metrics = new TaskMetrics + val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) + val response = receiverRef.askWithReply[HeartbeatResponse]( + Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) + + verify(scheduler).executorHeartbeatReceived( + Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) + assert(false === response.reregisterBlockManager) + } + + test("HeartbeatReceiver re-register") { + sc = spy(new SparkContext("local[2]", "test")) + val scheduler = mock(classOf[TaskScheduler]) + when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(false) + when(sc.taskScheduler).thenReturn(scheduler) + + val heartbeatReceiver = new HeartbeatReceiver(sc) + sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(heartbeatReceiver.scheduler != null) + } + val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv) + + val metrics = new TaskMetrics + val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) + val response = receiverRef.askWithReply[HeartbeatResponse]( + Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) + + verify(scheduler).executorHeartbeatReceived( + Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) + assert(true === response.reregisterBlockManager) + } +} diff --git a/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala b/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala index d895230ecf33..51348c039b5c 100644 --- a/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala @@ -51,7 +51,7 @@ private object ImplicitOrderingSuite { override def compare(o: OrderedClass): Int = ??? } - def basicMapExpectations(rdd: RDD[Int]) = { + def basicMapExpectations(rdd: RDD[Int]): List[(Boolean, String)] = { List((rdd.map(x => (x, x)).keyOrdering.isDefined, "rdd.map(x => (x, x)).keyOrdering.isDefined"), (rdd.map(x => (1, x)).keyOrdering.isDefined, @@ -68,7 +68,7 @@ private object ImplicitOrderingSuite { "rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined")) } - def otherRDDMethodExpectations(rdd: RDD[Int]) = { + def otherRDDMethodExpectations(rdd: RDD[Int]): List[(Boolean, String)] = { List((rdd.groupBy(x => x).keyOrdering.isDefined, "rdd.groupBy(x => x).keyOrdering.isDefined"), (rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty, @@ -82,4 +82,4 @@ private object ImplicitOrderingSuite { (rdd.groupBy((x: Int) => x, new HashPartitioner(5)).keyOrdering.isDefined, "rdd.groupBy((x: Int) => x, new HashPartitioner(5)).keyOrdering.isDefined")) } -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 7584ae79fc92..4d3e09793faf 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -171,11 +171,11 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter assert(jobB.get() === 100) } - ignore("two jobs sharing the same stage") { + test("two jobs sharing the same stage") { // sem1: make sure cancel is issued after some tasks are launched - // sem2: make sure the first stage is not finished until cancel is issued + // twoJobsSharingStageSemaphore: + // make sure the first stage is not finished until cancel is issued val sem1 = new Semaphore(0) - val sem2 = new Semaphore(0) sc = new SparkContext("local[2]", "test") sc.addSparkListener(new SparkListener { @@ -186,9 +186,9 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter // Create two actions that would share the some stages. val rdd = sc.parallelize(1 to 10, 2).map { i => - sem2.acquire() + JobCancellationSuite.twoJobsSharingStageSemaphore.acquire() (i, i) - }.reduceByKey(_+_) + }.reduceByKey(_ + _) val f1 = rdd.collectAsync() val f2 = rdd.countAsync() @@ -196,13 +196,13 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter future { sem1.acquire() f1.cancel() - sem2.release(10) + JobCancellationSuite.twoJobsSharingStageSemaphore.release(10) } - // Expect both to fail now. - // TODO: update this test when we change Spark so cancelling f1 wouldn't affect f2. + // Expect f1 to fail due to cancellation, intercept[SparkException] { f1.get() } - intercept[SparkException] { f2.get() } + // but f2 should not be affected + f2.get() } def testCount() { @@ -268,4 +268,5 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter object JobCancellationSuite { val taskStartedSemaphore = new Semaphore(0) val taskCancelledSemaphore = new Semaphore(0) + val twoJobsSharingStageSemaphore = new Semaphore(0) } diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala index 53e367a61715..8bf2e55defd0 100644 --- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -37,7 +37,7 @@ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self super.afterEach() } - def resetSparkContext() = { + def resetSparkContext(): Unit = { LocalSparkContext.stop(sc) sc = null } @@ -54,7 +54,7 @@ object LocalSparkContext { } /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ - def withSpark[T](sc: SparkContext)(f: SparkContext => T) = { + def withSpark[T](sc: SparkContext)(f: SparkContext => T): T = { try { f(sc) } finally { diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index d27880f4bc32..6ed057a7cab9 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -17,34 +17,37 @@ package org.apache.spark -import scala.concurrent.Await - -import akka.actor._ -import akka.testkit.TestActorRef +import org.mockito.Mockito._ +import org.mockito.Matchers.{any, isA} import org.scalatest.FunSuite +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.AkkaUtils class MapOutputTrackerSuite extends FunSuite { private val conf = new SparkConf + def createRpcEnv(name: String, host: String = "localhost", port: Int = 0, + securityManager: SecurityManager = new SecurityManager(conf)): RpcEnv = { + RpcEnv.create(name, host, port, conf, securityManager) + } + test("master start and stop") { - val actorSystem = ActorSystem("test") + val rpcEnv = createRpcEnv("test") val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = - actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.stop() - actorSystem.shutdown() + rpcEnv.shutdown() } test("master register shuffle and fetch") { - val actorSystem = ActorSystem("test") + val rpcEnv = createRpcEnv("test") val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = - actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) assert(tracker.containsShuffle(10)) val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) @@ -57,13 +60,14 @@ class MapOutputTrackerSuite extends FunSuite { assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000), (BlockManagerId("b", "hostB", 1000), size10000))) tracker.stop() - actorSystem.shutdown() + rpcEnv.shutdown() } test("master register and unregister shuffle") { - val actorSystem = ActorSystem("test") + val rpcEnv = createRpcEnv("test") val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) @@ -78,14 +82,14 @@ class MapOutputTrackerSuite extends FunSuite { assert(tracker.getServerStatuses(10, 0).isEmpty) tracker.stop() - actorSystem.shutdown() + rpcEnv.shutdown() } test("master register shuffle and unregister map output and fetch") { - val actorSystem = ActorSystem("test") + val rpcEnv = createRpcEnv("test") val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = - actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) @@ -104,25 +108,21 @@ class MapOutputTrackerSuite extends FunSuite { intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) } tracker.stop() - actorSystem.shutdown() + rpcEnv.shutdown() } test("remote fetch") { val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, - securityManager = new SecurityManager(conf)) + val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf)) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf, - securityManager = new SecurityManager(conf)) + val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf)) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) masterTracker.registerShuffle(10, 1) masterTracker.incrementEpoch() @@ -147,43 +147,47 @@ class MapOutputTrackerSuite extends FunSuite { masterTracker.stop() slaveTracker.stop() - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch below akka frame size") { val newConf = new SparkConf newConf.set("spark.akka.frameSize", "1") - newConf.set("spark.akka.askTimeout", "1") // Fail fast + newConf.set("spark.rpc.askTimeout", "1") // Fail fast val masterTracker = new MapOutputTrackerMaster(conf) - val actorSystem = ActorSystem("test") - val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) - val masterActor = actorRef.underlyingActor + val rpcEnv = createRpcEnv("spark") + val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) + rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) // Frame size should be ~123B, and no exception should be thrown masterTracker.registerShuffle(10, 1) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) - masterActor.receive(GetMapOutputStatuses(10)) + val sender = mock(classOf[RpcEndpointRef]) + when(sender.address).thenReturn(RpcAddress("localhost", 12345)) + val rpcCallContext = mock(classOf[RpcCallContext]) + when(rpcCallContext.sender).thenReturn(sender) + masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(10)) + verify(rpcCallContext).reply(any()) + verify(rpcCallContext, never()).sendFailure(any()) // masterTracker.stop() // this throws an exception - actorSystem.shutdown() + rpcEnv.shutdown() } test("remote fetch exceeds akka frame size") { val newConf = new SparkConf newConf.set("spark.akka.frameSize", "1") - newConf.set("spark.akka.askTimeout", "1") // Fail fast + newConf.set("spark.rpc.askTimeout", "1") // Fail fast val masterTracker = new MapOutputTrackerMaster(conf) - val actorSystem = ActorSystem("test") - val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) - val masterActor = actorRef.underlyingActor + val rpcEnv = createRpcEnv("test") + val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) + rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) - // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception. + // Frame size should be ~1.1MB, and MapOutputTrackerMasterEndpoint should throw exception. // Note that the size is hand-selected here because map output statuses are compressed before // being sent. masterTracker.registerShuffle(20, 100) @@ -191,9 +195,15 @@ class MapOutputTrackerSuite extends FunSuite { masterTracker.registerMapOutput(20, i, new CompressedMapStatus( BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) } - intercept[SparkException] { masterActor.receive(GetMapOutputStatuses(20)) } + val sender = mock(classOf[RpcEndpointRef]) + when(sender.address).thenReturn(RpcAddress("localhost", 12345)) + val rpcCallContext = mock(classOf[RpcCallContext]) + when(rpcCallContext.sender).thenReturn(sender) + masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20)) + verify(rpcCallContext, never()).reply(any()) + verify(rpcCallContext).sendFailure(isA(classOf[SparkException])) // masterTracker.stop() // this throws an exception - actorSystem.shutdown() + rpcEnv.shutdown() } } diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index b7532314ada0..47e3bf6e1ac4 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -92,7 +92,7 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet test("RangePartitioner for keys that are not Comparable (but with Ordering)") { // Row does not extend Comparable, but has an implicit Ordering defined. implicit object RowOrdering extends Ordering[Row] { - override def compare(x: Row, y: Row) = x.value - y.value + override def compare(x: Row, y: Row): Int = x.value - y.value } val rdd = sc.parallelize(1 to 4500).map(x => (Row(x), Row(x))) @@ -212,20 +212,24 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet val arrPairs: RDD[(Array[Int], Int)] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => (Array(x), x)) - assert(intercept[SparkException]{ arrs.distinct() }.getMessage.contains("array")) + def verify(testFun: => Unit): Unit = { + intercept[SparkException](testFun).getMessage.contains("array") + } + + verify(arrs.distinct()) // We can't catch all usages of arrays, since they might occur inside other collections: // assert(fails { arrPairs.distinct() }) - assert(intercept[SparkException]{ arrPairs.partitionBy(new HashPartitioner(2)) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.join(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.leftOuterJoin(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.rightOuterJoin(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.fullOuterJoin(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.groupByKey() }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.countByKey() }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.countByKeyApprox(1) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.cogroup(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array")) + verify(arrPairs.partitionBy(new HashPartitioner(2))) + verify(arrPairs.join(arrPairs)) + verify(arrPairs.leftOuterJoin(arrPairs)) + verify(arrPairs.rightOuterJoin(arrPairs)) + verify(arrPairs.fullOuterJoin(arrPairs)) + verify(arrPairs.groupByKey()) + verify(arrPairs.countByKey()) + verify(arrPairs.countByKeyApprox(1)) + verify(arrPairs.cogroup(arrPairs)) + verify(arrPairs.reduceByKeyLocally(_ + _)) + verify(arrPairs.reduceByKey(_ + _)) } test("zero-length partitions should be correctly handled") { diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala new file mode 100644 index 000000000000..93f46ef11c0e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.File + +import com.google.common.io.Files +import org.apache.spark.util.Utils +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +class SSLOptionsSuite extends FunSuite with BeforeAndAfterAll { + + test("test resolving property file as spark conf ") { + val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath + val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath + + val conf = new SparkConf + conf.set("spark.ssl.enabled", "true") + conf.set("spark.ssl.keyStore", keyStorePath) + conf.set("spark.ssl.keyStorePassword", "password") + conf.set("spark.ssl.keyPassword", "password") + conf.set("spark.ssl.trustStore", trustStorePath) + conf.set("spark.ssl.trustStorePassword", "password") + conf.set("spark.ssl.enabledAlgorithms", + "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") + conf.set("spark.ssl.protocol", "SSLv3") + + val opts = SSLOptions.parse(conf, "spark.ssl") + + assert(opts.enabled === true) + assert(opts.trustStore.isDefined === true) + assert(opts.trustStore.get.getName === "truststore") + assert(opts.trustStore.get.getAbsolutePath === trustStorePath) + assert(opts.keyStore.isDefined === true) + assert(opts.keyStore.get.getName === "keystore") + assert(opts.keyStore.get.getAbsolutePath === keyStorePath) + assert(opts.trustStorePassword === Some("password")) + assert(opts.keyStorePassword === Some("password")) + assert(opts.keyPassword === Some("password")) + assert(opts.protocol === Some("SSLv3")) + assert(opts.enabledAlgorithms === + Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) + } + + test("test resolving property with defaults specified ") { + val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath + val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath + + val conf = new SparkConf + conf.set("spark.ssl.enabled", "true") + conf.set("spark.ssl.keyStore", keyStorePath) + conf.set("spark.ssl.keyStorePassword", "password") + conf.set("spark.ssl.keyPassword", "password") + conf.set("spark.ssl.trustStore", trustStorePath) + conf.set("spark.ssl.trustStorePassword", "password") + conf.set("spark.ssl.enabledAlgorithms", + "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") + conf.set("spark.ssl.protocol", "SSLv3") + + val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, "spark.ui.ssl", defaults = Some(defaultOpts)) + + assert(opts.enabled === true) + assert(opts.trustStore.isDefined === true) + assert(opts.trustStore.get.getName === "truststore") + assert(opts.trustStore.get.getAbsolutePath === trustStorePath) + assert(opts.keyStore.isDefined === true) + assert(opts.keyStore.get.getName === "keystore") + assert(opts.keyStore.get.getAbsolutePath === keyStorePath) + assert(opts.trustStorePassword === Some("password")) + assert(opts.keyStorePassword === Some("password")) + assert(opts.keyPassword === Some("password")) + assert(opts.protocol === Some("SSLv3")) + assert(opts.enabledAlgorithms === + Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) + } + + test("test whether defaults can be overridden ") { + val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath + val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath + + val conf = new SparkConf + conf.set("spark.ssl.enabled", "true") + conf.set("spark.ui.ssl.enabled", "false") + conf.set("spark.ssl.keyStore", keyStorePath) + conf.set("spark.ssl.keyStorePassword", "password") + conf.set("spark.ui.ssl.keyStorePassword", "12345") + conf.set("spark.ssl.keyPassword", "password") + conf.set("spark.ssl.trustStore", trustStorePath) + conf.set("spark.ssl.trustStorePassword", "password") + conf.set("spark.ssl.enabledAlgorithms", + "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") + conf.set("spark.ui.ssl.enabledAlgorithms", "ABC, DEF") + conf.set("spark.ssl.protocol", "SSLv3") + + val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, "spark.ui.ssl", defaults = Some(defaultOpts)) + + assert(opts.enabled === false) + assert(opts.trustStore.isDefined === true) + assert(opts.trustStore.get.getName === "truststore") + assert(opts.trustStore.get.getAbsolutePath === trustStorePath) + assert(opts.keyStore.isDefined === true) + assert(opts.keyStore.get.getName === "keystore") + assert(opts.keyStore.get.getAbsolutePath === keyStorePath) + assert(opts.trustStorePassword === Some("password")) + assert(opts.keyStorePassword === Some("12345")) + assert(opts.keyPassword === Some("password")) + assert(opts.protocol === Some("SSLv3")) + assert(opts.enabledAlgorithms === Set("ABC", "DEF")) + } + +} diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala new file mode 100644 index 000000000000..308b9ea17708 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.File + +object SSLSampleConfigs { + val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath + val untrustedKeyStorePath = new File( + this.getClass.getResource("/untrusted-keystore").toURI).getAbsolutePath + val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath + + def sparkSSLConfig(): SparkConf = { + val conf = new SparkConf(loadDefaults = false) + conf.set("spark.ssl.enabled", "true") + conf.set("spark.ssl.keyStore", keyStorePath) + conf.set("spark.ssl.keyStorePassword", "password") + conf.set("spark.ssl.keyPassword", "password") + conf.set("spark.ssl.trustStore", trustStorePath) + conf.set("spark.ssl.trustStorePassword", "password") + conf.set("spark.ssl.enabledAlgorithms", + "TLS_RSA_WITH_AES_128_CBC_SHA, SSL_RSA_WITH_DES_CBC_SHA") + conf.set("spark.ssl.protocol", "TLSv1") + conf + } + + def sparkSSLConfigUntrusted(): SparkConf = { + val conf = new SparkConf(loadDefaults = false) + conf.set("spark.ssl.enabled", "true") + conf.set("spark.ssl.keyStore", untrustedKeyStorePath) + conf.set("spark.ssl.keyStorePassword", "password") + conf.set("spark.ssl.keyPassword", "password") + conf.set("spark.ssl.trustStore", trustStorePath) + conf.set("spark.ssl.trustStorePassword", "password") + conf.set("spark.ssl.enabledAlgorithms", + "TLS_RSA_WITH_AES_128_CBC_SHA, SSL_RSA_WITH_DES_CBC_SHA") + conf.set("spark.ssl.protocol", "TLSv1") + conf + } + +} diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index fcca0867b807..62cb7649c028 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -17,10 +17,12 @@ package org.apache.spark -import scala.collection.mutable.ArrayBuffer +import java.io.File import org.scalatest.FunSuite +import org.apache.spark.util.Utils + class SecurityManagerSuite extends FunSuite { test("set security with conf") { @@ -125,6 +127,53 @@ class SecurityManagerSuite extends FunSuite { } + test("ssl on setup") { + val conf = SSLSampleConfigs.sparkSSLConfig() + + val securityManager = new SecurityManager(conf) + + assert(securityManager.fileServerSSLOptions.enabled === true) + assert(securityManager.akkaSSLOptions.enabled === true) + + assert(securityManager.sslSocketFactory.isDefined === true) + assert(securityManager.hostnameVerifier.isDefined === true) + + assert(securityManager.fileServerSSLOptions.trustStore.isDefined === true) + assert(securityManager.fileServerSSLOptions.trustStore.get.getName === "truststore") + assert(securityManager.fileServerSSLOptions.keyStore.isDefined === true) + assert(securityManager.fileServerSSLOptions.keyStore.get.getName === "keystore") + assert(securityManager.fileServerSSLOptions.trustStorePassword === Some("password")) + assert(securityManager.fileServerSSLOptions.keyStorePassword === Some("password")) + assert(securityManager.fileServerSSLOptions.keyPassword === Some("password")) + assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1")) + assert(securityManager.fileServerSSLOptions.enabledAlgorithms === + Set("TLS_RSA_WITH_AES_128_CBC_SHA", "SSL_RSA_WITH_DES_CBC_SHA")) + + assert(securityManager.akkaSSLOptions.trustStore.isDefined === true) + assert(securityManager.akkaSSLOptions.trustStore.get.getName === "truststore") + assert(securityManager.akkaSSLOptions.keyStore.isDefined === true) + assert(securityManager.akkaSSLOptions.keyStore.get.getName === "keystore") + assert(securityManager.akkaSSLOptions.trustStorePassword === Some("password")) + assert(securityManager.akkaSSLOptions.keyStorePassword === Some("password")) + assert(securityManager.akkaSSLOptions.keyPassword === Some("password")) + assert(securityManager.akkaSSLOptions.protocol === Some("TLSv1")) + assert(securityManager.akkaSSLOptions.enabledAlgorithms === + Set("TLS_RSA_WITH_AES_128_CBC_SHA", "SSL_RSA_WITH_DES_CBC_SHA")) + } + + test("ssl off setup") { + val file = File.createTempFile("SSLOptionsSuite", "conf", Utils.createTempDir()) + + System.setProperty("spark.ssl.configFile", file.getAbsolutePath) + val conf = new SparkConf() + + val securityManager = new SecurityManager(conf) + + assert(securityManager.fileServerSSLOptions.enabled === false) + assert(securityManager.akkaSSLOptions.enabled === false) + assert(securityManager.sslSocketFactory.isDefined === false) + assert(securityManager.hostnameVerifier.isDefined === false) + } } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index f57921b76831..d7180516029d 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -142,7 +142,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("shuffle on mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,512]", "test", conf) - def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) + def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) val results = new ShuffledRDD[Int, Int, Int](pairs, @@ -155,7 +155,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex // This is not in SortingSuite because of the local cluster setup. // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,512]", "test", conf) - def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) + def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data = Array(p(1, 11), p(3, 33), p(100, 100), p(2, 22)) val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) val results = new OrderedRDDFunctions[Int, Int, MutablePair[Int, Int]](pairs) @@ -169,7 +169,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("cogroup using mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,512]", "test", conf) - def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) + def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"), p(3, "3")) val pairs1: RDD[MutablePair[Int, Int]] = sc.parallelize(data1, 2) @@ -196,7 +196,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("subtract mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,512]", "test", conf) - def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) + def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1), p(3, 33)) val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22")) val pairs1: RDD[MutablePair[Int, Int]] = sc.parallelize(data1, 2) @@ -242,14 +242,14 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex shuffleSpillCompress <- Set(true, false); shuffleCompress <- Set(true, false) ) { - val conf = new SparkConf() + val myConf = conf.clone() .setAppName("test") .setMaster("local") .set("spark.shuffle.spill.compress", shuffleSpillCompress.toString) .set("spark.shuffle.compress", shuffleCompress.toString) .set("spark.shuffle.memoryFraction", "0.001") resetSparkContext() - sc = new SparkContext(conf) + sc = new SparkContext(myConf) try { sc.parallelize(0 until 100000).map(i => (i / 4, i)).groupByKey().collect() } catch { diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 790976a5ac30..272e6af0514e 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -17,9 +17,15 @@ package org.apache.spark +import java.util.concurrent.{TimeUnit, Executors} + +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.util.{Try, Random} + import org.scalatest.FunSuite import org.apache.spark.serializer.{KryoRegistrator, KryoSerializer} -import org.apache.spark.util.ResetSystemProperties +import org.apache.spark.util.{RpcUtils, ResetSystemProperties} import com.esotericsoftware.kryo.Kryo class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemProperties { @@ -123,6 +129,27 @@ class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(conf.get("spark.test.a.b.c") === "a.b.c") } + test("Thread safeness - SPARK-5425") { + import scala.collection.JavaConversions._ + val executor = Executors.newSingleThreadScheduledExecutor() + val sf = executor.scheduleAtFixedRate(new Runnable { + override def run(): Unit = + System.setProperty("spark.5425." + Random.nextInt(), Random.nextInt().toString) + }, 0, 1, TimeUnit.MILLISECONDS) + + try { + val t0 = System.currentTimeMillis() + while ((System.currentTimeMillis() - t0) < 1000) { + val conf = Try(new SparkConf(loadDefaults = true)) + assert(conf.isSuccess === true) + } + } finally { + executor.shutdownNow() + for (key <- System.getProperties.stringPropertyNames() if key.startsWith("spark.5425.")) + System.getProperties.remove(key) + } + } + test("register kryo classes through registerKryoClasses") { val conf = new SparkConf().set("spark.kryo.registrationRequired", "true") @@ -172,6 +199,51 @@ class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemPro serializer.newInstance().serialize(new StringBuffer()) } + test("deprecated configs") { + val conf = new SparkConf() + val newName = "spark.history.fs.update.interval" + + assert(!conf.contains(newName)) + + conf.set("spark.history.updateInterval", "1") + assert(conf.get(newName) === "1") + + conf.set("spark.history.fs.updateInterval", "2") + assert(conf.get(newName) === "2") + + conf.set("spark.history.fs.update.interval.seconds", "3") + assert(conf.get(newName) === "3") + + conf.set(newName, "4") + assert(conf.get(newName) === "4") + + val count = conf.getAll.filter { case (k, v) => k.startsWith("spark.history.") }.size + assert(count === 4) + + conf.set("spark.yarn.applicationMaster.waitTries", "42") + assert(conf.getTimeAsSeconds("spark.yarn.am.waitTime") === 420) + } + + test("akka deprecated configs") { + val conf = new SparkConf() + + assert(!conf.contains("spark.rpc.numRetries")) + assert(!conf.contains("spark.rpc.retry.wait")) + assert(!conf.contains("spark.rpc.askTimeout")) + assert(!conf.contains("spark.rpc.lookupTimeout")) + + conf.set("spark.akka.num.retries", "1") + assert(RpcUtils.numRetries(conf) === 1) + + conf.set("spark.akka.retry.wait", "2") + assert(RpcUtils.retryWaitMs(conf) === 2L) + + conf.set("spark.akka.askTimeout", "3") + assert(RpcUtils.askTimeout(conf) === (3 seconds)) + + conf.set("spark.akka.lookupTimeout", "4") + assert(RpcUtils.lookupTimeout(conf) === (4 seconds)) + } } class Class1 {} diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index 8ae4f243ec1a..bbed8ddc6baf 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -149,7 +149,7 @@ class SparkContextSchedulerCreationSuite } test("yarn-client") { - testYarn("yarn-client", "org.apache.spark.scheduler.cluster.YarnClientClusterScheduler") + testYarn("yarn-client", "org.apache.spark.scheduler.cluster.YarnScheduler") } def testMesos(master: String, expectedClass: Class[_], coarse: Boolean) { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 8b3c6871a7b3..728558a42478 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -17,9 +17,19 @@ package org.apache.spark +import java.io.File +import java.util.concurrent.TimeUnit + +import com.google.common.base.Charsets._ +import com.google.common.io.Files + import org.scalatest.FunSuite import org.apache.hadoop.io.BytesWritable +import org.apache.spark.util.Utils + +import scala.concurrent.Await +import scala.concurrent.duration.Duration class SparkContextSuite extends FunSuite with LocalSparkContext { @@ -57,6 +67,26 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { } } + test("Test getOrCreate") { + var sc2: SparkContext = null + SparkContext.clearActiveContext() + val conf = new SparkConf().setAppName("test").setMaster("local") + + sc = SparkContext.getOrCreate(conf) + + assert(sc.getConf.get("spark.app.name").equals("test")) + sc2 = SparkContext.getOrCreate(new SparkConf().setAppName("test2").setMaster("local")) + assert(sc2.getConf.get("spark.app.name").equals("test")) + assert(sc === sc2) + assert(sc eq sc2) + + // Try creating second context to confirm that it's still possible, if desired + sc2 = new SparkContext(new SparkConf().setAppName("test3").setMaster("local") + .set("spark.driver.allowMultipleContexts", "true")) + + sc2.stop() + } + test("BytesWritable implicit conversion is correct") { // Regression test for SPARK-3121 val bytesWritable = new BytesWritable() @@ -72,4 +102,115 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { val byteArray2 = converter.convert(bytesWritable) assert(byteArray2.length === 0) } + + test("addFile works") { + val dir = Utils.createTempDir() + + val file1 = File.createTempFile("someprefix1", "somesuffix1", dir) + val absolutePath1 = file1.getAbsolutePath + + val file2 = File.createTempFile("someprefix2", "somesuffix2", dir) + val relativePath = file2.getParent + "/../" + file2.getParentFile.getName + "/" + file2.getName + val absolutePath2 = file2.getAbsolutePath + + try { + Files.write("somewords1", file1, UTF_8) + Files.write("somewords2", file2, UTF_8) + val length1 = file1.length() + val length2 = file2.length() + + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + sc.addFile(file1.getAbsolutePath) + sc.addFile(relativePath) + sc.parallelize(Array(1), 1).map(x => { + val gotten1 = new File(SparkFiles.get(file1.getName)) + val gotten2 = new File(SparkFiles.get(file2.getName)) + if (!gotten1.exists()) { + throw new SparkException("file doesn't exist : " + absolutePath1) + } + if (!gotten2.exists()) { + throw new SparkException("file doesn't exist : " + absolutePath2) + } + + if (length1 != gotten1.length()) { + throw new SparkException( + s"file has different length $length1 than added file ${gotten1.length()} : " + + absolutePath1) + } + if (length2 != gotten2.length()) { + throw new SparkException( + s"file has different length $length2 than added file ${gotten2.length()} : " + + absolutePath2) + } + + if (absolutePath1 == gotten1.getAbsolutePath) { + throw new SparkException("file should have been copied :" + absolutePath1) + } + if (absolutePath2 == gotten2.getAbsolutePath) { + throw new SparkException("file should have been copied : " + absolutePath2) + } + x + }).count() + } finally { + sc.stop() + } + } + + test("addFile recursive works") { + val pluto = Utils.createTempDir() + val neptune = Utils.createTempDir(pluto.getAbsolutePath) + val saturn = Utils.createTempDir(neptune.getAbsolutePath) + val alien1 = File.createTempFile("alien", "1", neptune) + val alien2 = File.createTempFile("alien", "2", saturn) + + try { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + sc.addFile(neptune.getAbsolutePath, true) + sc.parallelize(Array(1), 1).map(x => { + val sep = File.separator + if (!new File(SparkFiles.get(neptune.getName + sep + alien1.getName)).exists()) { + throw new SparkException("can't access file under root added directory") + } + if (!new File(SparkFiles.get(neptune.getName + sep + saturn.getName + sep + alien2.getName)) + .exists()) { + throw new SparkException("can't access file in nested directory") + } + if (new File(SparkFiles.get(pluto.getName + sep + neptune.getName + sep + alien1.getName)) + .exists()) { + throw new SparkException("file exists that shouldn't") + } + x + }).count() + } finally { + sc.stop() + } + } + + test("addFile recursive can't add directories by default") { + val dir = Utils.createTempDir() + + try { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + intercept[SparkException] { + sc.addFile(dir.getAbsolutePath) + } + } finally { + sc.stop() + } + } + + test("Cancelling job group should not cause SparkContext to shutdown (SPARK-6414)") { + try { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + val future = sc.parallelize(Seq(0)).foreachAsync(_ => {Thread.sleep(1000L)}) + sc.cancelJobGroup("nonExistGroupId") + Await.ready(future, Duration(2, TimeUnit.SECONDS)) + + // In SPARK-6414, sc.cancelJobGroup will cause NullPointerException and cause + // SparkContext to shutdown, so the following assertion will fail. + assert(sc.parallelize(1 to 10).count() == 10L) + } finally { + sc.stop() + } + } } diff --git a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala index 41d6ea29d5b0..084eb237d70d 100644 --- a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala @@ -82,7 +82,8 @@ class StatusTrackerSuite extends FunSuite with Matchers with LocalSparkContext { secondJobFuture.jobIds.head } eventually(timeout(10 seconds)) { - sc.statusTracker.getJobIdsForGroup("my-job-group").toSet should be (Set(firstJobId, secondJobId)) + sc.statusTracker.getJobIdsForGroup("my-job-group").toSet should be ( + Set(firstJobId, secondJobId)) } } -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala index 7b866f08a0e9..c63d834f9048 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala @@ -23,11 +23,22 @@ import org.scalatest.FunSuite class PythonRDDSuite extends FunSuite { - test("Writing large strings to the worker") { - val input: List[String] = List("a"*100000) - val buffer = new DataOutputStream(new ByteArrayOutputStream) - PythonRDD.writeIteratorToStream(input.iterator, buffer) - } + test("Writing large strings to the worker") { + val input: List[String] = List("a"*100000) + val buffer = new DataOutputStream(new ByteArrayOutputStream) + PythonRDD.writeIteratorToStream(input.iterator, buffer) + } + test("Handle nulls gracefully") { + val buffer = new DataOutputStream(new ByteArrayOutputStream) + // Should not have NPE when write an Iterator with null in it + // The correctness will be tested in Python + PythonRDD.writeIteratorToStream(Iterator("a", null), buffer) + PythonRDD.writeIteratorToStream(Iterator(null, "a"), buffer) + PythonRDD.writeIteratorToStream(Iterator("a".getBytes, null), buffer) + PythonRDD.writeIteratorToStream(Iterator(null, "a".getBytes), buffer) + PythonRDD.writeIteratorToStream(Iterator((null, null), ("a", null), (null, "b")), buffer) + PythonRDD.writeIteratorToStream( + Iterator((null, null), ("a".getBytes, null), (null, "b".getBytes)), buffer) + } } - diff --git a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala b/core/src/test/scala/org/apache/spark/api/python/SerDeUtilSuite.scala similarity index 55% rename from project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala rename to core/src/test/scala/org/apache/spark/api/python/SerDeUtilSuite.scala index 3d43c3529955..f8c39326145e 100644 --- a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala +++ b/core/src/test/scala/org/apache/spark/api/python/SerDeUtilSuite.scala @@ -15,25 +15,24 @@ * limitations under the License. */ +package org.apache.spark.api.python -package org.apache.spark.scalastyle +import org.scalatest.FunSuite -import java.util.regex.Pattern +import org.apache.spark.SharedSparkContext -import org.scalastyle.{PositionError, ScalariformChecker, ScalastyleError} +class SerDeUtilSuite extends FunSuite with SharedSparkContext { -import scalariform.lexer.Token -import scalariform.parser.CompilationUnit - -class NonASCIICharacterChecker extends ScalariformChecker { - val errorKey: String = "non.ascii.character.disallowed" - - override def verify(ast: CompilationUnit): List[ScalastyleError] = { - ast.tokens.filter(hasNonAsciiChars).map(x => PositionError(x.offset)).toList + test("Converting an empty pair RDD to python does not throw an exception (SPARK-5441)") { + val emptyRdd = sc.makeRDD(Seq[(Any, Any)]()) + SerDeUtil.pairRDDToPython(emptyRdd, 10) } - private def hasNonAsciiChars(x: Token) = - x.rawText.trim.nonEmpty && !Pattern.compile( """\p{ASCII}+""", Pattern.DOTALL) - .matcher(x.text.trim).matches() - + test("Converting an empty python RDD to pair RDD does not throw an exception (SPARK-5441)") { + val emptyRdd = sc.makeRDD(Seq[(Any, Any)]()) + val javaRdd = emptyRdd.toJavaRDD() + val pythonRdd = SerDeUtil.javaToPython(javaRdd) + SerDeUtil.pythonToPairRDD(pythonRdd, false) + } } + diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index b0a70f012f1f..c8fdfa693912 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -33,7 +33,7 @@ class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { val broadcast = rdd.context.broadcast(list) val bid = broadcast.id - def doSomething() = { + def doSomething(): Set[(Int, Boolean)] = { rdd.map { x => val bm = SparkEnv.get.blockManager // Check if broadcast block was fetched @@ -170,6 +170,15 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { testPackage.runCallSiteTest(sc) } + test("Broadcast variables cannot be created after SparkContext is stopped (SPARK-5065)") { + sc = new SparkContext("local", "test") + sc.stop() + val thrown = intercept[IllegalStateException] { + sc.broadcast(Seq(1, 2, 3)) + } + assert(thrown.getMessage.toLowerCase.contains("stopped")) + } + /** * Verify the persistence of state associated with an HttpBroadcast in either local mode or * local-cluster mode (when distributed = true). @@ -349,8 +358,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { package object testPackage extends Assertions { def runCallSiteTest(sc: SparkContext) { - val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) - val broadcast = sc.broadcast(rdd) + val broadcast = sc.broadcast(Array(1, 2, 3, 4)) broadcast.destroy() val thrown = intercept[SparkException] { broadcast.value } assert(thrown.getMessage.contains("BroadcastSuite.scala")) diff --git a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala index d2dae34be7bf..745f9eeee753 100644 --- a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala @@ -23,6 +23,7 @@ import org.scalatest.Matchers class ClientSuite extends FunSuite with Matchers { test("correctly validates driver jar URL's") { ClientArguments.isValidJarUrl("http://someHost:8080/foo.jar") should be (true) + ClientArguments.isValidJarUrl("https://someHost:8080/foo.jar") should be (true) // file scheme with authority and path is valid. ClientArguments.isValidJarUrl("file://somehost/path/to/a/jarFile.jar") should be (true) @@ -45,5 +46,4 @@ class ClientSuite extends FunSuite with Matchers { // Invalid syntax. ClientArguments.isValidJarUrl("hdfs:") should be (false) } - } diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index aa65f7e8915e..b58d62567afe 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -28,7 +28,7 @@ import org.scalatest.FunSuite import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, RecoveryState, WorkerInfo} import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} class JsonProtocolSuite extends FunSuite { @@ -68,7 +68,8 @@ class JsonProtocolSuite extends FunSuite { val completedApps = Array[ApplicationInfo]() val activeDrivers = Array(createDriverInfo()) val completedDrivers = Array(createDriverInfo()) - val stateResponse = new MasterStateResponse("host", 8080, workers, activeApps, completedApps, + val stateResponse = new MasterStateResponse( + "host", 8080, None, workers, activeApps, completedApps, activeDrivers, completedDrivers, RecoveryState.ALIVE) val output = JsonProtocol.writeMasterState(stateResponse) assertValidJson(output) @@ -99,13 +100,13 @@ class JsonProtocolSuite extends FunSuite { appInfo } - def createDriverCommand() = new Command( + def createDriverCommand(): Command = new Command( "org.apache.spark.FakeClass", Seq("some arg --and-some options -g foo"), Map(("K1", "V1"), ("K2", "V2")), Seq("cp1", "cp2"), Seq("lp1", "lp2"), Seq("-Dfoo") ) - def createDriverDesc() = new DriverDescription("hdfs://some-dir/some.jar", 100, 3, - false, createDriverCommand()) + def createDriverDesc(): DriverDescription = + new DriverDescription("hdfs://some-dir/some.jar", 100, 3, false, createDriverCommand()) def createDriverInfo(): DriverInfo = new DriverInfo(3, "driver-3", createDriverDesc(), new Date()) @@ -117,14 +118,15 @@ class JsonProtocolSuite extends FunSuite { } def createExecutorRunner(): ExecutorRunner = { - new ExecutorRunner("appId", 123, createAppDesc(), 4, 1234, null, "workerId", "host", - new File("sparkHome"), new File("workDir"), "akka://worker", + new ExecutorRunner("appId", 123, createAppDesc(), 4, 1234, null, "workerId", "host", 123, + "publicAddress", new File("sparkHome"), new File("workDir"), "akka://worker", new SparkConf, Seq("localDir"), ExecutorState.RUNNING) } def createDriverRunner(): DriverRunner = { - new DriverRunner(new SparkConf(), "driverId", new File("workDir"), new File("sparkHome"), - createDriverDesc(), null, "akka://worker") + val conf = new SparkConf() + new DriverRunner(conf, "driverId", new File("workDir"), new File("sparkHome"), + createDriverDesc(), null, "akka://worker", new SecurityManager(conf)) } def assertValidJson(json: JValue) { diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala new file mode 100644 index 000000000000..c93d16f8a158 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.net.URL + +import scala.collection.JavaConversions._ +import scala.collection.mutable +import scala.io.Source + +import org.scalatest.FunSuite + +import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.scheduler.{SparkListenerExecutorAdded, SparkListener} +import org.apache.spark.{SparkConf, SparkContext, LocalSparkContext} + +class LogUrlsStandaloneSuite extends FunSuite with LocalSparkContext { + + /** Length of time to wait while draining listener events. */ + private val WAIT_TIMEOUT_MILLIS = 10000 + + test("verify that correct log urls get propagated from workers") { + sc = new SparkContext("local-cluster[2,1,512]", "test") + + val listener = new SaveExecutorInfo + sc.addSparkListener(listener) + + // Trigger a job so that executors get added + sc.parallelize(1 to 100, 4).map(_.toString).count() + + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + listener.addedExecutorInfos.values.foreach { info => + assert(info.logUrlMap.nonEmpty) + // Browse to each URL to check that it's valid + info.logUrlMap.foreach { case (logType, logUrl) => + val html = Source.fromURL(logUrl).mkString + assert(html.contains(s"$logType log page")) + } + } + } + + test("verify that log urls reflect SPARK_PUBLIC_DNS (SPARK-6175)") { + val SPARK_PUBLIC_DNS = "public_dns" + class MySparkConf extends SparkConf(false) { + override def getenv(name: String): String = { + if (name == "SPARK_PUBLIC_DNS") SPARK_PUBLIC_DNS + else super.getenv(name) + } + + override def clone: SparkConf = { + new MySparkConf().setAll(getAll) + } + } + val conf = new MySparkConf().set( + "spark.extraListeners", classOf[SaveExecutorInfo].getName) + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + + // Trigger a job so that executors get added + sc.parallelize(1 to 100, 4).map(_.toString).count() + + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + val listeners = sc.listenerBus.findListenersByClass[SaveExecutorInfo] + assert(listeners.size === 1) + val listener = listeners(0) + listener.addedExecutorInfos.values.foreach { info => + assert(info.logUrlMap.nonEmpty) + info.logUrlMap.values.foreach { logUrl => + assert(new URL(logUrl).getHost === SPARK_PUBLIC_DNS) + } + } + } +} + +private[spark] class SaveExecutorInfo extends SparkListener { + val addedExecutorInfos = mutable.Map[String, ExecutorInfo]() + + override def onExecutorAdded(executor: SparkListenerExecutorAdded) { + addedExecutorInfos(executor.executorId) = executor.executorInfo + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 065b7534cece..4561e5b8e966 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -21,25 +21,30 @@ import java.io._ import scala.collection.mutable.ArrayBuffer +import com.google.common.base.Charsets.UTF_8 +import com.google.common.io.ByteStreams +import org.scalatest.FunSuite +import org.scalatest.Matchers +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ + import org.apache.spark._ import org.apache.spark.deploy.SparkSubmit._ import org.apache.spark.util.{ResetSystemProperties, Utils} -import org.scalatest.FunSuite -import org.scalatest.Matchers // Note: this suite mixes in ResetSystemProperties because SparkSubmit.main() sets a bunch // of properties that neeed to be cleared after tests. -class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties { +class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties with Timeouts { def beforeAll() { System.setProperty("spark.testing", "true") } - val noOpOutputStream = new OutputStream { + private val noOpOutputStream = new OutputStream { def write(b: Int) = {} } /** Simple PrintStream that reads data into a buffer */ - class BufferPrintStream extends PrintStream(noOpOutputStream) { + private class BufferPrintStream extends PrintStream(noOpOutputStream) { var lineBuffer = ArrayBuffer[String]() override def println(line: String) { lineBuffer += line @@ -47,7 +52,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties } /** Returns true if the script exits and the given search string is printed. */ - def testPrematureExit(input: Array[String], searchString: String) = { + private def testPrematureExit(input: Array[String], searchString: String) = { val printStream = new BufferPrintStream() SparkSubmit.printStream = printStream @@ -138,7 +143,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) + val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") childArgsStr should include ("--class org.SomeClass") childArgsStr should include ("--executor-memory 5g") @@ -177,7 +182,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) + val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (4) @@ -198,6 +203,18 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties } test("handles standalone cluster mode") { + testStandaloneCluster(useRest = true) + } + + test("handles legacy standalone cluster mode") { + testStandaloneCluster(useRest = false) + } + + /** + * Test whether the launch environment is correctly set up in standalone cluster mode. + * @param useRest whether to use the REST submission gateway introduced in Spark 1.3 + */ + private def testStandaloneCluster(useRest: Boolean): Unit = { val clArgs = Seq( "--deploy-mode", "cluster", "--master", "spark://h:p", @@ -209,17 +226,26 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) + appArgs.useRest = useRest + val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") - childArgsStr should startWith ("--memory 4g --cores 5 --supervise") - childArgsStr should include regex ("launch spark://h:p .*thejar.jar org.SomeClass arg1 arg2") - mainClass should be ("org.apache.spark.deploy.Client") - classpath should have size (0) - sysProps should have size (5) + if (useRest) { + childArgsStr should endWith ("thejar.jar org.SomeClass arg1 arg2") + mainClass should be ("org.apache.spark.deploy.rest.StandaloneRestClient") + } else { + childArgsStr should startWith ("--supervise --memory 4g --cores 5") + childArgsStr should include regex "launch spark://h:p .*thejar.jar org.SomeClass arg1 arg2" + mainClass should be ("org.apache.spark.deploy.Client") + } + classpath should have size 0 + sysProps should have size 8 sysProps.keys should contain ("SPARK_SUBMIT") sysProps.keys should contain ("spark.master") sysProps.keys should contain ("spark.app.name") sysProps.keys should contain ("spark.jars") + sysProps.keys should contain ("spark.driver.memory") + sysProps.keys should contain ("spark.driver.cores") + sysProps.keys should contain ("spark.driver.supervise") sysProps.keys should contain ("spark.shuffle.spill") sysProps("spark.shuffle.spill") should be ("false") } @@ -236,7 +262,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) + val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (1) @@ -258,7 +284,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) + val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (1) @@ -278,7 +304,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, sysProps, mainClass) = createLaunchEnv(appArgs) + val (_, _, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) sysProps("spark.executor.memory") should be ("5g") sysProps("spark.master") should be ("yarn-cluster") mainClass should be ("org.apache.spark.deploy.yarn.Client") @@ -290,7 +316,6 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", "--master", "local", - "--conf", "spark.ui.enabled=false", unusedJar.toString) runSparkSubmit(args) } @@ -305,8 +330,21 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "--name", "testApp", "--master", "local-cluster[2,1,512]", "--jars", jarsString, + unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") + runSparkSubmit(args) + } + + test("includes jars passed in through --packages") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val packagesString = "com.databricks:spark-csv_2.10:0.1,com.databricks:spark-avro_2.10:0.1" + val args = Seq( + "--class", JarCreationTest.getClass.getName.stripSuffix("$"), + "--name", "testApp", + "--master", "local-cluster[2,1,512]", + "--packages", packagesString, "--conf", "spark.ui.enabled=false", - unusedJar.toString) + unusedJar.toString, + "com.databricks.spark.csv.DefaultSource", "com.databricks.spark.avro.DefaultSource") runSparkSubmit(args) } @@ -324,7 +362,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "--files", files, "thejar.jar") val appArgs = new SparkSubmitArguments(clArgs) - val sysProps = SparkSubmit.createLaunchEnv(appArgs)._3 + val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3 appArgs.jars should be (Utils.resolveURIs(jars)) appArgs.files should be (Utils.resolveURIs(files)) sysProps("spark.jars") should be (Utils.resolveURIs(jars + ",thejar.jar")) @@ -339,7 +377,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar" ) val appArgs2 = new SparkSubmitArguments(clArgs2) - val sysProps2 = SparkSubmit.createLaunchEnv(appArgs2)._3 + val sysProps2 = SparkSubmit.prepareSubmitEnvironment(appArgs2)._3 appArgs2.files should be (Utils.resolveURIs(files)) appArgs2.archives should be (Utils.resolveURIs(archives)) sysProps2("spark.yarn.dist.files") should be (Utils.resolveURIs(files)) @@ -352,7 +390,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "mister.py" ) val appArgs3 = new SparkSubmitArguments(clArgs3) - val sysProps3 = SparkSubmit.createLaunchEnv(appArgs3)._3 + val sysProps3 = SparkSubmit.prepareSubmitEnvironment(appArgs3)._3 appArgs3.pyFiles should be (Utils.resolveURIs(pyFiles)) sysProps3("spark.submit.pyFiles") should be ( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) @@ -364,8 +402,10 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties val archives = "file:/archive1,archive2" // spark.yarn.dist.archives val pyFiles = "py-file1,py-file2" // spark.submit.pyFiles + val tmpDir = Utils.createTempDir() + // Test jars and files - val f1 = File.createTempFile("test-submit-jars-files", "") + val f1 = File.createTempFile("test-submit-jars-files", "", tmpDir) val writer1 = new PrintWriter(f1) writer1.println("spark.jars " + jars) writer1.println("spark.files " + files) @@ -377,12 +417,12 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar" ) val appArgs = new SparkSubmitArguments(clArgs) - val sysProps = SparkSubmit.createLaunchEnv(appArgs)._3 + val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3 sysProps("spark.jars") should be(Utils.resolveURIs(jars + ",thejar.jar")) sysProps("spark.files") should be(Utils.resolveURIs(files)) // Test files and archives (Yarn) - val f2 = File.createTempFile("test-submit-files-archives", "") + val f2 = File.createTempFile("test-submit-files-archives", "", tmpDir) val writer2 = new PrintWriter(f2) writer2.println("spark.yarn.dist.files " + files) writer2.println("spark.yarn.dist.archives " + archives) @@ -394,12 +434,12 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar" ) val appArgs2 = new SparkSubmitArguments(clArgs2) - val sysProps2 = SparkSubmit.createLaunchEnv(appArgs2)._3 + val sysProps2 = SparkSubmit.prepareSubmitEnvironment(appArgs2)._3 sysProps2("spark.yarn.dist.files") should be(Utils.resolveURIs(files)) sysProps2("spark.yarn.dist.archives") should be(Utils.resolveURIs(archives)) // Test python files - val f3 = File.createTempFile("test-submit-python-files", "") + val f3 = File.createTempFile("test-submit-python-files", "", tmpDir) val writer3 = new PrintWriter(f3) writer3.println("spark.submit.pyFiles " + pyFiles) writer3.close() @@ -409,11 +449,24 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "mister.py" ) val appArgs3 = new SparkSubmitArguments(clArgs3) - val sysProps3 = SparkSubmit.createLaunchEnv(appArgs3)._3 + val sysProps3 = SparkSubmit.prepareSubmitEnvironment(appArgs3)._3 sysProps3("spark.submit.pyFiles") should be( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) } + test("user classpath first in driver") { + val systemJar = TestUtils.createJarWithFiles(Map("test.resource" -> "SYSTEM")) + val userJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER")) + val args = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripSuffix("$"), + "--name", "testApp", + "--master", "local", + "--conf", "spark.driver.extraClassPath=" + systemJar, + "--conf", "spark.driver.userClassPathFirst=true", + userJar.toString) + runSparkSubmit(args) + } + test("SPARK_CONF_DIR overrides spark-defaults.conf") { forConfDir(Map("spark.executor.memory" -> "2.3g")) { path => val unusedJar = TestUtils.createJarWithClasses(Seq.empty) @@ -425,20 +478,23 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties val appArgs = new SparkSubmitArguments(args, Map("SPARK_CONF_DIR" -> path)) assert(appArgs.propertiesFile != null) assert(appArgs.propertiesFile.startsWith(path)) - appArgs.executorMemory should be ("2.3g") + appArgs.executorMemory should be ("2.3g") } } // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. - def runSparkSubmit(args: Seq[String]): String = { + private def runSparkSubmit(args: Seq[String]): Unit = { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - Utils.executeAndGetOutput( + val process = Utils.executeCommand( Seq("./bin/spark-submit") ++ args, new File(sparkHome), Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) + failAfter(60 seconds) { process.waitFor() } + // Ensure we still kill the process in case it timed out + process.destroy() } - def forConfDir(defaults: Map[String, String]) (f: String => Unit) = { + private def forConfDir(defaults: Map[String, String]) (f: String => Unit) = { val tmpDir = Utils.createTempDir() val defaultsConf = new File(tmpDir.getAbsolutePath, "spark-defaults.conf") @@ -463,8 +519,8 @@ object JarCreationTest extends Logging { val result = sc.makeRDD(1 to 100, 10).mapPartitions { x => var exception: String = null try { - Class.forName("SparkSubmitClassA", true, Thread.currentThread().getContextClassLoader) - Class.forName("SparkSubmitClassA", true, Thread.currentThread().getContextClassLoader) + Class.forName(args(0), true, Thread.currentThread().getContextClassLoader) + Class.forName(args(1), true, Thread.currentThread().getContextClassLoader) } catch { case t: Throwable => exception = t + "\n" + t.getStackTraceString @@ -502,3 +558,15 @@ object SimpleApplicationTest { } } } + +object UserClasspathFirstTest { + def main(args: Array[String]) { + val ccl = Thread.currentThread().getContextClassLoader() + val resource = ccl.getResourceAsStream("test.resource") + val bytes = ByteStreams.toByteArray(resource) + val contents = new String(bytes, 0, bytes.length, UTF_8) + if (contents != "USER") { + throw new SparkException("Should have read user resource, but instead read: " + contents) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala new file mode 100644 index 000000000000..8bcca926097a --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io.{PrintStream, OutputStream, File} + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.ivy.core.module.descriptor.MDArtifact +import org.apache.ivy.plugins.resolver.IBiblioResolver + +class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { + + private val noOpOutputStream = new OutputStream { + def write(b: Int) = {} + } + + /** Simple PrintStream that reads data into a buffer */ + private class BufferPrintStream extends PrintStream(noOpOutputStream) { + var lineBuffer = ArrayBuffer[String]() + override def println(line: String) { + lineBuffer += line + } + } + + override def beforeAll() { + super.beforeAll() + // We don't want to write logs during testing + SparkSubmitUtils.printStream = new BufferPrintStream + } + + test("incorrect maven coordinate throws error") { + val coordinates = Seq("a:b: ", " :a:b", "a: :b", "a:b:", ":a:b", "a::b", "::", "a:b", "a") + for (coordinate <- coordinates) { + intercept[IllegalArgumentException] { + SparkSubmitUtils.extractMavenCoordinates(coordinate) + } + } + } + + test("create repo resolvers") { + val resolver1 = SparkSubmitUtils.createRepoResolvers(None) + // should have central and spark-packages by default + assert(resolver1.getResolvers.size() === 2) + assert(resolver1.getResolvers.get(0).asInstanceOf[IBiblioResolver].getName === "central") + assert(resolver1.getResolvers.get(1).asInstanceOf[IBiblioResolver].getName === "spark-packages") + + val repos = "a/1,b/2,c/3" + val resolver2 = SparkSubmitUtils.createRepoResolvers(Option(repos)) + assert(resolver2.getResolvers.size() === 5) + val expected = repos.split(",").map(r => s"$r/") + resolver2.getResolvers.toArray.zipWithIndex.foreach { case (resolver: IBiblioResolver, i) => + if (i == 0) { + assert(resolver.getName === "central") + } else if (i == 1) { + assert(resolver.getName === "spark-packages") + } else { + assert(resolver.getName === s"repo-${i - 1}") + assert(resolver.getRoot === expected(i - 2)) + } + } + } + + test("add dependencies works correctly") { + val md = SparkSubmitUtils.getModuleDescriptor + val artifacts = SparkSubmitUtils.extractMavenCoordinates("com.databricks:spark-csv_2.10:0.1," + + "com.databricks:spark-avro_2.10:0.1") + + SparkSubmitUtils.addDependenciesToIvy(md, artifacts, "default") + assert(md.getDependencies.length === 2) + } + + test("ivy path works correctly") { + val ivyPath = "dummy/ivy" + val md = SparkSubmitUtils.getModuleDescriptor + val artifacts = for (i <- 0 until 3) yield new MDArtifact(md, s"jar-$i", "jar", "jar") + var jPaths = SparkSubmitUtils.resolveDependencyPaths(artifacts.toArray, new File(ivyPath)) + for (i <- 0 until 3) { + val index = jPaths.indexOf(ivyPath) + assert(index >= 0) + jPaths = jPaths.substring(index + ivyPath.length) + } + // end to end + val jarPath = SparkSubmitUtils.resolveMavenCoordinates( + "com.databricks:spark-csv_2.10:0.1", None, Option(ivyPath), true) + assert(jarPath.indexOf(ivyPath) >= 0, "should use non-default ivy path") + } + + test("search for artifact at other repositories") { + val path = SparkSubmitUtils.resolveMavenCoordinates("com.agimatec:agimatec-validation:0.9.3", + Option("https://oss.sonatype.org/content/repositories/agimatec/"), None, true) + assert(path.indexOf("agimatec-validation") >= 0, "should find package. If it doesn't, check" + + "if package still exists. If it has been removed, replace the example in this test.") + } + + test("dependency not found throws RuntimeException") { + intercept[RuntimeException] { + SparkSubmitUtils.resolveMavenCoordinates("a:b:c", None, None, true) + } + } + + test("neglects Spark and Spark's dependencies") { + val components = Seq("bagel_", "catalyst_", "core_", "graphx_", "hive_", "mllib_", "repl_", + "sql_", "streaming_", "yarn_", "network-common_", "network-shuffle_", "network-yarn_") + + val coordinates = + components.map(comp => s"org.apache.spark:spark-${comp}2.10:1.2.0").mkString(",") + + ",org.apache.spark:spark-core_fake:1.2.0" + + val path = SparkSubmitUtils.resolveMavenCoordinates(coordinates, None, None, true) + assert(path === "", "should return empty path") + // Should not exclude the following dependency. Will throw an error, because it doesn't exist, + // but the fact that it is checking means that it wasn't excluded. + intercept[RuntimeException] { + SparkSubmitUtils.resolveMavenCoordinates(coordinates + + ",org.apache.spark:spark-streaming-kafka-assembly_2.10:1.2.0", None, None, true) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 8379883e065e..fcae603c7d18 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -17,18 +17,17 @@ package org.apache.spark.deploy.history -import java.io.{File, FileOutputStream, OutputStreamWriter} +import java.io.{BufferedOutputStream, File, FileOutputStream, OutputStreamWriter} +import java.net.URI import scala.io.Source -import com.google.common.io.Files import org.apache.hadoop.fs.Path import org.json4s.jackson.JsonMethods._ import org.scalatest.{BeforeAndAfter, FunSuite} import org.scalatest.Matchers import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io._ import org.apache.spark.scheduler._ import org.apache.spark.util.{JsonProtocol, Utils} @@ -37,54 +36,67 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers private var testDir: File = null - private var provider: FsHistoryProvider = null - before { testDir = Utils.createTempDir() - provider = new FsHistoryProvider(new SparkConf() - .set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) - .set("spark.history.fs.updateInterval", "0")) } after { Utils.deleteRecursively(testDir) } + /** Create a fake log file using the new log format used in Spark 1.3+ */ + private def newLogFile( + appId: String, + inProgress: Boolean, + codec: Option[String] = None): File = { + val ip = if (inProgress) EventLoggingListener.IN_PROGRESS else "" + val logUri = EventLoggingListener.getLogPath(testDir.toURI, appId) + val logPath = new URI(logUri).getPath + ip + new File(logPath) + } + test("Parse new and old application logs") { - val conf = new SparkConf() - .set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) - .set("spark.history.fs.updateInterval", "0") - val provider = new FsHistoryProvider(conf) + val provider = new FsHistoryProvider(createTestConf()) // Write a new-style application log. - val logFile1 = new File(testDir, "new1") - writeFile(logFile1, true, None, - SparkListenerApplicationStart("app1-1", None, 1L, "test"), - SparkListenerApplicationEnd(2L) + val newAppComplete = newLogFile("new1", inProgress = false) + writeFile(newAppComplete, true, None, + SparkListenerApplicationStart("new-app-complete", None, 1L, "test"), + SparkListenerApplicationEnd(5L) ) + // Write a new-style application log. + val newAppCompressedComplete = newLogFile("new1compressed", inProgress = false, Some("lzf")) + writeFile(newAppCompressedComplete, true, None, + SparkListenerApplicationStart("new-app-compressed-complete", None, 1L, "test"), + SparkListenerApplicationEnd(4L)) + // Write an unfinished app, new-style. - val logFile2 = new File(testDir, "new2" + EventLoggingListener.IN_PROGRESS) - writeFile(logFile2, true, None, - SparkListenerApplicationStart("app2-2", None, 1L, "test") + val newAppIncomplete = newLogFile("new2", inProgress = true) + writeFile(newAppIncomplete, true, None, + SparkListenerApplicationStart("new-app-incomplete", None, 1L, "test") ) // Write an old-style application log. - val oldLog = new File(testDir, "old1") - oldLog.mkdir() - createEmptyFile(new File(oldLog, provider.SPARK_VERSION_PREFIX + "1.0")) - writeFile(new File(oldLog, provider.LOG_PREFIX + "1"), false, None, - SparkListenerApplicationStart("app3", None, 2L, "test"), + val oldAppComplete = new File(testDir, "old1") + oldAppComplete.mkdir() + createEmptyFile(new File(oldAppComplete, provider.SPARK_VERSION_PREFIX + "1.0")) + writeFile(new File(oldAppComplete, provider.LOG_PREFIX + "1"), false, None, + SparkListenerApplicationStart("old-app-complete", None, 2L, "test"), SparkListenerApplicationEnd(3L) ) - createEmptyFile(new File(oldLog, provider.APPLICATION_COMPLETE)) + createEmptyFile(new File(oldAppComplete, provider.APPLICATION_COMPLETE)) + + // Check for logs so that we force the older unfinished app to be loaded, to make + // sure unfinished apps are also sorted correctly. + provider.checkForLogs() // Write an unfinished app, old-style. - val oldLog2 = new File(testDir, "old2") - oldLog2.mkdir() - createEmptyFile(new File(oldLog2, provider.SPARK_VERSION_PREFIX + "1.0")) - writeFile(new File(oldLog2, provider.LOG_PREFIX + "1"), false, None, - SparkListenerApplicationStart("app4", None, 2L, "test") + val oldAppIncomplete = new File(testDir, "old2") + oldAppIncomplete.mkdir() + createEmptyFile(new File(oldAppIncomplete, provider.SPARK_VERSION_PREFIX + "1.0")) + writeFile(new File(oldAppIncomplete, provider.LOG_PREFIX + "1"), false, None, + SparkListenerApplicationStart("old-app-incomplete", None, 2L, "test") ) // Force a reload of data from the log directory, and check that both logs are loaded. @@ -93,17 +105,19 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers val list = provider.getListing().toSeq list should not be (null) - list.size should be (4) - list.count(e => e.completed) should be (2) - - list(0) should be (ApplicationHistoryInfo(oldLog.getName(), "app3", 2L, 3L, - oldLog.lastModified(), "test", true)) - list(1) should be (ApplicationHistoryInfo(logFile1.getName(), "app1-1", 1L, 2L, - logFile1.lastModified(), "test", true)) - list(2) should be (ApplicationHistoryInfo(oldLog2.getName(), "app4", 2L, -1L, - oldLog2.lastModified(), "test", false)) - list(3) should be (ApplicationHistoryInfo(logFile2.getName(), "app2-2", 1L, -1L, - logFile2.lastModified(), "test", false)) + list.size should be (5) + list.count(_.completed) should be (3) + + list(0) should be (ApplicationHistoryInfo(newAppComplete.getName(), "new-app-complete", 1L, 5L, + newAppComplete.lastModified(), "test", true)) + list(1) should be (ApplicationHistoryInfo(newAppCompressedComplete.getName(), + "new-app-compressed-complete", 1L, 4L, newAppCompressedComplete.lastModified(), "test", true)) + list(2) should be (ApplicationHistoryInfo(oldAppComplete.getName(), "old-app-complete", 2L, 3L, + oldAppComplete.lastModified(), "test", true)) + list(3) should be (ApplicationHistoryInfo(oldAppIncomplete.getName(), "old-app-incomplete", 2L, + -1L, oldAppIncomplete.lastModified(), "test", false)) + list(4) should be (ApplicationHistoryInfo(newAppIncomplete.getName(), "new-app-incomplete", 1L, + -1L, newAppIncomplete.lastModified(), "test", false)) // Make sure the UI can be rendered. list.foreach { case info => @@ -113,6 +127,7 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers } test("Parse legacy logs with compression codec set") { + val provider = new FsHistoryProvider(createTestConf()) val testCodecs = List((classOf[LZFCompressionCodec].getName(), true), (classOf[SnappyCompressionCodec].getName(), true), ("invalid.codec", false)) @@ -130,7 +145,7 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers val logPath = new Path(logDir.getAbsolutePath()) try { - val (logInput, sparkVersion) = provider.openLegacyEventLog(logPath) + val logInput = provider.openLegacyEventLog(logPath) try { Source.fromInputStream(logInput).getLines().toSeq.size should be (2) } finally { @@ -144,22 +159,19 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers } test("SPARK-3697: ignore directories that cannot be read.") { - val logFile1 = new File(testDir, "new1") + val logFile1 = newLogFile("new1", inProgress = false) writeFile(logFile1, true, None, SparkListenerApplicationStart("app1-1", None, 1L, "test"), SparkListenerApplicationEnd(2L) ) - val logFile2 = new File(testDir, "new2") + val logFile2 = newLogFile("new2", inProgress = false) writeFile(logFile2, true, None, SparkListenerApplicationStart("app1-2", None, 1L, "test"), SparkListenerApplicationEnd(2L) ) logFile2.setReadable(false, false) - val conf = new SparkConf() - .set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) - .set("spark.history.fs.updateInterval", "0") - val provider = new FsHistoryProvider(conf) + val provider = new FsHistoryProvider(createTestConf()) provider.checkForLogs() val list = provider.getListing().toSeq @@ -167,16 +179,51 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers list.size should be (1) } + test("history file is renamed from inprogress to completed") { + val provider = new FsHistoryProvider(createTestConf()) + + val logFile1 = newLogFile("app1", inProgress = true) + writeFile(logFile1, true, None, + SparkListenerApplicationStart("app1", Some("app1"), 1L, "test"), + SparkListenerApplicationEnd(2L) + ) + provider.checkForLogs() + val appListBeforeRename = provider.getListing() + appListBeforeRename.size should be (1) + appListBeforeRename.head.logPath should endWith(EventLoggingListener.IN_PROGRESS) + + logFile1.renameTo(newLogFile("app1", inProgress = false)) + provider.checkForLogs() + val appListAfterRename = provider.getListing() + appListAfterRename.size should be (1) + appListAfterRename.head.logPath should not endWith(EventLoggingListener.IN_PROGRESS) + } + + test("SPARK-5582: empty log directory") { + val provider = new FsHistoryProvider(createTestConf()) + + val logFile1 = newLogFile("app1", inProgress = true) + writeFile(logFile1, true, None, + SparkListenerApplicationStart("app1", Some("app1"), 1L, "test"), + SparkListenerApplicationEnd(2L)) + + val oldLog = new File(testDir, "old1") + oldLog.mkdir() + + provider.checkForLogs() + val appListAfterRename = provider.getListing() + appListAfterRename.size should be (1) + } + private def writeFile(file: File, isNewFormat: Boolean, codec: Option[CompressionCodec], events: SparkListenerEvent*) = { - val out = - if (isNewFormat) { - EventLoggingListener.initEventLog(new FileOutputStream(file), codec) - } else { - val fileStream = new FileOutputStream(file) - codec.map(_.compressedOutputStream(fileStream)).getOrElse(fileStream) - } - val writer = new OutputStreamWriter(out, "UTF-8") + val fstream = new FileOutputStream(file) + val cstream = codec.map(_.compressedOutputStream(fstream)).getOrElse(fstream) + val bstream = new BufferedOutputStream(cstream) + if (isNewFormat) { + EventLoggingListener.initEventLog(new FileOutputStream(file)) + } + val writer = new OutputStreamWriter(bstream, "UTF-8") try { events.foreach(e => writer.write(compact(render(JsonProtocol.sparkEventToJson(e))) + "\n")) } finally { @@ -188,4 +235,8 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers new FileOutputStream(file).close() } + private def createTestConf(): SparkConf = { + new SparkConf().set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) + } + } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala new file mode 100644 index 000000000000..20de46fdab90 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.history + +import javax.servlet.http.HttpServletRequest + +import scala.collection.mutable + +import org.apache.hadoop.fs.Path +import org.mockito.Mockito.{when} +import org.scalatest.FunSuite +import org.scalatest.Matchers +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.ui.SparkUI + +class HistoryServerSuite extends FunSuite with Matchers with MockitoSugar { + + test("generate history page with relative links") { + val historyServer = mock[HistoryServer] + val request = mock[HttpServletRequest] + val ui = mock[SparkUI] + val link = "/history/app1" + val info = new ApplicationHistoryInfo("app1", "app1", 0, 2, 1, "xxx", true) + when(historyServer.getApplicationList()).thenReturn(Seq(info)) + when(ui.basePath).thenReturn(link) + when(historyServer.getProviderConfig()).thenReturn(Map[String, String]()) + val page = new HistoryPage(historyServer) + + // when + val response = page.render(request) + + // then + val links = response \\ "a" + val justHrefs = for { + l <- links + attrs <- l.attribute("href") + } yield (attrs.toString) + justHrefs should contain(link) + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 3d2335f9b363..34c74d87f0a6 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -20,30 +20,46 @@ package org.apache.spark.deploy.master import akka.actor.Address import org.scalatest.FunSuite -import org.apache.spark.SparkException +import org.apache.spark.{SSLOptions, SparkConf, SparkException} class MasterSuite extends FunSuite { test("toAkkaUrl") { - val akkaUrl = Master.toAkkaUrl("spark://1.2.3.4:1234") + val conf = new SparkConf(loadDefaults = false) + val akkaUrl = Master.toAkkaUrl("spark://1.2.3.4:1234", "akka.tcp") assert("akka.tcp://sparkMaster@1.2.3.4:1234/user/Master" === akkaUrl) } + test("toAkkaUrl with SSL") { + val conf = new SparkConf(loadDefaults = false) + val akkaUrl = Master.toAkkaUrl("spark://1.2.3.4:1234", "akka.ssl.tcp") + assert("akka.ssl.tcp://sparkMaster@1.2.3.4:1234/user/Master" === akkaUrl) + } + test("toAkkaUrl: a typo url") { + val conf = new SparkConf(loadDefaults = false) val e = intercept[SparkException] { - Master.toAkkaUrl("spark://1.2. 3.4:1234") + Master.toAkkaUrl("spark://1.2. 3.4:1234", "akka.tcp") } assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage) } test("toAkkaAddress") { - val address = Master.toAkkaAddress("spark://1.2.3.4:1234") + val conf = new SparkConf(loadDefaults = false) + val address = Master.toAkkaAddress("spark://1.2.3.4:1234", "akka.tcp") assert(Address("akka.tcp", "sparkMaster", "1.2.3.4", 1234) === address) } + test("toAkkaAddress with SSL") { + val conf = new SparkConf(loadDefaults = false) + val address = Master.toAkkaAddress("spark://1.2.3.4:1234", "akka.ssl.tcp") + assert(Address("akka.ssl.tcp", "sparkMaster", "1.2.3.4", 1234) === address) + } + test("toAkkaAddress: a typo url") { + val conf = new SparkConf(loadDefaults = false) val e = intercept[SparkException] { - Master.toAkkaAddress("spark://1.2. 3.4:1234") + Master.toAkkaAddress("spark://1.2. 3.4:1234", "akka.tcp") } assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage) } diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala new file mode 100644 index 000000000000..8e0997663638 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -0,0 +1,606 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.io.DataOutputStream +import java.net.{HttpURLConnection, URL} +import javax.servlet.http.HttpServletResponse + +import scala.collection.mutable + +import akka.actor.{Actor, ActorRef, ActorSystem, Props} +import com.google.common.base.Charsets +import org.scalatest.{BeforeAndAfterEach, FunSuite} +import org.json4s.JsonAST._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark._ +import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.deploy.DeployMessages._ +import org.apache.spark.deploy.{SparkSubmit, SparkSubmitArguments} +import org.apache.spark.deploy.master.DriverState._ + +/** + * Tests for the REST application submission protocol used in standalone cluster mode. + */ +class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { + private val client = new StandaloneRestClient + private var actorSystem: Option[ActorSystem] = None + private var server: Option[StandaloneRestServer] = None + + override def afterEach() { + actorSystem.foreach(_.shutdown()) + server.foreach(_.stop()) + } + + test("construct submit request") { + val appArgs = Array("one", "two", "three") + val sparkProperties = Map("spark.app.name" -> "pi") + val environmentVariables = Map("SPARK_ONE" -> "UN", "SPARK_TWO" -> "DEUX") + val request = client.constructSubmitRequest( + "my-app-resource", "my-main-class", appArgs, sparkProperties, environmentVariables) + assert(request.action === Utils.getFormattedClassName(request)) + assert(request.clientSparkVersion === SPARK_VERSION) + assert(request.appResource === "my-app-resource") + assert(request.mainClass === "my-main-class") + assert(request.appArgs === appArgs) + assert(request.sparkProperties === sparkProperties) + assert(request.environmentVariables === environmentVariables) + } + + test("create submission") { + val submittedDriverId = "my-driver-id" + val submitMessage = "your driver is submitted" + val masterUrl = startDummyServer(submitId = submittedDriverId, submitMessage = submitMessage) + val appArgs = Array("one", "two", "four") + val request = constructSubmitRequest(masterUrl, appArgs) + assert(request.appArgs === appArgs) + assert(request.sparkProperties("spark.master") === masterUrl) + val response = client.createSubmission(masterUrl, request) + val submitResponse = getSubmitResponse(response) + assert(submitResponse.action === Utils.getFormattedClassName(submitResponse)) + assert(submitResponse.serverSparkVersion === SPARK_VERSION) + assert(submitResponse.message === submitMessage) + assert(submitResponse.submissionId === submittedDriverId) + assert(submitResponse.success) + } + + test("create submission from main method") { + val submittedDriverId = "your-driver-id" + val submitMessage = "my driver is submitted" + val masterUrl = startDummyServer(submitId = submittedDriverId, submitMessage = submitMessage) + val conf = new SparkConf(loadDefaults = false) + conf.set("spark.master", masterUrl) + conf.set("spark.app.name", "dreamer") + val appArgs = Array("one", "two", "six") + // main method calls this + val response = StandaloneRestClient.run("app-resource", "main-class", appArgs, conf) + val submitResponse = getSubmitResponse(response) + assert(submitResponse.action === Utils.getFormattedClassName(submitResponse)) + assert(submitResponse.serverSparkVersion === SPARK_VERSION) + assert(submitResponse.message === submitMessage) + assert(submitResponse.submissionId === submittedDriverId) + assert(submitResponse.success) + } + + test("kill submission") { + val submissionId = "my-lyft-driver" + val killMessage = "your driver is killed" + val masterUrl = startDummyServer(killMessage = killMessage) + val response = client.killSubmission(masterUrl, submissionId) + val killResponse = getKillResponse(response) + assert(killResponse.action === Utils.getFormattedClassName(killResponse)) + assert(killResponse.serverSparkVersion === SPARK_VERSION) + assert(killResponse.message === killMessage) + assert(killResponse.submissionId === submissionId) + assert(killResponse.success) + } + + test("request submission status") { + val submissionId = "my-uber-driver" + val submissionState = KILLED + val submissionException = new Exception("there was an irresponsible mix of alcohol and cars") + val masterUrl = startDummyServer(state = submissionState, exception = Some(submissionException)) + val response = client.requestSubmissionStatus(masterUrl, submissionId) + val statusResponse = getStatusResponse(response) + assert(statusResponse.action === Utils.getFormattedClassName(statusResponse)) + assert(statusResponse.serverSparkVersion === SPARK_VERSION) + assert(statusResponse.message.contains(submissionException.getMessage)) + assert(statusResponse.submissionId === submissionId) + assert(statusResponse.driverState === submissionState.toString) + assert(statusResponse.success) + } + + test("create then kill") { + val masterUrl = startSmartServer() + val request = constructSubmitRequest(masterUrl) + val response1 = client.createSubmission(masterUrl, request) + val submitResponse = getSubmitResponse(response1) + assert(submitResponse.success) + assert(submitResponse.submissionId != null) + // kill submission that was just created + val submissionId = submitResponse.submissionId + val response2 = client.killSubmission(masterUrl, submissionId) + val killResponse = getKillResponse(response2) + assert(killResponse.success) + assert(killResponse.submissionId === submissionId) + } + + test("create then request status") { + val masterUrl = startSmartServer() + val request = constructSubmitRequest(masterUrl) + val response1 = client.createSubmission(masterUrl, request) + val submitResponse = getSubmitResponse(response1) + assert(submitResponse.success) + assert(submitResponse.submissionId != null) + // request status of submission that was just created + val submissionId = submitResponse.submissionId + val response2 = client.requestSubmissionStatus(masterUrl, submissionId) + val statusResponse = getStatusResponse(response2) + assert(statusResponse.success) + assert(statusResponse.submissionId === submissionId) + assert(statusResponse.driverState === RUNNING.toString) + } + + test("create then kill then request status") { + val masterUrl = startSmartServer() + val request = constructSubmitRequest(masterUrl) + val response1 = client.createSubmission(masterUrl, request) + val response2 = client.createSubmission(masterUrl, request) + val submitResponse1 = getSubmitResponse(response1) + val submitResponse2 = getSubmitResponse(response2) + assert(submitResponse1.success) + assert(submitResponse2.success) + assert(submitResponse1.submissionId != null) + assert(submitResponse2.submissionId != null) + val submissionId1 = submitResponse1.submissionId + val submissionId2 = submitResponse2.submissionId + // kill only submission 1, but not submission 2 + val response3 = client.killSubmission(masterUrl, submissionId1) + val killResponse = getKillResponse(response3) + assert(killResponse.success) + assert(killResponse.submissionId === submissionId1) + // request status for both submissions: 1 should be KILLED but 2 should be RUNNING still + val response4 = client.requestSubmissionStatus(masterUrl, submissionId1) + val response5 = client.requestSubmissionStatus(masterUrl, submissionId2) + val statusResponse1 = getStatusResponse(response4) + val statusResponse2 = getStatusResponse(response5) + assert(statusResponse1.submissionId === submissionId1) + assert(statusResponse2.submissionId === submissionId2) + assert(statusResponse1.driverState === KILLED.toString) + assert(statusResponse2.driverState === RUNNING.toString) + } + + test("kill or request status before create") { + val masterUrl = startSmartServer() + val doesNotExist = "does-not-exist" + // kill a non-existent submission + val response1 = client.killSubmission(masterUrl, doesNotExist) + val killResponse = getKillResponse(response1) + assert(!killResponse.success) + assert(killResponse.submissionId === doesNotExist) + // request status for a non-existent submission + val response2 = client.requestSubmissionStatus(masterUrl, doesNotExist) + val statusResponse = getStatusResponse(response2) + assert(!statusResponse.success) + assert(statusResponse.submissionId === doesNotExist) + } + + /* ---------------------------------------- * + | Aberrant client / server behavior | + * ---------------------------------------- */ + + test("good request paths") { + val masterUrl = startSmartServer() + val httpUrl = masterUrl.replace("spark://", "http://") + val v = StandaloneRestServer.PROTOCOL_VERSION + val json = constructSubmitRequest(masterUrl).toJson + val submitRequestPath = s"$httpUrl/$v/submissions/create" + val killRequestPath = s"$httpUrl/$v/submissions/kill" + val statusRequestPath = s"$httpUrl/$v/submissions/status" + val (response1, code1) = sendHttpRequestWithResponse(submitRequestPath, "POST", json) + val (response2, code2) = sendHttpRequestWithResponse(s"$killRequestPath/anything", "POST") + val (response3, code3) = sendHttpRequestWithResponse(s"$killRequestPath/any/thing", "POST") + val (response4, code4) = sendHttpRequestWithResponse(s"$statusRequestPath/anything", "GET") + val (response5, code5) = sendHttpRequestWithResponse(s"$statusRequestPath/any/thing", "GET") + // these should all succeed and the responses should be of the correct types + getSubmitResponse(response1) + val killResponse1 = getKillResponse(response2) + val killResponse2 = getKillResponse(response3) + val statusResponse1 = getStatusResponse(response4) + val statusResponse2 = getStatusResponse(response5) + assert(killResponse1.submissionId === "anything") + assert(killResponse2.submissionId === "any") + assert(statusResponse1.submissionId === "anything") + assert(statusResponse2.submissionId === "any") + assert(code1 === HttpServletResponse.SC_OK) + assert(code2 === HttpServletResponse.SC_OK) + assert(code3 === HttpServletResponse.SC_OK) + assert(code4 === HttpServletResponse.SC_OK) + assert(code5 === HttpServletResponse.SC_OK) + } + + test("good request paths, bad requests") { + val masterUrl = startSmartServer() + val httpUrl = masterUrl.replace("spark://", "http://") + val v = StandaloneRestServer.PROTOCOL_VERSION + val submitRequestPath = s"$httpUrl/$v/submissions/create" + val killRequestPath = s"$httpUrl/$v/submissions/kill" + val statusRequestPath = s"$httpUrl/$v/submissions/status" + val goodJson = constructSubmitRequest(masterUrl).toJson + val badJson1 = goodJson.replaceAll("action", "fraction") // invalid JSON + val badJson2 = goodJson.substring(goodJson.size / 2) // malformed JSON + val notJson = "\"hello, world\"" + val (response1, code1) = sendHttpRequestWithResponse(submitRequestPath, "POST") // missing JSON + val (response2, code2) = sendHttpRequestWithResponse(submitRequestPath, "POST", badJson1) + val (response3, code3) = sendHttpRequestWithResponse(submitRequestPath, "POST", badJson2) + val (response4, code4) = sendHttpRequestWithResponse(killRequestPath, "POST") // missing ID + val (response5, code5) = sendHttpRequestWithResponse(s"$killRequestPath/", "POST") + val (response6, code6) = sendHttpRequestWithResponse(statusRequestPath, "GET") // missing ID + val (response7, code7) = sendHttpRequestWithResponse(s"$statusRequestPath/", "GET") + val (response8, code8) = sendHttpRequestWithResponse(submitRequestPath, "POST", notJson) + // these should all fail as error responses + getErrorResponse(response1) + getErrorResponse(response2) + getErrorResponse(response3) + getErrorResponse(response4) + getErrorResponse(response5) + getErrorResponse(response6) + getErrorResponse(response7) + getErrorResponse(response8) + assert(code1 === HttpServletResponse.SC_BAD_REQUEST) + assert(code2 === HttpServletResponse.SC_BAD_REQUEST) + assert(code3 === HttpServletResponse.SC_BAD_REQUEST) + assert(code4 === HttpServletResponse.SC_BAD_REQUEST) + assert(code5 === HttpServletResponse.SC_BAD_REQUEST) + assert(code6 === HttpServletResponse.SC_BAD_REQUEST) + assert(code7 === HttpServletResponse.SC_BAD_REQUEST) + assert(code8 === HttpServletResponse.SC_BAD_REQUEST) + } + + test("bad request paths") { + val masterUrl = startSmartServer() + val httpUrl = masterUrl.replace("spark://", "http://") + val v = StandaloneRestServer.PROTOCOL_VERSION + val (response1, code1) = sendHttpRequestWithResponse(httpUrl, "GET") + val (response2, code2) = sendHttpRequestWithResponse(s"$httpUrl/", "GET") + val (response3, code3) = sendHttpRequestWithResponse(s"$httpUrl/$v", "GET") + val (response4, code4) = sendHttpRequestWithResponse(s"$httpUrl/$v/", "GET") + val (response5, code5) = sendHttpRequestWithResponse(s"$httpUrl/$v/submissions", "GET") + val (response6, code6) = sendHttpRequestWithResponse(s"$httpUrl/$v/submissions/", "GET") + val (response7, code7) = sendHttpRequestWithResponse(s"$httpUrl/$v/submissions/bad", "GET") + val (response8, code8) = sendHttpRequestWithResponse(s"$httpUrl/bad-version", "GET") + assert(code1 === HttpServletResponse.SC_BAD_REQUEST) + assert(code2 === HttpServletResponse.SC_BAD_REQUEST) + assert(code3 === HttpServletResponse.SC_BAD_REQUEST) + assert(code4 === HttpServletResponse.SC_BAD_REQUEST) + assert(code5 === HttpServletResponse.SC_BAD_REQUEST) + assert(code6 === HttpServletResponse.SC_BAD_REQUEST) + assert(code7 === HttpServletResponse.SC_BAD_REQUEST) + assert(code8 === StandaloneRestServer.SC_UNKNOWN_PROTOCOL_VERSION) + // all responses should be error responses + val errorResponse1 = getErrorResponse(response1) + val errorResponse2 = getErrorResponse(response2) + val errorResponse3 = getErrorResponse(response3) + val errorResponse4 = getErrorResponse(response4) + val errorResponse5 = getErrorResponse(response5) + val errorResponse6 = getErrorResponse(response6) + val errorResponse7 = getErrorResponse(response7) + val errorResponse8 = getErrorResponse(response8) + // only the incompatible version response should have server protocol version set + assert(errorResponse1.highestProtocolVersion === null) + assert(errorResponse2.highestProtocolVersion === null) + assert(errorResponse3.highestProtocolVersion === null) + assert(errorResponse4.highestProtocolVersion === null) + assert(errorResponse5.highestProtocolVersion === null) + assert(errorResponse6.highestProtocolVersion === null) + assert(errorResponse7.highestProtocolVersion === null) + assert(errorResponse8.highestProtocolVersion === StandaloneRestServer.PROTOCOL_VERSION) + } + + test("server returns unknown fields") { + val masterUrl = startSmartServer() + val httpUrl = masterUrl.replace("spark://", "http://") + val v = StandaloneRestServer.PROTOCOL_VERSION + val submitRequestPath = s"$httpUrl/$v/submissions/create" + val oldJson = constructSubmitRequest(masterUrl).toJson + val oldFields = parse(oldJson).asInstanceOf[JObject].obj + val newFields = oldFields ++ Seq( + JField("tomato", JString("not-a-fruit")), + JField("potato", JString("not-po-tah-to")) + ) + val newJson = pretty(render(JObject(newFields))) + // send two requests, one with the unknown fields and the other without + val (response1, code1) = sendHttpRequestWithResponse(submitRequestPath, "POST", oldJson) + val (response2, code2) = sendHttpRequestWithResponse(submitRequestPath, "POST", newJson) + val submitResponse1 = getSubmitResponse(response1) + val submitResponse2 = getSubmitResponse(response2) + assert(code1 === HttpServletResponse.SC_OK) + assert(code2 === HttpServletResponse.SC_OK) + // only the response to the modified request should have unknown fields set + assert(submitResponse1.unknownFields === null) + assert(submitResponse2.unknownFields === Array("tomato", "potato")) + } + + test("client handles faulty server") { + val masterUrl = startFaultyServer() + val httpUrl = masterUrl.replace("spark://", "http://") + val v = StandaloneRestServer.PROTOCOL_VERSION + val submitRequestPath = s"$httpUrl/$v/submissions/create" + val killRequestPath = s"$httpUrl/$v/submissions/kill/anything" + val statusRequestPath = s"$httpUrl/$v/submissions/status/anything" + val json = constructSubmitRequest(masterUrl).toJson + // server returns malformed response unwittingly + // client should throw an appropriate exception to indicate server failure + val conn1 = sendHttpRequest(submitRequestPath, "POST", json) + intercept[SubmitRestProtocolException] { client.readResponse(conn1) } + // server attempts to send invalid response, but fails internally on validation + // client should receive an error response as server is able to recover + val conn2 = sendHttpRequest(killRequestPath, "POST") + val response2 = client.readResponse(conn2) + getErrorResponse(response2) + assert(conn2.getResponseCode === HttpServletResponse.SC_INTERNAL_SERVER_ERROR) + // server explodes internally beyond recovery + // client should throw an appropriate exception to indicate server failure + val conn3 = sendHttpRequest(statusRequestPath, "GET") + intercept[SubmitRestProtocolException] { client.readResponse(conn3) } // empty response + assert(conn3.getResponseCode === HttpServletResponse.SC_INTERNAL_SERVER_ERROR) + } + + /* --------------------- * + | Helper methods | + * --------------------- */ + + /** Start a dummy server that responds to requests using the specified parameters. */ + private def startDummyServer( + submitId: String = "fake-driver-id", + submitMessage: String = "driver is submitted", + killMessage: String = "driver is killed", + state: DriverState = FINISHED, + exception: Option[Exception] = None): String = { + startServer(new DummyMaster(submitId, submitMessage, killMessage, state, exception)) + } + + /** Start a smarter dummy server that keeps track of submitted driver states. */ + private def startSmartServer(): String = { + startServer(new SmarterMaster) + } + + /** Start a dummy server that is faulty in many ways... */ + private def startFaultyServer(): String = { + startServer(new DummyMaster, faulty = true) + } + + /** + * Start a [[StandaloneRestServer]] that communicates with the given actor. + * If `faulty` is true, start an [[FaultyStandaloneRestServer]] instead. + * Return the master URL that corresponds to the address of this server. + */ + private def startServer(makeFakeMaster: => Actor, faulty: Boolean = false): String = { + val name = "test-standalone-rest-protocol" + val conf = new SparkConf + val localhost = Utils.localHostName() + val securityManager = new SecurityManager(conf) + val (_actorSystem, _) = AkkaUtils.createActorSystem(name, localhost, 0, conf, securityManager) + val fakeMasterRef = _actorSystem.actorOf(Props(makeFakeMaster)) + val _server = + if (faulty) { + new FaultyStandaloneRestServer(localhost, 0, fakeMasterRef, "spark://fake:7077", conf) + } else { + new StandaloneRestServer(localhost, 0, fakeMasterRef, "spark://fake:7077", conf) + } + val port = _server.start() + // set these to clean them up after every test + actorSystem = Some(_actorSystem) + server = Some(_server) + s"spark://$localhost:$port" + } + + /** Create a submit request with real parameters using Spark submit. */ + private def constructSubmitRequest( + masterUrl: String, + appArgs: Array[String] = Array.empty): CreateSubmissionRequest = { + val mainClass = "main-class-not-used" + val mainJar = "dummy-jar-not-used.jar" + val commandLineArgs = Array( + "--deploy-mode", "cluster", + "--master", masterUrl, + "--name", mainClass, + "--class", mainClass, + mainJar) ++ appArgs + val args = new SparkSubmitArguments(commandLineArgs) + val (_, _, sparkProperties, _) = SparkSubmit.prepareSubmitEnvironment(args) + client.constructSubmitRequest( + mainJar, mainClass, appArgs, sparkProperties.toMap, Map.empty) + } + + /** Return the response as a submit response, or fail with error otherwise. */ + private def getSubmitResponse(response: SubmitRestProtocolResponse): CreateSubmissionResponse = { + response match { + case s: CreateSubmissionResponse => s + case e: ErrorResponse => fail(s"Server returned error: ${e.message}") + case r => fail(s"Expected submit response. Actual: ${r.toJson}") + } + } + + /** Return the response as a kill response, or fail with error otherwise. */ + private def getKillResponse(response: SubmitRestProtocolResponse): KillSubmissionResponse = { + response match { + case k: KillSubmissionResponse => k + case e: ErrorResponse => fail(s"Server returned error: ${e.message}") + case r => fail(s"Expected kill response. Actual: ${r.toJson}") + } + } + + /** Return the response as a status response, or fail with error otherwise. */ + private def getStatusResponse(response: SubmitRestProtocolResponse): SubmissionStatusResponse = { + response match { + case s: SubmissionStatusResponse => s + case e: ErrorResponse => fail(s"Server returned error: ${e.message}") + case r => fail(s"Expected status response. Actual: ${r.toJson}") + } + } + + /** Return the response as an error response, or fail if the response was not an error. */ + private def getErrorResponse(response: SubmitRestProtocolResponse): ErrorResponse = { + response match { + case e: ErrorResponse => e + case r => fail(s"Expected error response. Actual: ${r.toJson}") + } + } + + /** + * Send an HTTP request to the given URL using the method and the body specified. + * Return the connection object. + */ + private def sendHttpRequest( + url: String, + method: String, + body: String = ""): HttpURLConnection = { + val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod(method) + if (body.nonEmpty) { + conn.setDoOutput(true) + val out = new DataOutputStream(conn.getOutputStream) + out.write(body.getBytes(Charsets.UTF_8)) + out.close() + } + conn + } + + /** + * Send an HTTP request to the given URL using the method and the body specified. + * Return a 2-tuple of the response message from the server and the response code. + */ + private def sendHttpRequestWithResponse( + url: String, + method: String, + body: String = ""): (SubmitRestProtocolResponse, Int) = { + val conn = sendHttpRequest(url, method, body) + (client.readResponse(conn), conn.getResponseCode) + } +} + +/** + * A mock standalone Master that responds with dummy messages. + * In all responses, the success parameter is always true. + */ +private class DummyMaster( + submitId: String = "fake-driver-id", + submitMessage: String = "submitted", + killMessage: String = "killed", + state: DriverState = FINISHED, + exception: Option[Exception] = None) + extends Actor { + + override def receive: PartialFunction[Any, Unit] = { + case RequestSubmitDriver(driverDesc) => + sender ! SubmitDriverResponse(success = true, Some(submitId), submitMessage) + case RequestKillDriver(driverId) => + sender ! KillDriverResponse(driverId, success = true, killMessage) + case RequestDriverStatus(driverId) => + sender ! DriverStatusResponse(found = true, Some(state), None, None, exception) + } +} + +/** + * A mock standalone Master that keeps track of drivers that have been submitted. + * + * If a driver is submitted, its state is immediately set to RUNNING. + * If an existing driver is killed, its state is immediately set to KILLED. + * If an existing driver's status is requested, its state is returned in the response. + * Submits are always successful while kills and status requests are successful only + * if the driver was submitted in the past. + */ +private class SmarterMaster extends Actor { + private var counter: Int = 0 + private val submittedDrivers = new mutable.HashMap[String, DriverState] + + override def receive: PartialFunction[Any, Unit] = { + case RequestSubmitDriver(driverDesc) => + val driverId = s"driver-$counter" + submittedDrivers(driverId) = RUNNING + counter += 1 + sender ! SubmitDriverResponse(success = true, Some(driverId), "submitted") + + case RequestKillDriver(driverId) => + val success = submittedDrivers.contains(driverId) + if (success) { + submittedDrivers(driverId) = KILLED + } + sender ! KillDriverResponse(driverId, success, "killed") + + case RequestDriverStatus(driverId) => + val found = submittedDrivers.contains(driverId) + val state = submittedDrivers.get(driverId) + sender ! DriverStatusResponse(found, state, None, None, None) + } +} + +/** + * A [[StandaloneRestServer]] that is faulty in many ways. + * + * When handling a submit request, the server returns a malformed JSON. + * When handling a kill request, the server returns an invalid JSON. + * When handling a status request, the server throws an internal exception. + * The purpose of this class is to test that client handles these cases gracefully. + */ +private class FaultyStandaloneRestServer( + host: String, + requestedPort: Int, + masterActor: ActorRef, + masterUrl: String, + masterConf: SparkConf) + extends StandaloneRestServer(host, requestedPort, masterActor, masterUrl, masterConf) { + + protected override val contextToServlet = Map[String, StandaloneRestServlet]( + s"$baseContext/create/*" -> new MalformedSubmitServlet, + s"$baseContext/kill/*" -> new InvalidKillServlet, + s"$baseContext/status/*" -> new ExplodingStatusServlet, + "/*" -> new ErrorServlet + ) + + /** A faulty servlet that produces malformed responses. */ + class MalformedSubmitServlet extends SubmitRequestServlet(masterActor, masterUrl, masterConf) { + protected override def sendResponse( + responseMessage: SubmitRestProtocolResponse, + responseServlet: HttpServletResponse): Unit = { + val badJson = responseMessage.toJson.drop(10).dropRight(20) + responseServlet.getWriter.write(badJson) + } + } + + /** A faulty servlet that produces invalid responses. */ + class InvalidKillServlet extends KillRequestServlet(masterActor, masterConf) { + protected override def handleKill(submissionId: String): KillSubmissionResponse = { + val k = super.handleKill(submissionId) + k.submissionId = null + k + } + } + + /** A faulty status servlet that explodes. */ + class ExplodingStatusServlet extends StatusRequestServlet(masterActor, masterConf) { + private def explode: Int = 1 / 0 + protected override def handleStatus(submissionId: String): SubmissionStatusResponse = { + val s = super.handleStatus(submissionId) + s.workerId = explode.toString + s + } + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala new file mode 100644 index 000000000000..61071ee17256 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -0,0 +1,325 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.lang.Boolean +import java.lang.Integer + +import org.json4s.jackson.JsonMethods._ +import org.scalatest.FunSuite + +import org.apache.spark.SparkConf + +/** + * Tests for the REST application submission protocol. + */ +class SubmitRestProtocolSuite extends FunSuite { + + test("validate") { + val request = new DummyRequest + intercept[SubmitRestProtocolException] { request.validate() } // missing everything + request.clientSparkVersion = "1.2.3" + intercept[SubmitRestProtocolException] { request.validate() } // missing name and age + request.name = "something" + intercept[SubmitRestProtocolException] { request.validate() } // missing only age + request.age = 2 + intercept[SubmitRestProtocolException] { request.validate() } // age too low + request.age = 10 + request.validate() // everything is set properly + request.clientSparkVersion = null + intercept[SubmitRestProtocolException] { request.validate() } // missing only Spark version + request.clientSparkVersion = "1.2.3" + request.name = null + intercept[SubmitRestProtocolException] { request.validate() } // missing only name + request.message = "not-setting-name" + intercept[SubmitRestProtocolException] { request.validate() } // still missing name + } + + test("request to and from JSON") { + val request = new DummyRequest + intercept[SubmitRestProtocolException] { request.toJson } // implicit validation + request.clientSparkVersion = "1.2.3" + request.active = true + request.age = 25 + request.name = "jung" + val json = request.toJson + assertJsonEquals(json, dummyRequestJson) + val newRequest = SubmitRestProtocolMessage.fromJson(json, classOf[DummyRequest]) + assert(newRequest.clientSparkVersion === "1.2.3") + assert(newRequest.clientSparkVersion === "1.2.3") + assert(newRequest.active) + assert(newRequest.age === 25) + assert(newRequest.name === "jung") + assert(newRequest.message === null) + } + + test("response to and from JSON") { + val response = new DummyResponse + response.serverSparkVersion = "3.3.4" + response.success = true + val json = response.toJson + assertJsonEquals(json, dummyResponseJson) + val newResponse = SubmitRestProtocolMessage.fromJson(json, classOf[DummyResponse]) + assert(newResponse.serverSparkVersion === "3.3.4") + assert(newResponse.serverSparkVersion === "3.3.4") + assert(newResponse.success) + assert(newResponse.message === null) + } + + test("CreateSubmissionRequest") { + val message = new CreateSubmissionRequest + intercept[SubmitRestProtocolException] { message.validate() } + message.clientSparkVersion = "1.2.3" + message.appResource = "honey-walnut-cherry.jar" + message.mainClass = "org.apache.spark.examples.SparkPie" + val conf = new SparkConf(false) + conf.set("spark.app.name", "SparkPie") + message.sparkProperties = conf.getAll.toMap + message.validate() + // optional fields + conf.set("spark.jars", "mayonnaise.jar,ketchup.jar") + conf.set("spark.files", "fireball.png") + conf.set("spark.driver.memory", "512m") + conf.set("spark.driver.cores", "180") + conf.set("spark.driver.extraJavaOptions", " -Dslices=5 -Dcolor=mostly_red") + conf.set("spark.driver.extraClassPath", "food-coloring.jar") + conf.set("spark.driver.extraLibraryPath", "pickle.jar") + conf.set("spark.driver.supervise", "false") + conf.set("spark.executor.memory", "256m") + conf.set("spark.cores.max", "10000") + message.sparkProperties = conf.getAll.toMap + message.appArgs = Array("two slices", "a hint of cinnamon") + message.environmentVariables = Map("PATH" -> "/dev/null") + message.validate() + // bad fields + var badConf = conf.clone().set("spark.driver.cores", "one hundred feet") + message.sparkProperties = badConf.getAll.toMap + intercept[SubmitRestProtocolException] { message.validate() } + badConf = conf.clone().set("spark.driver.supervise", "nope, never") + message.sparkProperties = badConf.getAll.toMap + intercept[SubmitRestProtocolException] { message.validate() } + badConf = conf.clone().set("spark.cores.max", "two men") + message.sparkProperties = badConf.getAll.toMap + intercept[SubmitRestProtocolException] { message.validate() } + message.sparkProperties = conf.getAll.toMap + // test JSON + val json = message.toJson + assertJsonEquals(json, submitDriverRequestJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[CreateSubmissionRequest]) + assert(newMessage.clientSparkVersion === "1.2.3") + assert(newMessage.appResource === "honey-walnut-cherry.jar") + assert(newMessage.mainClass === "org.apache.spark.examples.SparkPie") + assert(newMessage.sparkProperties("spark.app.name") === "SparkPie") + assert(newMessage.sparkProperties("spark.jars") === "mayonnaise.jar,ketchup.jar") + assert(newMessage.sparkProperties("spark.files") === "fireball.png") + assert(newMessage.sparkProperties("spark.driver.memory") === "512m") + assert(newMessage.sparkProperties("spark.driver.cores") === "180") + assert(newMessage.sparkProperties("spark.driver.extraJavaOptions") === + " -Dslices=5 -Dcolor=mostly_red") + assert(newMessage.sparkProperties("spark.driver.extraClassPath") === "food-coloring.jar") + assert(newMessage.sparkProperties("spark.driver.extraLibraryPath") === "pickle.jar") + assert(newMessage.sparkProperties("spark.driver.supervise") === "false") + assert(newMessage.sparkProperties("spark.executor.memory") === "256m") + assert(newMessage.sparkProperties("spark.cores.max") === "10000") + assert(newMessage.appArgs === message.appArgs) + assert(newMessage.sparkProperties === message.sparkProperties) + assert(newMessage.environmentVariables === message.environmentVariables) + } + + test("CreateSubmissionResponse") { + val message = new CreateSubmissionResponse + intercept[SubmitRestProtocolException] { message.validate() } + message.serverSparkVersion = "1.2.3" + message.submissionId = "driver_123" + message.success = true + message.validate() + // test JSON + val json = message.toJson + assertJsonEquals(json, submitDriverResponseJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[CreateSubmissionResponse]) + assert(newMessage.serverSparkVersion === "1.2.3") + assert(newMessage.submissionId === "driver_123") + assert(newMessage.success) + } + + test("KillSubmissionResponse") { + val message = new KillSubmissionResponse + intercept[SubmitRestProtocolException] { message.validate() } + message.serverSparkVersion = "1.2.3" + message.submissionId = "driver_123" + message.success = true + message.validate() + // test JSON + val json = message.toJson + assertJsonEquals(json, killDriverResponseJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[KillSubmissionResponse]) + assert(newMessage.serverSparkVersion === "1.2.3") + assert(newMessage.submissionId === "driver_123") + assert(newMessage.success) + } + + test("SubmissionStatusResponse") { + val message = new SubmissionStatusResponse + intercept[SubmitRestProtocolException] { message.validate() } + message.serverSparkVersion = "1.2.3" + message.submissionId = "driver_123" + message.success = true + message.validate() + // optional fields + message.driverState = "RUNNING" + message.workerId = "worker_123" + message.workerHostPort = "1.2.3.4:7780" + // test JSON + val json = message.toJson + assertJsonEquals(json, driverStatusResponseJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[SubmissionStatusResponse]) + assert(newMessage.serverSparkVersion === "1.2.3") + assert(newMessage.submissionId === "driver_123") + assert(newMessage.driverState === "RUNNING") + assert(newMessage.success) + assert(newMessage.workerId === "worker_123") + assert(newMessage.workerHostPort === "1.2.3.4:7780") + } + + test("ErrorResponse") { + val message = new ErrorResponse + intercept[SubmitRestProtocolException] { message.validate() } + message.serverSparkVersion = "1.2.3" + message.message = "Field not found in submit request: X" + message.validate() + // test JSON + val json = message.toJson + assertJsonEquals(json, errorJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[ErrorResponse]) + assert(newMessage.serverSparkVersion === "1.2.3") + assert(newMessage.message === "Field not found in submit request: X") + } + + private val dummyRequestJson = + """ + |{ + | "action" : "DummyRequest", + | "active" : true, + | "age" : 25, + | "clientSparkVersion" : "1.2.3", + | "name" : "jung" + |} + """.stripMargin + + private val dummyResponseJson = + """ + |{ + | "action" : "DummyResponse", + | "serverSparkVersion" : "3.3.4", + | "success": true + |} + """.stripMargin + + private val submitDriverRequestJson = + """ + |{ + | "action" : "CreateSubmissionRequest", + | "appArgs" : [ "two slices", "a hint of cinnamon" ], + | "appResource" : "honey-walnut-cherry.jar", + | "clientSparkVersion" : "1.2.3", + | "environmentVariables" : { + | "PATH" : "/dev/null" + | }, + | "mainClass" : "org.apache.spark.examples.SparkPie", + | "sparkProperties" : { + | "spark.driver.extraLibraryPath" : "pickle.jar", + | "spark.jars" : "mayonnaise.jar,ketchup.jar", + | "spark.driver.supervise" : "false", + | "spark.app.name" : "SparkPie", + | "spark.cores.max" : "10000", + | "spark.driver.memory" : "512m", + | "spark.files" : "fireball.png", + | "spark.driver.cores" : "180", + | "spark.driver.extraJavaOptions" : " -Dslices=5 -Dcolor=mostly_red", + | "spark.executor.memory" : "256m", + | "spark.driver.extraClassPath" : "food-coloring.jar" + | } + |} + """.stripMargin + + private val submitDriverResponseJson = + """ + |{ + | "action" : "CreateSubmissionResponse", + | "serverSparkVersion" : "1.2.3", + | "submissionId" : "driver_123", + | "success" : true + |} + """.stripMargin + + private val killDriverResponseJson = + """ + |{ + | "action" : "KillSubmissionResponse", + | "serverSparkVersion" : "1.2.3", + | "submissionId" : "driver_123", + | "success" : true + |} + """.stripMargin + + private val driverStatusResponseJson = + """ + |{ + | "action" : "SubmissionStatusResponse", + | "driverState" : "RUNNING", + | "serverSparkVersion" : "1.2.3", + | "submissionId" : "driver_123", + | "success" : true, + | "workerHostPort" : "1.2.3.4:7780", + | "workerId" : "worker_123" + |} + """.stripMargin + + private val errorJson = + """ + |{ + | "action" : "ErrorResponse", + | "message" : "Field not found in submit request: X", + | "serverSparkVersion" : "1.2.3" + |} + """.stripMargin + + /** Assert that the contents in the two JSON strings are equal after ignoring whitespace. */ + private def assertJsonEquals(jsonString1: String, jsonString2: String): Unit = { + val trimmedJson1 = jsonString1.trim + val trimmedJson2 = jsonString2.trim + val json1 = compact(render(parse(trimmedJson1))) + val json2 = compact(render(parse(trimmedJson2))) + // Put this on a separate line to avoid printing comparison twice when test fails + val equals = json1 == json2 + assert(equals, "\"[%s]\" did not equal \"[%s]\"".format(trimmedJson1, trimmedJson2)) + } +} + +private class DummyResponse extends SubmitRestProtocolResponse +private class DummyRequest extends SubmitRestProtocolRequest { + var active: Boolean = null + var age: Integer = null + var name: String = null + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(name, "name") + assertFieldIsSet(age, "age") + assert(age > 5, "Not old enough!") + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/CommandUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala similarity index 94% rename from core/src/test/scala/org/apache/spark/deploy/CommandUtilsSuite.scala rename to core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala index 7915ee75d877..1c27d83cf876 100644 --- a/core/src/test/scala/org/apache/spark/deploy/CommandUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala @@ -15,11 +15,10 @@ * limitations under the License. */ -package org.apache.spark.deploy +package org.apache.spark.deploy.worker -import org.apache.spark.deploy.worker.CommandUtils +import org.apache.spark.deploy.Command import org.apache.spark.util.Utils - import org.scalatest.{FunSuite, Matchers} class CommandUtilsSuite extends FunSuite with Matchers { diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala index b6f4411e0587..2159fd8c16c6 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala @@ -25,15 +25,17 @@ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.FunSuite -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.{Command, DriverDescription} +import org.apache.spark.util.Clock class DriverRunnerTest extends FunSuite { private def createDriverRunner() = { val command = new Command("mainClass", Seq(), Map(), Seq(), Seq(), Seq()) val driverDescription = new DriverDescription("jarUrl", 512, 1, true, command) - new DriverRunner(new SparkConf(), "driverId", new File("workDir"), new File("sparkHome"), - driverDescription, null, "akka://1.2.3.4/worker/") + val conf = new SparkConf() + new DriverRunner(conf, "driverId", new File("workDir"), new File("sparkHome"), + driverDescription, null, "akka://1.2.3.4/worker/", new SecurityManager(conf)) } private def createProcessBuilderAndProcess(): (ProcessBuilderLike, Process) = { @@ -129,7 +131,7 @@ class DriverRunnerTest extends FunSuite { .thenReturn(-1) // fail 3 .thenReturn(-1) // fail 4 .thenReturn(0) // success - when(clock.currentTimeMillis()) + when(clock.getTimeMillis()) .thenReturn(0).thenReturn(1000) // fail 1 (short) .thenReturn(1000).thenReturn(2000) // fail 2 (short) .thenReturn(2000).thenReturn(10000) // fail 3 (long) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index 6f233d7cf97a..a8b9df227c99 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -32,10 +32,11 @@ class ExecutorRunnerTest extends FunSuite { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) val appDesc = new ApplicationDescription("app name", Some(8), 500, Command("foo", Seq(appId), Map(), Seq(), Seq(), Seq()), "appUiUrl") - val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", - new File(sparkHome), new File("ooga"), "blah", new SparkConf, Seq("localDir"), + val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", 123, + "publicAddr", new File(sparkHome), new File("ooga"), "blah", new SparkConf, Seq("localDir"), ExecutorState.RUNNING) - val builder = CommandUtils.buildProcessBuilder(appDesc.command, 512, sparkHome, er.substituteVariables) + val builder = CommandUtils.buildProcessBuilder( + appDesc.command, 512, sparkHome, er.substituteVariables) assert(builder.command().last === appId) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala index 1a28a9a187cd..7cc210428146 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala @@ -37,13 +37,13 @@ class WorkerArgumentsTest extends FunSuite { val args = Array("spark://localhost:0000 ") class MySparkConf extends SparkConf(false) { - override def getenv(name: String) = { + override def getenv(name: String): String = { if (name == "SPARK_WORKER_MEMORY") "50000" else super.getenv(name) } override def clone: SparkConf = { - new MySparkConf().setAll(settings) + new MySparkConf().setAll(getAll) } } val conf = new MySparkConf() @@ -56,13 +56,13 @@ class WorkerArgumentsTest extends FunSuite { val args = Array("spark://localhost:0000 ") class MySparkConf extends SparkConf(false) { - override def getenv(name: String) = { + override def getenv(name: String): String = { if (name == "SPARK_WORKER_MEMORY") "5G" else super.getenv(name) } override def clone: SparkConf = { - new MySparkConf().setAll(settings) + new MySparkConf().setAll(getAll) } } val conf = new MySparkConf() diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala new file mode 100644 index 000000000000..450fba21f4b5 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.worker + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.Command + +import org.scalatest.{Matchers, FunSuite} + +class WorkerSuite extends FunSuite with Matchers { + + def cmd(javaOpts: String*): Command = { + Command("", Seq.empty, Map.empty, Seq.empty, Seq.empty, Seq(javaOpts:_*)) + } + def conf(opts: (String, String)*): SparkConf = new SparkConf(loadDefaults = false).setAll(opts) + + test("test isUseLocalNodeSSLConfig") { + Worker.isUseLocalNodeSSLConfig(cmd("-Dasdf=dfgh")) shouldBe false + Worker.isUseLocalNodeSSLConfig(cmd("-Dspark.ssl.useNodeLocalConf=true")) shouldBe true + Worker.isUseLocalNodeSSLConfig(cmd("-Dspark.ssl.useNodeLocalConf=false")) shouldBe false + Worker.isUseLocalNodeSSLConfig(cmd("-Dspark.ssl.useNodeLocalConf=")) shouldBe false + } + + test("test maybeUpdateSSLSettings") { + Worker.maybeUpdateSSLSettings( + cmd("-Dasdf=dfgh", "-Dspark.ssl.opt1=x"), + conf("spark.ssl.opt1" -> "y", "spark.ssl.opt2" -> "z")) + .javaOpts should contain theSameElementsInOrderAs Seq( + "-Dasdf=dfgh", "-Dspark.ssl.opt1=x") + + Worker.maybeUpdateSSLSettings( + cmd("-Dspark.ssl.useNodeLocalConf=false", "-Dspark.ssl.opt1=x"), + conf("spark.ssl.opt1" -> "y", "spark.ssl.opt2" -> "z")) + .javaOpts should contain theSameElementsInOrderAs Seq( + "-Dspark.ssl.useNodeLocalConf=false", "-Dspark.ssl.opt1=x") + + Worker.maybeUpdateSSLSettings( + cmd("-Dspark.ssl.useNodeLocalConf=true", "-Dspark.ssl.opt1=x"), + conf("spark.ssl.opt1" -> "y", "spark.ssl.opt2" -> "z")) + .javaOpts should contain theSameElementsAs Seq( + "-Dspark.ssl.useNodeLocalConf=true", "-Dspark.ssl.opt1=y", "-Dspark.ssl.opt2=z") + + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index 5e538d6fab2a..6a6f29dd613c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -17,32 +17,38 @@ package org.apache.spark.deploy.worker -import akka.actor.{ActorSystem, AddressFromURIString, Props} -import akka.testkit.TestActorRef -import akka.remote.DisassociatedEvent +import akka.actor.AddressFromURIString +import org.apache.spark.SparkConf +import org.apache.spark.SecurityManager +import org.apache.spark.rpc.{RpcAddress, RpcEnv} import org.scalatest.FunSuite class WorkerWatcherSuite extends FunSuite { test("WorkerWatcher shuts down on valid disassociation") { - val actorSystem = ActorSystem("test") - val targetWorkerUrl = "akka://1.2.3.4/user/Worker" + val conf = new SparkConf() + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker" val targetWorkerAddress = AddressFromURIString(targetWorkerUrl) - val actorRef = TestActorRef[WorkerWatcher](Props(classOf[WorkerWatcher], targetWorkerUrl))(actorSystem) - val workerWatcher = actorRef.underlyingActor + val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) - actorRef.underlyingActor.receive(new DisassociatedEvent(null, targetWorkerAddress, false)) - assert(actorRef.underlyingActor.isShutDown) + rpcEnv.setupEndpoint("worker-watcher", workerWatcher) + workerWatcher.onDisconnected( + RpcAddress(targetWorkerAddress.host.get, targetWorkerAddress.port.get)) + assert(workerWatcher.isShutDown) + rpcEnv.shutdown() } test("WorkerWatcher stays alive on invalid disassociation") { - val actorSystem = ActorSystem("test") - val targetWorkerUrl = "akka://1.2.3.4/user/Worker" - val otherAkkaURL = "akka://4.3.2.1/user/OtherActor" + val conf = new SparkConf() + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker" + val otherAkkaURL = "akka://test@4.3.2.1:1234/user/OtherActor" val otherAkkaAddress = AddressFromURIString(otherAkkaURL) - val actorRef = TestActorRef[WorkerWatcher](Props(classOf[WorkerWatcher], targetWorkerUrl))(actorSystem) - val workerWatcher = actorRef.underlyingActor + val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) - actorRef.underlyingActor.receive(new DisassociatedEvent(null, otherAkkaAddress, false)) - assert(!actorRef.underlyingActor.isShutDown) + rpcEnv.setupEndpoint("worker-watcher", workerWatcher) + workerWatcher.onDisconnected(RpcAddress(otherAkkaAddress.host.get, otherAkkaAddress.port.get)) + assert(!workerWatcher.isShutDown) + rpcEnv.shutdown() } } diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala new file mode 100644 index 000000000000..326e203afe13 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import org.scalatest.FunSuite + +class TaskMetricsSuite extends FunSuite { + test("[SPARK-5701] updateShuffleReadMetrics: ShuffleReadMetrics not added when no shuffle deps") { + val taskMetrics = new TaskMetrics() + taskMetrics.updateShuffleReadMetrics() + assert(taskMetrics.shuffleReadMetrics.isEmpty) + } +} diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala index 98b0a16ce88b..2e58c159a2ed 100644 --- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala @@ -28,7 +28,7 @@ import org.scalatest.FunSuite import org.apache.hadoop.io.Text -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.util.Utils import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, GzipCodec} @@ -42,7 +42,15 @@ class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll { private var factory: CompressionCodecFactory = _ override def beforeAll() { - sc = new SparkContext("local", "test") + // Hadoop's FileSystem caching does not use the Configuration as part of its cache key, which + // can cause Filesystem.get(Configuration) to return a cached instance created with a different + // configuration than the one passed to get() (see HADOOP-8490 for more details). This caused + // hard-to-reproduce test failures, since any suites that were run after this one would inherit + // the new value of "fs.local.block.size" (see SPARK-5227 and SPARK-5679). To work around this, + // we disable FileSystem caching in this suite. + val conf = new SparkConf().set("spark.hadoop.fs.file.impl.disable.cache", "true") + + sc = new SparkContext("local", "test", conf) // Set the block size of local file system to test whether files are split right or not. sc.hadoopConfiguration.setLong("fs.local.block.size", 32) diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 10a39990f80c..190b08d950a0 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -21,35 +21,46 @@ import java.io.{File, FileWriter, PrintWriter} import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - +import org.apache.commons.lang.math.RandomUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{LongWritable, Text} -import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} +import org.apache.hadoop.mapred.lib.{CombineFileInputFormat => OldCombineFileInputFormat, + CombineFileRecordReader => OldCombineFileRecordReader, CombineFileSplit => OldCombineFileSplit} +import org.apache.hadoop.mapred.{JobConf, Reporter, FileSplit => OldFileSplit, + InputSplit => OldInputSplit, LineRecordReader => OldLineRecordReader, + RecordReader => OldRecordReader, TextInputFormat => OldTextInputFormat} +import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat => NewCombineFileInputFormat, + CombineFileRecordReader => NewCombineFileRecordReader, CombineFileSplit => NewCombineFileSplit, + FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} +import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} +import org.apache.hadoop.mapreduce.{TaskAttemptContext, InputSplit => NewInputSplit, + RecordReader => NewRecordReader} +import org.scalatest.{BeforeAndAfter, FunSuite} import org.apache.spark.SharedSparkContext import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.util.Utils -class InputOutputMetricsSuite extends FunSuite with SharedSparkContext { +class InputOutputMetricsSuite extends FunSuite with SharedSparkContext + with BeforeAndAfter { @transient var tmpDir: File = _ @transient var tmpFile: File = _ @transient var tmpFilePath: String = _ + @transient val numRecords: Int = 100000 + @transient val numBuckets: Int = 10 - override def beforeAll() { - super.beforeAll() - + before { tmpDir = Utils.createTempDir() val testTempDir = new File(tmpDir, "test") testTempDir.mkdir() tmpFile = new File(testTempDir, getClass.getSimpleName + ".txt") val pw = new PrintWriter(new FileWriter(tmpFile)) - for (x <- 1 to 1000000) { - pw.println("s") + for (x <- 1 to numRecords) { + pw.println(RandomUtils.nextInt(numBuckets)) } pw.close() @@ -57,8 +68,7 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext { tmpFilePath = "file://" + tmpFile.getAbsolutePath } - override def afterAll() { - super.afterAll() + after { Utils.deleteRecursively(tmpDir) } @@ -146,6 +156,101 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext { assert(bytesRead >= tmpFile.length()) } + test("input metrics on records read - simple") { + val records = runAndReturnRecordsRead { + sc.textFile(tmpFilePath, 4).count() + } + assert(records == numRecords) + } + + test("input metrics on records read - more stages") { + val records = runAndReturnRecordsRead { + sc.textFile(tmpFilePath, 4) + .map(key => (key.length, 1)) + .reduceByKey(_ + _) + .count() + } + assert(records == numRecords) + } + + test("input metrics on records - New Hadoop API") { + val records = runAndReturnRecordsRead { + sc.newAPIHadoopFile(tmpFilePath, classOf[NewTextInputFormat], classOf[LongWritable], + classOf[Text]).count() + } + assert(records == numRecords) + } + + test("input metrics on recordsd read with cache") { + // prime the cache manager + val rdd = sc.textFile(tmpFilePath, 4).cache() + rdd.collect() + + val records = runAndReturnRecordsRead { + rdd.count() + } + + assert(records == numRecords) + } + + test("shuffle records read metrics") { + val recordsRead = runAndReturnShuffleRecordsRead { + sc.textFile(tmpFilePath, 4) + .map(key => (key, 1)) + .groupByKey() + .collect() + } + assert(recordsRead == numRecords) + } + + test("shuffle records written metrics") { + val recordsWritten = runAndReturnShuffleRecordsWritten { + sc.textFile(tmpFilePath, 4) + .map(key => (key, 1)) + .groupByKey() + .collect() + } + assert(recordsWritten == numRecords) + } + + /** + * Tests the metrics from end to end. + * 1) reading a hadoop file + * 2) shuffle and writing to a hadoop file. + * 3) writing to hadoop file. + */ + test("input read/write and shuffle read/write metrics all line up") { + var inputRead = 0L + var outputWritten = 0L + var shuffleRead = 0L + var shuffleWritten = 0L + sc.addSparkListener(new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + val metrics = taskEnd.taskMetrics + metrics.inputMetrics.foreach(inputRead += _.recordsRead) + metrics.outputMetrics.foreach(outputWritten += _.recordsWritten) + metrics.shuffleReadMetrics.foreach(shuffleRead += _.recordsRead) + metrics.shuffleWriteMetrics.foreach(shuffleWritten += _.shuffleRecordsWritten) + } + }) + + val tmpFile = new File(tmpDir, getClass.getSimpleName) + + sc.textFile(tmpFilePath, 4) + .map(key => (key, 1)) + .reduceByKey(_ + _) + .saveAsTextFile("file://" + tmpFile.getAbsolutePath) + + sc.listenerBus.waitUntilEmpty(500) + assert(inputRead == numRecords) + + // Only supported on newer Hadoop + if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback().isDefined) { + assert(outputWritten == numBuckets) + } + assert(shuffleRead == shuffleWritten) + } + test("input metrics with interleaved reads") { val numPartitions = 2 val cartVector = 0 to 9 @@ -184,25 +289,73 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext { assert(cartesianBytes == firstSize * numPartitions + (cartVector.length * secondSize)) } - private def runAndReturnBytesRead(job : => Unit): Long = { - val taskBytesRead = new ArrayBuffer[Long]() + private def runAndReturnBytesRead(job: => Unit): Long = { + runAndReturnMetrics(job, _.taskMetrics.inputMetrics.map(_.bytesRead)) + } + + private def runAndReturnRecordsRead(job: => Unit): Long = { + runAndReturnMetrics(job, _.taskMetrics.inputMetrics.map(_.recordsRead)) + } + + private def runAndReturnRecordsWritten(job: => Unit): Long = { + runAndReturnMetrics(job, _.taskMetrics.outputMetrics.map(_.recordsWritten)) + } + + private def runAndReturnShuffleRecordsRead(job: => Unit): Long = { + runAndReturnMetrics(job, _.taskMetrics.shuffleReadMetrics.map(_.recordsRead)) + } + + private def runAndReturnShuffleRecordsWritten(job: => Unit): Long = { + runAndReturnMetrics(job, _.taskMetrics.shuffleWriteMetrics.map(_.shuffleRecordsWritten)) + } + + private def runAndReturnMetrics(job: => Unit, + collector: (SparkListenerTaskEnd) => Option[Long]): Long = { + val taskMetrics = new ArrayBuffer[Long]() sc.addSparkListener(new SparkListener() { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { - taskBytesRead += taskEnd.taskMetrics.inputMetrics.get.bytesRead + collector(taskEnd).foreach(taskMetrics += _) } }) job sc.listenerBus.waitUntilEmpty(500) - taskBytesRead.sum + taskMetrics.sum + } + + test("output metrics on records written") { + // Only supported on newer Hadoop + if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback().isDefined) { + val file = new File(tmpDir, getClass.getSimpleName) + val filePath = "file://" + file.getAbsolutePath + + val records = runAndReturnRecordsWritten { + sc.parallelize(1 to numRecords).saveAsTextFile(filePath) + } + assert(records == numRecords) + } + } + + test("output metrics on records written - new Hadoop API") { + // Only supported on newer Hadoop + if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback().isDefined) { + val file = new File(tmpDir, getClass.getSimpleName) + val filePath = "file://" + file.getAbsolutePath + + val records = runAndReturnRecordsWritten { + sc.parallelize(1 to numRecords).map(key => (key.toString, key.toString)) + .saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](filePath) + } + assert(records == numRecords) + } } test("output metrics when writing text file") { val fs = FileSystem.getLocal(new Configuration()) val outPath = new Path(fs.getWorkingDirectory, "outdir") - if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback(outPath, fs.getConf).isDefined) { + if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback().isDefined) { val taskBytesWritten = new ArrayBuffer[Long]() sc.addSparkListener(new SparkListener() { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { @@ -225,4 +378,88 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext { } } } + + test("input metrics with old CombineFileInputFormat") { + val bytesRead = runAndReturnBytesRead { + sc.hadoopFile(tmpFilePath, classOf[OldCombineTextInputFormat], classOf[LongWritable], + classOf[Text], 2).count() + } + assert(bytesRead >= tmpFile.length()) + } + + test("input metrics with new CombineFileInputFormat") { + val bytesRead = runAndReturnBytesRead { + sc.newAPIHadoopFile(tmpFilePath, classOf[NewCombineTextInputFormat], classOf[LongWritable], + classOf[Text], new Configuration()).count() + } + assert(bytesRead >= tmpFile.length()) + } +} + +/** + * Hadoop 2 has a version of this, but we can't use it for backwards compatibility + */ +class OldCombineTextInputFormat extends OldCombineFileInputFormat[LongWritable, Text] { + override def getRecordReader(split: OldInputSplit, conf: JobConf, reporter: Reporter) + : OldRecordReader[LongWritable, Text] = { + new OldCombineFileRecordReader[LongWritable, Text](conf, + split.asInstanceOf[OldCombineFileSplit], reporter, classOf[OldCombineTextRecordReaderWrapper] + .asInstanceOf[Class[OldRecordReader[LongWritable, Text]]]) + } +} + +class OldCombineTextRecordReaderWrapper( + split: OldCombineFileSplit, + conf: Configuration, + reporter: Reporter, + idx: Integer) extends OldRecordReader[LongWritable, Text] { + + val fileSplit = new OldFileSplit(split.getPath(idx), + split.getOffset(idx), + split.getLength(idx), + split.getLocations()) + + val delegate: OldLineRecordReader = new OldTextInputFormat().getRecordReader(fileSplit, + conf.asInstanceOf[JobConf], reporter).asInstanceOf[OldLineRecordReader] + + override def next(key: LongWritable, value: Text): Boolean = delegate.next(key, value) + override def createKey(): LongWritable = delegate.createKey() + override def createValue(): Text = delegate.createValue() + override def getPos(): Long = delegate.getPos + override def close(): Unit = delegate.close() + override def getProgress(): Float = delegate.getProgress +} + +/** + * Hadoop 2 has a version of this, but we can't use it for backwards compatibility + */ +class NewCombineTextInputFormat extends NewCombineFileInputFormat[LongWritable,Text] { + def createRecordReader(split: NewInputSplit, context: TaskAttemptContext) + : NewRecordReader[LongWritable, Text] = { + new NewCombineFileRecordReader[LongWritable,Text](split.asInstanceOf[NewCombineFileSplit], + context, classOf[NewCombineTextRecordReaderWrapper]) + } +} + +class NewCombineTextRecordReaderWrapper( + split: NewCombineFileSplit, + context: TaskAttemptContext, + idx: Integer) extends NewRecordReader[LongWritable, Text] { + + val fileSplit = new NewFileSplit(split.getPath(idx), + split.getOffset(idx), + split.getLength(idx), + split.getLocations()) + + val delegate = new NewTextInputFormat().createRecordReader(fileSplit, context) + + override def initialize(split: NewInputSplit, context: TaskAttemptContext): Unit = { + delegate.initialize(fileSplit, context) + } + + override def nextKeyValue(): Boolean = delegate.nextKeyValue() + override def getCurrentKey(): LongWritable = delegate.getCurrentKey + override def getCurrentValue(): Text = delegate.getCurrentValue + override def getProgress(): Float = delegate.getProgress + override def close(): Unit = delegate.close() } diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala index 1a9ce8c607dc..100ac77dec1f 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala @@ -27,7 +27,7 @@ class MetricsConfigSuite extends FunSuite with BeforeAndAfter { } test("MetricsConfig with default properties") { - val conf = new MetricsConfig(Option("dummy-file")) + val conf = new MetricsConfig(None) conf.initialize() assert(conf.properties.size() === 4) @@ -35,7 +35,8 @@ class MetricsConfigSuite extends FunSuite with BeforeAndAfter { val property = conf.getInstance("random") assert(property.size() === 2) - assert(property.getProperty("sink.servlet.class") === "org.apache.spark.metrics.sink.MetricsServlet") + assert(property.getProperty("sink.servlet.class") === + "org.apache.spark.metrics.sink.MetricsServlet") assert(property.getProperty("sink.servlet.path") === "/metrics/json") } @@ -47,16 +48,20 @@ class MetricsConfigSuite extends FunSuite with BeforeAndAfter { assert(masterProp.size() === 5) assert(masterProp.getProperty("sink.console.period") === "20") assert(masterProp.getProperty("sink.console.unit") === "minutes") - assert(masterProp.getProperty("source.jvm.class") === "org.apache.spark.metrics.source.JvmSource") - assert(masterProp.getProperty("sink.servlet.class") === "org.apache.spark.metrics.sink.MetricsServlet") + assert(masterProp.getProperty("source.jvm.class") === + "org.apache.spark.metrics.source.JvmSource") + assert(masterProp.getProperty("sink.servlet.class") === + "org.apache.spark.metrics.sink.MetricsServlet") assert(masterProp.getProperty("sink.servlet.path") === "/metrics/master/json") val workerProp = conf.getInstance("worker") assert(workerProp.size() === 5) assert(workerProp.getProperty("sink.console.period") === "10") assert(workerProp.getProperty("sink.console.unit") === "seconds") - assert(workerProp.getProperty("source.jvm.class") === "org.apache.spark.metrics.source.JvmSource") - assert(workerProp.getProperty("sink.servlet.class") === "org.apache.spark.metrics.sink.MetricsServlet") + assert(workerProp.getProperty("source.jvm.class") === + "org.apache.spark.metrics.source.JvmSource") + assert(workerProp.getProperty("sink.servlet.class") === + "org.apache.spark.metrics.sink.MetricsServlet") assert(workerProp.getProperty("sink.servlet.path") === "/metrics/json") } diff --git a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala index 716f875d30b8..02424c59d683 100644 --- a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala @@ -260,8 +260,8 @@ class ConnectionManagerSuite extends FunSuite { test("sendMessageReliably timeout") { val clientConf = new SparkConf clientConf.set("spark.authenticate", "false") - val ackTimeout = 30 - clientConf.set("spark.core.connection.ack.wait.timeout", s"${ackTimeout}") + val ackTimeoutS = 30 + clientConf.set("spark.core.connection.ack.wait.timeout", s"${ackTimeoutS}s") val clientSecurityManager = new SecurityManager(clientConf) val manager = new ConnectionManager(0, clientConf, clientSecurityManager) @@ -272,7 +272,7 @@ class ConnectionManagerSuite extends FunSuite { val managerServer = new ConnectionManager(0, serverConf, serverSecurityManager) managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { // sleep 60 sec > ack timeout for simulating server slow down or hang up - Thread.sleep(ackTimeout * 3 * 1000) + Thread.sleep(ackTimeoutS * 3 * 1000) None }) @@ -287,7 +287,7 @@ class ConnectionManagerSuite extends FunSuite { // Otherwise TimeoutExcepton is thrown from Await.result. // We expect TimeoutException is not thrown. intercept[IOException] { - Await.result(future, (ackTimeout * 2) second) + Await.result(future, (ackTimeoutS * 2) second) } manager.stop() diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala index de306533752c..01039b9449da 100644 --- a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -22,6 +22,12 @@ import org.scalatest.FunSuite import org.apache.spark._ class DoubleRDDSuite extends FunSuite with SharedSparkContext { + test("sum") { + assert(sc.parallelize(Seq.empty[Double]).sum() === 0.0) + assert(sc.parallelize(Seq(1.0)).sum() === 1.0) + assert(sc.parallelize(Seq(1.0, 2.0)).sum() === 3.0) + } + // Verify tests on the histogram functionality. We test with both evenly // and non-evenly spaced buckets as the bucket lookup function changes. test("WorksOnEmpty") { @@ -33,6 +39,9 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { val expectedHistogramResults = Array(0) assert(histogramResults === expectedHistogramResults) assert(histogramResults2 === expectedHistogramResults) + val emptyRDD: RDD[Double] = sc.emptyRDD + assert(emptyRDD.histogram(buckets) === expectedHistogramResults) + assert(emptyRDD.histogram(buckets, true) === expectedHistogramResults) } test("WorksWithOutOfRangeWithOneBucket") { @@ -232,6 +241,12 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramBuckets === expectedHistogramBuckets) } + test("WorksWithDoubleValuesAtMinMax") { + val rdd = sc.parallelize(Seq(1, 1, 1, 2, 3, 3)) + assert(Array(3, 0, 1, 2) === rdd.map(_.toDouble).histogram(4)._2) + assert(Array(3, 1, 2) === rdd.map(_.toDouble).histogram(3)._2) + } + test("WorksWithoutBucketsWithMoreRequestedThanElements") { // Verify the basic case of one bucket and all elements in that bucket works val rdd = sc.parallelize(Seq(1, 2)) @@ -245,7 +260,7 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { } test("WorksWithoutBucketsForLargerDatasets") { - // Verify the case of slighly larger datasets + // Verify the case of slightly larger datasets val rdd = sc.parallelize(6 to 99) val (histogramBuckets, histogramResults) = rdd.histogram(8) val expectedHistogramResults = @@ -256,17 +271,27 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramBuckets === expectedHistogramBuckets) } - test("WorksWithoutBucketsWithIrrationalBucketEdges") { - // Verify the case of buckets with irrational edges. See #SPARK-2862. + test("WorksWithoutBucketsWithNonIntegralBucketEdges") { + // Verify the case of buckets with nonintegral edges. See #SPARK-2862. val rdd = sc.parallelize(6 to 99) val (histogramBuckets, histogramResults) = rdd.histogram(9) + // Buckets are 6.0, 16.333333333333336, 26.666666666666668, 37.0, 47.333333333333336 ... val expectedHistogramResults = - Array(11, 10, 11, 10, 10, 11, 10, 10, 11) + Array(11, 10, 10, 11, 10, 10, 11, 10, 11) assert(histogramResults === expectedHistogramResults) assert(histogramBuckets(0) === 6.0) assert(histogramBuckets(9) === 99.0) } + test("WorksWithHugeRange") { + val rdd = sc.parallelize(Array(0, 1.0e24, 1.0e30)) + val histogramResults = rdd.histogram(1000000)._2 + assert(histogramResults(0) === 1) + assert(histogramResults(1) === 1) + assert(histogramResults.last === 1) + assert((2 to histogramResults.length - 2).forall(i => histogramResults(i) == 0)) + } + // Test the failure mode with an invalid RDD test("ThrowsExceptionOnInvalidRDDs") { // infinity diff --git a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala index 6138d0bbd57f..be8467354b22 100644 --- a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala @@ -29,22 +29,42 @@ class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { Class.forName("org.apache.derby.jdbc.EmbeddedDriver") val conn = DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true") try { - val create = conn.createStatement - create.execute(""" - CREATE TABLE FOO( - ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1), - DATA INTEGER - )""") - create.close() - val insert = conn.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)") - (1 to 100).foreach { i => - insert.setInt(1, i * 2) - insert.executeUpdate + + try { + val create = conn.createStatement + create.execute(""" + CREATE TABLE FOO( + ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1), + DATA INTEGER + )""") + create.close() + val insert = conn.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)") + (1 to 100).foreach { i => + insert.setInt(1, i * 2) + insert.executeUpdate + } + insert.close() + } catch { + case e: SQLException if e.getSQLState == "X0Y32" => + // table exists } - insert.close() - } catch { - case e: SQLException if e.getSQLState == "X0Y32" => + + try { + val create = conn.createStatement + create.execute("CREATE TABLE BIGINT_TEST(ID BIGINT NOT NULL, DATA INTEGER)") + create.close() + val insert = conn.prepareStatement("INSERT INTO BIGINT_TEST VALUES(?,?)") + (1 to 100).foreach { i => + insert.setLong(1, 100000000000000000L + 4000000000000000L * i) + insert.setInt(2, i) + insert.executeUpdate + } + insert.close() + } catch { + case e: SQLException if e.getSQLState == "X0Y32" => // table exists + } + } finally { conn.close() } @@ -60,7 +80,19 @@ class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { (r: ResultSet) => { r.getInt(1) } ).cache() assert(rdd.count === 100) - assert(rdd.reduce(_+_) === 10100) + assert(rdd.reduce(_ + _) === 10100) + } + + test("large id overflow") { + sc = new SparkContext("local", "test") + val rdd = new JdbcRDD( + sc, + () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") }, + "SELECT DATA FROM BIGINT_TEST WHERE ? <= ID AND ID <= ?", + 1131544775L, 567279358897692673L, 20, + (r: ResultSet) => { r.getInt(1) } ).cache() + assert(rdd.count === 100) + assert(rdd.reduce(_ + _) === 5050) } after { diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 108f70af43f3..ca0d953d306d 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -168,13 +168,13 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { test("reduceByKey") { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_).collect() + val sums = pairs.reduceByKey(_ + _).collect() assert(sums.toSet === Set((1, 7), (2, 1))) } test("reduceByKey with collectAsMap") { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_).collectAsMap() + val sums = pairs.reduceByKey(_ + _).collectAsMap() assert(sums.size === 2) assert(sums(1) === 7) assert(sums(2) === 1) @@ -182,7 +182,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { test("reduceByKey with many output partitons") { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_, 10).collect() + val sums = pairs.reduceByKey(_ + _, 10).collect() assert(sums.toSet === Set((1, 7), (2, 1))) } @@ -192,7 +192,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { def getPartition(key: Any) = key.asInstanceOf[Int] } val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p) - val sums = pairs.reduceByKey(_+_) + val sums = pairs.reduceByKey(_ + _) assert(sums.collect().toSet === Set((1, 4), (0, 1))) assert(sums.partitioner === Some(p)) // count the dependencies to make sure there is only 1 ShuffledRDD @@ -208,7 +208,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { } test("countApproxDistinctByKey") { - def error(est: Long, size: Long) = math.abs(est - size) / size.toDouble + def error(est: Long, size: Long): Double = math.abs(est - size) / size.toDouble /* Since HyperLogLog unique counting is approximate, and the relative standard deviation is * only a statistical bound, the tests can fail for large values of relativeSD. We will be using @@ -465,7 +465,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { test("foldByKey") { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.foldByKey(0)(_+_).collect() + val sums = pairs.foldByKey(0)(_ + _).collect() assert(sums.toSet === Set((1, 7), (2, 1))) } @@ -505,7 +505,8 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { conf.setOutputCommitter(classOf[FakeOutputCommitter]) FakeOutputCommitter.ran = false - pairs.saveAsHadoopFile("ignored", pairs.keyClass, pairs.valueClass, classOf[FakeOutputFormat], conf) + pairs.saveAsHadoopFile( + "ignored", pairs.keyClass, pairs.valueClass, classOf[FakeOutputFormat], conf) assert(FakeOutputCommitter.ran, "OutputCommitter was never called") } @@ -552,7 +553,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { } private object StratifiedAuxiliary { - def stratifier (fractionPositive: Double) = { + def stratifier (fractionPositive: Double): (Int) => String = { (x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0" } @@ -572,7 +573,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { def testSampleExact(stratifiedData: RDD[(String, Int)], samplingRate: Double, seed: Long, - n: Long) = { + n: Long): Unit = { testBernoulli(stratifiedData, true, samplingRate, seed, n) testPoisson(stratifiedData, true, samplingRate, seed, n) } @@ -580,7 +581,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { def testSample(stratifiedData: RDD[(String, Int)], samplingRate: Double, seed: Long, - n: Long) = { + n: Long): Unit = { testBernoulli(stratifiedData, false, samplingRate, seed, n) testPoisson(stratifiedData, false, samplingRate, seed, n) } @@ -590,7 +591,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { exact: Boolean, samplingRate: Double, seed: Long, - n: Long) = { + n: Long): Unit = { val expectedSampleSize = stratifiedData.countByKey() .mapValues(count => math.ceil(count * samplingRate).toInt) val fractions = Map("1" -> samplingRate, "0" -> samplingRate) @@ -612,7 +613,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { exact: Boolean, samplingRate: Double, seed: Long, - n: Long) = { + n: Long): Unit = { val expectedSampleSize = stratifiedData.countByKey().mapValues(count => math.ceil(count * samplingRate).toInt) val fractions = Map("1" -> samplingRate, "0" -> samplingRate) @@ -701,27 +702,27 @@ class FakeOutputFormat() extends OutputFormat[Integer, Integer]() { */ class NewFakeWriter extends NewRecordWriter[Integer, Integer] { - def close(p1: NewTaskAttempContext) = () + def close(p1: NewTaskAttempContext): Unit = () - def write(p1: Integer, p2: Integer) = () + def write(p1: Integer, p2: Integer): Unit = () } class NewFakeCommitter extends NewOutputCommitter { - def setupJob(p1: NewJobContext) = () + def setupJob(p1: NewJobContext): Unit = () def needsTaskCommit(p1: NewTaskAttempContext): Boolean = false - def setupTask(p1: NewTaskAttempContext) = () + def setupTask(p1: NewTaskAttempContext): Unit = () - def commitTask(p1: NewTaskAttempContext) = () + def commitTask(p1: NewTaskAttempContext): Unit = () - def abortTask(p1: NewTaskAttempContext) = () + def abortTask(p1: NewTaskAttempContext): Unit = () } class NewFakeFormat() extends NewOutputFormat[Integer, Integer]() { - def checkOutputSpecs(p1: NewJobContext) = () + def checkOutputSpecs(p1: NewJobContext): Unit = () def getRecordWriter(p1: NewTaskAttempContext): NewRecordWriter[Integer, Integer] = { new NewFakeWriter() @@ -735,7 +736,7 @@ class NewFakeFormat() extends NewOutputFormat[Integer, Integer]() { class ConfigTestFormat() extends NewFakeFormat() with Configurable { var setConfCalled = false - def setConf(p1: Configuration) = { + def setConf(p1: Configuration): Unit = { setConfCalled = true () } diff --git a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala index cd193ae4f523..1880364581c1 100644 --- a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala @@ -100,7 +100,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val data = 1 until 100 val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 99) + assert(slices.map(_.size).reduceLeft(_ + _) === 99) assert(slices.forall(_.isInstanceOf[Range])) } @@ -108,7 +108,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val data = 1 to 100 val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 100) + assert(slices.map(_.size).reduceLeft(_ + _) === 100) assert(slices.forall(_.isInstanceOf[Range])) } @@ -139,7 +139,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(slices(i).isInstanceOf[Range]) val range = slices(i).asInstanceOf[Range] assert(range.start === i * (N / 40), "slice " + i + " start") - assert(range.end === (i+1) * (N / 40), "slice " + i + " end") + assert(range.end === (i + 1) * (N / 40), "slice " + i + " end") assert(range.step === 1, "slice " + i + " step") } } @@ -156,7 +156,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val slices = ParallelCollectionRDD.slice(d, n) ("n slices" |: slices.size == n) && ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && - ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) + ("equal sizes" |: slices.map(_.size).forall(x => x == d.size / n || x == d.size /n + 1)) } check(prop) } @@ -174,7 +174,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { ("n slices" |: slices.size == n) && ("all ranges" |: slices.forall(_.isInstanceOf[Range])) && ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && - ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) + ("equal sizes" |: slices.map(_.size).forall(x => x == d.size / n || x == d.size / n + 1)) } check(prop) } @@ -192,7 +192,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { ("n slices" |: slices.size == n) && ("all ranges" |: slices.forall(_.isInstanceOf[Range])) && ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && - ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) + ("equal sizes" |: slices.map(_.size).forall(x => x == d.size / n || x == d.size / n + 1)) } check(prop) } @@ -201,7 +201,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val data = 1L until 100L val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 99) + assert(slices.map(_.size).reduceLeft(_ + _) === 99) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } @@ -209,7 +209,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val data = 1L to 100L val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 100) + assert(slices.map(_.size).reduceLeft(_ + _) === 100) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } @@ -217,7 +217,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val data = 1.0 until 100.0 by 1.0 val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 99) + assert(slices.map(_.size).reduceLeft(_ + _) === 99) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } @@ -225,7 +225,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val data = 1.0 to 100.0 by 1.0 val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 100) + assert(slices.map(_.size).reduceLeft(_ + _) === 100) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala index 8408d7e785c6..465068c6cbb1 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala @@ -23,7 +23,6 @@ import org.apache.spark.{Partition, SharedSparkContext, TaskContext} class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { - test("Pruned Partitions inherit locality prefs correctly") { val rdd = new RDD[Int](sc, Nil) { @@ -74,8 +73,6 @@ class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { } class TestPartition(i: Int, value: Int) extends Partition with Serializable { - def index = i - - def testValue = this.value - + def index: Int = i + def testValue: Int = this.value } diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala index a0483886f8db..0d1369c19c69 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala @@ -35,7 +35,7 @@ class MockSampler extends RandomSampler[Long, Long] { Iterator(s) } - override def clone = new MockSampler + override def clone: MockSampler = new MockSampler } class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext { diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index 1a9a0e857e54..aea76c1adcc0 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -22,7 +22,6 @@ import java.io.File import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.{FileSplit, JobConf, TextInputFormat} -import org.apache.spark._ import org.scalatest.FunSuite import scala.collection.Map @@ -30,6 +29,9 @@ import scala.language.postfixOps import scala.sys.process._ import scala.util.Try +import org.apache.spark._ +import org.apache.spark.util.Utils + class PipedRDDSuite extends FunSuite with SharedSparkContext { test("basic pipe") { @@ -141,7 +143,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext { // make sure symlinks were created assert(pipedLs.length > 0) // clean up top level tasks directory - new File("tasks").delete() + Utils.deleteRecursively(new File("tasks")) } else { assert(true) } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 381ee2d45630..df42faab6450 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -82,7 +82,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("countApproxDistinct") { - def error(est: Long, size: Long) = math.abs(est - size) / size.toDouble + def error(est: Long, size: Long): Double = math.abs(est - size) / size.toDouble val size = 1000 val uniformDistro = for (i <- 1 to 5000) yield i % size @@ -100,7 +100,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { } test("partitioner aware union") { - def makeRDDWithPartitioner(seq: Seq[Int]) = { + def makeRDDWithPartitioner(seq: Seq[Int]): RDD[Int] = { sc.makeRDD(seq, 1) .map(x => (x, null)) .partitionBy(new HashPartitioner(2)) @@ -157,6 +157,24 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5))) } + test("treeAggregate") { + val rdd = sc.makeRDD(-1000 until 1000, 10) + def seqOp: (Long, Int) => Long = (c: Long, x: Int) => c + x + def combOp: (Long, Long) => Long = (c1: Long, c2: Long) => c1 + c2 + for (depth <- 1 until 10) { + val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth) + assert(sum === -1000L) + } + } + + test("treeReduce") { + val rdd = sc.makeRDD(-1000 until 1000, 10) + for (depth <- 1 until 10) { + val sum = rdd.treeReduce(_ + _, depth) + assert(sum === -1000) + } + } + test("basic caching") { val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() assert(rdd.collect().toList === List(1, 2, 3, 4)) @@ -186,7 +204,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(empty.collect().size === 0) val thrown = intercept[UnsupportedOperationException]{ - empty.reduce(_+_) + empty.reduce(_ + _) } assert(thrown.getMessage.contains("empty")) @@ -303,7 +321,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(list3.sorted === Array("a","b","c"), "Locality preferences are dropped") // RDD with locality preferences spread (non-randomly) over 6 machines, m0 through m5 - val data = sc.makeRDD((1 to 9).map(i => (i, (i to (i+2)).map{ j => "m" + (j%6)}))) + val data = sc.makeRDD((1 to 9).map(i => (i, (i to (i + 2)).map{ j => "m" + (j%6)}))) val coalesced1 = data.coalesce(3) assert(coalesced1.collect().toList.sorted === (1 to 9).toList, "Data got *lost* in coalescing") @@ -903,15 +921,17 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("task serialization exception should not hang scheduler") { class BadSerializable extends Serializable { @throws(classOf[IOException]) - private def writeObject(out: ObjectOutputStream): Unit = throw new KryoException("Bad serialization") + private def writeObject(out: ObjectOutputStream): Unit = + throw new KryoException("Bad serialization") @throws(classOf[IOException]) private def readObject(in: ObjectInputStream): Unit = {} } - // Note that in the original bug, SPARK-4349, that this verifies, the job would only hang if there were - // more threads in the Spark Context than there were number of objects in this sequence. + // Note that in the original bug, SPARK-4349, that this verifies, the job would only hang if + // there were more threads in the Spark Context than there were number of objects in this + // sequence. intercept[Throwable] { - sc.parallelize(Seq(new BadSerializable, new BadSerializable)).collect + sc.parallelize(Seq(new BadSerializable, new BadSerializable)).collect() } // Check that the context has not crashed sc.parallelize(1 to 100).map(x => x*2).collect @@ -927,4 +947,45 @@ class RDDSuite extends FunSuite with SharedSparkContext { mutableDependencies += dep } } + + test("nested RDDs are not supported (SPARK-5063)") { + val rdd: RDD[Int] = sc.parallelize(1 to 100) + val rdd2: RDD[Int] = sc.parallelize(1 to 100) + val thrown = intercept[SparkException] { + val nestedRDD: RDD[RDD[Int]] = rdd.mapPartitions { x => Seq(rdd2.map(x => x)).iterator } + nestedRDD.count() + } + assert(thrown.getMessage.contains("SPARK-5063")) + } + + test("actions cannot be performed inside of transformations (SPARK-5063)") { + val rdd: RDD[Int] = sc.parallelize(1 to 100) + val rdd2: RDD[Int] = sc.parallelize(1 to 100) + val thrown = intercept[SparkException] { + rdd.map(x => x * rdd2.count).collect() + } + assert(thrown.getMessage.contains("SPARK-5063")) + } + + test("cannot run actions after SparkContext has been stopped (SPARK-5063)") { + val existingRDD = sc.parallelize(1 to 100) + sc.stop() + val thrown = intercept[IllegalStateException] { + existingRDD.count() + } + assert(thrown.getMessage.contains("shutdown")) + } + + test("cannot call methods on a stopped SparkContext (SPARK-5063)") { + sc.stop() + def assertFails(block: => Any): Unit = { + val thrown = intercept[IllegalStateException] { + block + } + assert(thrown.getMessage.contains("stopped")) + } + assertFails { sc.parallelize(1 to 100) } + assertFails { sc.textFile("/nonexistent-path") } + } + } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala index 4762fc17855c..fe695d85e29d 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala @@ -21,11 +21,11 @@ object RDDSuiteUtils { case class Person(first: String, last: String, age: Int) object AgeOrdering extends Ordering[Person] { - def compare(a:Person, b:Person) = a.age compare b.age + def compare(a:Person, b:Person): Int = a.age.compare(b.age) } object NameOrdering extends Ordering[Person] { - def compare(a:Person, b:Person) = + def compare(a:Person, b:Person): Int = implicitly[Ordering[Tuple2[String,String]]].compare((a.last, a.first), (b.last, b.first)) } } diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala index a40f2ffeffdf..64b1c24c4716 100644 --- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala @@ -119,5 +119,33 @@ class SortingSuite extends FunSuite with SharedSparkContext with Matchers with L partitions(1).last should be > partitions(2).head partitions(2).last should be > partitions(3).head } + + test("get a range of elements in a sorted RDD that is on one partition") { + val pairArr = (1 to 1000).map(x => (x, x)).toArray + val sorted = sc.parallelize(pairArr, 10).sortByKey() + val range = sorted.filterByRange(20, 40).collect() + assert((20 to 40).toArray === range.map(_._1)) + } + + test("get a range of elements over multiple partitions in a descendingly sorted RDD") { + val pairArr = (1000 to 1 by -1).map(x => (x, x)).toArray + val sorted = sc.parallelize(pairArr, 10).sortByKey(false) + val range = sorted.filterByRange(200, 800).collect() + assert((800 to 200 by -1).toArray === range.map(_._1)) + } + + test("get a range of elements in an array not partitioned by a range partitioner") { + val pairArr = util.Random.shuffle((1 to 1000).toList).map(x => (x, x)) + val pairs = sc.parallelize(pairArr,10) + val range = pairs.filterByRange(200, 800).collect() + assert((800 to 200 by -1).toArray.sorted === range.map(_._1).sorted) + } + + test("get a range of elements over multiple partitions but not taking up full partitions") { + val pairArr = (1000 to 1 by -1).map(x => (x, x)).toArray + val sorted = sc.parallelize(pairArr, 10).sortByKey(false) + val range = sorted.filterByRange(250, 850).collect() + assert((850 to 250 by -1).toArray === range.map(_._1)) + } } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala new file mode 100644 index 000000000000..44c88b00c442 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -0,0 +1,548 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException} + +import scala.collection.mutable +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark.{SparkException, SparkConf} + +/** + * Common tests for an RpcEnv implementation. + */ +abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { + + var env: RpcEnv = _ + + override def beforeAll(): Unit = { + val conf = new SparkConf() + env = createRpcEnv(conf, "local", 12345) + } + + override def afterAll(): Unit = { + if(env != null) { + env.shutdown() + } + } + + def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv + + test("send a message locally") { + @volatile var message: String = null + val rpcEndpointRef = env.setupEndpoint("send-locally", new RpcEndpoint { + override val rpcEnv = env + + override def receive = { + case msg: String => message = msg + } + }) + rpcEndpointRef.send("hello") + eventually(timeout(5 seconds), interval(10 millis)) { + assert("hello" === message) + } + } + + test("send a message remotely") { + @volatile var message: String = null + // Set up a RpcEndpoint using env + env.setupEndpoint("send-remotely", new RpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case msg: String => message = msg + } + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote" ,13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "send-remotely") + try { + rpcEndpointRef.send("hello") + eventually(timeout(5 seconds), interval(10 millis)) { + assert("hello" === message) + } + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("send a RpcEndpointRef") { + val endpoint = new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext) = { + case "Hello" => context.reply(self) + case "Echo" => context.reply("Echo") + } + } + val rpcEndpointRef = env.setupEndpoint("send-ref", endpoint) + + val newRpcEndpointRef = rpcEndpointRef.askWithReply[RpcEndpointRef]("Hello") + val reply = newRpcEndpointRef.askWithReply[String]("Echo") + assert("Echo" === reply) + } + + test("ask a message locally") { + val rpcEndpointRef = env.setupEndpoint("ask-locally", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case msg: String => { + context.reply(msg) + } + } + }) + val reply = rpcEndpointRef.askWithReply[String]("hello") + assert("hello" === reply) + } + + test("ask a message remotely") { + env.setupEndpoint("ask-remotely", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case msg: String => { + context.reply(msg) + } + } + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-remotely") + try { + val reply = rpcEndpointRef.askWithReply[String]("hello") + assert("hello" === reply) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("ask a message timeout") { + env.setupEndpoint("ask-timeout", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case msg: String => { + Thread.sleep(100) + context.reply(msg) + } + } + }) + + val conf = new SparkConf() + conf.set("spark.rpc.retry.wait", "0") + conf.set("spark.rpc.numRetries", "1") + val anotherEnv = createRpcEnv(conf, "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout") + try { + val e = intercept[Exception] { + rpcEndpointRef.askWithReply[String]("hello", 1 millis) + } + assert(e.isInstanceOf[TimeoutException] || e.getCause.isInstanceOf[TimeoutException]) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("onStart and onStop") { + val stopLatch = new CountDownLatch(1) + val calledMethods = mutable.ArrayBuffer[String]() + + val endpoint = new RpcEndpoint { + override val rpcEnv = env + + override def onStart(): Unit = { + calledMethods += "start" + } + + override def receive: PartialFunction[Any, Unit] = { + case msg: String => + } + + override def onStop(): Unit = { + calledMethods += "stop" + stopLatch.countDown() + } + } + val rpcEndpointRef = env.setupEndpoint("start-stop-test", endpoint) + env.stop(rpcEndpointRef) + stopLatch.await(10, TimeUnit.SECONDS) + assert(List("start", "stop") === calledMethods) + } + + test("onError: error in onStart") { + @volatile var e: Throwable = null + env.setupEndpoint("onError-onStart", new RpcEndpoint { + override val rpcEnv = env + + override def onStart(): Unit = { + throw new RuntimeException("Oops!") + } + + override def receive: PartialFunction[Any, Unit] = { + case m => + } + + override def onError(cause: Throwable): Unit = { + e = cause + } + }) + + eventually(timeout(5 seconds), interval(10 millis)) { + assert(e.getMessage === "Oops!") + } + } + + test("onError: error in onStop") { + @volatile var e: Throwable = null + val endpointRef = env.setupEndpoint("onError-onStop", new RpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case m => + } + + override def onError(cause: Throwable): Unit = { + e = cause + } + + override def onStop(): Unit = { + throw new RuntimeException("Oops!") + } + }) + + env.stop(endpointRef) + + eventually(timeout(5 seconds), interval(10 millis)) { + assert(e.getMessage === "Oops!") + } + } + + test("onError: error in receive") { + @volatile var e: Throwable = null + val endpointRef = env.setupEndpoint("onError-receive", new RpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case m => throw new RuntimeException("Oops!") + } + + override def onError(cause: Throwable): Unit = { + e = cause + } + }) + + endpointRef.send("Foo") + + eventually(timeout(5 seconds), interval(10 millis)) { + assert(e.getMessage === "Oops!") + } + } + + test("self: call in onStart") { + @volatile var callSelfSuccessfully = false + + env.setupEndpoint("self-onStart", new RpcEndpoint { + override val rpcEnv = env + + override def onStart(): Unit = { + self + callSelfSuccessfully = true + } + + override def receive: PartialFunction[Any, Unit] = { + case m => + } + }) + + eventually(timeout(5 seconds), interval(10 millis)) { + // Calling `self` in `onStart` is fine + assert(callSelfSuccessfully === true) + } + } + + test("self: call in receive") { + @volatile var callSelfSuccessfully = false + + val endpointRef = env.setupEndpoint("self-receive", new RpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case m => { + self + callSelfSuccessfully = true + } + } + }) + + endpointRef.send("Foo") + + eventually(timeout(5 seconds), interval(10 millis)) { + // Calling `self` in `receive` is fine + assert(callSelfSuccessfully === true) + } + } + + test("self: call in onStop") { + @volatile var selfOption: Option[RpcEndpointRef] = null + + val endpointRef = env.setupEndpoint("self-onStop", new RpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case m => + } + + override def onStop(): Unit = { + selfOption = Option(self) + } + + override def onError(cause: Throwable): Unit = { + } + }) + + env.stop(endpointRef) + + eventually(timeout(5 seconds), interval(10 millis)) { + // Calling `self` in `onStop` will return null, so selfOption will be None + assert(selfOption == None) + } + } + + test("call receive in sequence") { + // If a RpcEnv implementation breaks the `receive` contract, hope this test can expose it + for(i <- 0 until 100) { + @volatile var result = 0 + val endpointRef = env.setupEndpoint(s"receive-in-sequence-$i", new ThreadSafeRpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case m => result += 1 + } + + }) + + (0 until 10) foreach { _ => + new Thread { + override def run() { + (0 until 100) foreach { _ => + endpointRef.send("Hello") + } + } + }.start() + } + + eventually(timeout(5 seconds), interval(5 millis)) { + assert(result == 1000) + } + + env.stop(endpointRef) + } + } + + test("stop(RpcEndpointRef) reentrant") { + @volatile var onStopCount = 0 + val endpointRef = env.setupEndpoint("stop-reentrant", new RpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case m => + } + + override def onStop(): Unit = { + onStopCount += 1 + } + }) + + env.stop(endpointRef) + env.stop(endpointRef) + + eventually(timeout(5 seconds), interval(5 millis)) { + // Calling stop twice should only trigger onStop once. + assert(onStopCount == 1) + } + } + + test("sendWithReply") { + val endpointRef = env.setupEndpoint("sendWithReply", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case m => context.reply("ack") + } + }) + + val f = endpointRef.sendWithReply[String]("Hi") + val ack = Await.result(f, 5 seconds) + assert("ack" === ack) + + env.stop(endpointRef) + } + + test("sendWithReply: remotely") { + env.setupEndpoint("sendWithReply-remotely", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case m => context.reply("ack") + } + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "sendWithReply-remotely") + try { + val f = rpcEndpointRef.sendWithReply[String]("hello") + val ack = Await.result(f, 5 seconds) + assert("ack" === ack) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("sendWithReply: error") { + val endpointRef = env.setupEndpoint("sendWithReply-error", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case m => context.sendFailure(new SparkException("Oops")) + } + }) + + val f = endpointRef.sendWithReply[String]("Hi") + val e = intercept[SparkException] { + Await.result(f, 5 seconds) + } + assert("Oops" === e.getMessage) + + env.stop(endpointRef) + } + + test("sendWithReply: remotely error") { + env.setupEndpoint("sendWithReply-remotely-error", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case msg: String => context.sendFailure(new SparkException("Oops")) + } + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef( + "local", env.address, "sendWithReply-remotely-error") + try { + val f = rpcEndpointRef.sendWithReply[String]("hello") + val e = intercept[SparkException] { + Await.result(f, 5 seconds) + } + assert("Oops" === e.getMessage) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("network events") { + val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)] + env.setupEndpoint("network-events", new ThreadSafeRpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case "hello" => + case m => events += "receive" -> m + } + + override def onConnected(remoteAddress: RpcAddress): Unit = { + events += "onConnected" -> remoteAddress + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + events += "onDisconnected" -> remoteAddress + } + + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + events += "onNetworkError" -> remoteAddress + } + + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef( + "local", env.address, "network-events") + val remoteAddress = anotherEnv.address + rpcEndpointRef.send("hello") + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events === List(("onConnected", remoteAddress))) + } + + anotherEnv.shutdown() + anotherEnv.awaitTermination() + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events === List( + ("onConnected", remoteAddress), + ("onNetworkError", remoteAddress), + ("onDisconnected", remoteAddress))) + } + } + + test("sendWithReply: unserializable error") { + env.setupEndpoint("sendWithReply-unserializable-error", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case msg: String => context.sendFailure(new UnserializableException) + } + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef( + "local", env.address, "sendWithReply-unserializable-error") + try { + val f = rpcEndpointRef.sendWithReply[String]("hello") + intercept[TimeoutException] { + Await.result(f, 1 seconds) + } + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + +} + +class UnserializableClass + +class UnserializableException extends Exception { + private val unserializableField = new UnserializableClass +} diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala new file mode 100644 index 000000000000..58214c063723 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.akka + +import org.apache.spark.rpc._ +import org.apache.spark.{SecurityManager, SparkConf} + +class AkkaRpcEnvSuite extends RpcEnvSuite { + + override def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv = { + new AkkaRpcEnvFactory().create( + RpcEnvConfig(conf, name, "localhost", port, new SecurityManager(conf))) + } + + test("setupEndpointRef: systemName, address, endpointName") { + val ref = env.setupEndpoint("test_endpoint", new RpcEndpoint { + override val rpcEnv = env + + override def receive = { + case _ => + } + }) + val conf = new SparkConf() + val newRpcEnv = new AkkaRpcEnvFactory().create( + RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf))) + try { + val newRef = newRpcEnv.setupEndpointRef("local", ref.address, "test_endpoint") + assert("akka.tcp://local@localhost:12345/user/test_endpoint" === + newRef.asInstanceOf[AkkaRpcEndpointRef].actorRef.path.toString) + } finally { + newRpcEnv.shutdown() + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index eb116213f69f..3c52a8c4460c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -57,20 +57,18 @@ class MyRDD( locations: Seq[Seq[String]] = Nil) extends RDD[(Int, Int)](sc, dependencies) with Serializable { override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = throw new RuntimeException("should not be reached") - override def getPartitions = (0 until numPartitions).map(i => new Partition { - override def index = i + override def getPartitions: Array[Partition] = (0 until numPartitions).map(i => new Partition { + override def index: Int = i }).toArray override def getPreferredLocations(split: Partition): Seq[String] = - if (locations.isDefinedAt(split.index)) - locations(split.index) - else - Nil + if (locations.isDefinedAt(split.index)) locations(split.index) else Nil override def toString: String = "DAGSchedulerSuiteRDD " + id } class DAGSchedulerSuiteDummyException extends Exception -class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSparkContext with Timeouts { +class DAGSchedulerSuite + extends FunSuiteLike with BeforeAndAfter with LocalSparkContext with Timeouts { val conf = new SparkConf /** Set of TaskSets the DAGScheduler has requested executed. */ @@ -96,6 +94,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar } override def setDAGScheduler(dagScheduler: DAGScheduler) = {} override def defaultParallelism() = 2 + override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} } /** Length of time to wait while draining listener events. */ @@ -208,7 +207,8 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(taskSet.tasks.size >= results.size) for ((result, i) <- results.zipWithIndex) { if (i < taskSet.tasks.size) { - runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, null, null, null)) + runEvent(CompletionEvent( + taskSet.tasks(i), result._1, result._2, null, createFakeTaskInfo(), null)) } } } @@ -219,7 +219,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar for ((result, i) <- results.zipWithIndex) { if (i < taskSet.tasks.size) { runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, - Map[Long, Any]((accumId, 1)), null, null)) + Map[Long, Any]((accumId, 1)), createFakeTaskInfo(), null)) } } } @@ -268,21 +268,23 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar submit(new MyRDD(sc, 1, Nil), Array(0)) complete(taskSets(0), List((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("local job") { val rdd = new PairOfIntsRDD(sc, Nil) { override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = Array(42 -> 0).iterator - override def getPartitions = Array( new Partition { override def index = 0 } ) - override def getPreferredLocations(split: Partition) = Nil - override def toString = "DAGSchedulerSuite Local RDD" + override def getPartitions: Array[Partition] = + Array( new Partition { override def index: Int = 0 } ) + override def getPreferredLocations(split: Partition): List[String] = Nil + override def toString: String = "DAGSchedulerSuite Local RDD" } val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) + runEvent( + JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("local job oom") { @@ -294,9 +296,10 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar override def toString = "DAGSchedulerSuite Local RDD" } val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) + runEvent( + JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) assert(results.size == 0) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("run trivial job w/ dependency") { @@ -305,7 +308,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar submit(finalRdd, Array(0)) complete(taskSets(0), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("cache location preferences w/ dependency") { @@ -318,11 +321,23 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assertLocations(taskSet, Seq(Seq("hostA", "hostB"))) complete(taskSet, Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() + } + + test("regression test for getCacheLocs") { + val rdd = new MyRDD(sc, 3, Nil) + cacheLocations(rdd.id -> 0) = + Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")) + cacheLocations(rdd.id -> 1) = + Seq(makeBlockManagerId("hostB"), makeBlockManagerId("hostC")) + cacheLocations(rdd.id -> 2) = + Seq(makeBlockManagerId("hostC"), makeBlockManagerId("hostD")) + val locs = scheduler.getCacheLocs(rdd).map(_.map(_.host)) + assert(locs === Seq(Seq("hostA", "hostB"), Seq("hostB", "hostC"), Seq("hostC", "hostD"))) } test("avoid exponential blowup when getting preferred locs list") { - // Build up a complex dependency graph with repeated zip operations, without preferred locations. + // Build up a complex dependency graph with repeated zip operations, without preferred locations var rdd: RDD[_] = new MyRDD(sc, 1, Nil) (1 to 30).foreach(_ => rdd = rdd.zip(rdd)) // getPreferredLocs runs quickly, indicating that exponential graph traversal is avoided. @@ -344,7 +359,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.contains(0)) assert(sparkListener.failedStages.size === 1) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("trivial job failure") { @@ -354,7 +369,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.contains(0)) assert(sparkListener.failedStages.size === 1) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("trivial job cancellation") { @@ -365,7 +380,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.contains(0)) assert(sparkListener.failedStages.size === 1) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("job cancellation no-kill backend") { @@ -374,18 +389,21 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar val noKillTaskScheduler = new TaskScheduler() { override def rootPool: Pool = null override def schedulingMode: SchedulingMode = SchedulingMode.NONE - override def start() = {} - override def stop() = {} - override def submitTasks(taskSet: TaskSet) = { + override def start(): Unit = {} + override def stop(): Unit = {} + override def submitTasks(taskSet: TaskSet): Unit = { taskSets += taskSet } override def cancelTasks(stageId: Int, interruptThread: Boolean) { throw new UnsupportedOperationException } - override def setDAGScheduler(dagScheduler: DAGScheduler) = {} - override def defaultParallelism() = 2 - override def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)], - blockManagerId: BlockManagerId): Boolean = true + override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} + override def defaultParallelism(): Int = 2 + override def executorHeartbeatReceived( + execId: String, + taskMetrics: Array[(Long, TaskMetrics)], + blockManagerId: BlockManagerId): Boolean = true + override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} } val noKillScheduler = new DAGScheduler( sc, @@ -408,7 +426,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar // When the task set completes normally, state should be correctly updated. complete(taskSets(0), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.isEmpty) @@ -428,7 +446,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) complete(taskSets(1), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("run trivial shuffle with fetch failure") { @@ -451,10 +469,11 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar // have the 2nd attempt pass complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) // we can see both result blocks now - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB")) + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === + Array("hostA", "hostB")) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("trivial shuffle with multiple fetch failures") { @@ -476,7 +495,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null, Map[Long, Any](), - null, + createFakeTaskInfo(), null)) assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.contains(1)) @@ -487,7 +506,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1, "ignored"), null, Map[Long, Any](), - null, + createFakeTaskInfo(), null)) // The SparkListener should not receive redundant failure events. assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) @@ -507,19 +526,23 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(newEpoch > oldEpoch) val taskSet = taskSets(0) // should be ignored for being too old - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, null, null)) + runEvent(CompletionEvent( + taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) // should work because it's a non-failed host - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), null, null, null)) + runEvent(CompletionEvent( + taskSet.tasks(0), Success, makeMapStatus("hostB", 1), null, createFakeTaskInfo(), null)) // should be ignored for being too old - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, null, null)) + runEvent(CompletionEvent( + taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) // should work because it's a new epoch taskSet.tasks(1).epoch = newEpoch - runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), null, null, null)) + runEvent(CompletionEvent( + taskSet.tasks(1), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) complete(taskSets(1), Seq((Success, 42), (Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("run shuffle with map stage failure") { @@ -538,7 +561,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.toSet === Set(0)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } /** @@ -572,7 +595,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar class FailureRecordingJobListener() extends JobListener { var failureMessage: String = _ override def taskSucceeded(index: Int, result: Any) {} - override def jobFailed(exception: Exception) = { failureMessage = exception.getMessage } + override def jobFailed(exception: Exception): Unit = { failureMessage = exception.getMessage } } val listener1 = new FailureRecordingJobListener() val listener2 = new FailureRecordingJobListener() @@ -592,7 +615,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(listener1.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage") assert(listener2.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage") - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("run trivial shuffle with out-of-band failure and retry") { @@ -615,7 +638,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) complete(taskSets(2), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("recursive shuffle failures") { @@ -644,7 +667,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar complete(taskSets(4), Seq((Success, makeMapStatus("hostA", 1)))) complete(taskSets(5), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("cached post-shuffle") { @@ -676,7 +699,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar complete(taskSets(3), Seq((Success, makeMapStatus("hostD", 1)))) complete(taskSets(4), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("misbehaved accumulator should not crash DAGScheduler and SparkContext") { @@ -728,15 +751,19 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar } test("accumulator not calculated for resubmitted result stage") { - //just for register + // just for register val accum = new Accumulator[Int](0, AccumulatorParam.IntAccumulatorParam) val finalRdd = new MyRDD(sc, 1, Nil) submit(finalRdd, Array(0)) completeWithAccumulator(accum.id, taskSets(0), Seq((Success, 42))) completeWithAccumulator(accum.id, taskSets(0), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assert(Accumulators.originals(accum.id).value === 1) - assertDataStructuresEmpty + + val accVal = Accumulators.originals(accum.id).get.get.value + + assert(accVal === 1) + + assertDataStructuresEmpty() } /** @@ -756,7 +783,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar private def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345) - private def assertDataStructuresEmpty = { + private def assertDataStructuresEmpty(): Unit = { assert(scheduler.activeJobs.isEmpty) assert(scheduler.failedStages.isEmpty) assert(scheduler.jobIdToActiveJob.isEmpty) @@ -765,6 +792,16 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(scheduler.runningStages.isEmpty) assert(scheduler.shuffleToMapStage.isEmpty) assert(scheduler.waitingStages.isEmpty) + assert(scheduler.outputCommitCoordinator.isEmpty) } + + // Nothing in this test should break if the task info's fields are null, but + // OutputCommitCoordinator requires the task info itself to not be null. + private def createFakeTaskInfo(): TaskInfo = { + val info = new TaskInfo(0, 0, 0, 0L, "", "", TaskLocality.ANY, false) + info.finishTime = 1 // to prevent spurious errors in JobProgressListener + info + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 437d8693c0b1..6d25edb7d20d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -18,15 +18,16 @@ package org.apache.spark.scheduler import java.io.{File, FileOutputStream, InputStream, IOException} +import java.net.URI import scala.collection.mutable import scala.io.Source import org.apache.hadoop.fs.Path import org.json4s.jackson.JsonMethods._ -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.{FunSuiteLike, BeforeAndAfter, FunSuite} -import org.apache.spark.{Logging, SparkConf, SparkContext} +import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io._ import org.apache.spark.util.{JsonProtocol, Utils} @@ -38,7 +39,8 @@ import org.apache.spark.util.{JsonProtocol, Utils} * logging events, whether the parsing of the file names is correct, and whether the logged events * can be read and deserialized into actual SparkListenerEvents. */ -class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter with Logging { +class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with BeforeAndAfter + with Logging { import EventLoggingListenerSuite._ private val fileSystem = Utils.getHadoopFileSystem("/", @@ -59,7 +61,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter with Loggin test("Verify log file exist") { // Verify logging directory exists val conf = getLoggingConf(testDirPath) - val eventLogger = new EventLoggingListener("test", testDirPath.toUri().toString(), conf) + val eventLogger = new EventLoggingListener("test", testDirPath.toUri(), conf) eventLogger.start() val logPath = new Path(eventLogger.logPath + EventLoggingListener.IN_PROGRESS) @@ -78,7 +80,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter with Loggin test("Basic event logging with compression") { CompressionCodec.ALL_COMPRESSION_CODECS.foreach { codec => - testEventLogging(compressionCodec = Some(codec)) + testEventLogging(compressionCodec = Some(CompressionCodec.getShortName(codec))) } } @@ -88,25 +90,38 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter with Loggin test("End-to-end event logging with compression") { CompressionCodec.ALL_COMPRESSION_CODECS.foreach { codec => - testApplicationEventLogging(compressionCodec = Some(codec)) + testApplicationEventLogging(compressionCodec = Some(CompressionCodec.getShortName(codec))) } } test("Log overwriting") { - val log = new FileOutputStream(new File(testDir, "test")) - log.close() - try { - testEventLogging() - assert(false) - } catch { - case e: IOException => - // Expected, since we haven't enabled log overwrite. - } - + val logUri = EventLoggingListener.getLogPath(testDir.toURI, "test") + val logPath = new URI(logUri).getPath + // Create file before writing the event log + new FileOutputStream(new File(logPath)).close() + // Expected IOException, since we haven't enabled log overwrite. + intercept[IOException] { testEventLogging() } // Try again, but enable overwriting. testEventLogging(extraConf = Map("spark.eventLog.overwrite" -> "true")) } + test("Event log name") { + // without compression + assert(s"file:/base-dir/app1" === EventLoggingListener.getLogPath( + Utils.resolveURI("/base-dir"), "app1")) + // with compression + assert(s"file:/base-dir/app1.lzf" === + EventLoggingListener.getLogPath(Utils.resolveURI("/base-dir"), "app1", Some("lzf"))) + // illegal characters in app ID + assert(s"file:/base-dir/a-fine-mind_dollar_bills__1" === + EventLoggingListener.getLogPath(Utils.resolveURI("/base-dir"), + "a fine:mind$dollar{bills}.1")) + // illegal characters in app ID with compression + assert(s"file:/base-dir/a-fine-mind_dollar_bills__1.lz4" === + EventLoggingListener.getLogPath(Utils.resolveURI("/base-dir"), + "a fine:mind$dollar{bills}.1", Some("lz4"))) + } + /* ----------------- * * Actual test logic * * ----------------- */ @@ -125,7 +140,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter with Loggin val conf = getLoggingConf(testDirPath, compressionCodec) extraConf.foreach { case (k, v) => conf.set(k, v) } val logName = compressionCodec.map("test-" + _).getOrElse("test") - val eventLogger = new EventLoggingListener(logName, testDirPath.toUri().toString(), conf) + val eventLogger = new EventLoggingListener(logName, testDirPath.toUri(), conf) val listenerBus = new LiveListenerBus val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, 125L, "Mickey") @@ -133,22 +148,24 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter with Loggin // A comprehensive test on JSON de/serialization of all events is in JsonProtocolSuite eventLogger.start() - listenerBus.start() + listenerBus.start(sc) listenerBus.addListener(eventLogger) listenerBus.postToAll(applicationStart) listenerBus.postToAll(applicationEnd) eventLogger.stop() // Verify file contains exactly the two events logged - val (logData, version) = EventLoggingListener.openEventLog(new Path(eventLogger.logPath), - fileSystem) + val logData = EventLoggingListener.openEventLog(new Path(eventLogger.logPath), fileSystem) try { val lines = readLines(logData) - assert(lines.size === 2) - assert(lines(0).contains("SparkListenerApplicationStart")) - assert(lines(1).contains("SparkListenerApplicationEnd")) - assert(JsonProtocol.sparkEventFromJson(parse(lines(0))) === applicationStart) - assert(JsonProtocol.sparkEventFromJson(parse(lines(1))) === applicationEnd) + val logStart = SparkListenerLogStart(SPARK_VERSION) + assert(lines.size === 3) + assert(lines(0).contains("SparkListenerLogStart")) + assert(lines(1).contains("SparkListenerApplicationStart")) + assert(lines(2).contains("SparkListenerApplicationEnd")) + assert(JsonProtocol.sparkEventFromJson(parse(lines(0))) === logStart) + assert(JsonProtocol.sparkEventFromJson(parse(lines(1))) === applicationStart) + assert(JsonProtocol.sparkEventFromJson(parse(lines(2))) === applicationEnd) } finally { logData.close() } @@ -159,12 +176,17 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter with Loggin * This runs a simple Spark job and asserts that the expected events are logged when expected. */ private def testApplicationEventLogging(compressionCodec: Option[String] = None) { + // Set defaultFS to something that would cause an exception, to make sure we don't run + // into SPARK-6688. val conf = getLoggingConf(testDirPath, compressionCodec) + .set("spark.hadoop.fs.defaultFS", "unsupported://example.com") val sc = new SparkContext("local-cluster[2,2,512]", "test", conf) assert(sc.eventLogger.isDefined) val eventLogger = sc.eventLogger.get - val expectedLogDir = testDir.toURI().toString() - assert(eventLogger.logPath.startsWith(expectedLogDir + "/")) + val eventLogPath = eventLogger.logPath + val expectedLogDir = testDir.toURI() + assert(eventLogPath === EventLoggingListener.getLogPath( + expectedLogDir, sc.applicationId, compressionCodec.map(CompressionCodec.getShortName))) // Begin listening for events that trigger asserts val eventExistenceListener = new EventExistenceListener(eventLogger) @@ -178,8 +200,8 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter with Loggin eventExistenceListener.assertAllCallbacksInvoked() // Make sure expected events exist in the log file. - val (logData, version) = EventLoggingListener.openEventLog(new Path(eventLogger.logPath), - fileSystem) + val logData = EventLoggingListener.openEventLog(new Path(eventLogger.logPath), fileSystem) + val logStart = SparkListenerLogStart(SPARK_VERSION) val lines = readLines(logData) val eventSet = mutable.Set( SparkListenerApplicationStart, @@ -204,6 +226,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter with Loggin } } } + assert(JsonProtocol.sparkEventFromJson(parse(lines(0))) === logStart) assert(eventSet.isEmpty, "The following events are missing: " + eventSet.toSeq) } @@ -245,7 +268,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter with Loggin object EventLoggingListenerSuite { /** Get a SparkConf with event logging enabled. */ - def getLoggingConf(logDir: Path, compressionCodec: Option[String] = None) = { + def getLoggingConf(logDir: Path, compressionCodec: Option[String] = None): SparkConf = { val conf = new SparkConf conf.set("spark.eventLog.enabled", "true") conf.set("spark.eventLog.testing", "true") @@ -257,5 +280,5 @@ object EventLoggingListenerSuite { conf } - def getUniqueApplicationId = "test-" + System.currentTimeMillis + def getUniqueApplicationId: String = "test-" + System.currentTimeMillis } diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala index 6b75c98839e0..9b92f8de5675 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala @@ -24,7 +24,9 @@ import org.apache.spark.TaskContext /** * A Task implementation that fails to serialize. */ -private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) extends Task[Array[Byte]](stageId, 0) { +private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) + extends Task[Array[Byte]](stageId, 0) { + override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte] override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]() diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala new file mode 100644 index 000000000000..cf9770794670 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.io.File +import java.util.concurrent.TimeoutException + +import org.mockito.Matchers +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.{BeforeAndAfter, FunSuite} + +import org.apache.hadoop.mapred.{TaskAttemptID, JobConf, TaskAttemptContext, OutputCommitter} + +import org.apache.spark._ +import org.apache.spark.rdd.{RDD, FakeOutputCommitter} +import org.apache.spark.util.Utils + +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.language.postfixOps + +/** + * Unit tests for the output commit coordination functionality. + * + * The unit test makes both the original task and the speculated task + * attempt to commit, where committing is emulated by creating a + * directory. If both tasks create directories then the end result is + * a failure. + * + * Note that there are some aspects of this test that are less than ideal. + * In particular, the test mocks the speculation-dequeuing logic to always + * dequeue a task and consider it as speculated. Immediately after initially + * submitting the tasks and calling reviveOffers(), reviveOffers() is invoked + * again to pick up the speculated task. This may be hacking the original + * behavior in too much of an unrealistic fashion. + * + * Also, the validation is done by checking the number of files in a directory. + * Ideally, an accumulator would be used for this, where we could increment + * the accumulator in the output committer's commitTask() call. If the call to + * commitTask() was called twice erroneously then the test would ideally fail because + * the accumulator would be incremented twice. + * + * The problem with this test implementation is that when both a speculated task and + * its original counterpart complete, only one of the accumulator's increments is + * captured. This results in a paradox where if the OutputCommitCoordinator logic + * was not in SparkHadoopWriter, the tests would still pass because only one of the + * increments would be captured even though the commit in both tasks was executed + * erroneously. + */ +class OutputCommitCoordinatorSuite extends FunSuite with BeforeAndAfter { + + var outputCommitCoordinator: OutputCommitCoordinator = null + var tempDir: File = null + var sc: SparkContext = null + + before { + tempDir = Utils.createTempDir() + val conf = new SparkConf() + .setMaster("local[4]") + .setAppName(classOf[OutputCommitCoordinatorSuite].getSimpleName) + .set("spark.speculation", "true") + sc = new SparkContext(conf) { + override private[spark] def createSparkEnv( + conf: SparkConf, + isLocal: Boolean, + listenerBus: LiveListenerBus): SparkEnv = { + outputCommitCoordinator = spy(new OutputCommitCoordinator(conf)) + // Use Mockito.spy() to maintain the default infrastructure everywhere else. + // This mocking allows us to control the coordinator responses in test cases. + SparkEnv.createDriverEnv(conf, isLocal, listenerBus, Some(outputCommitCoordinator)) + } + } + // Use Mockito.spy() to maintain the default infrastructure everywhere else + val mockTaskScheduler = spy(sc.taskScheduler.asInstanceOf[TaskSchedulerImpl]) + + doAnswer(new Answer[Unit]() { + override def answer(invoke: InvocationOnMock): Unit = { + // Submit the tasks, then force the task scheduler to dequeue the + // speculated task + invoke.callRealMethod() + mockTaskScheduler.backend.reviveOffers() + } + }).when(mockTaskScheduler).submitTasks(Matchers.any()) + + doAnswer(new Answer[TaskSetManager]() { + override def answer(invoke: InvocationOnMock): TaskSetManager = { + val taskSet = invoke.getArguments()(0).asInstanceOf[TaskSet] + new TaskSetManager(mockTaskScheduler, taskSet, 4) { + var hasDequeuedSpeculatedTask = false + override def dequeueSpeculativeTask( + execId: String, + host: String, + locality: TaskLocality.Value): Option[(Int, TaskLocality.Value)] = { + if (!hasDequeuedSpeculatedTask) { + hasDequeuedSpeculatedTask = true + Some(0, TaskLocality.PROCESS_LOCAL) + } else { + None + } + } + } + } + }).when(mockTaskScheduler).createTaskSetManager(Matchers.any(), Matchers.any()) + + sc.taskScheduler = mockTaskScheduler + val dagSchedulerWithMockTaskScheduler = new DAGScheduler(sc, mockTaskScheduler) + sc.taskScheduler.setDAGScheduler(dagSchedulerWithMockTaskScheduler) + sc.dagScheduler = dagSchedulerWithMockTaskScheduler + } + + after { + sc.stop() + tempDir.delete() + outputCommitCoordinator = null + } + + test("Only one of two duplicate commit tasks should commit") { + val rdd = sc.parallelize(Seq(1), 1) + sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).commitSuccessfully _, + 0 until rdd.partitions.size, allowLocal = false) + assert(tempDir.list().size === 1) + } + + test("If commit fails, if task is retried it should not be locked, and will succeed.") { + val rdd = sc.parallelize(Seq(1), 1) + sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).failFirstCommitAttempt _, + 0 until rdd.partitions.size, allowLocal = false) + assert(tempDir.list().size === 1) + } + + test("Job should not complete if all commits are denied") { + // Create a mock OutputCommitCoordinator that denies all attempts to commit + doReturn(false).when(outputCommitCoordinator).handleAskPermissionToCommit( + Matchers.any(), Matchers.any(), Matchers.any()) + val rdd: RDD[Int] = sc.parallelize(Seq(1), 1) + def resultHandler(x: Int, y: Unit): Unit = {} + val futureAction: SimpleFutureAction[Unit] = sc.submitJob[Int, Unit, Unit](rdd, + OutputCommitFunctions(tempDir.getAbsolutePath).commitSuccessfully, + 0 until rdd.partitions.size, resultHandler, () => Unit) + // It's an error if the job completes successfully even though no committer was authorized, + // so throw an exception if the job was allowed to complete. + intercept[TimeoutException] { + Await.result(futureAction, 5 seconds) + } + assert(tempDir.list().size === 0) + } + + test("Only authorized committer failures can clear the authorized committer lock (SPARK-6614)") { + val stage: Int = 1 + val partition: Long = 2 + val authorizedCommitter: Long = 3 + val nonAuthorizedCommitter: Long = 100 + outputCommitCoordinator.stageStart(stage) + assert(outputCommitCoordinator.canCommit(stage, partition, attempt = authorizedCommitter)) + assert(!outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter)) + // The non-authorized committer fails + outputCommitCoordinator.taskCompleted( + stage, partition, attempt = nonAuthorizedCommitter, reason = TaskKilled) + // New tasks should still not be able to commit because the authorized committer has not failed + assert( + !outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 1)) + // The authorized committer now fails, clearing the lock + outputCommitCoordinator.taskCompleted( + stage, partition, attempt = authorizedCommitter, reason = TaskKilled) + // A new task should now be allowed to become the authorized committer + assert( + outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 2)) + // There can only be one authorized committer + assert( + !outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 3)) + } +} + +/** + * Class with methods that can be passed to runJob to test commits with a mock committer. + */ +private case class OutputCommitFunctions(tempDirPath: String) { + + // Mock output committer that simulates a successful commit (after commit is authorized) + private def successfulOutputCommitter = new FakeOutputCommitter { + override def commitTask(context: TaskAttemptContext): Unit = { + Utils.createDirectory(tempDirPath) + } + } + + // Mock output committer that simulates a failed commit (after commit is authorized) + private def failingOutputCommitter = new FakeOutputCommitter { + override def commitTask(taskAttemptContext: TaskAttemptContext) { + throw new RuntimeException + } + } + + def commitSuccessfully(iter: Iterator[Int]): Unit = { + val ctx = TaskContext.get() + runCommitWithProvidedCommitter(ctx, iter, successfulOutputCommitter) + } + + def failFirstCommitAttempt(iter: Iterator[Int]): Unit = { + val ctx = TaskContext.get() + runCommitWithProvidedCommitter(ctx, iter, + if (ctx.attemptNumber == 0) failingOutputCommitter else successfulOutputCommitter) + } + + private def runCommitWithProvidedCommitter( + ctx: TaskContext, + iter: Iterator[Int], + outputCommitter: OutputCommitter): Unit = { + def jobConf = new JobConf { + override def getOutputCommitter(): OutputCommitter = outputCommitter + } + val sparkHadoopWriter = new SparkHadoopWriter(jobConf) { + override def newTaskAttemptContext( + conf: JobConf, + attemptId: TaskAttemptID): TaskAttemptContext = { + mock(classOf[TaskAttemptContext]) + } + } + sparkHadoopWriter.setup(ctx.stageId, ctx.partitionId, ctx.attemptNumber) + sparkHadoopWriter.commit() + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 7e360cc6082e..6de6d2fec622 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.io.{File, PrintWriter} +import java.net.URI import org.json4s.jackson.JsonMethods._ import org.scalatest.{BeforeAndAfter, FunSuite} @@ -61,7 +62,7 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter { try { val replayer = new ReplayListenerBus() replayer.addListener(eventMonster) - replayer.replay(logData, SPARK_VERSION) + replayer.replay(logData, logFilePath.toString) } finally { logData.close() } @@ -115,12 +116,12 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter { assert(!eventLog.isDir) // Replay events - val (logData, version) = EventLoggingListener.openEventLog(eventLog.getPath(), fileSystem) + val logData = EventLoggingListener.openEventLog(eventLog.getPath(), fileSystem) val eventMonster = new EventMonster(conf) try { val replayer = new ReplayListenerBus() replayer.addListener(eventMonster) - replayer.replay(logData, version) + replayer.replay(logData, eventLog.getPath().toString) } finally { logData.close() } @@ -145,16 +146,9 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter { * log the events. */ private class EventMonster(conf: SparkConf) - extends EventLoggingListener("test", "testdir", conf) { + extends EventLoggingListener("test", new URI("testdir"), conf) { override def start() { } } - - private def getCompressionCodec(codecName: String) = { - val conf = new SparkConf - conf.set("spark.io.compression.codec", codecName) - CompressionCodec.createCodec(conf) - } - } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 0fb1bdd30d97..825c616c0c3e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -20,26 +20,22 @@ package org.apache.spark.scheduler import java.util.concurrent.Semaphore import scala.collection.mutable +import scala.collection.JavaConversions._ -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} -import org.scalatest.Matchers +import org.scalatest.{FunSuite, Matchers} -import org.apache.spark.{LocalSparkContext, SparkContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.util.ResetSystemProperties +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext} -class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers with BeforeAndAfter - with BeforeAndAfterAll with ResetSystemProperties { +class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers + with ResetSystemProperties { /** Length of time to wait while draining listener events. */ val WAIT_TIMEOUT_MILLIS = 10000 val jobCompletionTime = 1421191296660L - before { - sc = new SparkContext("local", "SparkListenerSuite") - } - test("basic creation and shutdown of LiveListenerBus") { val counter = new BasicJobCounter val bus = new LiveListenerBus @@ -50,7 +46,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers assert(counter.count === 0) // Starting listener bus should flush all buffered events - bus.start() + bus.start(sc) assert(bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(counter.count === 5) @@ -62,8 +58,8 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers // Listener bus must not be started twice intercept[IllegalStateException] { val bus = new LiveListenerBus - bus.start() - bus.start() + bus.start(sc) + bus.start(sc) } // ... or stopped before starting @@ -89,7 +85,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers val stopperReturned = new Semaphore(0) class BlockingListener extends SparkListener { - override def onJobEnd(jobEnd: SparkListenerJobEnd) = { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { listenerStarted.release() listenerWait.acquire() drained = true @@ -100,7 +96,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers val blockingListener = new BlockingListener bus.addListener(blockingListener) - bus.start() + bus.start(sc) bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) listenerStarted.acquire() @@ -127,6 +123,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers } test("basic creation of StageInfo") { + sc = new SparkContext("local", "SparkListenerSuite") val listener = new SaveStageAndTaskInfo sc.addSparkListener(listener) val rdd1 = sc.parallelize(1 to 100, 4) @@ -148,6 +145,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers } test("basic creation of StageInfo with shuffle") { + sc = new SparkContext("local", "SparkListenerSuite") val listener = new SaveStageAndTaskInfo sc.addSparkListener(listener) val rdd1 = sc.parallelize(1 to 100, 4) @@ -185,6 +183,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers } test("StageInfo with fewer tasks than partitions") { + sc = new SparkContext("local", "SparkListenerSuite") val listener = new SaveStageAndTaskInfo sc.addSparkListener(listener) val rdd1 = sc.parallelize(1 to 100, 4) @@ -201,13 +200,15 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers } test("local metrics") { + sc = new SparkContext("local", "SparkListenerSuite") val listener = new SaveStageAndTaskInfo sc.addSparkListener(listener) sc.addSparkListener(new StatsReportListener) // just to make sure some of the tasks take a noticeable amount of time val w = { i: Int => - if (i == 0) + if (i == 0) { Thread.sleep(100) + } i } @@ -247,12 +248,12 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers */ taskInfoMetrics.foreach { case (taskInfo, taskMetrics) => - taskMetrics.resultSize should be > (0l) + taskMetrics.resultSize should be > (0L) if (stageInfo.rddInfos.exists(info => info.name == d2.name || info.name == d3.name)) { taskMetrics.inputMetrics should not be ('defined) taskMetrics.outputMetrics should not be ('defined) taskMetrics.shuffleWriteMetrics should be ('defined) - taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0l) + taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0L) } if (stageInfo.rddInfos.exists(_.name == d4.name)) { taskMetrics.shuffleReadMetrics should be ('defined) @@ -260,13 +261,14 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers sm.totalBlocksFetched should be (128) sm.localBlocksFetched should be (128) sm.remoteBlocksFetched should be (0) - sm.remoteBytesRead should be (0l) + sm.remoteBytesRead should be (0L) } } } } test("onTaskGettingResult() called when result fetched remotely") { + sc = new SparkContext("local", "SparkListenerSuite") val listener = new SaveTaskEvents sc.addSparkListener(listener) @@ -287,6 +289,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers } test("onTaskGettingResult() not called when result sent directly") { + sc = new SparkContext("local", "SparkListenerSuite") val listener = new SaveTaskEvents sc.addSparkListener(listener) @@ -302,6 +305,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers } test("onTaskEnd() should be called for all started tasks, even after job has been killed") { + sc = new SparkContext("local", "SparkListenerSuite") val WAIT_TIMEOUT_MILLIS = 10000 val listener = new SaveTaskEvents sc.addSparkListener(listener) @@ -344,7 +348,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers bus.addListener(badListener) bus.addListener(jobCounter1) bus.addListener(jobCounter2) - bus.start() + bus.start(sc) // Post events to all listeners, and wait until the queue is drained (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } @@ -356,6 +360,17 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers assert(jobCounter2.count === 5) } + test("registering listeners via spark.extraListeners") { + val conf = new SparkConf().setMaster("local").setAppName("test") + .set("spark.extraListeners", classOf[ListenerThatAcceptsSparkConf].getName + "," + + classOf[BasicJobCounter].getName) + sc = new SparkContext(conf) + sc.listenerBus.listeners.collect { case x: BasicJobCounter => x}.size should be (1) + sc.listenerBus.listeners.collect { + case x: ListenerThatAcceptsSparkConf => x + }.size should be (1) + } + /** * Assert that the given list of numbers has an average that is greater than zero. */ @@ -363,14 +378,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers assert(m.sum / m.size.toDouble > 0.0, msg) } - /** - * A simple listener that counts the number of jobs observed. - */ - private class BasicJobCounter extends SparkListener { - var count = 0 - override def onJobEnd(job: SparkListenerJobEnd) = count += 1 - } - /** * A simple listener that saves all task infos and task metrics. */ @@ -400,12 +407,12 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers val startedGettingResultTasks = new mutable.HashSet[Int]() val endedTasks = new mutable.HashSet[Int]() - override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { startedTasks += taskStart.taskInfo.index notify() } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { endedTasks += taskEnd.taskInfo.index notify() } @@ -419,7 +426,23 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers * A simple listener that throws an exception on job end. */ private class BadListener extends SparkListener { - override def onJobEnd(jobEnd: SparkListenerJobEnd) = { throw new Exception } + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { throw new Exception } } } + +// These classes can't be declared inside of the SparkListenerSuite class because we don't want +// their constructors to contain references to SparkListenerSuite: + +/** + * A simple listener that counts the number of jobs observed. + */ +private class BasicJobCounter extends SparkListener { + var count = 0 + override def onJobEnd(job: SparkListenerJobEnd): Unit = count += 1 +} + +private class ListenerThatAcceptsSparkConf(conf: SparkConf) extends SparkListener { + var count = 0 + override def onJobEnd(job: SparkListenerJobEnd): Unit = count += 1 +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index add13f5b2176..ffa4381969b6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.scheduler -import java.util.Properties - import org.scalatest.FunSuite import org.apache.spark._ @@ -27,7 +25,7 @@ class FakeSchedulerBackend extends SchedulerBackend { def start() {} def stop() {} def reviveOffers() {} - def defaultParallelism() = 1 + def defaultParallelism(): Int = 1 } class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Logging { @@ -115,7 +113,8 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin } val numFreeCores = 1 taskScheduler.setDAGScheduler(dagScheduler) - var taskSet = new TaskSet(Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) + val taskSet = new TaskSet( + Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) val multiCoreWorkerOffers = Seq(new WorkerOffer("executor0", "host0", taskCpus), new WorkerOffer("executor1", "host1", numFreeCores)) taskScheduler.submitTasks(taskSet) @@ -123,7 +122,8 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin assert(0 === taskDescriptions.length) // Now check that we can still submit tasks - // Even if one of the tasks has not-serializable tasks, the other task set should still be processed without error + // Even if one of the tasks has not-serializable tasks, the other task set should + // still be processed without error taskScheduler.submitTasks(taskSet) taskScheduler.submitTasks(FakeTask.createTaskSet(1)) taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 84b9b788237b..6198cea46ddf 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.scheduler -import java.io.{ObjectInputStream, ObjectOutputStream, IOException} import java.util.Random import scala.collection.mutable.ArrayBuffer @@ -27,7 +26,7 @@ import org.scalatest.FunSuite import org.apache.spark._ import org.apache.spark.executor.TaskMetrics -import org.apache.spark.util.FakeClock +import org.apache.spark.util.{ManualClock, Utils} class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) extends DAGScheduler(sc) { @@ -67,7 +66,7 @@ object FakeRackUtil { hostToRack(host) = rack } - def getRackForHost(host: String) = { + def getRackForHost(host: String): Option[String] = { hostToRack.get(host) } } @@ -152,7 +151,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { private val conf = new SparkConf - val LOCALITY_WAIT = conf.getLong("spark.locality.wait", 3000) + val LOCALITY_WAIT_MS = conf.getTimeAsMs("spark.locality.wait", "3s") val MAX_TASK_FAILURES = 4 override def beforeEach() { @@ -164,7 +163,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(1) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // Offer a host with NO_PREF as the constraint, @@ -213,7 +212,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execC", "host2")) val taskSet = FakeTask.createTaskSet(1, Seq(TaskLocation("host1", "execB"))) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // An executor that is not NODE_LOCAL should be rejected. @@ -234,13 +233,13 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq(TaskLocation("host1"), TaskLocation("host2", "exec2")), Seq() // Last task has no locality prefs ) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // First offer host1, exec1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) == None) - clock.advance(LOCALITY_WAIT) + clock.advance(LOCALITY_WAIT_MS) // Offer host1, exec1 again, at NODE_LOCAL level: the node local (task 2) should // get chosen before the noPref task assert(manager.resourceOffer("exec1", "host1", NODE_LOCAL).get.index == 2) @@ -251,7 +250,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // Offer host2, exec3 again, at NODE_LOCAL level: we should get noPref task // after failing to find a node_Local task assert(manager.resourceOffer("exec2", "host2", NODE_LOCAL) == None) - clock.advance(LOCALITY_WAIT) + clock.advance(LOCALITY_WAIT_MS) assert(manager.resourceOffer("exec2", "host2", NO_PREF).get.index == 3) } @@ -263,7 +262,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq(TaskLocation("host2", "exec3")), Seq() // Last task has no locality prefs ) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // First offer host1, exec1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL).get.index === 0) @@ -283,7 +282,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq(TaskLocation("host3")), Seq(TaskLocation("host2")) ) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // First offer host1: first task should be chosen @@ -292,7 +291,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // Offer host1 again: nothing should get chosen assert(manager.resourceOffer("exec1", "host1", ANY) === None) - clock.advance(LOCALITY_WAIT) + clock.advance(LOCALITY_WAIT_MS) // Offer host1 again: second task (on host2) should get chosen assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 1) @@ -306,7 +305,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // Now that we've launched a local task, we should no longer launch the task for host3 assert(manager.resourceOffer("exec2", "host2", ANY) === None) - clock.advance(LOCALITY_WAIT) + clock.advance(LOCALITY_WAIT_MS) // After another delay, we can go ahead and launch that task non-locally assert(manager.resourceOffer("exec2", "host2", ANY).get.index === 3) @@ -314,20 +313,21 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("delay scheduling with failed hosts") { sc = new SparkContext("local", "test") - val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"), + ("exec3", "host3")) val taskSet = FakeTask.createTaskSet(3, Seq(TaskLocation("host1")), Seq(TaskLocation("host2")), Seq(TaskLocation("host3")) ) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // First offer host1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) - // After this, nothing should get chosen, because we have separated tasks with unavailable preference - // from the noPrefPendingTasks + // After this, nothing should get chosen, because we have separated tasks with unavailable + // preference from the noPrefPendingTasks assert(manager.resourceOffer("exec1", "host1", ANY) === None) // Now mark host2 as dead @@ -337,7 +337,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // nothing should be chosen assert(manager.resourceOffer("exec1", "host1", ANY) === None) - clock.advance(LOCALITY_WAIT * 2) + clock.advance(LOCALITY_WAIT_MS * 2) // task 1 and 2 would be scheduled as nonLocal task assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 1) @@ -352,7 +352,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(1) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) @@ -369,7 +369,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(1) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted @@ -401,7 +401,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { ("exec1.1", "host1"), ("exec2", "host2")) // affinity to exec1 on host1 - which we will fail. val taskSet = FakeTask.createTaskSet(1, Seq(TaskLocation("host1", "exec1"))) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, 4, clock) { @@ -485,7 +485,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq(TaskLocation("host1", "execB")), Seq(TaskLocation("host2", "execC")), Seq()) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // Only ANY is valid assert(manager.myLocalityLevels.sameElements(Array(NO_PREF, ANY))) @@ -498,7 +498,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { sched.addExecutor("execC", "host2") manager.executorAdded() // Valid locality should contain PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL and ANY - assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY))) + assert(manager.myLocalityLevels.sameElements( + Array(PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY))) // test if the valid locality is recomputed when the executor is lost sched.removeExecutor("execC") manager.executorLost("execC", "host2") @@ -521,12 +522,12 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { val taskSet = FakeTask.createTaskSet(2, Seq(TaskLocation("host1", "execA")), Seq(TaskLocation("host1", "execA"))) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY))) // Set allowed locality to ANY - clock.advance(LOCALITY_WAIT * 3) + clock.advance(LOCALITY_WAIT_MS * 3) // Offer host3 // No task is scheduled if we restrict locality to RACK_LOCAL assert(manager.resourceOffer("execC", "host3", RACK_LOCAL) === None) @@ -568,7 +569,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) - val taskSet = new TaskSet(Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) + val taskSet = new TaskSet( + Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) intercept[TaskNotSerializableException] { @@ -581,7 +583,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { val conf = new SparkConf().set("spark.driver.maxResultSize", "2m") sc = new SparkContext("local", "test", conf) - def genBytes(size: Int) = { (x: Int) => + def genBytes(size: Int): (Int) => Array[Byte] = { (x: Int) => val bytes = Array.ofDim[Byte](size) scala.util.Random.nextBytes(bytes) bytes @@ -604,13 +606,14 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("speculative and noPref task should be scheduled after node-local") { sc = new SparkContext("local", "test") - val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) + val sched = new FakeTaskScheduler( + sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) val taskSet = FakeTask.createTaskSet(4, Seq(TaskLocation("host1", "execA")), Seq(TaskLocation("host2"), TaskLocation("host1")), Seq(), Seq(TaskLocation("host3", "execC"))) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) assert(manager.resourceOffer("execA", "host1", PROCESS_LOCAL).get.index === 0) @@ -618,25 +621,27 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.resourceOffer("execA", "host1", NO_PREF).get.index == 1) manager.speculatableTasks += 1 - clock.advance(LOCALITY_WAIT) + clock.advance(LOCALITY_WAIT_MS) // schedule the nonPref task assert(manager.resourceOffer("execA", "host1", NO_PREF).get.index === 2) // schedule the speculative task assert(manager.resourceOffer("execB", "host2", NO_PREF).get.index === 1) - clock.advance(LOCALITY_WAIT * 3) + clock.advance(LOCALITY_WAIT_MS * 3) // schedule non-local tasks assert(manager.resourceOffer("execB", "host2", ANY).get.index === 3) } - test("node-local tasks should be scheduled right away when there are only node-local and no-preference tasks") { + test("node-local tasks should be scheduled right away " + + "when there are only node-local and no-preference tasks") { sc = new SparkContext("local", "test") - val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) + val sched = new FakeTaskScheduler( + sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) val taskSet = FakeTask.createTaskSet(4, Seq(TaskLocation("host1")), Seq(TaskLocation("host2")), Seq(), Seq(TaskLocation("host3"))) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // node-local tasks are scheduled without delay @@ -649,6 +654,48 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.resourceOffer("execA", "host3", NO_PREF).get.index === 2) } + test("SPARK-4939: node-local tasks should be scheduled right after process-local tasks finished") + { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) + val taskSet = FakeTask.createTaskSet(4, + Seq(TaskLocation("host1")), + Seq(TaskLocation("host2")), + Seq(ExecutorCacheTaskLocation("host1", "execA")), + Seq(ExecutorCacheTaskLocation("host2", "execB"))) + val clock = new ManualClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + + // process-local tasks are scheduled first + assert(manager.resourceOffer("execA", "host1", NODE_LOCAL).get.index === 2) + assert(manager.resourceOffer("execB", "host2", NODE_LOCAL).get.index === 3) + // node-local tasks are scheduled without delay + assert(manager.resourceOffer("execA", "host1", NODE_LOCAL).get.index === 0) + assert(manager.resourceOffer("execB", "host2", NODE_LOCAL).get.index === 1) + assert(manager.resourceOffer("execA", "host1", NODE_LOCAL) == None) + assert(manager.resourceOffer("execB", "host2", NODE_LOCAL) == None) + } + + test("SPARK-4939: no-pref tasks should be scheduled after process-local tasks finished") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) + val taskSet = FakeTask.createTaskSet(3, + Seq(), + Seq(ExecutorCacheTaskLocation("host1", "execA")), + Seq(ExecutorCacheTaskLocation("host2", "execB"))) + val clock = new ManualClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + + // process-local tasks are scheduled first + assert(manager.resourceOffer("execA", "host1", PROCESS_LOCAL).get.index === 1) + assert(manager.resourceOffer("execB", "host2", PROCESS_LOCAL).get.index === 2) + // no-pref tasks are scheduled without delay + assert(manager.resourceOffer("execA", "host1", PROCESS_LOCAL) == None) + assert(manager.resourceOffer("execA", "host1", NODE_LOCAL) == None) + assert(manager.resourceOffer("execA", "host1", NO_PREF).get.index === 0) + assert(manager.resourceOffer("execA", "host1", ANY) == None) + } + test("Ensure TaskSetManager is usable after addition of levels") { // Regression test for SPARK-2931 sc = new SparkContext("local", "test") @@ -656,7 +703,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { val taskSet = FakeTask.createTaskSet(2, Seq(TaskLocation("host1", "execA")), Seq(TaskLocation("host2", "execB.1"))) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // Only ANY is valid assert(manager.myLocalityLevels.sameElements(Array(ANY))) @@ -668,13 +715,13 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // Valid locality should contain PROCESS_LOCAL, NODE_LOCAL and ANY assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) assert(manager.resourceOffer("execA", "host1", ANY) !== None) - clock.advance(LOCALITY_WAIT * 4) + clock.advance(LOCALITY_WAIT_MS * 4) assert(manager.resourceOffer("execB.2", "host2", ANY) !== None) sched.removeExecutor("execA") sched.removeExecutor("execB.2") manager.executorLost("execA", "host1") manager.executorLost("execB.2", "host2") - clock.advance(LOCALITY_WAIT * 4) + clock.advance(LOCALITY_WAIT_MS * 4) sched.addExecutor("execC", "host3") manager.executorAdded() // Prior to the fix, this line resulted in an ArrayIndexOutOfBoundsException: @@ -690,7 +737,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq(HostTaskLocation("host1")), Seq(HostTaskLocation("host2")), Seq(HDFSCacheTaskLocation("host3"))) - val clock = new FakeClock + val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) sched.removeExecutor("execA") diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala new file mode 100644 index 000000000000..3fa0115e6825 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import org.mockito.Mockito._ +import org.scalatest.FunSuite +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.{SparkConf, SparkContext} + +class MemoryUtilsSuite extends FunSuite with MockitoSugar { + test("MesosMemoryUtils should always override memoryOverhead when it's set") { + val sparkConf = new SparkConf + + val sc = mock[SparkContext] + when(sc.conf).thenReturn(sparkConf) + + // 384 > sc.executorMemory * 0.1 => 512 + 384 = 896 + when(sc.executorMemory).thenReturn(512) + assert(MemoryUtils.calculateTotalMemory(sc) === 896) + + // 384 < sc.executorMemory * 0.1 => 4096 + (4096 * 0.1) = 4505.6 + when(sc.executorMemory).thenReturn(4096) + assert(MemoryUtils.calculateTotalMemory(sc) === 4505) + + // set memoryOverhead + sparkConf.set("spark.mesos.executor.memoryOverhead", "100") + assert(MemoryUtils.calculateTotalMemory(sc) === 4196) + sparkConf.set("spark.mesos.executor.memoryOverhead", "400") + assert(MemoryUtils.calculateTotalMemory(sc) === 4496) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala similarity index 50% rename from core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala rename to core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index 073814c127ed..cdd7be0fbe5d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -15,62 +15,66 @@ * limitations under the License. */ -package org.apache.spark.scheduler.mesos +package org.apache.spark.scheduler.cluster.mesos -import org.apache.spark.executor.MesosExecutorBackend -import org.scalatest.FunSuite -import org.apache.spark.{SparkConf, SparkContext, LocalSparkContext} -import org.apache.spark.scheduler.{SparkListenerExecutorAdded, LiveListenerBus, - TaskDescription, WorkerOffer, TaskSchedulerImpl} -import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.scheduler.cluster.mesos.{MemoryUtils, MesosSchedulerBackend} -import org.apache.mesos.SchedulerDriver -import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, _} -import org.apache.mesos.Protos.Value.Scalar -import org.easymock.{Capture, EasyMock} import java.nio.ByteBuffer -import java.util.Collections import java.util -import org.scalatest.mock.EasyMockSugar +import java.util.Collections import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with EasyMockSugar { +import org.apache.mesos.Protos.Value.Scalar +import org.apache.mesos.Protos._ +import org.apache.mesos.SchedulerDriver +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.{ArgumentCaptor, Matchers} +import org.scalatest.FunSuite +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.executor.MesosExecutorBackend +import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerExecutorAdded, + TaskDescription, TaskSchedulerImpl, WorkerOffer} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext} + +class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with MockitoSugar { test("check spark-class location correctly") { val conf = new SparkConf conf.set("spark.mesos.executor.home" , "/mesos-home") - val listenerBus = EasyMock.createMock(classOf[LiveListenerBus]) - listenerBus.post(SparkListenerExecutorAdded("s1", new ExecutorInfo("host1", 2))) - EasyMock.replay(listenerBus) - - val sc = EasyMock.createMock(classOf[SparkContext]) - EasyMock.expect(sc.getSparkHome()).andReturn(Option("/spark-home")).anyTimes() - EasyMock.expect(sc.conf).andReturn(conf).anyTimes() - EasyMock.expect(sc.executorEnvs).andReturn(new mutable.HashMap).anyTimes() - EasyMock.expect(sc.executorMemory).andReturn(100).anyTimes() - EasyMock.expect(sc.listenerBus).andReturn(listenerBus) - EasyMock.replay(sc) - val taskScheduler = EasyMock.createMock(classOf[TaskSchedulerImpl]) - EasyMock.expect(taskScheduler.CPUS_PER_TASK).andReturn(2).anyTimes() - EasyMock.replay(taskScheduler) + val listenerBus = mock[LiveListenerBus] + listenerBus.post( + SparkListenerExecutorAdded(anyLong, "s1", new ExecutorInfo("host1", 2, Map.empty))) + + val sc = mock[SparkContext] + when(sc.getSparkHome()).thenReturn(Option("/spark-home")) + + when(sc.conf).thenReturn(conf) + when(sc.executorEnvs).thenReturn(new mutable.HashMap[String, String]) + when(sc.executorMemory).thenReturn(100) + when(sc.listenerBus).thenReturn(listenerBus) + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.CPUS_PER_TASK).thenReturn(2) val mesosSchedulerBackend = new MesosSchedulerBackend(taskScheduler, sc, "master") // uri is null. val executorInfo = mesosSchedulerBackend.createExecutorInfo("test-id") - assert(executorInfo.getCommand.getValue === s" /mesos-home/bin/spark-class ${classOf[MesosExecutorBackend].getName}") + assert(executorInfo.getCommand.getValue === + s" /mesos-home/bin/spark-class ${classOf[MesosExecutorBackend].getName}") // uri exists. conf.set("spark.executor.uri", "hdfs:///test-app-1.0.0.tgz") val executorInfo1 = mesosSchedulerBackend.createExecutorInfo("test-id") - assert(executorInfo1.getCommand.getValue === s"cd test-app-1*; ./bin/spark-class ${classOf[MesosExecutorBackend].getName}") + assert(executorInfo1.getCommand.getValue === + s"cd test-app-1*; ./bin/spark-class ${classOf[MesosExecutorBackend].getName}") } test("mesos resource offers result in launching tasks") { - def createOffer(id: Int, mem: Int, cpu: Int) = { + def createOffer(id: Int, mem: Int, cpu: Int): Offer = { val builder = Offer.newBuilder() builder.addResourcesBuilder() .setName("mem") @@ -80,24 +84,25 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Ea .setName("cpus") .setType(Value.Type.SCALAR) .setScalar(Scalar.newBuilder().setValue(cpu)) - builder.setId(OfferID.newBuilder().setValue(s"o${id.toString}").build()).setFrameworkId(FrameworkID.newBuilder().setValue("f1")) - .setSlaveId(SlaveID.newBuilder().setValue(s"s${id.toString}")).setHostname(s"host${id.toString}").build() + builder.setId(OfferID.newBuilder().setValue(s"o${id.toString}").build()) + .setFrameworkId(FrameworkID.newBuilder().setValue("f1")) + .setSlaveId(SlaveID.newBuilder().setValue(s"s${id.toString}")) + .setHostname(s"host${id.toString}").build() } - val driver = EasyMock.createMock(classOf[SchedulerDriver]) - val taskScheduler = EasyMock.createMock(classOf[TaskSchedulerImpl]) + val driver = mock[SchedulerDriver] + val taskScheduler = mock[TaskSchedulerImpl] - val listenerBus = EasyMock.createMock(classOf[LiveListenerBus]) - listenerBus.post(SparkListenerExecutorAdded("s1", new ExecutorInfo("host1", 2))) - EasyMock.replay(listenerBus) + val listenerBus = mock[LiveListenerBus] + listenerBus.post( + SparkListenerExecutorAdded(anyLong, "s1", new ExecutorInfo("host1", 2, Map.empty))) - val sc = EasyMock.createMock(classOf[SparkContext]) - EasyMock.expect(sc.executorMemory).andReturn(100).anyTimes() - EasyMock.expect(sc.getSparkHome()).andReturn(Option("/path")).anyTimes() - EasyMock.expect(sc.executorEnvs).andReturn(new mutable.HashMap).anyTimes() - EasyMock.expect(sc.conf).andReturn(new SparkConf).anyTimes() - EasyMock.expect(sc.listenerBus).andReturn(listenerBus) - EasyMock.replay(sc) + val sc = mock[SparkContext] + when(sc.executorMemory).thenReturn(100) + when(sc.getSparkHome()).thenReturn(Option("/path")) + when(sc.executorEnvs).thenReturn(new mutable.HashMap[String, String]) + when(sc.conf).thenReturn(new SparkConf) + when(sc.listenerBus).thenReturn(listenerBus) val minMem = MemoryUtils.calculateTotalMemory(sc).toInt val minCpu = 4 @@ -113,33 +118,37 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Ea expectedWorkerOffers.append(new WorkerOffer( mesosOffers.get(0).getSlaveId.getValue, mesosOffers.get(0).getHostname, - 2 + (minCpu - backend.mesosExecutorCores).toInt )) expectedWorkerOffers.append(new WorkerOffer( mesosOffers.get(2).getSlaveId.getValue, mesosOffers.get(2).getHostname, - 2 + (minCpu - backend.mesosExecutorCores).toInt )) val taskDesc = new TaskDescription(1L, 0, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0))) - EasyMock.expect(taskScheduler.resourceOffers(EasyMock.eq(expectedWorkerOffers))).andReturn(Seq(Seq(taskDesc))) - EasyMock.expect(taskScheduler.CPUS_PER_TASK).andReturn(2).anyTimes() - EasyMock.replay(taskScheduler) + when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc))) + when(taskScheduler.CPUS_PER_TASK).thenReturn(2) - val capture = new Capture[util.Collection[TaskInfo]] - EasyMock.expect( + val capture = ArgumentCaptor.forClass(classOf[util.Collection[TaskInfo]]) + when( driver.launchTasks( - EasyMock.eq(Collections.singleton(mesosOffers.get(0).getId)), - EasyMock.capture(capture), - EasyMock.anyObject(classOf[Filters]) + Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), + capture.capture(), + any(classOf[Filters]) ) - ).andReturn(Status.valueOf(1)).once - EasyMock.expect(driver.declineOffer(mesosOffers.get(1).getId)).andReturn(Status.valueOf(1)).times(1) - EasyMock.expect(driver.declineOffer(mesosOffers.get(2).getId)).andReturn(Status.valueOf(1)).times(1) - EasyMock.replay(driver) + ).thenReturn(Status.valueOf(1)) + when(driver.declineOffer(mesosOffers.get(1).getId)).thenReturn(Status.valueOf(1)) + when(driver.declineOffer(mesosOffers.get(2).getId)).thenReturn(Status.valueOf(1)) backend.resourceOffers(driver, mesosOffers) - EasyMock.verify(driver) + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), + capture.capture(), + any(classOf[Filters]) + ) + verify(driver, times(1)).declineOffer(mesosOffers.get(1).getId) + verify(driver, times(1)).declineOffer(mesosOffers.get(2).getId) assert(capture.getValue.size() == 1) val taskInfo = capture.getValue.iterator().next() assert(taskInfo.getName.equals("n1")) @@ -151,15 +160,13 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Ea // Unwanted resources offered on an existing node. Make sure they are declined val mesosOffers2 = new java.util.ArrayList[Offer] mesosOffers2.add(createOffer(1, minMem, minCpu)) - EasyMock.reset(taskScheduler) - EasyMock.reset(driver) - EasyMock.expect(taskScheduler.resourceOffers(EasyMock.anyObject(classOf[Seq[WorkerOffer]])).andReturn(Seq(Seq()))) - EasyMock.expect(taskScheduler.CPUS_PER_TASK).andReturn(2).anyTimes() - EasyMock.replay(taskScheduler) - EasyMock.expect(driver.declineOffer(mesosOffers2.get(0).getId)).andReturn(Status.valueOf(1)).times(1) - EasyMock.replay(driver) + reset(taskScheduler) + reset(driver) + when(taskScheduler.resourceOffers(any(classOf[Seq[WorkerOffer]]))).thenReturn(Seq(Seq())) + when(taskScheduler.CPUS_PER_TASK).thenReturn(2) + when(driver.declineOffer(mesosOffers2.get(0).getId)).thenReturn(Status.valueOf(1)) backend.resourceOffers(driver, mesosOffers2) - EasyMock.verify(driver) + verify(driver, times(1)).declineOffer(mesosOffers2.get(0).getId) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosTaskLaunchDataSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala similarity index 92% rename from core/src/test/scala/org/apache/spark/scheduler/mesos/MesosTaskLaunchDataSuite.scala rename to core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala index 86a42a7398e4..eebcba40f8a1 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosTaskLaunchDataSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala @@ -15,14 +15,12 @@ * limitations under the License. */ -package org.apache.spark.scheduler.mesos +package org.apache.spark.scheduler.cluster.mesos import java.nio.ByteBuffer import org.scalatest.FunSuite -import org.apache.spark.scheduler.cluster.mesos.MesosTaskLaunchData - class MesosTaskLaunchDataSuite extends FunSuite { test("serialize and deserialize data must be same") { val serializedTask = ByteBuffer.allocate(40) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala index 855f1b627608..054a4c64897a 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala @@ -29,9 +29,9 @@ class KryoSerializerDistributedSuite extends FunSuite { test("kryo objects are serialised consistently in different processes") { val conf = new SparkConf(false) - conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.kryo.registrator", classOf[AppJarRegistrator].getName) - conf.set("spark.task.maxFailures", "1") + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .set("spark.kryo.registrator", classOf[AppJarRegistrator].getName) + .set("spark.task.maxFailures", "1") val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName)) conf.setJars(List(jar.getPath)) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index a70f67af2e62..b070a54aa989 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -23,9 +23,10 @@ import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo import org.scalatest.FunSuite -import org.apache.spark.{SparkConf, SharedSparkContext} +import org.apache.spark.{SharedSparkContext, SparkConf} +import org.apache.spark.scheduler.HighlyCompressedMapStatus import org.apache.spark.serializer.KryoTest._ - +import org.apache.spark.storage.BlockManagerId class KryoSerializerSuite extends FunSuite with SharedSparkContext { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") @@ -105,7 +106,9 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { check(mutable.HashMap(1 -> "one", 2 -> "two")) check(mutable.HashMap("one" -> 1, "two" -> 2)) check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4)))) - check(List(mutable.HashMap("one" -> 1, "two" -> 2),mutable.HashMap(1->"one",2->"two",3->"three"))) + check(List( + mutable.HashMap("one" -> 1, "two" -> 2), + mutable.HashMap(1->"one",2->"two",3->"three"))) } test("ranges") { @@ -168,7 +171,10 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { test("kryo with collect") { val control = 1 :: 2 :: Nil - val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)).collect().map(_.x) + val result = sc.parallelize(control, 2) + .map(new ClassWithoutNoArgConstructor(_)) + .collect() + .map(_.x) assert(control === result.toSeq) } @@ -236,12 +242,44 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { // Set a special, broken ClassLoader and make sure we get an exception on deserialization ser.setDefaultClassLoader(new ClassLoader() { - override def loadClass(name: String) = throw new UnsupportedOperationException + override def loadClass(name: String): Class[_] = throw new UnsupportedOperationException }) intercept[UnsupportedOperationException] { ser.newInstance().deserialize[ClassLoaderTestingObject](bytes) } } + + test("registration of HighlyCompressedMapStatus") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + + // these cases require knowing the internals of RoaringBitmap a little. Blocks span 2^16 + // values, and they use a bitmap (dense) if they have more than 4096 values, and an + // array (sparse) if they use less. So we just create two cases, one sparse and one dense. + // and we use a roaring bitmap for the empty blocks, so we trigger the dense case w/ mostly + // empty blocks + + val ser = new KryoSerializer(conf).newInstance() + val denseBlockSizes = new Array[Long](5000) + val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L) + Seq(denseBlockSizes, sparseBlockSizes).foreach { blockSizes => + ser.serialize(HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes)) + } + } + + test("serialization buffer overflow reporting") { + import org.apache.spark.SparkException + val kryoBufferMaxProperty = "spark.kryoserializer.buffer.max.mb" + + val largeObject = (1 to 1000000).toArray + + val conf = new SparkConf(false) + conf.set(kryoBufferMaxProperty, "1") + + val ser = new KryoSerializer(conf).newInstance() + val thrown = intercept[SparkException](ser.serialize(largeObject)) + assert(thrown.getMessage.contains(kryoBufferMaxProperty)) + } } @@ -254,14 +292,14 @@ object KryoTest { class ClassWithNoArgConstructor { var x: Int = 0 - override def equals(other: Any) = other match { + override def equals(other: Any): Boolean = other match { case c: ClassWithNoArgConstructor => x == c.x case _ => false } } class ClassWithoutNoArgConstructor(val x: Int) { - override def equals(other: Any) = other match { + override def equals(other: Any): Boolean = other match { case c: ClassWithoutNoArgConstructor => x == c.x case _ => false } diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala index d037e2c19a64..433fd6bb4a11 100644 --- a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala @@ -24,14 +24,16 @@ import org.apache.spark.rdd.RDD /* A trivial (but unserializable) container for trivial functions */ class UnserializableClass { - def op[T](x: T) = x.toString + def op[T](x: T): String = x.toString - def pred[T](x: T) = x.toString.length % 2 == 0 + def pred[T](x: T): Boolean = x.toString.length % 2 == 0 } class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContext { - def fixture = (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass) + def fixture: (RDD[String], UnserializableClass) = { + (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass) + } test("throws expected serialization exceptions on actions") { val (data, uc) = fixture diff --git a/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala new file mode 100644 index 000000000000..e62828c4fbac --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io.{ObjectOutput, ObjectInput} + +import org.scalatest.{BeforeAndAfterEach, FunSuite} + + +class SerializationDebuggerSuite extends FunSuite with BeforeAndAfterEach { + + import SerializationDebugger.find + + override def beforeEach(): Unit = { + SerializationDebugger.enableDebugging = true + } + + test("primitives, strings, and nulls") { + assert(find(1) === List.empty) + assert(find(1L) === List.empty) + assert(find(1.toShort) === List.empty) + assert(find(1.0) === List.empty) + assert(find("1") === List.empty) + assert(find(null) === List.empty) + } + + test("primitive arrays") { + assert(find(Array[Int](1, 2)) === List.empty) + assert(find(Array[Long](1, 2)) === List.empty) + } + + test("non-primitive arrays") { + assert(find(Array("aa", "bb")) === List.empty) + assert(find(Array(new SerializableClass1)) === List.empty) + } + + test("serializable object") { + assert(find(new Foo(1, "b", 'c', 'd', null, null, null)) === List.empty) + } + + test("nested arrays") { + val foo1 = new Foo(1, "b", 'c', 'd', null, null, null) + val foo2 = new Foo(1, "b", 'c', 'd', null, Array(foo1), null) + assert(find(new Foo(1, "b", 'c', 'd', null, Array(foo2), null)) === List.empty) + } + + test("nested objects") { + val foo1 = new Foo(1, "b", 'c', 'd', null, null, null) + val foo2 = new Foo(1, "b", 'c', 'd', null, null, foo1) + assert(find(new Foo(1, "b", 'c', 'd', null, null, foo2)) === List.empty) + } + + test("cycles (should not loop forever)") { + val foo1 = new Foo(1, "b", 'c', 'd', null, null, null) + foo1.g = foo1 + assert(find(new Foo(1, "b", 'c', 'd', null, null, foo1)) === List.empty) + } + + test("root object not serializable") { + val s = find(new NotSerializable) + assert(s.size === 1) + assert(s.head.contains("NotSerializable")) + } + + test("array containing not serializable element") { + val s = find(new SerializableArray(Array(new NotSerializable))) + assert(s.size === 5) + assert(s(0).contains("NotSerializable")) + assert(s(1).contains("element of array")) + assert(s(2).contains("array")) + assert(s(3).contains("arrayField")) + assert(s(4).contains("SerializableArray")) + } + + test("object containing not serializable field") { + val s = find(new SerializableClass2(new NotSerializable)) + assert(s.size === 3) + assert(s(0).contains("NotSerializable")) + assert(s(1).contains("objectField")) + assert(s(2).contains("SerializableClass2")) + } + + test("externalizable class writing out not serializable object") { + val s = find(new ExternalizableClass) + assert(s.size === 5) + assert(s(0).contains("NotSerializable")) + assert(s(1).contains("objectField")) + assert(s(2).contains("SerializableClass2")) + assert(s(3).contains("writeExternal")) + assert(s(4).contains("ExternalizableClass")) + } +} + + +class SerializableClass1 extends Serializable + + +class SerializableClass2(val objectField: Object) extends Serializable + + +class SerializableArray(val arrayField: Array[Object]) extends Serializable + + +class ExternalizableClass extends java.io.Externalizable { + override def writeExternal(out: ObjectOutput): Unit = { + out.writeInt(1) + out.writeObject(new SerializableClass2(new NotSerializable)) + } + + override def readExternal(in: ObjectInput): Unit = {} +} + + +class Foo( + a: Int, + b: String, + c: Char, + d: Byte, + e: Array[Int], + f: Array[Object], + var g: Foo) extends Serializable + + +class NotSerializable diff --git a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala index 0ade1bab18d7..963264cef3a7 100644 --- a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala +++ b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala @@ -27,7 +27,7 @@ import scala.reflect.ClassTag * A serializer implementation that always return a single element in a deserialization stream. */ class TestSerializer extends Serializer { - override def newInstance() = new TestSerializerInstance + override def newInstance(): TestSerializerInstance = new TestSerializerInstance } @@ -36,7 +36,8 @@ class TestSerializerInstance extends SerializerInstance { override def serializeStream(s: OutputStream): SerializationStream = ??? - override def deserializeStream(s: InputStream) = new TestDeserializationStream + override def deserializeStream(s: InputStream): TestDeserializationStream = + new TestDeserializationStream override def deserialize[T: ClassTag](bytes: ByteBuffer): T = ??? diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala index 6790388f9660..7d76435cd75e 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala @@ -54,7 +54,7 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext { sc = new SparkContext("local", "test", conf) val shuffleBlockManager = - SparkEnv.get.shuffleManager.shuffleBlockManager.asInstanceOf[FileShuffleBlockManager] + SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[FileShuffleBlockManager] val shuffle1 = shuffleBlockManager.forMapTask(1, 1, 1, new JavaSerializer(conf), new ShuffleWriteMetrics) @@ -85,8 +85,8 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext { // Now comes the test : // Write to shuffle 3; and close it, but before registering it, check if the file lengths for // previous task (forof shuffle1) is the same as 'segments'. Earlier, we were inferring length - // of block based on remaining data in file : which could mess things up when there is concurrent read - // and writes happening to the same shuffle group. + // of block based on remaining data in file : which could mess things up when there is + // concurrent read and writes happening to the same shuffle group. val shuffle3 = shuffleBlockManager.forMapTask(1, 3, 1, new JavaSerializer(testConf), new ShuffleWriteMetrics) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index c2903c859799..ffa5162a3184 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -22,11 +22,11 @@ import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps -import akka.actor.{ActorSystem, Props} import org.mockito.Mockito.{mock, when} -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester} +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} import org.scalatest.concurrent.Eventually._ +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} import org.apache.spark.network.BlockTransferService import org.apache.spark.network.nio.NioBlockTransferService @@ -34,13 +34,12 @@ import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage.StorageLevel._ -import org.apache.spark.util.{AkkaUtils, SizeEstimator} /** Testsuite that tests block replication in BlockManager */ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAndAfter { private val conf = new SparkConf(false) - var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null val securityMgr = new SecurityManager(conf) val mapOutputTracker = new MapOutputTrackerMaster(conf) @@ -61,7 +60,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) - val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + val store = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) store.initialize("app-id") allStores += store @@ -69,32 +68,29 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd } before { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem + rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) conf.set("spark.authenticate", "false") - conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", rpcEnv.address.port.toString) conf.set("spark.storage.unrollFraction", "0.4") conf.set("spark.storage.unrollMemoryThreshold", "512") // to make a replication attempt to inactive store fail fast - conf.set("spark.core.connection.ack.wait.timeout", "1") + conf.set("spark.core.connection.ack.wait.timeout", "1s") // to make cached peers refresh frequently conf.set("spark.storage.cachedPeersTtl", "10") - master = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf, true) + master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", + new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) allStores.clear() } after { allStores.foreach { _.stop() } allStores.clear() - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null + rpcEnv.shutdown() + rpcEnv.awaitTermination() + rpcEnv = null master = null } @@ -262,7 +258,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work when(failableTransfer.hostName).thenReturn("some-hostname") when(failableTransfer.port).thenReturn(1000) - val failableStore = new BlockManager("failable-store", actorSystem, master, serializer, + val failableStore = new BlockManager("failable-store", rpcEnv, master, serializer, 10000, conf, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0) failableStore.initialize("app-id") allStores += failableStore // so that this gets stopped after test diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index ffe6f039145e..7d82a7c66ad1 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -19,24 +19,18 @@ package org.apache.spark.storage import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.Arrays -import java.util.concurrent.TimeUnit import scala.collection.mutable.ArrayBuffer -import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps -import akka.actor._ -import akka.pattern.ask -import akka.util.Timeout - import org.mockito.Mockito.{mock, when} - import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} import org.apache.spark.executor.DataReadMethod import org.apache.spark.network.nio.NioBlockTransferService @@ -53,7 +47,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach private val conf = new SparkConf(false) var store: BlockManager = null var store2: BlockManager = null - var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null conf.set("spark.authenticate", "false") val securityMgr = new SecurityManager(conf) @@ -66,34 +60,31 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach // Implicitly convert strings to BlockIds for test clarity. implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) - def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId) + def rdd(rddId: Int, splitId: Int): RDDBlockId = RDDBlockId(rddId, splitId) private def makeBlockManager( maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) - val manager = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + val manager = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) manager.initialize("app-id") manager } override def beforeEach(): Unit = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem + rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case System.setProperty("os.arch", "amd64") conf.set("os.arch", "amd64") conf.set("spark.test.useCompressedOops", "true") - conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", rpcEnv.address.port.toString) conf.set("spark.storage.unrollFraction", "0.4") conf.set("spark.storage.unrollMemoryThreshold", "512") - master = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf, true) + master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", + new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() @@ -108,16 +99,18 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach store2.stop() store2 = null } - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null + rpcEnv.shutdown() + rpcEnv.awaitTermination() + rpcEnv = null master = null } test("StorageLevel object caching") { val level1 = StorageLevel(false, false, false, false, 3) - val level2 = StorageLevel(false, false, false, false, 3) // this should return the same object as level1 - val level3 = StorageLevel(false, false, false, false, 2) // this should return a different object + // this should return the same object as level1 + val level2 = StorageLevel(false, false, false, false, 3) + // this should return a different object + val level3 = StorageLevel(false, false, false, false, 2) assert(level2 === level1, "level2 is not same as level1") assert(level2.eq(level1), "level2 is not the same object as level1") assert(level3 != level1, "level3 is same as level1") @@ -148,6 +141,12 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach assert(id2_.eq(id1), "Deserialized id2 is not the same object as original id1") } + test("BlockManagerId.isDriver() backwards-compatibility with legacy driver ids (SPARK-6716)") { + assert(BlockManagerId(SparkContext.DRIVER_IDENTIFIER, "XXX", 1).isDriver) + assert(BlockManagerId(SparkContext.LEGACY_DRIVER_IDENTIFIER, "XXX", 1).isDriver) + assert(!BlockManagerId("notADriverIdentifier", "XXX", 1).isDriver) + } + test("master + 1 manager interaction") { store = makeBlockManager(20000) val a1 = new Array[Byte](4000) @@ -170,8 +169,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach assert(master.getLocations("a3").size === 0, "master was told about a3") // Drop a1 and a2 from memory; this should be reported back to the master - store.dropFromMemory("a1", null) - store.dropFromMemory("a2", null) + store.dropFromMemory("a1", null: Either[Array[Any], ByteBuffer]) + store.dropFromMemory("a2", null: Either[Array[Any], ByteBuffer]) assert(store.getSingle("a1") === None, "a1 not removed from store") assert(store.getSingle("a2") === None, "a2 not removed from store") assert(master.getLocations("a1").size === 0, "master did not remove a1") @@ -357,10 +356,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - implicit val timeout = Timeout(30, TimeUnit.SECONDS) - val reregister = !Await.result( - master.driverActor ? BlockManagerHeartbeat(store.blockManagerId), - timeout.duration).asInstanceOf[Boolean] + val reregister = !master.driverEndpoint.askWithReply[Boolean]( + BlockManagerHeartbeat(store.blockManagerId)) assert(reregister == true) } @@ -413,8 +410,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach t2.join() t3.join() - store.dropFromMemory("a1", null) - store.dropFromMemory("a2", null) + store.dropFromMemory("a1", null: Either[Array[Any], ByteBuffer]) + store.dropFromMemory("a2", null: Either[Array[Any], ByteBuffer]) store.waitForAsyncReregister() } } @@ -431,19 +428,19 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach val list1Get = store.get("list1") assert(list1Get.isDefined, "list1 expected to be in store") assert(list1Get.get.data.size === 2) - assert(list1Get.get.inputMetrics.bytesRead === list1SizeEstimate) - assert(list1Get.get.inputMetrics.readMethod === DataReadMethod.Memory) + assert(list1Get.get.bytes === list1SizeEstimate) + assert(list1Get.get.readMethod === DataReadMethod.Memory) val list2MemoryGet = store.get("list2memory") assert(list2MemoryGet.isDefined, "list2memory expected to be in store") assert(list2MemoryGet.get.data.size === 3) - assert(list2MemoryGet.get.inputMetrics.bytesRead === list2SizeEstimate) - assert(list2MemoryGet.get.inputMetrics.readMethod === DataReadMethod.Memory) + assert(list2MemoryGet.get.bytes === list2SizeEstimate) + assert(list2MemoryGet.get.readMethod === DataReadMethod.Memory) val list2DiskGet = store.get("list2disk") assert(list2DiskGet.isDefined, "list2memory expected to be in store") assert(list2DiskGet.get.data.size === 3) // We don't know the exact size of the data on disk, but it should certainly be > 0. - assert(list2DiskGet.get.inputMetrics.bytesRead > 0) - assert(list2DiskGet.get.inputMetrics.readMethod === DataReadMethod.Disk) + assert(list2DiskGet.get.bytes > 0) + assert(list2DiskGet.get.readMethod === DataReadMethod.Disk) } test("in-memory LRU storage") { @@ -785,7 +782,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach test("block store put failure") { // Use Java serializer so we can create an unserializable error. val transfer = new NioBlockTransferService(conf, securityMgr) - store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, actorSystem, master, + store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master, new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) @@ -807,7 +804,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach // Create a non-trivial (not all zeros) byte array var counter = 0.toByte - def incr = {counter = (counter + 1).toByte; counter;} + def incr: Byte = {counter = (counter + 1).toByte; counter;} val bytes = Array.fill[Byte](1000)(incr) val byteBuffer = ByteBuffer.wrap(bytes) @@ -961,8 +958,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach store.putIterator("list3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) // getLocations and getBlockStatus should yield the same locations - assert(store.master.getMatchingBlockIds(_.toString.contains("list"), askSlaves = false).size === 3) - assert(store.master.getMatchingBlockIds(_.toString.contains("list1"), askSlaves = false).size === 1) + assert(store.master.getMatchingBlockIds(_.toString.contains("list"), askSlaves = false).size + === 3) + assert(store.master.getMatchingBlockIds(_.toString.contains("list1"), askSlaves = false).size + === 1) // insert some more blocks store.putIterator("newlist1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) @@ -970,8 +969,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach store.putIterator("newlist3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) // getLocations and getBlockStatus should yield the same locations - assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = false).size === 1) - assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = true).size === 3) + assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = false).size + === 1) + assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = true).size + === 3) val blockIds = Seq(RDDBlockId(1, 0), RDDBlockId(1, 1), RDDBlockId(2, 0)) blockIds.foreach { blockId => @@ -1064,6 +1065,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach var unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks) verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true) assert(memoryStore.currentUnrollMemoryForThisThread === 0) + memoryStore.releasePendingUnrollMemoryForThisThread() // Unroll with not enough space. This should succeed after kicking out someBlock1. store.putIterator("someBlock1", smallList.iterator, StorageLevel.MEMORY_ONLY) @@ -1074,6 +1076,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach assert(droppedBlocks.size === 1) assert(droppedBlocks.head._1 === TestBlockId("someBlock1")) droppedBlocks.clear() + memoryStore.releasePendingUnrollMemoryForThisThread() // Unroll huge block with not enough space. Even after ensuring free space of 12000 * 0.4 = // 4800 bytes, there is still not enough room to unroll this block. This returns an iterator. @@ -1093,8 +1096,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach val memoryStore = store.memoryStore val smallList = List.fill(40)(new Array[Byte](100)) val bigList = List.fill(40)(new Array[Byte](1000)) - def smallIterator = smallList.iterator.asInstanceOf[Iterator[Any]] - def bigIterator = bigList.iterator.asInstanceOf[Iterator[Any]] + def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] + def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] assert(memoryStore.currentUnrollMemoryForThisThread === 0) // Unroll with plenty of space. This should succeed and cache both blocks. @@ -1147,8 +1150,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach val diskStore = store.diskStore val smallList = List.fill(40)(new Array[Byte](100)) val bigList = List.fill(40)(new Array[Byte](1000)) - def smallIterator = smallList.iterator.asInstanceOf[Iterator[Any]] - def bigIterator = bigList.iterator.asInstanceOf[Iterator[Any]] + def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] + def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] assert(memoryStore.currentUnrollMemoryForThisThread === 0) store.putIterator("b1", smallIterator, memAndDisk) @@ -1190,7 +1193,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach val memOnly = StorageLevel.MEMORY_ONLY val memoryStore = store.memoryStore val smallList = List.fill(40)(new Array[Byte](100)) - def smallIterator = smallList.iterator.asInstanceOf[Iterator[Any]] + def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] assert(memoryStore.currentUnrollMemoryForThisThread === 0) // All unroll memory used is released because unrollSafely returned an array @@ -1221,4 +1224,30 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach assert(unrollMemoryAfterB6 === unrollMemoryAfterB4) assert(unrollMemoryAfterB7 === unrollMemoryAfterB4) } + + test("lazily create a big ByteBuffer to avoid OOM if it cannot be put into MemoryStore") { + store = makeBlockManager(12000) + val memoryStore = store.memoryStore + val blockId = BlockId("rdd_3_10") + val result = memoryStore.putBytes(blockId, 13000, () => { + fail("A big ByteBuffer that cannot be put into MemoryStore should not be created") + }) + assert(result.size === 13000) + assert(result.data === null) + assert(result.droppedBlocks === Nil) + } + + test("put a small ByteBuffer to MemoryStore") { + store = makeBlockManager(12000) + val memoryStore = store.memoryStore + val blockId = BlockId("rdd_3_10") + var bytes: ByteBuffer = null + val result = memoryStore.putBytes(blockId, 10000, () => { + bytes = ByteBuffer.allocate(10000) + bytes + }) + assert(result.size === 10000) + assert(result.data === Right(bytes)) + assert(result.droppedBlocks === Nil) + } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala index bbc7e1357b90..003a728cb84a 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala @@ -16,21 +16,25 @@ */ package org.apache.spark.storage -import org.scalatest.FunSuite import java.io.File + +import org.scalatest.FunSuite + +import org.apache.spark.SparkConf import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.SparkConf +import org.apache.spark.util.Utils class BlockObjectWriterSuite extends FunSuite { test("verify write metrics") { - val file = new File("somefile") - file.deleteOnExit() + val file = new File(Utils.createTempDir(), "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, - new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics) + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) writer.write(Long.box(20)) + // Record metrics update on every write + assert(writeMetrics.shuffleRecordsWritten === 1) // Metrics don't update on every write assert(writeMetrics.shuffleBytesWritten == 0) // After 32 writes, metrics should update @@ -39,18 +43,20 @@ class BlockObjectWriterSuite extends FunSuite { writer.write(Long.box(i)) } assert(writeMetrics.shuffleBytesWritten > 0) + assert(writeMetrics.shuffleRecordsWritten === 33) writer.commitAndClose() assert(file.length() == writeMetrics.shuffleBytesWritten) } test("verify write metrics on revert") { - val file = new File("somefile") - file.deleteOnExit() + val file = new File(Utils.createTempDir(), "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, - new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics) + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) writer.write(Long.box(20)) + // Record metrics update on every write + assert(writeMetrics.shuffleRecordsWritten === 1) // Metrics don't update on every write assert(writeMetrics.shuffleBytesWritten == 0) // After 32 writes, metrics should update @@ -59,7 +65,22 @@ class BlockObjectWriterSuite extends FunSuite { writer.write(Long.box(i)) } assert(writeMetrics.shuffleBytesWritten > 0) + assert(writeMetrics.shuffleRecordsWritten === 33) writer.revertPartialWritesAndClose() assert(writeMetrics.shuffleBytesWritten == 0) + assert(writeMetrics.shuffleRecordsWritten == 0) + } + + test("Reopening a closed block writer") { + val file = new File(Utils.createTempDir(), "somefile") + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + + writer.open() + writer.close() + intercept[IllegalStateException] { + writer.open() + } } } diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala index dae7bf0e336d..b47157f8331c 100644 --- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.storage import java.io.File import org.apache.spark.util.Utils -import org.scalatest.FunSuite +import org.scalatest.{BeforeAndAfter, FunSuite} import org.apache.spark.SparkConf @@ -28,7 +28,11 @@ import org.apache.spark.SparkConf /** * Tests for the spark.local.dir and SPARK_LOCAL_DIRS configuration options. */ -class LocalDirsSuite extends FunSuite { +class LocalDirsSuite extends FunSuite with BeforeAndAfter { + + before { + Utils.clearLocalRootDirs() + } test("Utils.getLocalDir() returns a valid directory, even if some local dirs are missing") { // Regression test for SPARK-2974 @@ -43,13 +47,13 @@ class LocalDirsSuite extends FunSuite { assert(!new File("/NONEXISTENT_DIR").exists()) // SPARK_LOCAL_DIRS is a valid directory: class MySparkConf extends SparkConf(false) { - override def getenv(name: String) = { + override def getenv(name: String): String = { if (name == "SPARK_LOCAL_DIRS") System.getProperty("java.io.tmpdir") else super.getenv(name) } override def clone: SparkConf = { - new MySparkConf().setAll(settings) + new MySparkConf().setAll(getAll) } } // spark.local.dir only contains invalid directories, but that's not a problem since diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index e85a436cdba1..eb9db550fd74 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -17,37 +17,53 @@ package org.apache.spark.ui +import java.net.{HttpURLConnection, URL} +import javax.servlet.http.HttpServletRequest + import scala.collection.JavaConversions._ +import scala.xml.Node -import org.openqa.selenium.{By, WebDriver} import org.openqa.selenium.htmlunit.HtmlUnitDriver +import org.openqa.selenium.{By, WebDriver} import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.selenium.WebBrowser import org.scalatest.time.SpanSugar._ -import org.apache.spark._ import org.apache.spark.LocalSparkContext._ +import org.apache.spark._ import org.apache.spark.api.java.StorageLevels import org.apache.spark.shuffle.FetchFailedException + /** - * Selenium tests for the Spark Web UI. These tests are not run by default - * because they're slow. + * Selenium tests for the Spark Web UI. */ -@DoNotDiscover -class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { - implicit val webDriver: WebDriver = new HtmlUnitDriver +class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll { + + implicit var webDriver: WebDriver = _ + + override def beforeAll(): Unit = { + webDriver = new HtmlUnitDriver + } + + override def afterAll(): Unit = { + if (webDriver != null) { + webDriver.quit() + } + } /** * Create a test SparkContext with the SparkUI enabled. * It is safe to `get` the SparkUI directly from the SparkContext returned here. */ - private def newSparkContext(): SparkContext = { + private def newSparkContext(killEnabled: Boolean = true): SparkContext = { val conf = new SparkConf() .setMaster("local") .setAppName("test") .set("spark.ui.enabled", "true") + .set("spark.ui.port", "0") + .set("spark.ui.killEnabled", killEnabled.toString) val sc = new SparkContext(conf) assert(sc.ui.isDefined) sc @@ -93,7 +109,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { } eventually(timeout(5 seconds), interval(50 milliseconds)) { go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/stages") - find(id("active")).get.text should be("Active Stages (0)") + find(id("active")) should be(None) // Since we hide empty tables find(id("failed")).get.text should be("Failed Stages (1)") } @@ -105,7 +121,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { } eventually(timeout(5 seconds), interval(50 milliseconds)) { go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/stages") - find(id("active")).get.text should be("Active Stages (0)") + find(id("active")) should be(None) // Since we hide empty tables // The failure occurs before the stage becomes active, hence we should still show only one // failed stage, not two: find(id("failed")).get.text should be("Failed Stages (1)") @@ -114,21 +130,12 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { } test("spark.ui.killEnabled should properly control kill button display") { - def getSparkContext(killEnabled: Boolean): SparkContext = { - val conf = new SparkConf() - .setMaster("local") - .setAppName("test") - .set("spark.ui.enabled", "true") - .set("spark.ui.killEnabled", killEnabled.toString) - new SparkContext(conf) - } - - def hasKillLink = find(className("kill-link")).isDefined + def hasKillLink: Boolean = find(className("kill-link")).isDefined def runSlowJob(sc: SparkContext) { sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() } - withSpark(getSparkContext(killEnabled = true)) { sc => + withSpark(newSparkContext(killEnabled = true)) { sc => runSlowJob(sc) eventually(timeout(5 seconds), interval(50 milliseconds)) { go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/stages") @@ -136,7 +143,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { } } - withSpark(getSparkContext(killEnabled = false)) { sc => + withSpark(newSparkContext(killEnabled = false)) { sc => runSlowJob(sc) eventually(timeout(5 seconds), interval(50 milliseconds)) { go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/stages") @@ -167,13 +174,14 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { test("job progress bars should handle stage / task failures") { withSpark(newSparkContext()) { sc => - val data = sc.parallelize(Seq(1, 2, 3)).map(identity).groupBy(identity) + val data = sc.parallelize(Seq(1, 2, 3), 1).map(identity).groupBy(identity) val shuffleHandle = data.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle // Simulate fetch failures: val mappedData = data.map { x => val taskContext = TaskContext.get - if (taskContext.attemptNumber == 0) { // Cause this stage to fail on its first attempt. + if (taskContext.taskAttemptId() == 1) { + // Cause the post-shuffle stage to fail on its first attempt with a single task failure val env = SparkEnv.get val bmAddress = env.blockManager.blockManagerId val shuffleId = shuffleHandle.shuffleId @@ -218,7 +226,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { // because someone could change the error message and cause this test to pass by accident. // Instead, it's safer to check that each row contains a link to a stage details page. findAll(cssSelector("tbody tr")).foreach { row => - val link = row.underlying.findElement(By.xpath(".//a")) + val link = row.underlying.findElement(By.xpath("./td/div/a")) link.getAttribute("href") should include ("stage") } } @@ -299,4 +307,67 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { } } } + + test("attaching and detaching a new tab") { + withSpark(newSparkContext()) { sc => + val sparkUI = sc.ui.get + + val newTab = new WebUITab(sparkUI, "foo") { + attachPage(new WebUIPage("") { + def render(request: HttpServletRequest): Seq[Node] = { + "html magic" + } + }) + } + sparkUI.attachTab(newTab) + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/")) + find(cssSelector("""ul li a[href*="jobs"]""")) should not be(None) + find(cssSelector("""ul li a[href*="stages"]""")) should not be(None) + find(cssSelector("""ul li a[href*="storage"]""")) should not be(None) + find(cssSelector("""ul li a[href*="environment"]""")) should not be(None) + find(cssSelector("""ul li a[href*="foo"]""")) should not be(None) + } + eventually(timeout(10 seconds), interval(50 milliseconds)) { + // check whether new page exists + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/foo") + find(cssSelector("b")).get.text should include ("html magic") + } + sparkUI.detachTab(newTab) + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (sc.ui.get.appUIAddress.stripSuffix("/")) + find(cssSelector("""ul li a[href*="jobs"]""")) should not be(None) + find(cssSelector("""ul li a[href*="stages"]""")) should not be(None) + find(cssSelector("""ul li a[href*="storage"]""")) should not be(None) + find(cssSelector("""ul li a[href*="environment"]""")) should not be(None) + find(cssSelector("""ul li a[href*="foo"]""")) should be(None) + } + eventually(timeout(10 seconds), interval(50 milliseconds)) { + // check new page not exist + go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/foo") + find(cssSelector("b")) should be(None) + } + } + } + + test("kill stage is POST only") { + def getResponseCode(url: URL, method: String): Int = { + val connection = url.openConnection().asInstanceOf[HttpURLConnection] + connection.setRequestMethod(method) + connection.connect() + val code = connection.getResponseCode() + connection.disconnect() + code + } + + withSpark(newSparkContext(killEnabled = true)) { sc => + sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() + eventually(timeout(5 seconds), interval(50 milliseconds)) { + val url = new URL( + sc.ui.get.appUIAddress.stripSuffix("/") + "/stages/stage/kill/?id=0&terminate=true") + getResponseCode(url, "GET") should be (405) + getResponseCode(url, "POST") should be (200) + } + } + } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 92a21f82f3c2..77a038dc1720 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.ui import java.net.ServerSocket -import javax.servlet.http.HttpServletRequest import scala.io.Source import scala.util.{Failure, Success, Try} @@ -28,9 +27,8 @@ import org.scalatest.FunSuite import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkContext, SparkConf} import org.apache.spark.LocalSparkContext._ -import scala.xml.Node +import org.apache.spark.{SparkConf, SparkContext} class UISuite extends FunSuite { @@ -72,40 +70,6 @@ class UISuite extends FunSuite { } } - ignore("attaching a new tab") { - withSpark(newSparkContext()) { sc => - val sparkUI = sc.ui.get - - val newTab = new WebUITab(sparkUI, "foo") { - attachPage(new WebUIPage("") { - def render(request: HttpServletRequest): Seq[Node] = { - "html magic" - } - }) - } - sparkUI.attachTab(newTab) - eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(sparkUI.appUIAddress).mkString - assert(!html.contains("random data that should not be present")) - - // check whether new page exists - assert(html.toLowerCase.contains("foo")) - - // check whether other pages still exist - assert(html.toLowerCase.contains("stages")) - assert(html.toLowerCase.contains("storage")) - assert(html.toLowerCase.contains("environment")) - assert(html.toLowerCase.contains("executors")) - } - - eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(sparkUI.appUIAddress.stripSuffix("/") + "/foo").mkString - // check whether new page exists - assert(html.contains("magic")) - } - } - } - test("jetty selects different port under contention") { val server = new ServerSocket(0) val startPort = server.getLocalPort diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 68074ae32a67..21d826711413 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ui.jobs +import java.util.Properties + import org.scalatest.FunSuite import org.scalatest.Matchers @@ -44,11 +46,19 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc SparkListenerStageCompleted(stageInfo) } - private def createJobStartEvent(jobId: Int, stageIds: Seq[Int]) = { + private def createJobStartEvent( + jobId: Int, + stageIds: Seq[Int], + jobGroup: Option[String] = None): SparkListenerJobStart = { val stageInfos = stageIds.map { stageId => new StageInfo(stageId, 0, stageId.toString, 0, null, "") } - SparkListenerJobStart(jobId, jobSubmissionTime, stageInfos) + val properties: Option[Properties] = jobGroup.map { groupId => + val props = new Properties() + props.setProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId) + props + } + SparkListenerJobStart(jobId, jobSubmissionTime, stageInfos, properties.orNull) } private def createJobEndEvent(jobId: Int, failed: Boolean = false) = { @@ -88,6 +98,45 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc listener.completedStages.map(_.stageId).toSet should be (Set(50, 49, 48, 47, 46)) } + test("test clearing of stageIdToActiveJobs") { + val conf = new SparkConf() + conf.set("spark.ui.retainedStages", 5.toString) + val listener = new JobProgressListener(conf) + val jobId = 0 + val stageIds = 1 to 50 + // Start a job with 50 stages + listener.onJobStart(createJobStartEvent(jobId, stageIds)) + for (stageId <- stageIds) { + listener.onStageSubmitted(createStageStartEvent(stageId)) + } + listener.stageIdToActiveJobIds.size should be > 0 + + // Complete the stages and job + for (stageId <- stageIds) { + listener.onStageCompleted(createStageEndEvent(stageId, failed = false)) + } + listener.onJobEnd(createJobEndEvent(jobId, false)) + assertActiveJobsStateIsEmpty(listener) + listener.stageIdToActiveJobIds.size should be (0) + } + + test("test clearing of jobGroupToJobIds") { + val conf = new SparkConf() + conf.set("spark.ui.retainedJobs", 5.toString) + val listener = new JobProgressListener(conf) + + // Run 50 jobs, each with one stage + for (jobId <- 0 to 50) { + listener.onJobStart(createJobStartEvent(jobId, Seq(0), jobGroup = Some(jobId.toString))) + listener.onStageSubmitted(createStageStartEvent(0)) + listener.onStageCompleted(createStageEndEvent(0, failed = false)) + listener.onJobEnd(createJobEndEvent(jobId, false)) + } + assertActiveJobsStateIsEmpty(listener) + // This collection won't become empty, but it should be bounded by spark.ui.retainedJobs + listener.jobGroupToJobIds.size should be (5) + } + test("test LRU eviction of jobs") { val conf = new SparkConf() conf.set("spark.ui.retainedStages", 5.toString) @@ -220,13 +269,14 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc val taskType = Utils.getFormattedClassName(new ShuffleMapTask(0)) val execId = "exe-1" - def makeTaskMetrics(base: Int) = { + def makeTaskMetrics(base: Int): TaskMetrics = { val taskMetrics = new TaskMetrics() val shuffleReadMetrics = new ShuffleReadMetrics() val shuffleWriteMetrics = new ShuffleWriteMetrics() taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics)) taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics) shuffleReadMetrics.incRemoteBytesRead(base + 1) + shuffleReadMetrics.incLocalBytesRead(base + 9) shuffleReadMetrics.incRemoteBlocksFetched(base + 2) shuffleWriteMetrics.incShuffleBytesWritten(base + 3) taskMetrics.setExecutorRunTime(base + 4) @@ -234,14 +284,14 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskMetrics.incMemoryBytesSpilled(base + 6) val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) taskMetrics.setInputMetrics(Some(inputMetrics)) - inputMetrics.addBytesRead(base + 7) + inputMetrics.incBytesRead(base + 7) val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop) taskMetrics.outputMetrics = Some(outputMetrics) outputMetrics.setBytesWritten(base + 8) taskMetrics } - def makeTaskInfo(taskId: Long, finishTime: Int = 0) = { + def makeTaskInfo(taskId: Long, finishTime: Int = 0): TaskInfo = { val taskInfo = new TaskInfo(taskId, 0, 1, 0L, execId, "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = finishTime @@ -260,8 +310,8 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc var stage0Data = listener.stageIdToData.get((0, 0)).get var stage1Data = listener.stageIdToData.get((1, 0)).get - assert(stage0Data.shuffleReadBytes == 102) - assert(stage1Data.shuffleReadBytes == 201) + assert(stage0Data.shuffleReadTotalBytes == 220) + assert(stage1Data.shuffleReadTotalBytes == 410) assert(stage0Data.shuffleWriteBytes == 106) assert(stage1Data.shuffleWriteBytes == 203) assert(stage0Data.executorRunTime == 108) @@ -290,8 +340,11 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc stage0Data = listener.stageIdToData.get((0, 0)).get stage1Data = listener.stageIdToData.get((1, 0)).get - assert(stage0Data.shuffleReadBytes == 402) - assert(stage1Data.shuffleReadBytes == 602) + // Task 1235 contributed (100+1)+(100+9) = 210 shuffle bytes, and task 1234 contributed + // (300+1)+(300+9) = 610 total shuffle bytes, so the total for the stage is 820. + assert(stage0Data.shuffleReadTotalBytes == 820) + // Task 1236 contributed 410 shuffle bytes, and task 1237 contributed 810 shuffle bytes. + assert(stage1Data.shuffleReadTotalBytes == 1220) assert(stage0Data.shuffleWriteBytes == 406) assert(stage1Data.shuffleWriteBytes == 606) assert(stage0Data.executorRunTime == 408) diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala index e1bc1379b5d8..3744e479d2f0 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -107,7 +107,8 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { val myRddInfo0 = rddInfo0 val myRddInfo1 = rddInfo1 val myRddInfo2 = rddInfo2 - val stageInfo0 = new StageInfo(0, 0, "0", 100, Seq(myRddInfo0, myRddInfo1, myRddInfo2), "details") + val stageInfo0 = new StageInfo( + 0, 0, "0", 100, Seq(myRddInfo0, myRddInfo1, myRddInfo2), "details") bus.postToAll(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) assert(storageListener._rddInfoMap.size === 3) diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index 6bbf72e929dc..bec79fc4dc8f 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -17,15 +17,16 @@ package org.apache.spark.util -import scala.concurrent.Await - -import akka.actor._ +import java.util.concurrent.TimeoutException +import akka.actor.ActorNotFound import org.scalatest.FunSuite import org.apache.spark._ +import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage.BlockManagerId +import org.apache.spark.SSLSampleConfigs._ /** @@ -35,70 +36,64 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro test("remote fetch security bad password") { val conf = new SparkConf + conf.set("spark.rpc", "akka") conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val badconf = new SparkConf + badconf.set("spark.rpc", "akka") badconf.set("spark.authenticate", "true") badconf.set("spark.authenticate.secret", "bad") val securityManagerBad = new SecurityManager(badconf) assert(securityManagerBad.isAuthenticationEnabled() === true) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = conf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, conf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") - val timeout = AkkaUtils.lookupTimeout(conf) intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) } - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch security off") { val conf = new SparkConf conf.set("spark.authenticate", "false") conf.set("spark.authenticate.secret", "bad") - val securityManager = new SecurityManager(conf); + val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === false) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val badconf = new SparkConf badconf.set("spark.authenticate", "false") badconf.set("spark.authenticate.secret", "good") - val securityManagerBad = new SecurityManager(badconf); + val securityManagerBad = new SecurityManager(badconf) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = badconf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, badconf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) assert(securityManagerBad.isAuthenticationEnabled() === false) @@ -116,41 +111,37 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(slaveTracker.getServerStatuses(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000))) - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch security pass") { val conf = new SparkConf conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf); + val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val goodconf = new SparkConf goodconf.set("spark.authenticate", "true") goodconf.set("spark.authenticate.secret", "good") - val securityManagerGood = new SecurityManager(goodconf); + val securityManagerGood = new SecurityManager(goodconf) assert(securityManagerGood.isAuthenticationEnabled() === true) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = goodconf, securityManager = securityManagerGood) + val slaveRpcEnv =RpcEnv.create("spark-slave", hostname, 0, goodconf, securityManagerGood) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) masterTracker.registerShuffle(10, 1) masterTracker.incrementEpoch() @@ -166,47 +157,200 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(slaveTracker.getServerStatuses(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000))) - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch security off client") { val conf = new SparkConf + conf.set("spark.rpc", "akka") conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf); + val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val badconf = new SparkConf + badconf.set("spark.rpc", "akka") badconf.set("spark.authenticate", "false") badconf.set("spark.authenticate.secret", "bad") - val securityManagerBad = new SecurityManager(badconf); + val securityManagerBad = new SecurityManager(badconf) assert(securityManagerBad.isAuthenticationEnabled() === false) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = badconf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, badconf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") - val timeout = AkkaUtils.lookupTimeout(conf) intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + } + + rpcEnv.shutdown() + slaveRpcEnv.shutdown() + } + + test("remote fetch ssl on") { + val conf = sparkSSLConfig() + val securityManager = new SecurityManager(conf) + + val hostname = "localhost" + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) + + assert(securityManager.isAuthenticationEnabled() === false) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) + + val slaveConf = sparkSSLConfig() + val securityManagerBad = new SecurityManager(slaveConf) + + val slaveRpcEnv = RpcEnv.create("spark-slaves", hostname, 0, slaveConf, securityManagerBad) + val slaveTracker = new MapOutputTrackerWorker(conf) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + + assert(securityManagerBad.isAuthenticationEnabled() === false) + + masterTracker.registerShuffle(10, 1) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + masterTracker.registerMapOutput(10, 0, + MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L))) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + // this should succeed since security off + assert(slaveTracker.getServerStatuses(10, 0).toSeq === + Seq((BlockManagerId("a", "hostA", 1000), size1000))) + + rpcEnv.shutdown() + slaveRpcEnv.shutdown() + } + + + test("remote fetch ssl on and security enabled") { + val conf = sparkSSLConfig() + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + val securityManager = new SecurityManager(conf) + + val hostname = "localhost" + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) + + assert(securityManager.isAuthenticationEnabled() === true) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) + + val slaveConf = sparkSSLConfig() + slaveConf.set("spark.authenticate", "true") + slaveConf.set("spark.authenticate.secret", "good") + val securityManagerBad = new SecurityManager(slaveConf) + + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) + val slaveTracker = new MapOutputTrackerWorker(conf) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + + assert(securityManagerBad.isAuthenticationEnabled() === true) + + masterTracker.registerShuffle(10, 1) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + masterTracker.registerMapOutput(10, 0, + MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L))) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + assert(slaveTracker.getServerStatuses(10, 0).toSeq === + Seq((BlockManagerId("a", "hostA", 1000), size1000))) + + rpcEnv.shutdown() + slaveRpcEnv.shutdown() + } + + + test("remote fetch ssl on and security enabled - bad credentials") { + val conf = sparkSSLConfig() + conf.set("spark.rpc", "akka") + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + val securityManager = new SecurityManager(conf) + + val hostname = "localhost" + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) + + assert(securityManager.isAuthenticationEnabled() === true) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) + + val slaveConf = sparkSSLConfig() + slaveConf.set("spark.rpc", "akka") + slaveConf.set("spark.authenticate", "true") + slaveConf.set("spark.authenticate.secret", "bad") + val securityManagerBad = new SecurityManager(slaveConf) + + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) + val slaveTracker = new MapOutputTrackerWorker(conf) + intercept[akka.actor.ActorNotFound] { + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + } + + rpcEnv.shutdown() + slaveRpcEnv.shutdown() + } + + + test("remote fetch ssl on - untrusted server") { + val conf = sparkSSLConfigUntrusted() + val securityManager = new SecurityManager(conf) + + val hostname = "localhost" + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) + + assert(securityManager.isAuthenticationEnabled() === false) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) + + val slaveConf = sparkSSLConfig() + val securityManagerBad = new SecurityManager(slaveConf) + + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) + val slaveTracker = new MapOutputTrackerWorker(conf) + try { + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + fail("should receive either ActorNotFound or TimeoutException") + } catch { + case e: ActorNotFound => + case e: TimeoutException => } - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } } diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 054ef54e746a..c47162779bbb 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -83,7 +83,7 @@ object TestObject { class TestClass extends Serializable { var x = 5 - def getX = x + def getX: Int = x def run(): Int = { var nonSer = new NonSerializable @@ -95,7 +95,7 @@ class TestClass extends Serializable { } class TestClassWithoutDefaultConstructor(x: Int) extends Serializable { - def getX = x + def getX: Int = x def run(): Int = { var nonSer = new NonSerializable @@ -164,7 +164,7 @@ object TestObjectWithNesting { } class TestClassWithNesting(val y: Int) extends Serializable { - def getY = y + def getY: Int = y def run(): Int = { var nonSer = new NonSerializable diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala index 10541f878476..47b535206c94 100644 --- a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala @@ -41,7 +41,7 @@ class EventLoopSuite extends FunSuite with Timeouts { } eventLoop.start() (1 to 100).foreach(eventLoop.post) - eventually(timeout(5 seconds), interval(200 millis)) { + eventually(timeout(5 seconds), interval(5 millis)) { assert((1 to 100) === buffer.toSeq) } eventLoop.stop() @@ -76,7 +76,7 @@ class EventLoopSuite extends FunSuite with Timeouts { } eventLoop.start() eventLoop.post(1) - eventually(timeout(5 seconds), interval(200 millis)) { + eventually(timeout(5 seconds), interval(5 millis)) { assert(e === receivedError) } eventLoop.stop() @@ -98,7 +98,7 @@ class EventLoopSuite extends FunSuite with Timeouts { } eventLoop.start() eventLoop.post(1) - eventually(timeout(5 seconds), interval(200 millis)) { + eventually(timeout(5 seconds), interval(5 millis)) { assert(e === receivedError) assert(eventLoop.isActive) } @@ -153,7 +153,7 @@ class EventLoopSuite extends FunSuite with Timeouts { }.start() } - eventually(timeout(5 seconds), interval(200 millis)) { + eventually(timeout(5 seconds), interval(5 millis)) { assert(threadNum * eventsFromEachThread === receivedEventsCount) } eventLoop.stop() @@ -185,4 +185,94 @@ class EventLoopSuite extends FunSuite with Timeouts { } assert(false === eventLoop.isActive) } + + test("EventLoop: stop in eventThread") { + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = { + stop() + } + + override def onError(e: Throwable): Unit = { + } + + } + eventLoop.start() + eventLoop.post(1) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(!eventLoop.isActive) + } + } + + test("EventLoop: stop() in onStart should call onStop") { + @volatile var onStopCalled: Boolean = false + val eventLoop = new EventLoop[Int]("test") { + + override def onStart(): Unit = { + stop() + } + + override def onReceive(event: Int): Unit = { + } + + override def onError(e: Throwable): Unit = { + } + + override def onStop(): Unit = { + onStopCalled = true + } + } + eventLoop.start() + eventually(timeout(5 seconds), interval(5 millis)) { + assert(!eventLoop.isActive) + } + assert(onStopCalled) + } + + test("EventLoop: stop() in onReceive should call onStop") { + @volatile var onStopCalled: Boolean = false + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = { + stop() + } + + override def onError(e: Throwable): Unit = { + } + + override def onStop(): Unit = { + onStopCalled = true + } + } + eventLoop.start() + eventLoop.post(1) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(!eventLoop.isActive) + } + assert(onStopCalled) + } + + test("EventLoop: stop() in onError should call onStop") { + @volatile var onStopCalled: Boolean = false + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = { + throw new RuntimeException("Oops") + } + + override def onError(e: Throwable): Unit = { + stop() + } + + override def onStop(): Unit = { + onStopCalled = true + } + } + eventLoop.start() + eventLoop.post(1) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(!eventLoop.isActive) + } + assert(onStopCalled) + } } diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 4dc5b6103db7..c05317534cdd 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.logging.{RollingFileAppender, SizeBasedRollingPolic class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { - val testFile = new File("FileAppenderSuite-test-" + System.currentTimeMillis).getAbsoluteFile + val testFile = new File(Utils.createTempDir(), "FileAppenderSuite-test").getAbsoluteFile before { cleanup() @@ -109,7 +109,8 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { // verify whether the earliest file has been deleted val rolledOverFiles = allGeneratedFiles.filter { _ != testFile.toString }.toArray.sorted - logInfo(s"All rolled over files generated:${rolledOverFiles.size}\n" + rolledOverFiles.mkString("\n")) + logInfo(s"All rolled over files generated:${rolledOverFiles.size}\n" + + rolledOverFiles.mkString("\n")) assert(rolledOverFiles.size > 2) val earliestRolledOverFile = rolledOverFiles.head val existingRolledOverFiles = RollingFileAppender.getSortedRolledOverFiles( @@ -135,7 +136,7 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { val testOutputStream = new PipedOutputStream() val testInputStream = new PipedInputStream(testOutputStream) val appender = FileAppender(testInputStream, testFile, conf) - //assert(appender.getClass === classTag[ExpectedAppender].getClass) + // assert(appender.getClass === classTag[ExpectedAppender].getClass) assert(appender.getClass.getSimpleName === classTag[ExpectedAppender].runtimeClass.getSimpleName) if (appender.isInstanceOf[RollingFileAppender]) { @@ -153,9 +154,11 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { import RollingFileAppender._ - def rollingStrategy(strategy: String) = Seq(STRATEGY_PROPERTY -> strategy) - def rollingSize(size: String) = Seq(SIZE_PROPERTY -> size) - def rollingInterval(interval: String) = Seq(INTERVAL_PROPERTY -> interval) + def rollingStrategy(strategy: String): Seq[(String, String)] = + Seq(STRATEGY_PROPERTY -> strategy) + def rollingSize(size: String): Seq[(String, String)] = Seq(SIZE_PROPERTY -> size) + def rollingInterval(interval: String): Seq[(String, String)] = + Seq(INTERVAL_PROPERTY -> interval) val msInDay = 24 * 60 * 60 * 1000L val msInHour = 60 * 60 * 1000L diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 0357fc6ce278..a2be724254d7 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -37,6 +37,9 @@ class JsonProtocolSuite extends FunSuite { val jobSubmissionTime = 1421191042750L val jobCompletionTime = 1421191296660L + val executorAddedTime = 1421458410000L + val executorRemovedTime = 1421458922000L + test("SparkListenerEvent") { val stageSubmitted = SparkListenerStageSubmitted(makeStageInfo(100, 200, 300, 400L, 500L), properties) @@ -73,9 +76,10 @@ class JsonProtocolSuite extends FunSuite { val unpersistRdd = SparkListenerUnpersistRDD(12345) val applicationStart = SparkListenerApplicationStart("The winner of all", None, 42L, "Garfield") val applicationEnd = SparkListenerApplicationEnd(42L) - val executorAdded = SparkListenerExecutorAdded("exec1", - new ExecutorInfo("Hostee.awesome.com", 11)) - val executorRemoved = SparkListenerExecutorRemoved("exec2") + val logUrlMap = Map("stderr" -> "mystderr", "stdout" -> "mystdout").toMap + val executorAdded = SparkListenerExecutorAdded(executorAddedTime, "exec1", + new ExecutorInfo("Hostee.awesome.com", 11, logUrlMap)) + val executorRemoved = SparkListenerExecutorRemoved(executorRemovedTime, "exec2", "test reason") testEvent(stageSubmitted, stageSubmittedJsonString) testEvent(stageCompleted, stageCompletedJsonString) @@ -97,13 +101,14 @@ class JsonProtocolSuite extends FunSuite { } test("Dependent Classes") { + val logUrlMap = Map("stderr" -> "mystderr", "stdout" -> "mystdout").toMap testRDDInfo(makeRddInfo(2, 3, 4, 5L, 6L)) testStageInfo(makeStageInfo(10, 20, 30, 40L, 50L)) testTaskInfo(makeTaskInfo(999L, 888, 55, 777L, false)) testTaskMetrics(makeTaskMetrics( 33333L, 44444L, 55555L, 66666L, 7, 8, hasHadoopInput = false, hasOutput = false)) testBlockManagerId(BlockManagerId("Hong", "Kong", 500)) - testExecutorInfo(new ExecutorInfo("host", 43)) + testExecutorInfo(new ExecutorInfo("host", 43, logUrlMap)) // StorageLevel testStorageLevel(StorageLevel.NONE) @@ -184,6 +189,34 @@ class JsonProtocolSuite extends FunSuite { assert(newMetrics.inputMetrics.isEmpty) } + test("Input/Output records backwards compatibility") { + // records read were added after 1.2 + val metrics = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, + hasHadoopInput = true, hasOutput = true, hasRecords = false) + assert(metrics.inputMetrics.nonEmpty) + assert(metrics.outputMetrics.nonEmpty) + val newJson = JsonProtocol.taskMetricsToJson(metrics) + val oldJson = newJson.removeField { case (field, _) => field == "Records Read" } + .removeField { case (field, _) => field == "Records Written" } + val newMetrics = JsonProtocol.taskMetricsFromJson(oldJson) + assert(newMetrics.inputMetrics.get.recordsRead == 0) + assert(newMetrics.outputMetrics.get.recordsWritten == 0) + } + + test("Shuffle Read/Write records backwards compatibility") { + // records read were added after 1.2 + val metrics = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, + hasHadoopInput = false, hasOutput = false, hasRecords = false) + assert(metrics.shuffleReadMetrics.nonEmpty) + assert(metrics.shuffleWriteMetrics.nonEmpty) + val newJson = JsonProtocol.taskMetricsToJson(metrics) + val oldJson = newJson.removeField { case (field, _) => field == "Total Records Read" } + .removeField { case (field, _) => field == "Shuffle Records Written" } + val newMetrics = JsonProtocol.taskMetricsFromJson(oldJson) + assert(newMetrics.shuffleReadMetrics.get.recordsRead == 0) + assert(newMetrics.shuffleWriteMetrics.get.shuffleRecordsWritten == 0) + } + test("OutputMetrics backward compatibility") { // OutputMetrics were added after 1.1 val metrics = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, hasHadoopInput = false, hasOutput = true) @@ -227,6 +260,18 @@ class JsonProtocolSuite extends FunSuite { assert(expectedFetchFailed === JsonProtocol.taskEndReasonFromJson(oldEvent)) } + test("ShuffleReadMetrics: Local bytes read and time taken backwards compatibility") { + // Metrics about local shuffle bytes read and local read time were added in 1.3.1. + val metrics = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, + hasHadoopInput = false, hasOutput = false, hasRecords = false) + assert(metrics.shuffleReadMetrics.nonEmpty) + val newJson = JsonProtocol.taskMetricsToJson(metrics) + val oldJson = newJson.removeField { case (field, _) => field == "Local Bytes Read" } + .removeField { case (field, _) => field == "Local Read Time" } + val newMetrics = JsonProtocol.taskMetricsFromJson(oldJson) + assert(newMetrics.shuffleReadMetrics.get.localBytesRead == 0) + } + test("SparkListenerApplicationStart backwards compatibility") { // SparkListenerApplicationStart in Spark 1.0.0 do not have an "appId" property. val applicationStart = SparkListenerApplicationStart("test", None, 1L, "user") @@ -639,7 +684,8 @@ class JsonProtocolSuite extends FunSuite { e: Int, f: Int, hasHadoopInput: Boolean, - hasOutput: Boolean) = { + hasOutput: Boolean, + hasRecords: Boolean = true) = { val t = new TaskMetrics t.setHostname("localhost") t.setExecutorDeserializeTime(a) @@ -651,7 +697,8 @@ class JsonProtocolSuite extends FunSuite { if (hasHadoopInput) { val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) - inputMetrics.addBytesRead(d + e + f) + inputMetrics.incBytesRead(d + e + f) + inputMetrics.incRecordsRead(if (hasRecords) (d + e + f) / 100 else -1) t.setInputMetrics(Some(inputMetrics)) } else { val sr = new ShuffleReadMetrics @@ -659,16 +706,20 @@ class JsonProtocolSuite extends FunSuite { sr.incLocalBlocksFetched(e) sr.incFetchWaitTime(a + d) sr.incRemoteBlocksFetched(f) + sr.incRecordsRead(if (hasRecords) (b + d) / 100 else -1) + sr.incLocalBytesRead(a + f) t.setShuffleReadMetrics(Some(sr)) } if (hasOutput) { val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop) outputMetrics.setBytesWritten(a + b + c) + outputMetrics.setRecordsWritten(if (hasRecords) (a + b + c)/100 else -1) t.outputMetrics = Some(outputMetrics) } else { val sw = new ShuffleWriteMetrics sw.incShuffleBytesWritten(a + b + c) sw.incShuffleWriteTime(b + c + d) + sw.setShuffleRecordsWritten(if (hasRecords) (a + b + c) / 100 else -1) t.shuffleWriteMetrics = Some(sw) } // Make at most 6 blocks @@ -902,11 +953,14 @@ class JsonProtocolSuite extends FunSuite { | "Remote Blocks Fetched": 800, | "Local Blocks Fetched": 700, | "Fetch Wait Time": 900, - | "Remote Bytes Read": 1000 + | "Remote Bytes Read": 1000, + | "Local Bytes Read": 1100, + | "Total Records Read" : 10 | }, | "Shuffle Write Metrics": { | "Shuffle Bytes Written": 1200, - | "Shuffle Write Time": 1500 + | "Shuffle Write Time": 1500, + | "Shuffle Records Written": 12 | }, | "Updated Blocks": [ | { @@ -983,11 +1037,13 @@ class JsonProtocolSuite extends FunSuite { | "Disk Bytes Spilled": 0, | "Shuffle Write Metrics": { | "Shuffle Bytes Written": 1200, - | "Shuffle Write Time": 1500 + | "Shuffle Write Time": 1500, + | "Shuffle Records Written": 12 | }, | "Input Metrics": { | "Data Read Method": "Hadoop", - | "Bytes Read": 2100 + | "Bytes Read": 2100, + | "Records Read": 21 | }, | "Updated Blocks": [ | { @@ -1064,11 +1120,13 @@ class JsonProtocolSuite extends FunSuite { | "Disk Bytes Spilled": 0, | "Input Metrics": { | "Data Read Method": "Hadoop", - | "Bytes Read": 2100 + | "Bytes Read": 2100, + | "Records Read": 21 | }, | "Output Metrics": { | "Data Write Method": "Hadoop", - | "Bytes Written": 1200 + | "Bytes Written": 1200, + | "Records Written": 12 | }, | "Updated Blocks": [ | { @@ -1453,22 +1511,29 @@ class JsonProtocolSuite extends FunSuite { """ private val executorAddedJsonString = - """ + s""" |{ | "Event": "SparkListenerExecutorAdded", + | "Timestamp": ${executorAddedTime}, | "Executor ID": "exec1", | "Executor Info": { | "Host": "Hostee.awesome.com", - | "Total Cores": 11 + | "Total Cores": 11, + | "Log Urls" : { + | "stderr" : "mystderr", + | "stdout" : "mystdout" + | } | } |} """ private val executorRemovedJsonString = - """ + s""" |{ | "Event": "SparkListenerExecutorRemoved", - | "Executor ID": "exec2" + | "Timestamp": ${executorRemovedTime}, + | "Executor ID": "exec2", + | "Removed Reason": "test reason" |} """ } diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorURLClassLoaderSuite.scala b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala similarity index 73% rename from core/src/test/scala/org/apache/spark/executor/ExecutorURLClassLoaderSuite.scala rename to core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala index e2050e95a1b8..87de90bb0dfb 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorURLClassLoaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala @@ -15,41 +15,48 @@ * limitations under the License. */ -package org.apache.spark.executor +package org.apache.spark.util import java.net.URLClassLoader import org.scalatest.FunSuite -import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, TestUtils} -import org.apache.spark.util.Utils +import org.apache.spark.{SparkContext, SparkException, TestUtils} -class ExecutorURLClassLoaderSuite extends FunSuite { +class MutableURLClassLoaderSuite extends FunSuite { - val childClassNames = List("FakeClass1", "FakeClass2") - val parentClassNames = List("FakeClass1", "FakeClass2", "FakeClass3") - val urls = List(TestUtils.createJarWithClasses(childClassNames, "1")).toArray - val urls2 = List(TestUtils.createJarWithClasses(parentClassNames, "2")).toArray + val urls2 = List(TestUtils.createJarWithClasses( + classNames = Seq("FakeClass1", "FakeClass2", "FakeClass3"), + toStringValue = "2")).toArray + val urls = List(TestUtils.createJarWithClasses( + classNames = Seq("FakeClass1"), + classNamesWithBase = Seq(("FakeClass2", "FakeClass3")), // FakeClass3 is in parent + toStringValue = "1", + classpathUrls = urls2)).toArray test("child first") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ChildExecutorURLClassLoader(urls, parentLoader) + val classLoader = new ChildFirstURLClassLoader(urls, parentLoader) val fakeClass = classLoader.loadClass("FakeClass2").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "1") + val fakeClass2 = classLoader.loadClass("FakeClass2").newInstance() + assert(fakeClass.getClass === fakeClass2.getClass) } test("parent first") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorURLClassLoader(urls, parentLoader) + val classLoader = new MutableURLClassLoader(urls, parentLoader) val fakeClass = classLoader.loadClass("FakeClass1").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") + val fakeClass2 = classLoader.loadClass("FakeClass1").newInstance() + assert(fakeClass.getClass === fakeClass2.getClass) } test("child first can fall back") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ChildExecutorURLClassLoader(urls, parentLoader) + val classLoader = new ChildFirstURLClassLoader(urls, parentLoader) val fakeClass = classLoader.loadClass("FakeClass3").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") @@ -57,7 +64,7 @@ class ExecutorURLClassLoaderSuite extends FunSuite { test("child first can fail") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ChildExecutorURLClassLoader(urls, parentLoader) + val classLoader = new ChildFirstURLClassLoader(urls, parentLoader) intercept[java.lang.ClassNotFoundException] { classLoader.loadClass("FakeClassDoesNotExist").newInstance() } diff --git a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala index 72e81f3f1a88..403dcb03bd6e 100644 --- a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala @@ -71,7 +71,7 @@ class NextIteratorSuite extends FunSuite with Matchers { class StubIterator(ints: Buffer[Int]) extends NextIterator[Int] { var closeCalled = 0 - override def getNext() = { + override def getNext(): Int = { if (ints.size == 0) { finished = true 0 diff --git a/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala b/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala index d4b92f33dd9e..bad1aa99952c 100644 --- a/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala +++ b/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala @@ -19,6 +19,7 @@ package org.apache.spark.util import java.util.Properties +import org.apache.commons.lang3.SerializationUtils import org.scalatest.{BeforeAndAfterEach, Suite} /** @@ -42,7 +43,11 @@ private[spark] trait ResetSystemProperties extends BeforeAndAfterEach { this: Su var oldProperties: Properties = null override def beforeEach(): Unit = { - oldProperties = new Properties(System.getProperties) + // we need SerializationUtils.clone instead of `new Properties(System.getProperties()` because + // the later way of creating a copy does not copy the properties but it initializes a new + // Properties object with the given properties as defaults. They are not recognized at all + // by standard Scala wrapper over Java Properties then. + oldProperties = SerializationUtils.clone(System.getProperties) super.beforeEach() } diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 7424c2e91d4f..67a9f75ff218 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -98,8 +98,10 @@ class SizeEstimatorSuite // If an array contains the *same* element many times, we should only count it once. val d1 = new DummyClass1 - assertResult(72)(SizeEstimator.estimate(Array.fill(10)(d1))) // 10 pointers plus 8-byte object - assertResult(432)(SizeEstimator.estimate(Array.fill(100)(d1))) // 100 pointers plus 8-byte object + // 10 pointers plus 8-byte object + assertResult(72)(SizeEstimator.estimate(Array.fill(10)(d1))) + // 100 pointers plus 8-byte object + assertResult(432)(SizeEstimator.estimate(Array.fill(100)(d1))) // Same thing with huge array containing the same element many times. Note that this won't // return exactly 4032 because it can't tell that *all* the elements will equal the first diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala new file mode 100644 index 000000000000..a3aa3e953fbe --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.util + +import java.util.concurrent.{CountDownLatch, TimeUnit} + +import org.scalatest.FunSuite + +class ThreadUtilsSuite extends FunSuite { + + test("newDaemonSingleThreadExecutor") { + val executor = ThreadUtils.newDaemonSingleThreadExecutor("this-is-a-thread-name") + @volatile var threadName = "" + executor.submit(new Runnable { + override def run(): Unit = { + threadName = Thread.currentThread().getName() + } + }) + executor.shutdown() + executor.awaitTermination(10, TimeUnit.SECONDS) + assert(threadName === "this-is-a-thread-name") + } + + test("newDaemonSingleThreadScheduledExecutor") { + val executor = ThreadUtils.newDaemonSingleThreadScheduledExecutor("this-is-a-thread-name") + try { + val latch = new CountDownLatch(1) + @volatile var threadName = "" + executor.schedule(new Runnable { + override def run(): Unit = { + threadName = Thread.currentThread().getName() + latch.countDown() + } + }, 1, TimeUnit.MILLISECONDS) + latch.await(10, TimeUnit.SECONDS) + assert(threadName === "this-is-a-thread-name") + } finally { + executor.shutdownNow() + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala index c1c605cdb487..8b72fe665c21 100644 --- a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala @@ -63,7 +63,7 @@ class TimeStampedHashMapSuite extends FunSuite { assert(map1.getTimestamp("k1").get < threshTime1) assert(map1.getTimestamp("k2").isDefined) assert(map1.getTimestamp("k2").get >= threshTime1) - map1.clearOldValues(threshTime1) //should only clear k1 + map1.clearOldValues(threshTime1) // should only clear k1 assert(map1.get("k1") === None) assert(map1.get("k2").isDefined) } @@ -93,7 +93,7 @@ class TimeStampedHashMapSuite extends FunSuite { assert(map1.getTimestamp("k1").get < threshTime1) assert(map1.getTimestamp("k2").isDefined) assert(map1.getTimestamp("k2").get >= threshTime1) - map1.clearOldValues(threshTime1) //should only clear k1 + map1.clearOldValues(threshTime1) // should only clear k1 assert(map1.get("k1") === None) assert(map1.get("k2").isDefined) } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 4544382094f9..1ba99803f5a0 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -17,22 +17,71 @@ package org.apache.spark.util -import scala.util.Random - import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream} import java.net.{BindException, ServerSocket, URI} import java.nio.{ByteBuffer, ByteOrder} import java.text.DecimalFormatSymbols +import java.util.concurrent.TimeUnit import java.util.Locale +import java.util.PriorityQueue + +import scala.collection.mutable.ListBuffer +import scala.util.Random import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.scalatest.FunSuite +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + import org.apache.spark.SparkConf class UtilsSuite extends FunSuite with ResetSystemProperties { + test("timeConversion") { + // Test -1 + assert(Utils.timeStringAsSeconds("-1") === -1) + + // Test zero + assert(Utils.timeStringAsSeconds("0") === 0) + + assert(Utils.timeStringAsSeconds("1") === 1) + assert(Utils.timeStringAsSeconds("1s") === 1) + assert(Utils.timeStringAsSeconds("1000ms") === 1) + assert(Utils.timeStringAsSeconds("1000000us") === 1) + assert(Utils.timeStringAsSeconds("1m") === TimeUnit.MINUTES.toSeconds(1)) + assert(Utils.timeStringAsSeconds("1min") === TimeUnit.MINUTES.toSeconds(1)) + assert(Utils.timeStringAsSeconds("1h") === TimeUnit.HOURS.toSeconds(1)) + assert(Utils.timeStringAsSeconds("1d") === TimeUnit.DAYS.toSeconds(1)) + + assert(Utils.timeStringAsMs("1") === 1) + assert(Utils.timeStringAsMs("1ms") === 1) + assert(Utils.timeStringAsMs("1000us") === 1) + assert(Utils.timeStringAsMs("1s") === TimeUnit.SECONDS.toMillis(1)) + assert(Utils.timeStringAsMs("1m") === TimeUnit.MINUTES.toMillis(1)) + assert(Utils.timeStringAsMs("1min") === TimeUnit.MINUTES.toMillis(1)) + assert(Utils.timeStringAsMs("1h") === TimeUnit.HOURS.toMillis(1)) + assert(Utils.timeStringAsMs("1d") === TimeUnit.DAYS.toMillis(1)) + + // Test invalid strings + intercept[NumberFormatException] { + Utils.timeStringAsMs("This breaks 600s") + } + + intercept[NumberFormatException] { + Utils.timeStringAsMs("This breaks 600ds") + } + + intercept[NumberFormatException] { + Utils.timeStringAsMs("600s This breaks") + } + + intercept[NumberFormatException] { + Utils.timeStringAsMs("This 123s breaks") + } + } + test("bytesToString") { assert(Utils.bytesToString(10) === "10.0 B") assert(Utils.bytesToString(1500) === "1500.0 B") @@ -103,7 +152,7 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { val second = 1000 val minute = second * 60 val hour = minute * 60 - def str = Utils.msDurationToString(_) + def str: (Long) => String = Utils.msDurationToString(_) val sep = new DecimalFormatSymbols(Locale.getDefault()).getDecimalSeparator() @@ -119,7 +168,6 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { test("reading offset bytes of a file") { val tmpDir2 = Utils.createTempDir() - tmpDir2.deleteOnExit() val f1Path = tmpDir2 + "/f1" val f1 = new FileOutputStream(f1Path) f1.write("1\n2\n3\n4\n5\n6\n7\n8\n9\n".getBytes(UTF_8)) @@ -148,7 +196,6 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { test("reading offset bytes across multiple files") { val tmpDir = Utils.createTempDir() - tmpDir.deleteOnExit() val files = (1 to 3).map(i => new File(tmpDir, i.toString)) Files.write("0123456789", files(0), UTF_8) Files.write("abcdefghij", files(1), UTF_8) @@ -198,25 +245,26 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { test("doesDirectoryContainFilesNewerThan") { // create some temporary directories and files val parent: File = Utils.createTempDir() - val child1: File = Utils.createTempDir(parent.getCanonicalPath) // The parent directory has two child directories + // The parent directory has two child directories + val child1: File = Utils.createTempDir(parent.getCanonicalPath) val child2: File = Utils.createTempDir(parent.getCanonicalPath) val child3: File = Utils.createTempDir(child1.getCanonicalPath) // set the last modified time of child1 to 30 secs old child1.setLastModified(System.currentTimeMillis() - (1000 * 30)) // although child1 is old, child2 is still new so return true - assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5)) + assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5)) child2.setLastModified(System.currentTimeMillis - (1000 * 30)) - assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5)) + assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5)) parent.setLastModified(System.currentTimeMillis - (1000 * 30)) // although parent and its immediate children are new, child3 is still old // we expect a full recursive search for new files. - assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5)) + assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5)) child3.setLastModified(System.currentTimeMillis - (1000 * 30)) - assert(!Utils.doesDirectoryContainAnyNewFiles(parent, 5)) + assert(!Utils.doesDirectoryContainAnyNewFiles(parent, 5)) } test("resolveURI") { @@ -336,25 +384,26 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { assert(!tempDir1.exists()) val tempDir2 = Utils.createTempDir() - val tempFile1 = new File(tempDir2, "foo.txt") - Files.touch(tempFile1) - assert(tempFile1.exists()) - Utils.deleteRecursively(tempFile1) - assert(!tempFile1.exists()) + val sourceFile1 = new File(tempDir2, "foo.txt") + Files.touch(sourceFile1) + assert(sourceFile1.exists()) + Utils.deleteRecursively(sourceFile1) + assert(!sourceFile1.exists()) val tempDir3 = new File(tempDir2, "subdir") assert(tempDir3.mkdir()) - val tempFile2 = new File(tempDir3, "bar.txt") - Files.touch(tempFile2) - assert(tempFile2.exists()) + val sourceFile2 = new File(tempDir3, "bar.txt") + Files.touch(sourceFile2) + assert(sourceFile2.exists()) Utils.deleteRecursively(tempDir2) assert(!tempDir2.exists()) assert(!tempDir3.exists()) - assert(!tempFile2.exists()) + assert(!sourceFile2.exists()) } test("loading properties from file") { - val outFile = File.createTempFile("test-load-spark-properties", "test") + val tmpDir = Utils.createTempDir() + val outFile = File.createTempFile("test-load-spark-properties", "test", tmpDir) try { System.setProperty("spark.test.fileNameLoadB", "2") Files.write("spark.test.fileNameLoadA true\n" + @@ -367,7 +416,7 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { assert(sparkConf.getBoolean("spark.test.fileNameLoadA", false) === true) assert(sparkConf.getInt("spark.test.fileNameLoadB", 1) === 2) } finally { - outFile.delete() + Utils.deleteRecursively(tmpDir) } } @@ -381,4 +430,56 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { require(cnt === 2, "prepare should be called twice") require(time < 500, "preparation time should not count") } + + test("fetch hcfs dir") { + val tempDir = Utils.createTempDir() + val sourceDir = new File(tempDir, "source-dir") + val innerSourceDir = Utils.createTempDir(root=sourceDir.getPath) + val sourceFile = File.createTempFile("someprefix", "somesuffix", innerSourceDir) + val targetDir = new File(tempDir, "target-dir") + Files.write("some text", sourceFile, UTF_8) + + val path = new Path("file://" + sourceDir.getAbsolutePath) + val conf = new Configuration() + val fs = Utils.getHadoopFileSystem(path.toString, conf) + + assert(!targetDir.isDirectory()) + Utils.fetchHcfsFile(path, targetDir, fs, new SparkConf(), conf, false) + assert(targetDir.isDirectory()) + + // Copy again to make sure it doesn't error if the dir already exists. + Utils.fetchHcfsFile(path, targetDir, fs, new SparkConf(), conf, false) + + val destDir = new File(targetDir, sourceDir.getName()) + assert(destDir.isDirectory()) + + val destInnerDir = new File(destDir, innerSourceDir.getName) + assert(destInnerDir.isDirectory()) + + val destInnerFile = new File(destInnerDir, sourceFile.getName) + assert(destInnerFile.isFile()) + + val filePath = new Path("file://" + sourceFile.getAbsolutePath) + val testFileDir = new File(tempDir, "test-filename") + val testFileName = "testFName" + val testFilefs = Utils.getHadoopFileSystem(filePath.toString, conf) + Utils.fetchHcfsFile(filePath, testFileDir, testFilefs, new SparkConf(), + conf, false, Some(testFileName)) + val newFileName = new File(testFileDir, testFileName) + assert(newFileName.isFile()) + } + + test("shutdown hook manager") { + val manager = new SparkShutdownHookManager() + val output = new ListBuffer[Int]() + + val hook1 = manager.add(1, () => output += 1) + manager.add(3, () => output += 3) + manager.add(2, () => output += 2) + manager.add(4, () => output += 4) + manager.remove(hook1) + + manager.runAll() + assert(output.toList === List(4, 3, 2)) + } } diff --git a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala index 794a55d61750..ce2968728a99 100644 --- a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.FunSuite @deprecated("suppress compile time deprecation warning", "1.0.0") class VectorSuite extends FunSuite { - def verifyVector(vector: Vector, expectedLength: Int) = { + def verifyVector(vector: Vector, expectedLength: Int): Unit = { assert(vector.length == expectedLength) assert(vector.elements.min > 0.0) assert(vector.elements.max < 1.0) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 48f79ea65101..dff8f3ddc816 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -185,7 +185,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { // reduceByKey val rdd = sc.parallelize(1 to 10).map(i => (i%2, 1)) - val result1 = rdd.reduceByKey(_+_).collect() + val result1 = rdd.reduceByKey(_ + _).collect() assert(result1.toSet === Set[(Int, Int)]((0, 5), (1, 5))) // groupByKey diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 72d96798b114..9ff067f86af4 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -553,10 +553,10 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) - def createCombiner(i: String) = ArrayBuffer[String](i) - def mergeValue(buffer: ArrayBuffer[String], i: String) = buffer += i - def mergeCombiners(buffer1: ArrayBuffer[String], buffer2: ArrayBuffer[String]) = - buffer1 ++= buffer2 + def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) + def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i + def mergeCombiners(buffer1: ArrayBuffer[String], buffer2: ArrayBuffer[String]) + : ArrayBuffer[String] = buffer1 ++= buffer2 val agg = new Aggregator[String, String, ArrayBuffer[String]]( createCombiner _, mergeValue _, mergeCombiners _) @@ -633,14 +633,17 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) - def createCombiner(i: Int) = ArrayBuffer[Int](i) - def mergeValue(buffer: ArrayBuffer[Int], i: Int) = buffer += i - def mergeCombiners(buf1: ArrayBuffer[Int], buf2: ArrayBuffer[Int]) = buf1 ++= buf2 + def createCombiner(i: Int): ArrayBuffer[Int] = ArrayBuffer[Int](i) + def mergeValue(buffer: ArrayBuffer[Int], i: Int): ArrayBuffer[Int] = buffer += i + def mergeCombiners(buf1: ArrayBuffer[Int], buf2: ArrayBuffer[Int]): ArrayBuffer[Int] = { + buf1 ++= buf2 + } val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners) val sorter = new ExternalSorter[Int, Int, ArrayBuffer[Int]](Some(agg), None, None, None) - sorter.insertAll((1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue))) + sorter.insertAll( + (1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue))) val it = sorter.iterator while (it.hasNext) { @@ -654,9 +657,10 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) - def createCombiner(i: String) = ArrayBuffer[String](i) - def mergeValue(buffer: ArrayBuffer[String], i: String) = buffer += i - def mergeCombiners(buf1: ArrayBuffer[String], buf2: ArrayBuffer[String]) = buf1 ++= buf2 + def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) + def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i + def mergeCombiners(buf1: ArrayBuffer[String], buf2: ArrayBuffer[String]): ArrayBuffer[String] = + buf1 ++= buf2 val agg = new Aggregator[String, String, ArrayBuffer[String]]( createCombiner, mergeValue, mergeCombiners) @@ -720,7 +724,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe // Using wrongOrdering to show integer overflow introduced exception. val rand = new Random(100L) val wrongOrdering = new Ordering[String] { - override def compare(a: String, b: String) = { + override def compare(a: String, b: String): Int = { val h1 = if (a == null) 0 else a.hashCode() val h2 = if (b == null) 0 else b.hashCode() h1 - h2 @@ -742,9 +746,10 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe // Using aggregation and external spill to make sure ExternalSorter using // partitionKeyComparator. - def createCombiner(i: String) = ArrayBuffer(i) - def mergeValue(c: ArrayBuffer[String], i: String) = c += i - def mergeCombiners(c1: ArrayBuffer[String], c2: ArrayBuffer[String]) = c1 ++= c2 + def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer(i) + def mergeValue(c: ArrayBuffer[String], i: String): ArrayBuffer[String] = c += i + def mergeCombiners(c1: ArrayBuffer[String], c2: ArrayBuffer[String]): ArrayBuffer[String] = + c1 ++= c2 val agg = new Aggregator[String, String, ArrayBuffer[String]]( createCombiner, mergeValue, mergeCombiners) diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala index 6a7087735640..ef890d2ba60f 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -176,4 +176,14 @@ class OpenHashMapSuite extends FunSuite with Matchers { assert(map(i.toString) === i.toString) } } + + test("contains") { + val map = new OpenHashMap[String, Int](2) + map("a") = 1 + assert(map.contains("a")) + assert(!map.contains("b")) + assert(!map.contains(null)) + map(null) = 0 + assert(map.contains(null)) + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala index 8c7df7d73dcd..caf378fec8b3 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala @@ -118,4 +118,11 @@ class PrimitiveKeyOpenHashMapSuite extends FunSuite with Matchers { assert(map(i.toLong) === i.toString) } } + + test("contains") { + val map = new PrimitiveKeyOpenHashMap[Int, Int](1) + map(0) = 0 + assert(map.contains(0)) + assert(!map.contains(1)) + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala index 0cb1ed739765..e0d6cc16bde0 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala @@ -65,6 +65,13 @@ class SorterSuite extends FunSuite { } } + // http://www.envisage-project.eu/timsort-specification-and-verification/ + test("SPARK-5984 TimSort bug") { + val data = TestTimSort.getTimSortBugTestSet(67108864) + new Sorter(new IntArraySortDataFormat).sort(data, 0, data.length, Ordering.Int) + (0 to data.length - 2).foreach(i => assert(data(i) <= data(i + 1))) + } + /** Runs an experiment several times. */ def runExperiment(name: String, skip: Boolean = false)(f: => Unit, prepare: () => Unit): Unit = { if (skip) { diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala index ef7178bcdf5c..03f5f2d1b852 100644 --- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala @@ -28,7 +28,7 @@ import scala.language.reflectiveCalls class XORShiftRandomSuite extends FunSuite with Matchers { - def fixture = new { + def fixture: Object {val seed: Long; val hundMil: Int; val xorRand: XORShiftRandom} = new { val seed = 1L val xorRand = new XORShiftRandom(seed) val hundMil = 1e8.toInt diff --git a/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala b/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala index 4918e2d92beb..daa795a04349 100644 --- a/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala +++ b/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala @@ -44,13 +44,21 @@ class ImplicitSuite { } def testRddToSequenceFileRDDFunctions(): Unit = { - // TODO eliminating `import intToIntWritable` needs refactoring SequenceFileRDDFunctions. - // That will be a breaking change. - import org.apache.spark.SparkContext.intToIntWritable val rdd: org.apache.spark.rdd.RDD[(Int, Int)] = mockRDD rdd.saveAsSequenceFile("/a/test/path") } + def testRddToSequenceFileRDDFunctionsWithWritable(): Unit = { + val rdd: org.apache.spark.rdd.RDD[(org.apache.hadoop.io.IntWritable, org.apache.hadoop.io.Text)] + = mockRDD + rdd.saveAsSequenceFile("/a/test/path") + } + + def testRddToSequenceFileRDDFunctionsWithBytesArray(): Unit = { + val rdd: org.apache.spark.rdd.RDD[(Int, Array[Byte])] = mockRDD + rdd.saveAsSequenceFile("/a/test/path") + } + def testRddToOrderedRDDFunctions(): Unit = { val rdd: org.apache.spark.rdd.RDD[(Int, Int)] = mockRDD rdd.sortByKey() diff --git a/data/mllib/als/sample_movielens_movies.txt b/data/mllib/als/sample_movielens_movies.txt new file mode 100644 index 000000000000..934a0253849e --- /dev/null +++ b/data/mllib/als/sample_movielens_movies.txt @@ -0,0 +1,100 @@ +0::Movie 0::Romance|Comedy +1::Movie 1::Action|Anime +2::Movie 2::Romance|Thriller +3::Movie 3::Action|Romance +4::Movie 4::Anime|Comedy +5::Movie 5::Action|Action +6::Movie 6::Action|Comedy +7::Movie 7::Anime|Comedy +8::Movie 8::Comedy|Action +9::Movie 9::Anime|Thriller +10::Movie 10::Action|Anime +11::Movie 11::Action|Anime +12::Movie 12::Anime|Comedy +13::Movie 13::Thriller|Action +14::Movie 14::Anime|Comedy +15::Movie 15::Comedy|Thriller +16::Movie 16::Anime|Romance +17::Movie 17::Thriller|Action +18::Movie 18::Action|Comedy +19::Movie 19::Anime|Romance +20::Movie 20::Action|Anime +21::Movie 21::Romance|Thriller +22::Movie 22::Romance|Romance +23::Movie 23::Comedy|Comedy +24::Movie 24::Anime|Action +25::Movie 25::Comedy|Comedy +26::Movie 26::Anime|Romance +27::Movie 27::Anime|Anime +28::Movie 28::Thriller|Anime +29::Movie 29::Anime|Romance +30::Movie 30::Thriller|Romance +31::Movie 31::Thriller|Romance +32::Movie 32::Comedy|Anime +33::Movie 33::Comedy|Comedy +34::Movie 34::Anime|Anime +35::Movie 35::Action|Thriller +36::Movie 36::Anime|Romance +37::Movie 37::Romance|Anime +38::Movie 38::Thriller|Romance +39::Movie 39::Romance|Comedy +40::Movie 40::Action|Anime +41::Movie 41::Comedy|Thriller +42::Movie 42::Comedy|Action +43::Movie 43::Thriller|Anime +44::Movie 44::Anime|Action +45::Movie 45::Comedy|Romance +46::Movie 46::Comedy|Action +47::Movie 47::Romance|Comedy +48::Movie 48::Action|Comedy +49::Movie 49::Romance|Romance +50::Movie 50::Comedy|Romance +51::Movie 51::Action|Action +52::Movie 52::Thriller|Action +53::Movie 53::Action|Action +54::Movie 54::Romance|Thriller +55::Movie 55::Anime|Romance +56::Movie 56::Comedy|Action +57::Movie 57::Action|Anime +58::Movie 58::Thriller|Romance +59::Movie 59::Thriller|Comedy +60::Movie 60::Anime|Comedy +61::Movie 61::Comedy|Action +62::Movie 62::Comedy|Romance +63::Movie 63::Romance|Thriller +64::Movie 64::Romance|Action +65::Movie 65::Anime|Romance +66::Movie 66::Comedy|Action +67::Movie 67::Thriller|Anime +68::Movie 68::Thriller|Romance +69::Movie 69::Action|Comedy +70::Movie 70::Thriller|Thriller +71::Movie 71::Action|Comedy +72::Movie 72::Thriller|Romance +73::Movie 73::Comedy|Action +74::Movie 74::Action|Action +75::Movie 75::Action|Action +76::Movie 76::Comedy|Comedy +77::Movie 77::Comedy|Comedy +78::Movie 78::Comedy|Comedy +79::Movie 79::Thriller|Thriller +80::Movie 80::Comedy|Anime +81::Movie 81::Comedy|Anime +82::Movie 82::Romance|Anime +83::Movie 83::Comedy|Thriller +84::Movie 84::Anime|Action +85::Movie 85::Thriller|Anime +86::Movie 86::Romance|Anime +87::Movie 87::Thriller|Thriller +88::Movie 88::Romance|Thriller +89::Movie 89::Action|Anime +90::Movie 90::Anime|Romance +91::Movie 91::Anime|Thriller +92::Movie 92::Action|Comedy +93::Movie 93::Romance|Thriller +94::Movie 94::Thriller|Comedy +95::Movie 95::Action|Action +96::Movie 96::Thriller|Romance +97::Movie 97::Thriller|Thriller +98::Movie 98::Thriller|Comedy +99::Movie 99::Thriller|Romance diff --git a/data/mllib/als/sample_movielens_ratings.txt b/data/mllib/als/sample_movielens_ratings.txt new file mode 100644 index 000000000000..088914295079 --- /dev/null +++ b/data/mllib/als/sample_movielens_ratings.txt @@ -0,0 +1,1501 @@ +0::2::3::1424380312 +0::3::1::1424380312 +0::5::2::1424380312 +0::9::4::1424380312 +0::11::1::1424380312 +0::12::2::1424380312 +0::15::1::1424380312 +0::17::1::1424380312 +0::19::1::1424380312 +0::21::1::1424380312 +0::23::1::1424380312 +0::26::3::1424380312 +0::27::1::1424380312 +0::28::1::1424380312 +0::29::1::1424380312 +0::30::1::1424380312 +0::31::1::1424380312 +0::34::1::1424380312 +0::37::1::1424380312 +0::41::2::1424380312 +0::44::1::1424380312 +0::45::2::1424380312 +0::46::1::1424380312 +0::47::1::1424380312 +0::48::1::1424380312 +0::50::1::1424380312 +0::51::1::1424380312 +0::54::1::1424380312 +0::55::1::1424380312 +0::59::2::1424380312 +0::61::2::1424380312 +0::64::1::1424380312 +0::67::1::1424380312 +0::68::1::1424380312 +0::69::1::1424380312 +0::71::1::1424380312 +0::72::1::1424380312 +0::77::2::1424380312 +0::79::1::1424380312 +0::83::1::1424380312 +0::87::1::1424380312 +0::89::2::1424380312 +0::91::3::1424380312 +0::92::4::1424380312 +0::94::1::1424380312 +0::95::2::1424380312 +0::96::1::1424380312 +0::98::1::1424380312 +0::99::1::1424380312 +1::2::2::1424380312 +1::3::1::1424380312 +1::4::2::1424380312 +1::6::1::1424380312 +1::9::3::1424380312 +1::12::1::1424380312 +1::13::1::1424380312 +1::14::1::1424380312 +1::16::1::1424380312 +1::19::1::1424380312 +1::21::3::1424380312 +1::27::1::1424380312 +1::28::3::1424380312 +1::33::1::1424380312 +1::36::2::1424380312 +1::37::1::1424380312 +1::40::1::1424380312 +1::41::2::1424380312 +1::43::1::1424380312 +1::44::1::1424380312 +1::47::1::1424380312 +1::50::1::1424380312 +1::54::1::1424380312 +1::56::2::1424380312 +1::57::1::1424380312 +1::58::1::1424380312 +1::60::1::1424380312 +1::62::4::1424380312 +1::63::1::1424380312 +1::67::1::1424380312 +1::68::4::1424380312 +1::70::2::1424380312 +1::72::1::1424380312 +1::73::1::1424380312 +1::74::2::1424380312 +1::76::1::1424380312 +1::77::3::1424380312 +1::78::1::1424380312 +1::81::1::1424380312 +1::82::1::1424380312 +1::85::3::1424380312 +1::86::2::1424380312 +1::88::2::1424380312 +1::91::1::1424380312 +1::92::2::1424380312 +1::93::1::1424380312 +1::94::2::1424380312 +1::96::1::1424380312 +1::97::1::1424380312 +2::4::3::1424380312 +2::6::1::1424380312 +2::8::5::1424380312 +2::9::1::1424380312 +2::10::1::1424380312 +2::12::3::1424380312 +2::13::1::1424380312 +2::15::2::1424380312 +2::18::2::1424380312 +2::19::4::1424380312 +2::22::1::1424380312 +2::26::1::1424380312 +2::28::1::1424380312 +2::34::4::1424380312 +2::35::1::1424380312 +2::37::5::1424380312 +2::38::1::1424380312 +2::39::5::1424380312 +2::40::4::1424380312 +2::47::1::1424380312 +2::50::1::1424380312 +2::52::2::1424380312 +2::54::1::1424380312 +2::55::1::1424380312 +2::57::2::1424380312 +2::58::2::1424380312 +2::59::1::1424380312 +2::61::1::1424380312 +2::62::1::1424380312 +2::64::1::1424380312 +2::65::1::1424380312 +2::66::3::1424380312 +2::68::1::1424380312 +2::71::3::1424380312 +2::76::1::1424380312 +2::77::1::1424380312 +2::78::1::1424380312 +2::80::1::1424380312 +2::83::5::1424380312 +2::85::1::1424380312 +2::87::2::1424380312 +2::88::1::1424380312 +2::89::4::1424380312 +2::90::1::1424380312 +2::92::4::1424380312 +2::93::5::1424380312 +3::0::1::1424380312 +3::1::1::1424380312 +3::2::1::1424380312 +3::7::3::1424380312 +3::8::3::1424380312 +3::9::1::1424380312 +3::14::1::1424380312 +3::15::1::1424380312 +3::16::1::1424380312 +3::18::4::1424380312 +3::19::1::1424380312 +3::24::3::1424380312 +3::26::1::1424380312 +3::29::3::1424380312 +3::33::1::1424380312 +3::34::3::1424380312 +3::35::1::1424380312 +3::36::3::1424380312 +3::37::1::1424380312 +3::38::2::1424380312 +3::43::1::1424380312 +3::44::1::1424380312 +3::46::1::1424380312 +3::47::1::1424380312 +3::51::5::1424380312 +3::52::3::1424380312 +3::56::1::1424380312 +3::58::1::1424380312 +3::60::3::1424380312 +3::62::1::1424380312 +3::65::2::1424380312 +3::66::1::1424380312 +3::67::1::1424380312 +3::68::2::1424380312 +3::70::1::1424380312 +3::72::2::1424380312 +3::76::3::1424380312 +3::79::3::1424380312 +3::80::4::1424380312 +3::81::1::1424380312 +3::83::1::1424380312 +3::84::1::1424380312 +3::86::1::1424380312 +3::87::2::1424380312 +3::88::4::1424380312 +3::89::1::1424380312 +3::91::1::1424380312 +3::94::3::1424380312 +4::1::1::1424380312 +4::6::1::1424380312 +4::8::1::1424380312 +4::9::1::1424380312 +4::10::1::1424380312 +4::11::1::1424380312 +4::12::1::1424380312 +4::13::1::1424380312 +4::14::2::1424380312 +4::15::1::1424380312 +4::17::1::1424380312 +4::20::1::1424380312 +4::22::1::1424380312 +4::23::1::1424380312 +4::24::1::1424380312 +4::29::4::1424380312 +4::30::1::1424380312 +4::31::1::1424380312 +4::34::1::1424380312 +4::35::1::1424380312 +4::36::1::1424380312 +4::39::2::1424380312 +4::40::3::1424380312 +4::41::4::1424380312 +4::43::2::1424380312 +4::44::1::1424380312 +4::45::1::1424380312 +4::46::1::1424380312 +4::47::1::1424380312 +4::49::2::1424380312 +4::50::1::1424380312 +4::51::1::1424380312 +4::52::4::1424380312 +4::54::1::1424380312 +4::55::1::1424380312 +4::60::3::1424380312 +4::61::1::1424380312 +4::62::4::1424380312 +4::63::3::1424380312 +4::65::1::1424380312 +4::67::2::1424380312 +4::69::1::1424380312 +4::70::4::1424380312 +4::71::1::1424380312 +4::73::1::1424380312 +4::78::1::1424380312 +4::84::1::1424380312 +4::85::1::1424380312 +4::87::3::1424380312 +4::88::3::1424380312 +4::89::2::1424380312 +4::96::1::1424380312 +4::97::1::1424380312 +4::98::1::1424380312 +4::99::1::1424380312 +5::0::1::1424380312 +5::1::1::1424380312 +5::4::1::1424380312 +5::5::1::1424380312 +5::8::1::1424380312 +5::9::3::1424380312 +5::10::2::1424380312 +5::13::3::1424380312 +5::15::1::1424380312 +5::19::1::1424380312 +5::20::3::1424380312 +5::21::2::1424380312 +5::23::3::1424380312 +5::27::1::1424380312 +5::28::1::1424380312 +5::29::1::1424380312 +5::31::1::1424380312 +5::36::3::1424380312 +5::38::2::1424380312 +5::39::1::1424380312 +5::42::1::1424380312 +5::48::3::1424380312 +5::49::4::1424380312 +5::50::3::1424380312 +5::51::1::1424380312 +5::52::1::1424380312 +5::54::1::1424380312 +5::55::5::1424380312 +5::56::3::1424380312 +5::58::1::1424380312 +5::60::1::1424380312 +5::61::1::1424380312 +5::64::3::1424380312 +5::65::2::1424380312 +5::68::4::1424380312 +5::70::1::1424380312 +5::71::1::1424380312 +5::72::1::1424380312 +5::74::1::1424380312 +5::79::1::1424380312 +5::81::2::1424380312 +5::84::1::1424380312 +5::85::1::1424380312 +5::86::1::1424380312 +5::88::1::1424380312 +5::90::4::1424380312 +5::91::2::1424380312 +5::95::2::1424380312 +5::99::1::1424380312 +6::0::1::1424380312 +6::1::1::1424380312 +6::2::3::1424380312 +6::5::1::1424380312 +6::6::1::1424380312 +6::9::1::1424380312 +6::10::1::1424380312 +6::15::2::1424380312 +6::16::2::1424380312 +6::17::1::1424380312 +6::18::1::1424380312 +6::20::1::1424380312 +6::21::1::1424380312 +6::22::1::1424380312 +6::24::1::1424380312 +6::25::5::1424380312 +6::26::1::1424380312 +6::28::1::1424380312 +6::30::1::1424380312 +6::33::1::1424380312 +6::38::1::1424380312 +6::39::1::1424380312 +6::43::4::1424380312 +6::44::1::1424380312 +6::45::1::1424380312 +6::48::1::1424380312 +6::49::1::1424380312 +6::50::1::1424380312 +6::53::1::1424380312 +6::54::1::1424380312 +6::55::1::1424380312 +6::56::1::1424380312 +6::58::4::1424380312 +6::59::1::1424380312 +6::60::1::1424380312 +6::61::3::1424380312 +6::63::3::1424380312 +6::66::1::1424380312 +6::67::3::1424380312 +6::68::1::1424380312 +6::69::1::1424380312 +6::71::2::1424380312 +6::73::1::1424380312 +6::75::1::1424380312 +6::77::1::1424380312 +6::79::1::1424380312 +6::81::1::1424380312 +6::84::1::1424380312 +6::85::3::1424380312 +6::86::1::1424380312 +6::87::1::1424380312 +6::88::1::1424380312 +6::89::1::1424380312 +6::91::2::1424380312 +6::94::1::1424380312 +6::95::2::1424380312 +6::96::1::1424380312 +7::1::1::1424380312 +7::2::2::1424380312 +7::3::1::1424380312 +7::4::1::1424380312 +7::7::1::1424380312 +7::10::1::1424380312 +7::11::2::1424380312 +7::14::2::1424380312 +7::15::1::1424380312 +7::16::1::1424380312 +7::18::1::1424380312 +7::21::1::1424380312 +7::22::1::1424380312 +7::23::1::1424380312 +7::25::5::1424380312 +7::26::1::1424380312 +7::29::4::1424380312 +7::30::1::1424380312 +7::31::3::1424380312 +7::32::1::1424380312 +7::33::1::1424380312 +7::35::1::1424380312 +7::37::2::1424380312 +7::39::3::1424380312 +7::40::2::1424380312 +7::42::2::1424380312 +7::44::1::1424380312 +7::45::2::1424380312 +7::47::4::1424380312 +7::48::1::1424380312 +7::49::1::1424380312 +7::53::1::1424380312 +7::54::1::1424380312 +7::55::1::1424380312 +7::56::1::1424380312 +7::59::1::1424380312 +7::61::2::1424380312 +7::62::3::1424380312 +7::63::2::1424380312 +7::66::1::1424380312 +7::67::3::1424380312 +7::74::1::1424380312 +7::75::1::1424380312 +7::76::3::1424380312 +7::77::1::1424380312 +7::81::1::1424380312 +7::82::1::1424380312 +7::84::2::1424380312 +7::85::4::1424380312 +7::86::1::1424380312 +7::92::2::1424380312 +7::96::1::1424380312 +7::97::1::1424380312 +7::98::1::1424380312 +8::0::1::1424380312 +8::2::4::1424380312 +8::3::2::1424380312 +8::4::2::1424380312 +8::5::1::1424380312 +8::7::1::1424380312 +8::9::1::1424380312 +8::11::1::1424380312 +8::15::1::1424380312 +8::18::1::1424380312 +8::19::1::1424380312 +8::21::1::1424380312 +8::29::5::1424380312 +8::31::3::1424380312 +8::33::1::1424380312 +8::35::1::1424380312 +8::36::1::1424380312 +8::40::2::1424380312 +8::44::1::1424380312 +8::45::1::1424380312 +8::50::1::1424380312 +8::51::1::1424380312 +8::52::5::1424380312 +8::53::5::1424380312 +8::54::1::1424380312 +8::55::1::1424380312 +8::56::1::1424380312 +8::58::4::1424380312 +8::60::3::1424380312 +8::62::4::1424380312 +8::64::1::1424380312 +8::67::3::1424380312 +8::69::1::1424380312 +8::71::1::1424380312 +8::72::3::1424380312 +8::77::3::1424380312 +8::78::1::1424380312 +8::79::1::1424380312 +8::83::1::1424380312 +8::85::5::1424380312 +8::86::1::1424380312 +8::88::1::1424380312 +8::90::1::1424380312 +8::92::2::1424380312 +8::95::4::1424380312 +8::96::3::1424380312 +8::97::1::1424380312 +8::98::1::1424380312 +8::99::1::1424380312 +9::2::3::1424380312 +9::3::1::1424380312 +9::4::1::1424380312 +9::5::1::1424380312 +9::6::1::1424380312 +9::7::5::1424380312 +9::9::1::1424380312 +9::12::1::1424380312 +9::14::3::1424380312 +9::15::1::1424380312 +9::19::1::1424380312 +9::21::1::1424380312 +9::22::1::1424380312 +9::24::1::1424380312 +9::25::1::1424380312 +9::26::1::1424380312 +9::30::3::1424380312 +9::32::4::1424380312 +9::35::2::1424380312 +9::36::2::1424380312 +9::37::2::1424380312 +9::38::1::1424380312 +9::39::1::1424380312 +9::43::3::1424380312 +9::49::5::1424380312 +9::50::3::1424380312 +9::53::1::1424380312 +9::54::1::1424380312 +9::58::1::1424380312 +9::59::1::1424380312 +9::60::1::1424380312 +9::61::1::1424380312 +9::63::3::1424380312 +9::64::3::1424380312 +9::68::1::1424380312 +9::69::1::1424380312 +9::70::3::1424380312 +9::71::1::1424380312 +9::73::2::1424380312 +9::75::1::1424380312 +9::77::2::1424380312 +9::81::2::1424380312 +9::82::1::1424380312 +9::83::1::1424380312 +9::84::1::1424380312 +9::86::1::1424380312 +9::87::4::1424380312 +9::88::1::1424380312 +9::90::3::1424380312 +9::94::2::1424380312 +9::95::3::1424380312 +9::97::2::1424380312 +9::98::1::1424380312 +10::0::3::1424380312 +10::2::4::1424380312 +10::4::3::1424380312 +10::7::1::1424380312 +10::8::1::1424380312 +10::10::1::1424380312 +10::13::2::1424380312 +10::14::1::1424380312 +10::16::2::1424380312 +10::17::1::1424380312 +10::18::1::1424380312 +10::21::1::1424380312 +10::22::1::1424380312 +10::24::1::1424380312 +10::25::3::1424380312 +10::28::1::1424380312 +10::35::1::1424380312 +10::36::1::1424380312 +10::37::1::1424380312 +10::38::1::1424380312 +10::39::1::1424380312 +10::40::4::1424380312 +10::41::2::1424380312 +10::42::3::1424380312 +10::43::1::1424380312 +10::49::3::1424380312 +10::50::1::1424380312 +10::51::1::1424380312 +10::52::1::1424380312 +10::55::2::1424380312 +10::56::1::1424380312 +10::58::1::1424380312 +10::63::1::1424380312 +10::66::1::1424380312 +10::67::2::1424380312 +10::68::1::1424380312 +10::75::1::1424380312 +10::77::1::1424380312 +10::79::1::1424380312 +10::86::1::1424380312 +10::89::3::1424380312 +10::90::1::1424380312 +10::97::1::1424380312 +10::98::1::1424380312 +11::0::1::1424380312 +11::6::2::1424380312 +11::9::1::1424380312 +11::10::1::1424380312 +11::11::1::1424380312 +11::12::1::1424380312 +11::13::4::1424380312 +11::16::1::1424380312 +11::18::5::1424380312 +11::19::4::1424380312 +11::20::1::1424380312 +11::21::1::1424380312 +11::22::1::1424380312 +11::23::5::1424380312 +11::25::1::1424380312 +11::27::5::1424380312 +11::30::5::1424380312 +11::32::5::1424380312 +11::35::3::1424380312 +11::36::2::1424380312 +11::37::2::1424380312 +11::38::4::1424380312 +11::39::1::1424380312 +11::40::1::1424380312 +11::41::1::1424380312 +11::43::2::1424380312 +11::45::1::1424380312 +11::47::1::1424380312 +11::48::5::1424380312 +11::50::4::1424380312 +11::51::3::1424380312 +11::59::1::1424380312 +11::61::1::1424380312 +11::62::1::1424380312 +11::64::1::1424380312 +11::66::4::1424380312 +11::67::1::1424380312 +11::69::5::1424380312 +11::70::1::1424380312 +11::71::3::1424380312 +11::72::3::1424380312 +11::75::3::1424380312 +11::76::1::1424380312 +11::77::1::1424380312 +11::78::1::1424380312 +11::79::5::1424380312 +11::80::3::1424380312 +11::81::4::1424380312 +11::82::1::1424380312 +11::86::1::1424380312 +11::88::1::1424380312 +11::89::1::1424380312 +11::90::4::1424380312 +11::94::2::1424380312 +11::97::3::1424380312 +11::99::1::1424380312 +12::2::1::1424380312 +12::4::1::1424380312 +12::6::1::1424380312 +12::7::3::1424380312 +12::8::1::1424380312 +12::14::1::1424380312 +12::15::2::1424380312 +12::16::4::1424380312 +12::17::5::1424380312 +12::18::2::1424380312 +12::21::1::1424380312 +12::22::2::1424380312 +12::23::3::1424380312 +12::24::1::1424380312 +12::25::1::1424380312 +12::27::5::1424380312 +12::30::2::1424380312 +12::31::4::1424380312 +12::35::5::1424380312 +12::38::1::1424380312 +12::41::1::1424380312 +12::44::2::1424380312 +12::45::1::1424380312 +12::50::4::1424380312 +12::51::1::1424380312 +12::52::1::1424380312 +12::53::1::1424380312 +12::54::1::1424380312 +12::56::2::1424380312 +12::57::1::1424380312 +12::60::1::1424380312 +12::63::1::1424380312 +12::64::5::1424380312 +12::66::3::1424380312 +12::67::1::1424380312 +12::70::1::1424380312 +12::72::1::1424380312 +12::74::1::1424380312 +12::75::1::1424380312 +12::77::1::1424380312 +12::78::1::1424380312 +12::79::3::1424380312 +12::82::2::1424380312 +12::83::1::1424380312 +12::84::1::1424380312 +12::85::1::1424380312 +12::86::1::1424380312 +12::87::1::1424380312 +12::88::1::1424380312 +12::91::3::1424380312 +12::92::1::1424380312 +12::94::4::1424380312 +12::95::2::1424380312 +12::96::1::1424380312 +12::98::2::1424380312 +13::0::1::1424380312 +13::3::1::1424380312 +13::4::2::1424380312 +13::5::1::1424380312 +13::6::1::1424380312 +13::12::1::1424380312 +13::14::2::1424380312 +13::15::1::1424380312 +13::17::1::1424380312 +13::18::3::1424380312 +13::20::1::1424380312 +13::21::1::1424380312 +13::22::1::1424380312 +13::26::1::1424380312 +13::27::1::1424380312 +13::29::3::1424380312 +13::31::1::1424380312 +13::33::1::1424380312 +13::40::2::1424380312 +13::43::2::1424380312 +13::44::1::1424380312 +13::45::1::1424380312 +13::49::1::1424380312 +13::51::1::1424380312 +13::52::2::1424380312 +13::53::3::1424380312 +13::54::1::1424380312 +13::62::1::1424380312 +13::63::2::1424380312 +13::64::1::1424380312 +13::68::1::1424380312 +13::71::1::1424380312 +13::72::3::1424380312 +13::73::1::1424380312 +13::74::3::1424380312 +13::77::2::1424380312 +13::78::1::1424380312 +13::79::2::1424380312 +13::83::3::1424380312 +13::85::1::1424380312 +13::86::1::1424380312 +13::87::2::1424380312 +13::88::2::1424380312 +13::90::1::1424380312 +13::93::4::1424380312 +13::94::1::1424380312 +13::98::1::1424380312 +13::99::1::1424380312 +14::1::1::1424380312 +14::3::3::1424380312 +14::4::1::1424380312 +14::5::1::1424380312 +14::6::1::1424380312 +14::7::1::1424380312 +14::9::1::1424380312 +14::10::1::1424380312 +14::11::1::1424380312 +14::12::1::1424380312 +14::13::1::1424380312 +14::14::3::1424380312 +14::15::1::1424380312 +14::16::1::1424380312 +14::17::1::1424380312 +14::20::1::1424380312 +14::21::1::1424380312 +14::24::1::1424380312 +14::25::2::1424380312 +14::27::1::1424380312 +14::28::1::1424380312 +14::29::5::1424380312 +14::31::3::1424380312 +14::34::1::1424380312 +14::36::1::1424380312 +14::37::2::1424380312 +14::39::2::1424380312 +14::40::1::1424380312 +14::44::1::1424380312 +14::45::1::1424380312 +14::47::3::1424380312 +14::48::1::1424380312 +14::49::1::1424380312 +14::51::1::1424380312 +14::52::5::1424380312 +14::53::3::1424380312 +14::54::1::1424380312 +14::55::1::1424380312 +14::56::1::1424380312 +14::62::4::1424380312 +14::63::5::1424380312 +14::67::3::1424380312 +14::68::1::1424380312 +14::69::3::1424380312 +14::71::1::1424380312 +14::72::4::1424380312 +14::73::1::1424380312 +14::76::5::1424380312 +14::79::1::1424380312 +14::82::1::1424380312 +14::83::1::1424380312 +14::88::1::1424380312 +14::93::3::1424380312 +14::94::1::1424380312 +14::95::2::1424380312 +14::96::4::1424380312 +14::98::1::1424380312 +15::0::1::1424380312 +15::1::4::1424380312 +15::2::1::1424380312 +15::5::2::1424380312 +15::6::1::1424380312 +15::7::1::1424380312 +15::13::1::1424380312 +15::14::1::1424380312 +15::15::1::1424380312 +15::17::2::1424380312 +15::19::2::1424380312 +15::22::2::1424380312 +15::23::2::1424380312 +15::25::1::1424380312 +15::26::3::1424380312 +15::27::1::1424380312 +15::28::2::1424380312 +15::29::1::1424380312 +15::32::1::1424380312 +15::33::2::1424380312 +15::34::1::1424380312 +15::35::2::1424380312 +15::36::1::1424380312 +15::37::1::1424380312 +15::39::1::1424380312 +15::42::1::1424380312 +15::46::5::1424380312 +15::48::2::1424380312 +15::50::2::1424380312 +15::51::1::1424380312 +15::52::1::1424380312 +15::58::1::1424380312 +15::62::1::1424380312 +15::64::3::1424380312 +15::65::2::1424380312 +15::72::1::1424380312 +15::73::1::1424380312 +15::74::1::1424380312 +15::79::1::1424380312 +15::80::1::1424380312 +15::81::1::1424380312 +15::82::2::1424380312 +15::85::1::1424380312 +15::87::1::1424380312 +15::91::2::1424380312 +15::96::1::1424380312 +15::97::1::1424380312 +15::98::3::1424380312 +16::2::1::1424380312 +16::5::3::1424380312 +16::6::2::1424380312 +16::7::1::1424380312 +16::9::1::1424380312 +16::12::1::1424380312 +16::14::1::1424380312 +16::15::1::1424380312 +16::19::1::1424380312 +16::21::2::1424380312 +16::29::4::1424380312 +16::30::2::1424380312 +16::32::1::1424380312 +16::34::1::1424380312 +16::36::1::1424380312 +16::38::1::1424380312 +16::46::1::1424380312 +16::47::3::1424380312 +16::48::1::1424380312 +16::49::1::1424380312 +16::50::1::1424380312 +16::51::5::1424380312 +16::54::5::1424380312 +16::55::1::1424380312 +16::56::2::1424380312 +16::57::1::1424380312 +16::60::1::1424380312 +16::63::2::1424380312 +16::65::1::1424380312 +16::67::1::1424380312 +16::72::1::1424380312 +16::74::1::1424380312 +16::80::1::1424380312 +16::81::1::1424380312 +16::82::1::1424380312 +16::85::5::1424380312 +16::86::1::1424380312 +16::90::5::1424380312 +16::91::1::1424380312 +16::93::1::1424380312 +16::94::3::1424380312 +16::95::2::1424380312 +16::96::3::1424380312 +16::98::3::1424380312 +16::99::1::1424380312 +17::2::1::1424380312 +17::3::1::1424380312 +17::6::1::1424380312 +17::10::4::1424380312 +17::11::1::1424380312 +17::13::2::1424380312 +17::17::5::1424380312 +17::19::1::1424380312 +17::20::5::1424380312 +17::22::4::1424380312 +17::28::1::1424380312 +17::29::1::1424380312 +17::33::1::1424380312 +17::34::1::1424380312 +17::35::2::1424380312 +17::37::1::1424380312 +17::38::1::1424380312 +17::45::1::1424380312 +17::46::5::1424380312 +17::47::1::1424380312 +17::49::3::1424380312 +17::51::1::1424380312 +17::55::5::1424380312 +17::56::3::1424380312 +17::57::1::1424380312 +17::58::1::1424380312 +17::59::1::1424380312 +17::60::1::1424380312 +17::63::1::1424380312 +17::66::1::1424380312 +17::68::4::1424380312 +17::69::1::1424380312 +17::70::1::1424380312 +17::72::1::1424380312 +17::73::3::1424380312 +17::78::1::1424380312 +17::79::1::1424380312 +17::82::2::1424380312 +17::84::1::1424380312 +17::90::5::1424380312 +17::91::3::1424380312 +17::92::1::1424380312 +17::93::1::1424380312 +17::94::4::1424380312 +17::95::2::1424380312 +17::97::1::1424380312 +18::1::1::1424380312 +18::4::3::1424380312 +18::5::2::1424380312 +18::6::1::1424380312 +18::7::1::1424380312 +18::10::1::1424380312 +18::11::4::1424380312 +18::12::2::1424380312 +18::13::1::1424380312 +18::15::1::1424380312 +18::18::1::1424380312 +18::20::1::1424380312 +18::21::2::1424380312 +18::22::1::1424380312 +18::23::2::1424380312 +18::25::1::1424380312 +18::26::1::1424380312 +18::27::1::1424380312 +18::28::5::1424380312 +18::29::1::1424380312 +18::31::1::1424380312 +18::32::1::1424380312 +18::36::1::1424380312 +18::38::5::1424380312 +18::39::5::1424380312 +18::40::1::1424380312 +18::42::1::1424380312 +18::43::1::1424380312 +18::44::4::1424380312 +18::46::1::1424380312 +18::47::1::1424380312 +18::48::1::1424380312 +18::51::2::1424380312 +18::55::1::1424380312 +18::56::1::1424380312 +18::57::1::1424380312 +18::62::1::1424380312 +18::63::1::1424380312 +18::66::3::1424380312 +18::67::1::1424380312 +18::70::1::1424380312 +18::75::1::1424380312 +18::76::3::1424380312 +18::77::1::1424380312 +18::80::3::1424380312 +18::81::3::1424380312 +18::82::1::1424380312 +18::83::5::1424380312 +18::84::1::1424380312 +18::97::1::1424380312 +18::98::1::1424380312 +18::99::2::1424380312 +19::0::1::1424380312 +19::1::1::1424380312 +19::2::1::1424380312 +19::4::1::1424380312 +19::6::2::1424380312 +19::11::1::1424380312 +19::12::1::1424380312 +19::14::1::1424380312 +19::23::1::1424380312 +19::26::1::1424380312 +19::31::1::1424380312 +19::32::4::1424380312 +19::33::1::1424380312 +19::34::1::1424380312 +19::37::1::1424380312 +19::38::1::1424380312 +19::41::1::1424380312 +19::43::1::1424380312 +19::45::1::1424380312 +19::48::1::1424380312 +19::49::1::1424380312 +19::50::2::1424380312 +19::53::2::1424380312 +19::54::3::1424380312 +19::55::1::1424380312 +19::56::2::1424380312 +19::58::1::1424380312 +19::61::1::1424380312 +19::62::1::1424380312 +19::63::1::1424380312 +19::64::1::1424380312 +19::65::1::1424380312 +19::69::2::1424380312 +19::72::1::1424380312 +19::74::3::1424380312 +19::76::1::1424380312 +19::78::1::1424380312 +19::79::1::1424380312 +19::81::1::1424380312 +19::82::1::1424380312 +19::84::1::1424380312 +19::86::1::1424380312 +19::87::2::1424380312 +19::90::4::1424380312 +19::93::1::1424380312 +19::94::4::1424380312 +19::95::2::1424380312 +19::96::1::1424380312 +19::98::4::1424380312 +20::0::1::1424380312 +20::1::1::1424380312 +20::2::2::1424380312 +20::4::2::1424380312 +20::6::1::1424380312 +20::8::1::1424380312 +20::12::1::1424380312 +20::21::2::1424380312 +20::22::5::1424380312 +20::24::2::1424380312 +20::25::1::1424380312 +20::26::1::1424380312 +20::29::2::1424380312 +20::30::2::1424380312 +20::32::2::1424380312 +20::39::1::1424380312 +20::40::1::1424380312 +20::41::2::1424380312 +20::45::2::1424380312 +20::48::1::1424380312 +20::50::1::1424380312 +20::51::3::1424380312 +20::53::3::1424380312 +20::55::1::1424380312 +20::57::2::1424380312 +20::60::1::1424380312 +20::61::1::1424380312 +20::64::1::1424380312 +20::66::1::1424380312 +20::70::2::1424380312 +20::72::1::1424380312 +20::73::2::1424380312 +20::75::4::1424380312 +20::76::1::1424380312 +20::77::4::1424380312 +20::78::1::1424380312 +20::79::1::1424380312 +20::84::2::1424380312 +20::85::2::1424380312 +20::88::3::1424380312 +20::89::1::1424380312 +20::90::3::1424380312 +20::91::1::1424380312 +20::92::2::1424380312 +20::93::1::1424380312 +20::94::4::1424380312 +20::97::1::1424380312 +21::0::1::1424380312 +21::2::4::1424380312 +21::3::1::1424380312 +21::7::2::1424380312 +21::11::1::1424380312 +21::12::1::1424380312 +21::13::1::1424380312 +21::14::3::1424380312 +21::17::1::1424380312 +21::19::1::1424380312 +21::20::1::1424380312 +21::21::1::1424380312 +21::22::1::1424380312 +21::23::1::1424380312 +21::24::1::1424380312 +21::27::1::1424380312 +21::29::5::1424380312 +21::30::2::1424380312 +21::38::1::1424380312 +21::40::2::1424380312 +21::43::3::1424380312 +21::44::1::1424380312 +21::45::1::1424380312 +21::46::1::1424380312 +21::48::1::1424380312 +21::51::1::1424380312 +21::53::5::1424380312 +21::54::1::1424380312 +21::55::1::1424380312 +21::56::1::1424380312 +21::58::3::1424380312 +21::59::3::1424380312 +21::64::1::1424380312 +21::66::1::1424380312 +21::68::1::1424380312 +21::71::1::1424380312 +21::73::1::1424380312 +21::74::4::1424380312 +21::80::1::1424380312 +21::81::1::1424380312 +21::83::1::1424380312 +21::84::1::1424380312 +21::85::3::1424380312 +21::87::4::1424380312 +21::89::2::1424380312 +21::92::2::1424380312 +21::96::3::1424380312 +21::99::1::1424380312 +22::0::1::1424380312 +22::3::2::1424380312 +22::5::2::1424380312 +22::6::2::1424380312 +22::9::1::1424380312 +22::10::1::1424380312 +22::11::1::1424380312 +22::13::1::1424380312 +22::14::1::1424380312 +22::16::1::1424380312 +22::18::3::1424380312 +22::19::1::1424380312 +22::22::5::1424380312 +22::25::1::1424380312 +22::26::1::1424380312 +22::29::3::1424380312 +22::30::5::1424380312 +22::32::4::1424380312 +22::33::1::1424380312 +22::35::1::1424380312 +22::36::3::1424380312 +22::37::1::1424380312 +22::40::1::1424380312 +22::41::3::1424380312 +22::44::1::1424380312 +22::45::2::1424380312 +22::48::1::1424380312 +22::51::5::1424380312 +22::55::1::1424380312 +22::56::2::1424380312 +22::60::3::1424380312 +22::61::1::1424380312 +22::62::4::1424380312 +22::63::1::1424380312 +22::65::1::1424380312 +22::66::1::1424380312 +22::68::4::1424380312 +22::69::4::1424380312 +22::70::3::1424380312 +22::71::1::1424380312 +22::74::5::1424380312 +22::75::5::1424380312 +22::78::1::1424380312 +22::80::3::1424380312 +22::81::1::1424380312 +22::82::1::1424380312 +22::84::1::1424380312 +22::86::1::1424380312 +22::87::3::1424380312 +22::88::5::1424380312 +22::90::2::1424380312 +22::92::3::1424380312 +22::95::2::1424380312 +22::96::2::1424380312 +22::98::4::1424380312 +22::99::1::1424380312 +23::0::1::1424380312 +23::2::1::1424380312 +23::4::1::1424380312 +23::6::2::1424380312 +23::10::4::1424380312 +23::12::1::1424380312 +23::13::4::1424380312 +23::14::1::1424380312 +23::15::1::1424380312 +23::18::4::1424380312 +23::22::2::1424380312 +23::23::4::1424380312 +23::24::1::1424380312 +23::25::1::1424380312 +23::26::1::1424380312 +23::27::5::1424380312 +23::28::1::1424380312 +23::29::1::1424380312 +23::30::4::1424380312 +23::32::5::1424380312 +23::33::2::1424380312 +23::36::3::1424380312 +23::37::1::1424380312 +23::38::1::1424380312 +23::39::1::1424380312 +23::43::1::1424380312 +23::48::5::1424380312 +23::49::5::1424380312 +23::50::4::1424380312 +23::53::1::1424380312 +23::55::5::1424380312 +23::57::1::1424380312 +23::59::1::1424380312 +23::60::1::1424380312 +23::61::1::1424380312 +23::64::4::1424380312 +23::65::5::1424380312 +23::66::2::1424380312 +23::67::1::1424380312 +23::68::3::1424380312 +23::69::1::1424380312 +23::72::1::1424380312 +23::73::3::1424380312 +23::77::1::1424380312 +23::82::2::1424380312 +23::83::1::1424380312 +23::84::1::1424380312 +23::85::1::1424380312 +23::87::3::1424380312 +23::88::1::1424380312 +23::95::2::1424380312 +23::97::1::1424380312 +24::4::1::1424380312 +24::6::3::1424380312 +24::7::1::1424380312 +24::10::2::1424380312 +24::12::1::1424380312 +24::15::1::1424380312 +24::19::1::1424380312 +24::24::1::1424380312 +24::27::3::1424380312 +24::30::5::1424380312 +24::31::1::1424380312 +24::32::3::1424380312 +24::33::1::1424380312 +24::37::1::1424380312 +24::39::1::1424380312 +24::40::1::1424380312 +24::42::1::1424380312 +24::43::3::1424380312 +24::45::2::1424380312 +24::46::1::1424380312 +24::47::1::1424380312 +24::48::1::1424380312 +24::49::1::1424380312 +24::50::1::1424380312 +24::52::5::1424380312 +24::57::1::1424380312 +24::59::4::1424380312 +24::63::4::1424380312 +24::65::1::1424380312 +24::66::1::1424380312 +24::67::1::1424380312 +24::68::3::1424380312 +24::69::5::1424380312 +24::71::1::1424380312 +24::72::4::1424380312 +24::77::4::1424380312 +24::78::1::1424380312 +24::80::1::1424380312 +24::82::1::1424380312 +24::84::1::1424380312 +24::86::1::1424380312 +24::87::1::1424380312 +24::88::2::1424380312 +24::89::1::1424380312 +24::90::5::1424380312 +24::91::1::1424380312 +24::92::1::1424380312 +24::94::2::1424380312 +24::95::1::1424380312 +24::96::5::1424380312 +24::98::1::1424380312 +24::99::1::1424380312 +25::1::3::1424380312 +25::2::1::1424380312 +25::7::1::1424380312 +25::9::1::1424380312 +25::12::3::1424380312 +25::16::3::1424380312 +25::17::1::1424380312 +25::18::1::1424380312 +25::20::1::1424380312 +25::22::1::1424380312 +25::23::1::1424380312 +25::26::2::1424380312 +25::29::1::1424380312 +25::30::1::1424380312 +25::31::2::1424380312 +25::33::4::1424380312 +25::34::3::1424380312 +25::35::2::1424380312 +25::36::1::1424380312 +25::37::1::1424380312 +25::40::1::1424380312 +25::41::1::1424380312 +25::43::1::1424380312 +25::47::4::1424380312 +25::50::1::1424380312 +25::51::1::1424380312 +25::53::1::1424380312 +25::56::1::1424380312 +25::58::2::1424380312 +25::64::2::1424380312 +25::67::2::1424380312 +25::68::1::1424380312 +25::70::1::1424380312 +25::71::4::1424380312 +25::73::1::1424380312 +25::74::1::1424380312 +25::76::1::1424380312 +25::79::1::1424380312 +25::82::1::1424380312 +25::84::2::1424380312 +25::85::1::1424380312 +25::91::3::1424380312 +25::92::1::1424380312 +25::94::1::1424380312 +25::95::1::1424380312 +25::97::2::1424380312 +26::0::1::1424380312 +26::1::1::1424380312 +26::2::1::1424380312 +26::3::1::1424380312 +26::4::4::1424380312 +26::5::2::1424380312 +26::6::3::1424380312 +26::7::5::1424380312 +26::13::3::1424380312 +26::14::1::1424380312 +26::16::1::1424380312 +26::18::3::1424380312 +26::20::1::1424380312 +26::21::3::1424380312 +26::22::5::1424380312 +26::23::5::1424380312 +26::24::5::1424380312 +26::27::1::1424380312 +26::31::1::1424380312 +26::35::1::1424380312 +26::36::4::1424380312 +26::40::1::1424380312 +26::44::1::1424380312 +26::45::2::1424380312 +26::47::1::1424380312 +26::48::1::1424380312 +26::49::3::1424380312 +26::50::2::1424380312 +26::52::1::1424380312 +26::54::4::1424380312 +26::55::1::1424380312 +26::57::3::1424380312 +26::58::1::1424380312 +26::61::1::1424380312 +26::62::2::1424380312 +26::66::1::1424380312 +26::68::4::1424380312 +26::71::1::1424380312 +26::73::4::1424380312 +26::76::1::1424380312 +26::81::3::1424380312 +26::85::1::1424380312 +26::86::3::1424380312 +26::88::5::1424380312 +26::91::1::1424380312 +26::94::5::1424380312 +26::95::1::1424380312 +26::96::1::1424380312 +26::97::1::1424380312 +27::0::1::1424380312 +27::9::1::1424380312 +27::10::1::1424380312 +27::18::4::1424380312 +27::19::3::1424380312 +27::20::1::1424380312 +27::22::2::1424380312 +27::24::2::1424380312 +27::25::1::1424380312 +27::27::3::1424380312 +27::28::1::1424380312 +27::29::1::1424380312 +27::31::1::1424380312 +27::33::3::1424380312 +27::40::1::1424380312 +27::42::1::1424380312 +27::43::1::1424380312 +27::44::3::1424380312 +27::45::1::1424380312 +27::51::3::1424380312 +27::52::1::1424380312 +27::55::3::1424380312 +27::57::1::1424380312 +27::59::1::1424380312 +27::60::1::1424380312 +27::61::1::1424380312 +27::64::1::1424380312 +27::66::3::1424380312 +27::68::1::1424380312 +27::70::1::1424380312 +27::71::2::1424380312 +27::72::1::1424380312 +27::75::3::1424380312 +27::78::1::1424380312 +27::80::3::1424380312 +27::82::1::1424380312 +27::83::3::1424380312 +27::86::1::1424380312 +27::87::2::1424380312 +27::90::1::1424380312 +27::91::1::1424380312 +27::92::1::1424380312 +27::93::1::1424380312 +27::94::2::1424380312 +27::95::1::1424380312 +27::98::1::1424380312 +28::0::3::1424380312 +28::1::1::1424380312 +28::2::4::1424380312 +28::3::1::1424380312 +28::6::1::1424380312 +28::7::1::1424380312 +28::12::5::1424380312 +28::13::2::1424380312 +28::14::1::1424380312 +28::15::1::1424380312 +28::17::1::1424380312 +28::19::3::1424380312 +28::20::1::1424380312 +28::23::3::1424380312 +28::24::3::1424380312 +28::27::1::1424380312 +28::29::1::1424380312 +28::33::1::1424380312 +28::34::1::1424380312 +28::36::1::1424380312 +28::38::2::1424380312 +28::39::2::1424380312 +28::44::1::1424380312 +28::45::1::1424380312 +28::49::4::1424380312 +28::50::1::1424380312 +28::52::1::1424380312 +28::54::1::1424380312 +28::56::1::1424380312 +28::57::3::1424380312 +28::58::1::1424380312 +28::59::1::1424380312 +28::60::1::1424380312 +28::62::3::1424380312 +28::63::1::1424380312 +28::65::1::1424380312 +28::75::1::1424380312 +28::78::1::1424380312 +28::81::5::1424380312 +28::82::4::1424380312 +28::83::1::1424380312 +28::85::1::1424380312 +28::88::2::1424380312 +28::89::4::1424380312 +28::90::1::1424380312 +28::92::5::1424380312 +28::94::1::1424380312 +28::95::2::1424380312 +28::98::1::1424380312 +28::99::1::1424380312 +29::3::1::1424380312 +29::4::1::1424380312 +29::5::1::1424380312 +29::7::2::1424380312 +29::9::1::1424380312 +29::10::3::1424380312 +29::11::1::1424380312 +29::13::3::1424380312 +29::14::1::1424380312 +29::15::1::1424380312 +29::17::3::1424380312 +29::19::3::1424380312 +29::22::3::1424380312 +29::23::4::1424380312 +29::25::1::1424380312 +29::29::1::1424380312 +29::31::1::1424380312 +29::32::4::1424380312 +29::33::2::1424380312 +29::36::2::1424380312 +29::38::3::1424380312 +29::39::1::1424380312 +29::42::1::1424380312 +29::46::5::1424380312 +29::49::3::1424380312 +29::51::2::1424380312 +29::59::1::1424380312 +29::61::1::1424380312 +29::62::1::1424380312 +29::67::1::1424380312 +29::68::3::1424380312 +29::69::1::1424380312 +29::70::1::1424380312 +29::74::1::1424380312 +29::75::1::1424380312 +29::79::2::1424380312 +29::80::1::1424380312 +29::81::2::1424380312 +29::83::1::1424380312 +29::85::1::1424380312 +29::86::1::1424380312 +29::90::4::1424380312 +29::93::1::1424380312 +29::94::4::1424380312 +29::97::1::1424380312 +29::99::1::1424380312 diff --git a/data/mllib/gmm_data.txt b/data/mllib/gmm_data.txt new file mode 100644 index 000000000000..934ee4a83a2d --- /dev/null +++ b/data/mllib/gmm_data.txt @@ -0,0 +1,2000 @@ + 2.59470454e+00 2.12298217e+00 + 1.15807024e+00 -1.46498723e-01 + 2.46206638e+00 6.19556894e-01 + -5.54845070e-01 -7.24700066e-01 + -3.23111426e+00 -1.42579084e+00 + 3.02978115e+00 7.87121753e-01 + 1.97365907e+00 1.15914704e+00 + -6.44852101e+00 -3.18154314e+00 + 1.30963349e+00 1.62866434e-01 + 4.26482541e+00 2.15547996e+00 + 3.79927257e+00 1.50572445e+00 + 4.17452609e-01 -6.74032760e-01 + 4.21117627e-01 4.45590255e-01 + -2.80425571e+00 -7.77150554e-01 + 2.55928797e+00 7.03954218e-01 + 1.32554059e+00 -9.46663152e-01 + -3.39691439e+00 -1.49005743e+00 + -2.26542270e-01 3.60052515e-02 + 1.04994198e+00 5.29825685e-01 + -1.51566882e+00 -1.86264432e-01 + -3.27928172e-01 -7.60859110e-01 + -3.18054866e-01 3.97719805e-01 + 1.65579418e-01 -3.47232033e-01 + 6.47162333e-01 4.96059961e-02 + -2.80776647e-01 4.79418757e-01 + 7.45069752e-01 1.20790281e-01 + 2.13604102e-01 1.59542555e-01 + -3.08860224e+00 -1.43259870e+00 + 8.97066497e-01 1.10206801e+00 + -2.23918874e-01 -1.07267267e+00 + 2.51525708e+00 2.84761973e-01 + 9.98052532e-01 1.08333783e+00 + 1.76705588e+00 8.18866778e-01 + 5.31555163e-02 -1.90111151e-01 + -2.17405059e+00 7.21854582e-02 + -2.13772505e+00 -3.62010387e-01 + 2.95974057e+00 1.31602381e+00 + 2.74053561e+00 1.61781757e+00 + 6.68135448e-01 2.86586009e-01 + 2.82323739e+00 1.74437257e+00 + 8.11540288e-01 5.50744478e-01 + 4.10050897e-01 5.10668402e-03 + 9.58626136e-01 -3.49633680e-01 + 4.66599798e+00 1.49964894e+00 + 4.94507794e-01 2.58928077e-01 + -2.36029742e+00 -1.61042909e+00 + -4.99306804e-01 -8.04984769e-01 + 1.07448510e+00 9.39605828e-01 + -1.80448949e+00 -1.05983264e+00 + -3.22353821e-01 1.73612093e-01 + 1.85418702e+00 1.15640643e+00 + 6.93794163e-01 6.59993560e-01 + 1.99399102e+00 1.44547123e+00 + 3.38866124e+00 1.23379290e+00 + -4.24067720e+00 -1.22264282e+00 + 6.03230201e-02 2.95232729e-01 + -3.59341813e+00 -7.17453726e-01 + 4.87447372e-01 -2.00733911e-01 + 1.20149195e+00 4.07880197e-01 + -2.13331464e+00 -4.58518077e-01 + -3.84091083e+00 -1.71553950e+00 + -5.37279250e-01 2.64822629e-02 + -2.10155227e+00 -1.32558103e+00 + -1.71318897e+00 -7.12098563e-01 + -1.46280695e+00 -1.84868337e-01 + -3.59785325e+00 -1.54832434e+00 + -5.77528081e-01 -5.78580857e-01 + 3.14734283e-01 5.80184639e-01 + -2.71164714e+00 -1.19379432e+00 + 1.09634489e+00 7.20143887e-01 + -3.05527722e+00 -1.47774064e+00 + 6.71753586e-01 7.61350020e-01 + 3.98294144e+00 1.54166484e+00 + -3.37220384e+00 -2.21332064e+00 + 1.81222914e+00 7.41212752e-01 + 2.71458282e-01 1.36329078e-01 + -3.97815359e-01 1.16766886e-01 + -1.70192814e+00 -9.75851571e-01 + -3.46803804e+00 -1.09965988e+00 + -1.69649627e+00 -5.76045801e-01 + -1.02485636e-01 -8.81841246e-01 + -3.24194667e-02 2.55429276e-01 + -2.75343168e+00 -1.51366320e+00 + -2.78676702e+00 -5.22360489e-01 + 1.70483164e+00 1.19769805e+00 + 4.92022579e-01 3.24944706e-01 + 2.48768464e+00 1.00055363e+00 + 4.48786400e-01 7.63902870e-01 + 2.93862696e+00 1.73809968e+00 + -3.55019305e+00 -1.97875558e+00 + 1.74270784e+00 6.90229224e-01 + 5.13391994e-01 4.58374016e-01 + 1.78379499e+00 9.08026381e-01 + 1.75814147e+00 7.41449784e-01 + -2.30687792e-01 3.91009729e-01 + 3.92271353e+00 1.44006290e+00 + 2.93361679e-01 -4.99886375e-03 + 2.47902690e-01 -7.49542503e-01 + -3.97675355e-01 1.36824887e-01 + 3.56535953e+00 1.15181329e+00 + 3.22425301e+00 1.28702383e+00 + -2.94192478e-01 -2.42382557e-01 + 8.02068864e-01 -1.51671475e-01 + 8.54133530e-01 -4.89514885e-02 + -1.64316316e-01 -5.34642346e-01 + -6.08485405e-01 -2.10332352e-01 + -2.18940059e+00 -1.07024952e+00 + -1.71586960e+00 -2.83333492e-02 + 1.70200448e-01 -3.28031178e-01 + -1.97210346e+00 -5.39948532e-01 + 2.19500160e+00 1.05697170e+00 + -1.76239935e+00 -1.09377438e+00 + 1.68314744e+00 6.86491164e-01 + -2.99852288e+00 -1.46619067e+00 + -2.23769560e+00 -9.15008355e-01 + 9.46887516e-01 5.58410503e-01 + 5.02153123e-01 1.63851235e-01 + -9.70297062e-01 3.14625374e-01 + -1.29405593e+00 -8.20994131e-01 + 2.72516079e+00 7.85839947e-01 + 1.45788024e+00 3.37487353e-01 + -4.36292749e-01 -5.42150480e-01 + 2.21304711e+00 1.25254042e+00 + -1.20810271e-01 4.79632898e-01 + -3.30884511e+00 -1.50607586e+00 + -6.55882455e+00 -1.94231256e+00 + -3.17033630e+00 -9.94678930e-01 + 1.42043617e+00 7.28808957e-01 + -1.57546099e+00 -1.10320497e+00 + -3.22748754e+00 -1.64174579e+00 + 2.96776017e-03 -3.16191512e-02 + -2.25986054e+00 -6.13123197e-01 + 2.49434243e+00 7.73069183e-01 + 9.08494049e-01 -1.53926853e-01 + -2.80559090e+00 -1.37474221e+00 + 4.75224286e-01 2.53153674e-01 + 4.37644006e+00 8.49116998e-01 + 2.27282959e+00 6.16568202e-01 + 1.16006880e+00 1.65832798e-01 + -1.67163193e+00 -1.22555386e+00 + -1.38231118e+00 -7.29575504e-01 + -3.49922750e+00 -2.26446675e+00 + -3.73780110e-01 -1.90657869e-01 + 1.68627679e+00 1.05662987e+00 + -3.28891792e+00 -1.11080334e+00 + -2.59815798e+00 -1.51410198e+00 + -2.61203309e+00 -6.00143552e-01 + 6.58964943e-01 4.47216094e-01 + -2.26711381e+00 -7.26512923e-01 + -5.31429009e-02 -1.97925341e-02 + 3.19749807e+00 9.20425476e-01 + -1.37595787e+00 -6.58062732e-01 + 8.09900278e-01 -3.84286160e-01 + -5.07741280e+00 -1.97683808e+00 + -2.99764250e+00 -1.50753777e+00 + -9.87671815e-01 -4.63255889e-01 + 1.65390765e+00 6.73806615e-02 + 5.51252659e+00 2.69842267e+00 + -2.23724309e+00 -4.77624004e-01 + 4.99726228e+00 1.74690949e+00 + 1.75859162e-01 -1.49350995e-01 + 4.13382789e+00 1.31735161e+00 + 2.69058117e+00 4.87656923e-01 + 1.07180318e+00 1.01426954e+00 + 3.37216869e+00 1.05955377e+00 + -2.95006781e+00 -1.57048303e+00 + -2.46401648e+00 -8.37056374e-01 + 1.19012962e-01 7.54702770e-01 + 3.34142539e+00 4.81938295e-01 + 2.92643913e+00 1.04301050e+00 + 2.89697751e+00 1.37551442e+00 + -1.03094242e+00 2.20903962e-01 + -5.13914589e+00 -2.23355387e+00 + -8.81680780e-01 1.83590000e-01 + 2.82334775e+00 1.26650464e+00 + -2.81042540e-01 -3.26370240e-01 + 2.97995487e+00 8.34569452e-01 + -1.39857135e+00 -1.15798385e+00 + 4.27186506e+00 9.04253702e-01 + 6.98684517e-01 7.91167305e-01 + 3.52233095e+00 1.29976473e+00 + 2.21448029e+00 2.73213379e-01 + -3.13505683e-01 -1.20593774e-01 + 3.70571571e+00 1.06220876e+00 + 9.83881041e-01 5.67713803e-01 + -2.17897705e+00 2.52925205e-01 + 1.38734039e+00 4.61287066e-01 + -1.41181602e+00 -1.67248955e-02 + -1.69974639e+00 -7.17812071e-01 + -2.01005793e-01 -7.49662056e-01 + 1.69016336e+00 3.24687979e-01 + -2.03250179e+00 -2.76108460e-01 + 3.68776848e-01 4.12536941e-01 + 7.66238259e-01 -1.84750637e-01 + -2.73989147e-01 -1.72817250e-01 + -2.18623745e+00 -2.10906798e-01 + -1.39795625e-01 3.26066094e-02 + -2.73826912e-01 -6.67586097e-02 + -1.57880654e+00 -4.99395900e-01 + 4.55950908e+00 2.29410489e+00 + -7.36479631e-01 -1.57861857e-01 + 1.92082888e+00 1.05843391e+00 + 4.29192810e+00 1.38127810e+00 + 1.61852879e+00 1.95871986e-01 + -1.95027403e+00 -5.22448168e-01 + -1.67446281e+00 -9.41497162e-01 + 6.07097859e-01 3.44178029e-01 + -3.44004683e+00 -1.49258461e+00 + 2.72114752e+00 6.00728991e-01 + 8.80685522e-01 -2.53243336e-01 + 1.39254928e+00 3.42988512e-01 + 1.14194836e-01 -8.57945694e-02 + -1.49387332e+00 -7.60860481e-01 + -1.98053285e+00 -4.86039865e-01 + 3.56008568e+00 1.08438692e+00 + 2.27833961e-01 1.09441881e+00 + -1.16716710e+00 -6.54778242e-01 + 2.02156613e+00 5.42075758e-01 + 1.08429178e+00 -7.67420693e-01 + 6.63058455e-01 4.61680991e-01 + -1.06201537e+00 1.38862846e-01 + 3.08701875e+00 8.32580273e-01 + -4.96558108e-01 -2.47031257e-01 + 7.95109987e-01 7.59314147e-02 + -3.39903524e-01 8.71565566e-03 + 8.68351357e-01 4.78358641e-01 + 1.48750819e+00 7.63257420e-01 + -4.51224101e-01 -4.44056898e-01 + -3.02734750e-01 -2.98487961e-01 + 5.46846609e-01 7.02377629e-01 + 1.65129778e+00 3.74008231e-01 + -7.43336512e-01 3.95723531e-01 + -5.88446605e-01 -6.47520211e-01 + 3.58613167e+00 1.95024937e+00 + 3.11718883e+00 8.37984715e-01 + 1.80919244e+00 9.62644986e-01 + 5.43856371e-02 -5.86297543e-01 + -1.95186766e+00 -1.02624212e-01 + 8.95628057e-01 5.91812281e-01 + 4.97691627e-02 5.31137156e-01 + -1.07633113e+00 -2.47392788e-01 + -1.17257986e+00 -8.68528265e-01 + -8.19227665e-02 5.80579434e-03 + -2.86409787e-01 1.95812924e-01 + 1.10582671e+00 7.42853240e-01 + 4.06429774e+00 1.06557476e+00 + -3.42521792e+00 -7.74327139e-01 + 1.28468671e+00 6.20431661e-01 + 6.01201008e-01 -1.16799728e-01 + -1.85058727e-01 -3.76235293e-01 + 5.44083324e+00 2.98490868e+00 + 2.69273070e+00 7.83901153e-01 + 1.88938036e-01 -4.83222152e-01 + 1.05667256e+00 -2.57003165e-01 + 2.99711662e-01 -4.33131912e-01 + 7.73689216e-02 -1.78738364e-01 + 9.58326279e-01 6.38325706e-01 + -3.97727049e-01 2.27314759e-01 + 3.36098175e+00 1.12165237e+00 + 1.77804871e+00 6.46961933e-01 + -2.86945546e+00 -1.00395518e+00 + 3.03494815e+00 7.51814612e-01 + -1.43658194e+00 -3.55432244e-01 + -3.08455105e+00 -1.51535106e+00 + -1.55841975e+00 3.93454820e-02 + 7.96073412e-01 -3.11036969e-01 + -9.84125401e-01 -1.02064649e+00 + -7.75688143e+00 -3.65219926e+00 + 1.53816429e+00 7.65926670e-01 + -4.92712738e-01 2.32244240e-02 + -1.93166919e+00 -1.07701304e+00 + 2.03029875e-02 -7.54055699e-01 + 2.52177489e+00 1.01544979e+00 + 3.65109048e-01 -9.48328494e-01 + -1.28849143e-01 2.51947174e-01 + -1.02428075e+00 -9.37767116e-01 + -3.04179748e+00 -9.97926994e-01 + -2.51986980e+00 -1.69117413e+00 + -1.24900838e+00 -4.16179917e-01 + 2.77943992e+00 1.22842327e+00 + -4.37434557e+00 -1.70182693e+00 + -1.60019319e+00 -4.18345639e-01 + -1.67613646e+00 -9.44087262e-01 + -9.00843245e-01 8.26378089e-02 + 3.29770621e-01 -9.07870444e-01 + -2.84650535e+00 -9.00155396e-01 + 1.57111705e+00 7.07432268e-01 + 1.24948552e+00 1.04812849e-01 + 1.81440558e+00 9.53545082e-01 + -1.74915794e+00 -1.04606288e+00 + 1.20593269e+00 -1.12607147e-02 + 1.36004919e-01 -1.09828044e+00 + 2.57480693e-01 3.34941541e-01 + 7.78775385e-01 -5.32494732e-01 + -1.79155126e+00 -6.29994129e-01 + -1.75706839e+00 -8.35100126e-01 + 4.29512012e-01 7.81426910e-02 + 3.08349370e-01 -1.27359861e-01 + 1.05560329e+00 4.55150640e-01 + 1.95662574e+00 1.17593217e+00 + 8.77376632e-01 6.57866662e-01 + 7.71311255e-01 9.15134334e-02 + -6.36978275e+00 -2.55874241e+00 + -2.98335339e+00 -1.59567024e+00 + -3.67104587e-01 1.85315291e-01 + 1.95347407e+00 -7.15503113e-02 + 8.45556363e-01 6.51256415e-02 + 9.42868521e-01 3.56647624e-01 + 2.99321875e+00 1.07505254e+00 + -2.91030538e-01 -3.77637183e-01 + 1.62870918e+00 3.37563671e-01 + 2.05773173e-01 3.43337416e-01 + -8.40879199e-01 -1.35600767e-01 + 1.38101624e+00 5.99253495e-01 + -6.93715607e+00 -2.63580662e+00 + -1.04423404e+00 -8.32865050e-01 + 1.33448476e+00 1.04863475e+00 + 6.01675207e-01 1.98585194e-01 + 2.31233993e+00 7.98628331e-01 + 1.85201313e-01 -1.76070247e+00 + 1.92006354e+00 8.45737582e-01 + 1.06320415e+00 2.93426068e-01 + -1.20360141e+00 -1.00301288e+00 + 1.95926629e+00 6.26643532e-01 + 6.04483978e-02 5.72643059e-01 + -1.04568563e+00 -5.91021496e-01 + 2.62300678e+00 9.50997831e-01 + -4.04610275e-01 3.73150879e-01 + 2.26371902e+00 8.73627529e-01 + 2.12545313e+00 7.90640352e-01 + 7.72181917e-03 1.65718952e-02 + 1.00422340e-01 -2.05562936e-01 + -1.22989802e+00 -1.01841681e-01 + 3.09064082e+00 1.04288010e+00 + 5.18274167e+00 1.34749259e+00 + -8.32075153e-01 -1.97592029e-01 + 3.84126764e-02 5.58171345e-01 + 4.99560727e-01 -4.26154438e-02 + 4.79071151e+00 2.19728942e+00 + -2.78437968e+00 -1.17812590e+00 + -2.22804226e+00 -4.31174255e-01 + 8.50762292e-01 -1.06445261e-01 + 1.10812830e+00 -2.59118812e-01 + -2.91450155e-01 6.42802679e-01 + -1.38631532e-01 -5.88585623e-01 + -5.04120983e-01 -2.17094915e-01 + 3.41410820e+00 1.67897767e+00 + -2.23697326e+00 -6.62735244e-01 + -3.55961064e-01 -1.27647226e-01 + -3.55568274e+00 -2.49011369e+00 + -8.77586408e-01 -9.38268065e-03 + 1.52382384e-01 -5.62155760e-01 + 1.55885574e-01 1.07617069e-01 + -8.37129973e-01 -5.22259081e-01 + -2.92741750e+00 -1.35049428e+00 + -3.54670781e-01 5.69205952e-02 + 2.21030255e+00 1.34689986e+00 + 1.60787722e+00 5.75984706e-01 + 1.32294221e+00 5.31577509e-01 + 7.05672928e-01 3.34241244e-01 + 1.41406179e+00 1.15783408e+00 + -6.92172228e-01 -2.84817896e-01 + 3.28358655e-01 -2.66910083e-01 + 1.68013644e-01 -4.28016549e-02 + 2.07365974e+00 7.76496211e-01 + -3.92974907e-01 2.46796730e-01 + -5.76078636e-01 3.25676963e-01 + -1.82547204e-01 -5.06410543e-01 + 3.04754906e+00 1.16174496e+00 + -3.01090632e+00 -1.09195183e+00 + -1.44659696e+00 -6.87838682e-01 + 2.11395861e+00 9.10495785e-01 + 1.40962871e+00 1.13568678e+00 + -1.66653234e-01 -2.10012503e-01 + 3.17456029e+00 9.74502922e-01 + 2.15944820e+00 8.62807189e-01 + -3.45418719e+00 -1.33647548e+00 + -3.41357732e+00 -8.47048920e-01 + -3.06702448e-01 -6.64280634e-01 + -2.86930714e-01 -1.35268264e-01 + -3.15835557e+00 -5.43439253e-01 + 2.49541440e-01 -4.71733570e-01 + 2.71933912e+00 4.13308399e-01 + -2.43787038e+00 -1.08050547e+00 + -4.90234490e-01 -6.64069865e-01 + 8.99524451e-02 5.76180541e-01 + 5.00500404e+00 2.12125521e+00 + -1.73107940e-01 -2.28506575e-02 + 5.44938858e-01 -1.29523352e-01 + 5.13526842e+00 1.68785993e+00 + 1.70228304e+00 1.02601138e+00 + 3.58957507e+00 1.54396196e+00 + 1.85615738e+00 4.92916197e-01 + 2.55772147e+00 7.88438908e-01 + -1.57008279e+00 -4.17377300e-01 + -1.42548604e+00 -3.63684860e-01 + -8.52026118e-01 2.72052686e-01 + -5.10563077e+00 -2.35665994e+00 + -2.95517031e+00 -1.84945297e+00 + -2.91947959e+00 -1.66016784e+00 + -4.21462387e+00 -1.41131535e+00 + 6.59901121e-01 4.87156314e-01 + -9.75352532e-01 -4.50231285e-01 + -5.94084444e-01 -1.16922670e+00 + 7.50554615e-01 -9.83692552e-01 + 1.07054926e+00 2.77143030e-01 + -3.88079578e-01 -4.17737309e-02 + -9.59373733e-01 -8.85454886e-01 + -7.53560665e-02 -5.16223870e-02 + 9.84108158e-01 -5.89290700e-02 + 1.87272961e-01 -4.34238391e-01 + 6.86509981e-01 -3.15116460e-01 + -1.07762538e+00 6.58984161e-02 + 6.09266592e-01 6.91808473e-02 + -8.30529954e-01 -7.00454791e-01 + -9.13179464e-01 -6.31712891e-01 + 7.68744851e-01 1.09840676e+00 + -1.07606690e+00 -8.78390282e-01 + -1.71038184e+00 -5.73606033e-01 + 8.75982765e-01 3.66343143e-01 + -7.04919009e-01 -8.49182590e-01 + -1.00274668e+00 -7.99573611e-01 + -1.05562848e+00 -5.84060076e-01 + 4.03490015e+00 1.28679206e+00 + -3.53484804e+00 -1.71381255e+00 + 2.31527363e-01 1.04179397e-01 + -3.58592392e-02 3.74895739e-01 + 3.92253428e+00 1.81852726e+00 + -7.27384249e-01 -6.45605128e-01 + 4.65678097e+00 2.41379899e+00 + 1.16750534e+00 7.60718205e-01 + 1.15677059e+00 7.96225550e-01 + -1.42920261e+00 -4.66946295e-01 + 3.71148192e+00 1.88060191e+00 + 2.44052407e+00 3.84472199e-01 + -1.64535035e+00 -8.94530036e-01 + -3.69608753e+00 -1.36402754e+00 + 2.24419208e+00 9.69744889e-01 + 2.54822427e+00 1.22613039e+00 + 3.77484909e-01 -5.98521878e-01 + -3.61521175e+00 -1.11123912e+00 + 3.28113127e+00 1.52551775e+00 + -3.51030902e+00 -1.53913980e+00 + -2.44874505e+00 -6.30246005e-01 + -3.42516153e-01 -5.07352665e-01 + 1.09110502e+00 6.36821628e-01 + -2.49434967e+00 -8.02827146e-01 + 1.41763139e+00 -3.46591820e-01 + 1.61108619e+00 5.93871102e-01 + 3.97371717e+00 1.35552499e+00 + -1.33437177e+00 -2.83908670e-01 + -1.41606483e+00 -1.76402601e-01 + 2.23945322e-01 -1.77157065e-01 + 2.60271569e+00 2.40778251e-01 + -2.82213895e-02 1.98255474e-01 + 4.20727940e+00 1.31490863e+00 + 3.36944889e+00 1.57566635e+00 + 3.53049396e+00 1.73579350e+00 + -1.29170202e+00 -1.64196290e+00 + 9.27295604e-01 9.98808036e-01 + 1.75321843e-01 -2.83267817e-01 + -2.19069578e+00 -1.12814358e+00 + 1.66606031e+00 7.68006933e-01 + -7.13826035e-01 5.20881684e-02 + -3.43821888e+00 -2.36137021e+00 + -5.93210310e-01 1.21843813e-01 + -4.09800822e+00 -1.39893953e+00 + 2.74110954e+00 1.52728606e+00 + 1.72652512e+00 -1.25435113e-01 + 1.97722357e+00 6.40667481e-01 + 4.18635780e-01 3.57018509e-01 + -1.78303569e+00 -2.11864764e-01 + -3.52809366e+00 -2.58794450e-01 + -4.72407090e+00 -1.63870734e+00 + 1.73917807e+00 8.73251829e-01 + 4.37979356e-01 8.49210569e-01 + 3.93791881e+00 1.76269490e+00 + 2.79065411e+00 1.04019042e+00 + -8.47426142e-01 -3.40136892e-01 + -4.24389181e+00 -1.80253120e+00 + -1.86675870e+00 -7.64558265e-01 + 9.46212675e-01 -7.77681445e-02 + -2.82448462e+00 -1.33592449e+00 + -2.57938567e+00 -1.56554690e+00 + -2.71615767e+00 -6.27667233e-01 + -1.55999166e+00 -5.81013466e-01 + -4.24696864e-01 -7.44673250e-01 + 1.67592970e+00 7.68164292e-01 + 8.48455216e-01 -6.05681126e-01 + 6.12575454e+00 1.65607584e+00 + 1.38207327e+00 2.39261863e-01 + 3.13364450e+00 1.17154698e+00 + 1.71694858e+00 1.26744905e+00 + -1.61746367e+00 -8.80098073e-01 + -8.52196756e-01 -9.27299728e-01 + -1.51562462e-01 -8.36552490e-02 + -7.04792753e-01 -1.24726713e-02 + -3.35265757e+00 -1.82176312e+00 + 3.32173170e-01 -1.33405580e-01 + 4.95841013e-01 4.58292712e-01 + 1.57713955e+00 7.79272991e-01 + 2.09743109e+00 9.23542557e-01 + 3.90450311e-03 -8.42873164e-01 + 2.59519038e+00 7.56479591e-01 + -5.77643976e-01 -2.36401904e-01 + -5.22310654e-01 1.34187830e-01 + -2.22096086e+00 -7.75507719e-01 + 1.35907831e+00 7.80197510e-01 + 3.80355868e+00 1.16983476e+00 + 3.82746596e+00 1.31417718e+00 + 3.30451183e+00 1.55398159e+00 + -3.42917814e-01 -8.62281222e-02 + -2.59093020e+00 -9.29883526e-01 + 1.40928562e+00 1.08398346e+00 + 1.54400137e-01 3.35881092e-01 + 1.59171586e+00 1.18855802e+00 + -5.25164002e-01 -1.03104220e-01 + 2.20067959e+00 1.37074713e+00 + 6.97860830e-01 6.27718548e-01 + -4.59743507e-01 1.36061163e-01 + -1.04691963e-01 -2.16271727e-01 + -1.08905573e+00 -5.95510769e-01 + -1.00826983e+00 -5.38509162e-02 + -3.16402719e+00 -1.33414216e+00 + 1.47870874e-01 1.75234619e-01 + -2.57078234e-01 7.03316889e-02 + 1.81073945e+00 4.26901462e-01 + 2.65476530e+00 6.74217273e-01 + 1.27539811e+00 6.22914081e-01 + -3.76750499e-01 -1.20629449e+00 + 1.00177595e+00 -1.40660091e-01 + -2.98919265e+00 -1.65145013e+00 + -2.21557682e+00 -8.11123452e-01 + -3.22635378e+00 -1.65639056e+00 + -2.72868553e+00 -1.02812087e+00 + 1.26042797e+00 8.49005248e-01 + -9.38318534e-01 -9.87588651e-01 + 3.38013194e-01 -1.00237461e-01 + 1.91175691e+00 8.48716369e-01 + 4.30244344e-01 6.05539915e-02 + 2.21783435e+00 3.03268204e-01 + 1.78019576e+00 1.27377108e+00 + 1.59733274e+00 4.40674687e-02 + 3.97428484e+00 2.20881566e+00 + -2.41108677e+00 -6.01410418e-01 + -2.50796499e+00 -5.71169866e-01 + -3.71957427e+00 -1.38195726e+00 + -1.57992670e+00 1.32068593e-01 + -1.35278851e+00 -6.39349270e-01 + 1.23075932e+00 2.40445409e-01 + 1.35606530e+00 4.33180078e-01 + 9.60968518e-02 2.26734255e-01 + 6.22975063e-01 5.03431915e-02 + -1.47624851e+00 -3.60568238e-01 + -2.49337808e+00 -1.15083052e+00 + 2.15717792e+00 1.03071559e+00 + -3.07814376e-02 1.38700314e-02 + 4.52049499e-02 -4.86409775e-01 + 2.58231061e+00 1.14327809e-01 + 1.10999138e+00 -5.18568405e-01 + -2.19426443e-01 -5.37505538e-01 + -4.44740298e-01 6.78099955e-01 + 4.03379080e+00 1.49825720e+00 + -5.13182408e-01 -4.90201950e-01 + -6.90139716e-01 1.63875126e-01 + -8.17281461e-01 2.32155064e-01 + -2.92357619e-01 -8.02573544e-01 + -1.80769841e+00 -7.58907326e-01 + 2.16981590e+00 1.06728873e+00 + 1.98995203e-01 -6.84176682e-02 + -2.39546753e+00 -2.92873789e-01 + -4.24251021e+00 -1.46255564e+00 + -5.01411291e-01 -5.95712813e-03 + 2.68085809e+00 1.42883780e+00 + -4.13289873e+00 -1.62729388e+00 + 1.87957843e+00 3.63341638e-01 + -1.15270744e+00 -3.03563774e-01 + -4.43994248e+00 -2.97323905e+00 + -7.17067733e-01 -7.08349542e-01 + -3.28870393e+00 -1.19263863e+00 + -7.55325944e-01 -5.12703329e-01 + -2.07291938e+00 -2.65025085e-01 + -7.50073814e-01 -1.70771041e-01 + -8.77381404e-01 -5.47417325e-01 + -5.33725862e-01 5.15837119e-01 + 8.45056431e-01 2.82125560e-01 + -1.59598637e+00 -1.38743235e+00 + 1.41362902e+00 1.06407789e+00 + 1.02584504e+00 -3.68219466e-01 + -1.04644488e+00 -1.48769392e-01 + 2.66990191e+00 8.57633492e-01 + -1.84251857e+00 -9.82430175e-01 + 9.71404204e-01 -2.81934209e-01 + -2.50177989e+00 -9.21260335e-01 + -1.31060074e+00 -5.84488113e-01 + -2.12129400e-01 -3.06244708e-02 + -5.28933882e+00 -2.50663129e+00 + 1.90220541e+00 1.08662918e+00 + -3.99366086e-02 -6.87178973e-01 + -4.93417342e-01 4.37354182e-01 + 2.13494486e+00 1.37679569e+00 + 2.18396765e+00 5.81023868e-01 + -3.07866587e+00 -1.45384974e+00 + 6.10894119e-01 -4.17050124e-01 + -1.88766952e+00 -8.86160058e-01 + 3.34527253e+00 1.78571260e+00 + 6.87769059e-01 -5.01157336e-01 + 2.60470837e+00 1.45853560e+00 + -6.49315691e-01 -9.16112805e-01 + -1.29817687e+00 -2.15924339e-01 + -1.20100409e-03 -4.03137422e-01 + -1.36471594e+00 -6.93266356e-01 + 1.38682062e+00 7.15131598e-01 + 2.47830103e+00 1.24862305e+00 + -2.78288147e+00 -1.03329235e+00 + -7.33443403e-01 -6.11041652e-01 + -4.12745671e-01 -5.96133390e-02 + -2.58632336e+00 -4.51557058e-01 + -1.16570367e+00 -1.27065510e+00 + 2.76187104e+00 2.21895451e-01 + -3.80443767e+00 -1.66319902e+00 + 9.84658633e-01 6.81475569e-01 + 9.33814584e-01 -4.89335563e-02 + -4.63427997e-01 1.72989539e-01 + 1.82401546e+00 3.60164021e-01 + -5.36521077e-01 -8.08691351e-01 + -1.37367030e+00 -1.02126160e+00 + -3.70310682e+00 -1.19840844e+00 + -1.51894242e+00 -3.89510223e-01 + -3.67347940e-01 -3.25540516e-02 + -1.00988595e+00 1.82802194e-01 + 2.01622795e+00 7.86367901e-01 + 1.02440231e+00 8.79780360e-01 + -3.05971480e+00 -8.40901527e-01 + 2.73909457e+00 1.20558628e+00 + 2.39559056e+00 1.10786694e+00 + 1.65471544e+00 7.33824651e-01 + 2.18546787e+00 6.41168955e-01 + 1.47152266e+00 3.91839132e-01 + 1.45811155e+00 5.21820495e-01 + -4.27531469e-02 -3.52343068e-03 + -9.54948010e-01 -1.52313876e-01 + 7.57151215e-01 -5.68728854e-03 + -8.46205751e-01 -7.54580229e-01 + 4.14493548e+00 1.45532780e+00 + 4.58688968e-01 -4.54012803e-02 + -1.49295381e+00 -4.57471758e-01 + 1.80020351e+00 8.13724973e-01 + -5.82727738e+00 -2.18269581e+00 + -2.09017809e+00 -1.18305177e+00 + -2.31628303e+00 -7.21600235e-01 + -8.09679091e-01 -1.49101752e-01 + 8.88005605e-01 8.57940857e-01 + -1.44148219e+00 -3.10926299e-01 + 3.68828186e-01 -3.08848059e-01 + -6.63267389e-01 -8.58950139e-02 + -1.14702569e+00 -6.32147854e-01 + -1.51741715e+00 -8.53330564e-01 + -1.33903718e+00 -1.45875547e-01 + 4.12485387e+00 1.85620435e+00 + -2.42353639e+00 -2.92669850e-01 + 1.88708583e+00 9.35984730e-01 + 2.15585179e+00 6.30469051e-01 + -1.13627973e-01 -1.62554045e-01 + 2.04540494e+00 1.36599834e+00 + 2.81591381e+00 1.60897941e+00 + 3.02736260e-02 3.83255815e-03 + 7.97634013e-02 -2.82035099e-01 + -3.24607473e-01 -5.30065956e-01 + -3.91862894e+00 -1.94083334e+00 + 1.56360901e+00 7.93882743e-01 + -1.03905772e+00 6.25590229e-01 + 2.54746492e+00 1.64233560e+00 + -4.80774423e-01 -8.92298032e-02 + 9.06979990e-02 1.05020427e+00 + -2.47521290e+00 -1.78275982e-01 + -3.91871729e-01 3.80285423e-01 + 1.00658382e+00 4.58947483e-01 + 4.68102941e-01 1.02992741e+00 + 4.44242568e-01 2.89870239e-01 + 3.29684452e+00 1.44677474e+00 + -2.24983007e+00 -9.65574499e-01 + -3.54453926e-01 -3.99020325e-01 + -3.87429665e+00 -1.90079739e+00 + 2.02656674e+00 1.12444894e+00 + 3.77011621e+00 1.43200852e+00 + 1.61259275e+00 4.65417399e-01 + 2.28725434e+00 6.79181395e-01 + 2.75421009e+00 2.27327345e+00 + -2.40894409e+00 -1.03926359e+00 + 1.52996651e-01 -2.73373046e-02 + -2.63218977e+00 -7.22802821e-01 + 2.77688169e+00 1.15310186e+00 + 1.18832341e+00 4.73457165e-01 + -2.35536326e+00 -1.08034554e+00 + -5.84221627e-01 1.03505984e-02 + 2.96730300e+00 1.33478306e+00 + -8.61947692e-01 6.09137051e-02 + 8.22343921e-01 -8.14155286e-02 + 1.75809015e+00 1.07921470e+00 + 1.19501279e+00 1.05309972e+00 + -1.75901792e+00 9.75320161e-02 + 1.64398635e+00 9.54384323e-01 + -2.21878052e-01 -3.64847144e-01 + -2.03128968e+00 -8.57866419e-01 + 1.86750633e+00 7.08524487e-01 + 8.03972976e-01 3.47404314e-01 + 3.41203749e+00 1.39810900e+00 + 4.22397681e-01 -6.41440488e-01 + -4.88493360e+00 -1.58967816e+00 + -1.67649284e-01 -1.08485915e-01 + 2.11489023e+00 1.50506158e+00 + -1.81639929e+00 -3.85542192e-01 + 2.24044819e-01 -1.45100577e-01 + -3.39262411e+00 -1.44394324e+00 + 1.68706599e+00 2.29199618e-01 + -1.94093257e+00 -1.65975814e-01 + 8.28143367e-01 5.92109281e-01 + -8.29587998e-01 -9.57130831e-01 + -1.50011401e+00 -8.36802092e-01 + 2.40770449e+00 9.32820177e-01 + 7.41391309e-02 3.12878473e-01 + 1.87745264e-01 6.19231425e-01 + 9.57622692e-01 -2.20640033e-01 + 3.18479243e+00 1.02986233e+00 + 2.43133846e+00 8.41302677e-01 + -7.09963834e-01 1.99718943e-01 + -2.88253498e-01 -3.62772094e-01 + 5.14052574e+00 1.79304595e+00 + -3.27930993e+00 -1.29177973e+00 + -1.16723536e+00 1.29519656e-01 + 1.04801056e+00 3.41508300e-01 + -3.99256195e+00 -2.51176471e+00 + -7.62824318e-01 -6.84242153e-01 + 2.71524986e-02 5.35157164e-02 + 3.26430102e+00 1.34887262e+00 + -1.72357766e+00 -4.94524388e-01 + -3.81149536e+00 -1.28121944e+00 + 3.36919354e+00 1.10672075e+00 + -3.14841757e+00 -7.10713767e-01 + -3.16463676e+00 -7.58558435e-01 + -2.44745969e+00 -1.08816514e+00 + 2.79173264e-01 -2.19652051e-02 + 4.15309883e-01 6.07502790e-01 + -9.51007417e-01 -5.83976336e-01 + -1.47929839e+00 -8.39850409e-01 + 2.38335703e+00 6.16055149e-01 + -7.47749031e-01 -5.56164928e-01 + -3.65643622e-01 -5.06684411e-01 + -1.76634163e+00 -7.86382097e-01 + 6.76372222e-01 -3.06592181e-01 + -1.33505058e+00 -1.18301441e-01 + 3.59660179e+00 2.00424178e+00 + -7.88912762e-02 8.71956146e-02 + 1.22656397e+00 1.18149583e+00 + 4.24919729e+00 1.20082355e+00 + 2.94607456e+00 1.00676505e+00 + 7.46061275e-02 4.41761753e-02 + -2.47738025e-02 1.92737701e-01 + -2.20509316e-01 -3.79163193e-01 + -3.50222190e-01 3.58727299e-01 + -3.64788014e+00 -1.36107312e+00 + 3.56062799e+00 9.27032742e-01 + 1.04317289e+00 6.08035970e-01 + 4.06718718e-01 3.00628051e-01 + 4.33158086e+00 2.25860714e+00 + 2.13917145e-01 -1.72757967e-01 + -1.40637998e+00 -1.14119465e+00 + 3.61554872e+00 1.87797348e+00 + 1.01726871e+00 5.70255097e-01 + -7.04902551e-01 2.16444147e-01 + -2.51492186e+00 -8.52997369e-01 + 1.85097530e+00 1.15124496e+00 + -8.67569714e-01 -3.05682432e-01 + 8.07550858e-01 5.88901608e-01 + 1.85186755e-01 -1.94589367e-01 + -1.23378238e+00 -7.84128347e-01 + -1.22713161e+00 -4.21218235e-01 + 2.97751165e-01 2.81055275e-01 + 4.77703554e+00 1.66265524e+00 + 2.51549669e+00 7.49980674e-01 + 2.76510822e-01 1.40456909e-01 + 1.98740905e+00 -1.79608212e-01 + 9.35429145e-01 8.44344180e-01 + -1.20854492e+00 -5.00598453e-01 + 2.29936219e+00 8.10236668e-01 + 6.92555544e-01 -2.65891331e-01 + -1.58050994e+00 2.31237821e-01 + -1.50864880e+00 -9.49661690e-01 + -1.27689206e+00 -7.18260016e-01 + -3.12517127e+00 -1.75587113e+00 + 8.16062912e-02 -6.56551804e-01 + -5.02479939e-01 -4.67162543e-01 + -5.47435788e+00 -2.47799576e+00 + 1.95872901e-02 5.80874076e-01 + -1.59064958e+00 -6.34554756e-01 + -3.77521478e+00 -1.74301790e+00 + 5.89628224e-01 8.55736553e-01 + -1.81903543e+00 -7.50011008e-01 + 1.38557775e+00 3.71490991e-01 + 9.70032652e-01 -7.11356016e-01 + 2.63539625e-01 -4.20994771e-01 + 2.12154222e+00 8.19081400e-01 + -6.56977937e-01 -1.37810098e-01 + 8.91309581e-01 2.77864361e-01 + -7.43693195e-01 -1.46293770e-01 + 2.24447769e+00 4.00911438e-01 + -2.25169262e-01 2.04148801e-02 + 1.68744684e+00 9.47573007e-01 + 2.73086373e-01 3.30877195e-01 + 5.54294414e+00 2.14198009e+00 + -8.49238733e-01 3.65603298e-02 + 2.39685712e+00 1.17951039e+00 + -2.58230528e+00 -5.52116673e-01 + 2.79785277e+00 2.88833717e-01 + -1.96576188e-01 1.11652123e+00 + -4.69383301e-01 1.96496282e-01 + -1.95011845e+00 -6.15235169e-01 + 1.03379890e-02 2.33701239e-01 + 4.18933607e-01 2.77939814e-01 + -1.18473337e+00 -4.10051126e-01 + -7.61499744e-01 -1.43658094e+00 + -1.65586092e+00 -3.41615303e-01 + -5.58523700e-02 -5.21837080e-01 + -2.40331088e+00 -2.64521583e-01 + 2.24925206e+00 6.79843335e-02 + 1.46360479e+00 1.04271443e+00 + -3.09255443e+00 -1.82548953e+00 + 2.11325841e+00 1.14996627e+00 + -8.70657797e-01 1.02461839e-01 + -5.71056521e-01 9.71232588e-02 + -3.37870752e+00 -1.54091877e+00 + 1.03907189e+00 -1.35661392e-01 + 8.40057486e-01 6.12172413e-02 + -1.30998234e+00 -1.34077226e+00 + 7.53744974e-01 1.49447350e-01 + 9.13995056e-01 -1.81227962e-01 + 2.28386229e-01 3.74498520e-01 + 2.54829151e-01 -2.88802704e-01 + 1.61709009e+00 2.09319193e-01 + -1.12579380e+00 -5.95955338e-01 + -2.69610726e+00 -2.76222736e-01 + -2.63773329e+00 -7.84491970e-01 + -2.62167427e+00 -1.54792874e+00 + -4.80639856e-01 -1.30582102e-01 + -1.26130891e+00 -8.86841840e-01 + -1.24951950e+00 -1.18182622e+00 + -1.40107574e+00 -9.13695575e-01 + 4.99872179e-01 4.69014702e-01 + -2.03550193e-02 -1.48859738e-01 + -1.50189069e+00 -2.97714278e-02 + -2.07846113e+00 -7.29937809e-01 + -5.50576792e-01 -7.03151525e-01 + -3.88069238e+00 -1.63215295e+00 + 2.97032988e+00 6.43571144e-01 + -1.85999273e-01 1.18107620e+00 + 1.79249709e+00 6.65356160e-01 + 2.68842472e+00 1.35703255e+00 + 1.07675417e+00 1.39845588e-01 + 8.01226349e-01 2.11392275e-01 + 9.64329379e-01 3.96146195e-01 + -8.22529511e-01 1.96080831e-01 + 1.92481841e+00 4.62985744e-01 + 3.69756927e-01 3.77135799e-01 + 1.19807835e+00 8.87715050e-01 + -1.01363587e+00 -2.48151636e-01 + 8.53071010e-01 4.96887868e-01 + -3.41120553e+00 -1.35401843e+00 + -2.64787381e+00 -1.08690563e+00 + -1.11416759e+00 -4.43848915e-01 + 1.46242648e+00 6.17106076e-02 + -7.52968881e-01 -9.20972209e-01 + -1.22492228e+00 -5.40327617e-01 + 1.08001827e+00 5.29593785e-01 + -2.58706464e-01 1.13022085e-01 + -4.27394011e-01 1.17864354e-02 + -3.20728413e+00 -1.71224737e-01 + 1.71398530e+00 8.68885893e-01 + 2.12067866e+00 1.45092772e+00 + 4.32782616e-01 -3.34117769e-01 + 7.80084374e-01 -1.35100217e-01 + -2.05547729e+00 -4.70217750e-01 + 2.38379736e+00 1.09186058e+00 + -2.80825477e+00 -1.03320187e+00 + 2.63434576e+00 1.15671733e+00 + -1.60936214e+00 1.91843035e-01 + -5.02298769e+00 -2.32820708e+00 + 1.90349195e+00 1.45215416e+00 + 3.00232888e-01 3.24412586e-01 + -2.46503943e+00 -1.19550010e+00 + 1.06304233e+00 2.20136246e-01 + -2.99101388e+00 -1.58299318e+00 + 2.30071719e+00 1.12881362e+00 + -2.37587247e+00 -8.08298336e-01 + 7.27006308e-01 3.80828984e-01 + 2.61199061e+00 1.56473491e+00 + 8.33936357e-01 -1.42189425e-01 + 3.13291605e+00 1.77771210e+00 + 2.21917371e+00 5.68427075e-01 + 2.38867649e+00 9.06637262e-01 + -6.92959466e+00 -3.57682881e+00 + 2.57904824e+00 5.93959108e-01 + 2.71452670e+00 1.34436199e+00 + 4.39988761e+00 2.13124672e+00 + 5.71783077e-01 5.08346173e-01 + -3.65399429e+00 -1.18192861e+00 + 4.46176453e-01 3.75685594e-02 + -2.97501495e+00 -1.69459236e+00 + 1.60855728e+00 9.20930014e-01 + -1.44270290e+00 -1.93922306e-01 + 1.67624229e+00 1.66233866e+00 + -1.42579598e+00 -1.44990145e-01 + 1.19923176e+00 4.58490278e-01 + -9.00068460e-01 5.09701825e-02 + -1.69391694e+00 -7.60070300e-01 + -1.36576440e+00 -5.24244256e-01 + -1.03016748e+00 -3.44625878e-01 + 2.40519313e+00 1.09947587e+00 + 1.50365433e+00 1.06464802e+00 + -1.07609727e+00 -3.68897187e-01 + 2.44969069e+00 1.28486192e+00 + -1.25610307e+00 -1.14644789e+00 + 2.05962899e+00 4.31162369e-01 + -7.15886908e-01 -6.11587804e-02 + -6.92354119e-01 -7.85019920e-01 + -1.63016508e+00 -5.96944975e-01 + 1.90352536e+00 1.28197457e+00 + -4.01535243e+00 -1.81934488e+00 + -1.07534435e+00 -2.10544784e-01 + 3.25500866e-01 7.69603661e-01 + 2.18443365e+00 6.59773335e-01 + 8.80856790e-01 6.39505913e-01 + -2.23956372e-01 -4.65940132e-01 + -1.06766519e+00 -5.38388505e-03 + 7.25556863e-01 -2.91123488e-01 + -4.69451411e-01 7.89182650e-02 + 2.58146587e+00 1.29653243e+00 + 1.53747468e-01 7.69239075e-01 + -4.61152262e-01 -4.04151413e-01 + 1.48183517e+00 8.10079506e-01 + -1.83402614e+00 -1.36939322e+00 + 1.49315501e+00 7.95225425e-01 + 1.41922346e+00 1.05582774e-01 + 1.57473493e-01 9.70795657e-01 + -2.67603254e+00 -7.48562280e-01 + -8.49156216e-01 -6.05762529e-03 + 1.12944274e+00 3.67741591e-01 + 1.94228071e-01 5.28188141e-01 + -3.65610158e-01 4.05851838e-01 + -1.98839111e+00 -1.38452764e+00 + 2.73765752e+00 8.24150530e-01 + 7.63728641e-01 3.51617707e-01 + 5.78307267e+00 1.68103612e+00 + 2.27547227e+00 3.60876164e-01 + -3.50681697e+00 -1.74429984e+00 + 4.01241184e+00 1.26227829e+00 + 2.44946343e+00 9.06119057e-01 + -2.96638941e+00 -9.01532322e-01 + 1.11267643e+00 -3.43333381e-01 + -6.61868994e-01 -3.44666391e-01 + -8.34917179e-01 5.69478372e-01 + -1.91888454e+00 -3.03791075e-01 + 1.50397636e+00 8.31961240e-01 + 6.12260198e+00 2.16851807e+00 + 1.34093127e+00 8.86649385e-01 + 1.48748519e+00 8.26273697e-01 + 7.62243068e-01 2.64841396e-01 + -2.17604986e+00 -3.54219958e-01 + 2.64708640e-01 -4.38136718e-02 + 1.44725372e+00 1.18499914e-01 + -6.71259446e-01 -1.19526851e-01 + 2.40134595e-01 -8.90042323e-02 + -3.57238199e+00 -1.23166201e+00 + -3.77626645e+00 -1.19533443e+00 + -3.81101035e-01 -4.94160532e-01 + -3.02758757e+00 -1.18436066e+00 + 2.59116298e-01 1.38023047e+00 + 4.17900116e+00 1.12065959e+00 + 1.54598848e+00 2.89806755e-01 + 1.00656475e+00 1.76974511e-01 + -4.15730234e-01 -6.22681694e-01 + -6.00903565e-01 -1.43256959e-01 + -6.03652508e-01 -5.09936379e-01 + -1.94096658e+00 -9.48789544e-01 + -1.74464105e+00 -8.50491590e-01 + 1.17652544e+00 1.88118317e+00 + 2.35507776e+00 1.44000205e+00 + 2.63067924e+00 1.06692988e+00 + 2.88805386e+00 1.23924715e+00 + 8.27595008e-01 5.75364692e-01 + 3.91384216e-01 9.72781920e-02 + -1.03866816e+00 -1.37567768e+00 + -1.34777969e+00 -8.40266025e-02 + -4.12904508e+00 -1.67618340e+00 + 1.27918111e+00 3.52085961e-01 + 4.15361174e-01 6.28896189e-01 + -7.00539496e-01 4.80447955e-02 + -1.62332639e+00 -5.98236485e-01 + 1.45957300e+00 1.00305154e+00 + -3.06875603e+00 -1.25897545e+00 + -1.94708176e+00 4.85143006e-01 + 3.55744156e+00 -1.07468822e+00 + 1.21602223e+00 1.28768827e-01 + 1.89093098e+00 -4.70835659e-01 + -6.55759125e+00 2.70114082e+00 + 8.96843535e-01 -3.98115252e-01 + 4.13450429e+00 -2.32069236e+00 + 2.37764218e+00 -1.09098890e+00 + -1.11388901e+00 6.27083097e-01 + -6.34116929e-01 4.62816387e-01 + 2.90203079e+00 -1.33589143e+00 + 3.17457598e+00 -5.13575945e-01 + -1.76362299e+00 5.71820693e-01 + 1.66103362e+00 -8.99466249e-01 + -2.53947433e+00 8.40084780e-01 + 4.36631397e-01 7.24234261e-02 + -1.87589394e+00 5.08529113e-01 + 4.49563965e+00 -9.43365992e-01 + 1.78876299e+00 -1.27076149e+00 + -1.16269107e-01 -4.55078316e-01 + 1.92966079e+00 -8.05371385e-01 + 2.20632583e+00 -9.00919345e-01 + 1.52387824e+00 -4.82391996e-01 + 8.04004564e-01 -2.73650595e-01 + -7.75326067e-01 1.07469566e+00 + 1.83226282e+00 -4.52173344e-01 + 1.25079758e-01 -3.52895417e-02 + -9.90957437e-01 8.55993130e-01 + 1.71623322e+00 -7.08691667e-01 + -2.86175924e+00 6.75160955e-01 + -8.40817853e-01 -1.00361809e-01 + 1.33393000e+00 -4.65788123e-01 + 5.29394114e-01 -5.44881619e-02 + -8.07435599e-01 8.27353370e-01 + -4.33165824e+00 1.97299638e+00 + 1.26452422e+00 -8.34070486e-01 + 1.45996394e-02 2.97736043e-01 + -1.64489287e+00 6.72839598e-01 + -5.74234578e+00 3.20975117e+00 + 2.13841341e-02 3.64514015e-01 + 6.68084924e+00 -2.27464254e+00 + -3.22881590e+00 8.01879324e-01 + 3.02534313e-01 -4.56222796e-01 + -5.84520734e+00 1.95678162e+00 + 2.81515232e+00 -1.72101318e+00 + -2.39620908e-01 2.69145522e-01 + -7.41669691e-01 -2.30283281e-01 + -2.15682714e+00 3.45313021e-01 + 1.23475788e+00 -7.32276553e-01 + -1.71816113e-01 1.20419560e-02 + 1.89174235e+00 2.27435901e-01 + -3.64511114e-01 1.72260361e-02 + -3.24143860e+00 6.50125817e-01 + -2.25707409e+00 5.66970751e-01 + 1.03901456e+00 -1.00588433e+00 + -5.09159710e+00 1.58736109e+00 + 1.45534075e+00 -5.83787452e-01 + 4.28879587e+00 -1.58006866e+00 + 8.52384427e-01 -1.11042299e+00 + 4.51431615e+00 -2.63844265e+00 + -4.33042648e+00 1.86497078e+00 + -2.13568046e+00 5.82559743e-01 + -4.42568887e+00 1.26131214e+00 + 3.15821315e+00 -1.61515905e+00 + -3.14125204e+00 8.49604386e-01 + 6.54152300e-01 -2.04624711e-01 + -3.73374317e-01 9.94187820e-02 + -3.96177282e+00 1.27245623e+00 + 9.59825199e-01 -1.15547861e+00 + 3.56902055e+00 -1.46591091e+00 + 1.55433633e-02 6.93544345e-01 + 1.15684646e+00 -4.99836352e-01 + 3.11824573e+00 -4.75900506e-01 + -8.61706369e-01 -3.50774059e-01 + 9.89057391e-01 -7.16878802e-01 + -4.94787870e+00 2.09137481e+00 + 1.37777347e+00 -1.34946349e+00 + -1.13161577e+00 8.05114754e-01 + 8.12020675e-01 -1.04849421e+00 + 4.73783881e+00 -2.26718812e+00 + 8.99579366e-01 -8.89764451e-02 + 4.78524868e+00 -2.25795843e+00 + 1.75164590e+00 -1.73822209e-01 + 1.30204590e+00 -7.26724717e-01 + -7.26526403e-01 -5.23925361e-02 + 2.01255351e+00 -1.69965366e+00 + 9.87852740e-01 -4.63577220e-01 + 2.45957762e+00 -1.29278962e+00 + -3.13817948e+00 1.64433038e+00 + -1.76302159e+00 9.62784302e-01 + -1.91106331e+00 5.81460008e-01 + -3.30883001e+00 1.30378978e+00 + 5.54376450e-01 3.78814272e-01 + 1.09982111e+00 -1.47969612e+00 + -2.61300705e-02 -1.42573464e-01 + -2.22096157e+00 7.75684440e-01 + 1.70319323e+00 -2.89738444e-01 + -1.43223842e+00 6.39284281e-01 + 2.34360959e-01 -1.64379268e-01 + -2.67147991e+00 9.46548086e-01 + 1.51131425e+00 -4.91594395e-01 + -2.48446856e+00 1.01286123e+00 + 1.50534658e-01 -2.94620246e-01 + -1.66966792e+00 1.67755508e+00 + -1.50094241e+00 3.30163095e-01 + 2.27681194e+00 -1.08064317e+00 + 2.05122965e+00 -1.15165939e+00 + -4.23509309e-01 -6.56906167e-02 + 1.80084023e+00 -1.07228556e+00 + -2.65769521e+00 1.18023206e+00 + 2.02852676e+00 -8.06793574e-02 + -4.49544185e+00 2.68200163e+00 + -7.50043216e-01 1.17079331e+00 + 6.80060893e-02 3.99055351e-01 + -3.83634635e+00 1.38406887e+00 + 3.24858545e-01 -9.25273218e-02 + -2.19895100e+00 1.47819500e+00 + -3.61569522e-01 -1.03188739e-01 + 1.12180375e-01 -9.52696354e-02 + -1.31477803e+00 1.79900570e-01 + 2.39573628e+00 -6.09739269e-01 + -1.00135700e+00 6.02837296e-01 + -4.11994589e+00 2.49599192e+00 + -1.54196236e-01 -4.84921951e-01 + 5.92569908e-01 -1.87310359e-01 + 3.85407741e+00 -1.50979925e+00 + 5.17802528e+00 -2.26032607e+00 + -1.37018916e+00 1.87111822e-01 + 8.46682996e-01 -3.56676331e-01 + -1.17559949e+00 5.29057734e-02 + -5.56475671e-02 6.79049243e-02 + 1.07851745e+00 -5.14535101e-01 + -2.71622446e+00 1.00151846e+00 + -1.08477208e+00 8.81391054e-01 + 5.50755824e-01 -5.20577727e-02 + 4.70885495e+00 -2.04220397e+00 + -1.87375336e-01 -6.16962830e-02 + 3.52097100e-01 2.21163550e-01 + 7.07929984e-01 -1.75827590e-01 + -1.22149219e+00 1.83084346e-01 + 2.58247412e+00 -6.15914898e-01 + -6.01206182e-01 -2.29832987e-01 + 9.83360449e-01 -3.75870060e-01 + -3.20027685e+00 1.35467480e+00 + 1.79178978e+00 -1.38531981e+00 + -3.30376867e-01 -1.16250192e-01 + -1.89053055e+00 5.68463567e-01 + -4.20604849e+00 1.65429681e+00 + -1.01185529e+00 1.92801240e-01 + -6.18819882e-01 5.42206996e-01 + -5.08091672e+00 2.61598591e+00 + -2.62570344e+00 2.51590658e+00 + 3.05577906e+00 -1.49090609e+00 + 2.77609677e+00 -1.37681378e+00 + -7.93515301e-02 4.28072744e-01 + -2.08359471e+00 8.94334295e-01 + 2.20163801e+00 4.01127167e-02 + -1.18145785e-01 -2.06822464e-01 + -2.74788298e-01 2.96250607e-01 + 1.59613555e+00 -3.87246203e-01 + -3.82971472e-01 -3.39716093e-02 + -4.20311307e-02 3.88529510e-01 + 1.52128574e+00 -9.33138876e-01 + -9.06584458e-01 -2.75016094e-02 + 3.56216834e+00 -9.99384622e-01 + 2.11964220e+00 -9.98749118e-02 + 4.01203480e+00 -2.03032745e+00 + -1.24171557e+00 1.97596725e-01 + -1.57230455e+00 4.14126609e-01 + -1.85484741e+00 5.40041563e-01 + 1.76329831e+00 -6.95967734e-01 + -2.29439232e-01 5.08669245e-01 + -5.45124276e+00 2.26907549e+00 + -5.71364288e-02 5.04476476e-01 + 3.12468018e+00 -1.46358879e+00 + 8.20017359e-01 6.51949028e-01 + -1.33977500e+00 2.83634232e-04 + -1.83311685e+00 1.23947117e+00 + 6.31205922e-01 1.19792164e-02 + -2.21967834e+00 6.94056232e-01 + -1.41693842e+00 9.93526233e-01 + -7.58885703e-01 6.78547347e-01 + 3.60239086e+00 -1.08644935e+00 + 6.72217073e-02 3.00036011e-02 + -3.42680958e-01 -3.48049352e-01 + 1.87546079e+00 -4.78018246e-01 + 7.00485821e-01 -3.52905383e-01 + -8.54580948e-01 8.17330861e-01 + 8.19123706e-01 -5.73927281e-01 + 2.70855639e-01 -3.08940052e-01 + -1.05059952e+00 3.27873168e-01 + 1.08282999e+00 4.84559349e-02 + -7.89899220e-01 1.22291138e+00 + -2.87939816e+00 7.17403497e-01 + -2.08429452e+00 8.87409226e-01 + 1.58409232e+00 -4.74123532e-01 + 1.26882735e+00 1.59162510e-01 + -2.53782993e+00 6.18253491e-01 + -8.92757445e-01 3.35979011e-01 + 1.31867900e+00 -1.17355054e+00 + 1.14918879e-01 -5.35184038e-01 + -1.70288738e-01 5.35868087e-02 + 4.21355121e-01 5.41848690e-02 + 2.07926943e+00 -5.72538144e-01 + 4.08788970e-01 3.77655777e-01 + -3.39631381e+00 9.84216764e-01 + 2.94170163e+00 -1.83120916e+00 + -7.94798752e-01 7.39889052e-01 + 1.46555463e+00 -4.62275563e-01 + 2.57255955e+00 -1.04671434e+00 + 8.45042540e-01 -1.96952892e-01 + -3.23526646e+00 1.60049846e+00 + 3.21948565e+00 -8.88376674e-01 + 1.43005104e+00 -9.21561086e-01 + 8.82360506e-01 2.98403872e-01 + -8.91168097e-01 1.01319072e+00 + -5.13215241e-01 -2.47182649e-01 + -1.35759444e+00 7.07450608e-02 + -4.04550983e+00 2.23534867e+00 + 1.39348883e+00 3.81637747e-01 + -2.85676418e+00 1.53240862e+00 + -1.37183120e+00 6.37977425e-02 + -3.88195859e+00 1.73887145e+00 + 1.19509776e+00 -6.25013512e-01 + -2.80062734e+00 1.79840585e+00 + 1.96558429e+00 -4.70997234e-01 + 1.93111352e+00 -9.70318441e-01 + 3.57991190e+00 -1.65065116e+00 + 2.12831714e+00 -1.11531708e+00 + -3.95661018e-01 -8.54339904e-02 + -2.41630441e+00 1.65166304e+00 + 7.55412624e-01 -1.53453579e-01 + -1.77043450e+00 1.39928715e+00 + -9.32631260e-01 8.73649199e-01 + 1.53342205e+00 -8.39569765e-01 + -6.29846924e-02 1.25023084e-01 + 3.31509049e+00 -1.10733235e+00 + -2.18957109e+00 3.07376993e-01 + -2.35740747e+00 6.47437564e-01 + -2.22142438e+00 8.47318938e-01 + -6.51401147e-01 3.48398562e-01 + 2.75763095e+00 -1.21390708e+00 + 1.12550484e+00 -5.61412847e-01 + -5.65053161e-01 6.74365205e-02 + 1.68952456e+00 -6.57566096e-01 + 8.95598401e-01 3.96738993e-01 + -1.86537066e+00 9.44129208e-01 + -2.59933294e+00 2.57423247e-01 + -6.59598267e-01 1.91828851e-02 + -2.64506676e+00 8.41783205e-01 + -1.25911802e+00 5.52425066e-01 + -1.39754507e+00 3.73689222e-01 + 5.49550729e-02 1.35071215e+00 + 3.31874811e+00 -1.05682424e+00 + 3.63159604e+00 -1.42864695e+00 + -4.45944617e+00 1.42889446e+00 + 5.87314342e-01 -4.88892988e-01 + -7.26130820e-01 1.51936106e-01 + -1.79246441e+00 6.05888105e-01 + -5.50948207e-01 6.21443081e-01 + -3.17246063e-01 1.77213880e-01 + -2.00098937e+00 1.23799074e+00 + 4.33790961e+00 -1.08490465e+00 + -2.03114114e+00 1.31613237e+00 + -6.29216542e+00 1.92406317e+00 + -1.60265624e+00 8.87947500e-01 + 8.64465062e-01 -8.37416270e-01 + -2.14273937e+00 8.05485900e-01 + -2.36844256e+00 6.17915124e-01 + -1.40429636e+00 6.78296866e-01 + 9.99019988e-01 -5.84297572e-01 + 7.38824546e-01 1.68838678e-01 + 1.45681238e+00 3.04641461e-01 + 2.15914949e+00 -3.43089227e-01 + -1.23895930e+00 1.05339864e-01 + -1.23162264e+00 6.46629863e-01 + 2.28183862e+00 -9.24157063e-01 + -4.29615882e-01 5.69130863e-01 + -1.37449121e+00 -9.12032183e-01 + -7.33890904e-01 -3.91865471e-02 + 8.41400661e-01 -4.76002200e-01 + -1.73349274e-01 -6.84143467e-02 + 3.16042891e+00 -1.32651856e+00 + -3.78244609e+00 2.38619718e+00 + -3.69634380e+00 2.22368561e+00 + 1.83766344e+00 -1.65675953e+00 + -1.63206002e+00 1.19484469e+00 + 3.68480064e-01 -5.70764494e-01 + 3.61982479e-01 1.04274409e-01 + 2.48863048e+00 -1.13285542e+00 + -2.81896488e+00 9.47958768e-01 + 5.74952901e-01 -2.75959392e-01 + 3.72783275e-01 -3.48937848e-01 + 1.95935716e+00 -1.06750415e+00 + 5.19357531e+00 -2.32070803e+00 + 4.09246149e+00 -1.89976700e+00 + -3.36666087e-01 8.17645057e-02 + 1.85453493e-01 3.76913151e-01 + -3.06458262e+00 1.34106402e+00 + -3.13796566e+00 7.00485099e-01 + 1.42964058e+00 -1.35536932e-01 + -1.23440423e-01 4.60094177e-02 + -2.86753037e+00 -5.21724160e-02 + 2.67113726e+00 -1.83746924e+00 + -1.35335062e+00 1.28238073e+00 + -2.43569899e+00 1.25998539e+00 + 1.26036740e-01 -2.35416844e-01 + -1.35725745e+00 7.37788491e-01 + -3.80897538e-01 3.30757889e-01 + 6.58694434e-01 -1.07566603e+00 + 2.11273640e+00 -9.02260632e-01 + 4.00755057e-01 -2.49229150e-02 + -1.80095812e+00 9.73099742e-01 + -2.68408372e+00 1.63737364e+00 + -2.66079826e+00 7.47289412e-01 + -9.92321439e-02 -1.49331396e-01 + 4.45678251e+00 -1.80352394e+00 + 1.35962915e+00 -1.31554389e+00 + -7.76601417e-01 -9.66173523e-02 + 1.68096348e+00 -6.27235133e-01 + 1.53081227e-01 -3.54216830e-01 + -1.54913095e+00 3.43689269e-01 + 5.29187357e-02 -6.73916964e-01 + -2.06606084e+00 8.34784242e-01 + 1.73701179e+00 -6.06467340e-01 + 1.55856757e+00 -2.58642780e-01 + 1.04349101e+00 -4.43027348e-01 + -1.02397719e+00 1.01308824e+00 + -2.13860204e-01 -4.73347361e-01 + -2.59004955e+00 1.43367853e+00 + 7.98457679e-01 2.18621627e-02 + -1.32974762e+00 4.61802208e-01 + 3.21419359e-01 2.30723316e-02 + 2.87201888e-02 6.24566672e-02 + -1.22261418e+00 6.02340363e-01 + 1.28750335e+00 -3.34839548e-02 + -9.67952623e-01 4.34470505e-01 + 2.02850324e+00 -9.05160255e-01 + -4.13946010e+00 2.33779091e+00 + -4.47508806e-01 3.06440495e-01 + -3.91543394e+00 1.68251022e+00 + -6.45193001e-01 5.29781162e-01 + -2.15518916e-02 5.07278355e-01 + -2.83356868e+00 1.00670227e+00 + 1.82989749e+00 -1.37329222e+00 + -1.09330213e+00 1.08560688e+00 + 1.90533722e+00 -1.28905879e+00 + 2.33986084e+00 2.30642626e-02 + 8.01940220e-01 -1.63986962e+00 + -4.23415165e+00 2.07530423e+00 + 9.33382522e-01 -7.62917211e-01 + -1.84033954e+00 1.07469401e+00 + -2.81938669e+00 1.07342024e+00 + -7.05169988e-01 2.13124943e-01 + 5.09598137e-01 1.32725493e-01 + -2.34558226e+00 8.62383168e-01 + -1.70322072e+00 2.70893796e-01 + 1.23652660e+00 -7.53216034e-02 + 2.84660646e+00 -3.48178304e-02 + 2.50250128e+00 -1.27770855e+00 + -1.00279469e+00 8.77194218e-01 + -4.34674121e-02 -2.12091350e-01 + -5.84151289e-01 1.50382340e-01 + -1.79024013e+00 4.24972808e-01 + -1.23434666e+00 -8.85546570e-02 + 1.36575412e+00 -6.42639880e-01 + -1.98429947e+00 2.27650336e-01 + 2.36253589e+00 -1.51340773e+00 + 8.79157643e-01 6.84142159e-01 + -2.18577755e+00 2.76526200e-01 + -3.55473434e-01 8.29976561e-01 + 1.16442595e+00 -5.97699411e-01 + -7.35528097e-01 2.40318183e-01 + -1.73702631e-01 7.33788663e-02 + -1.40451745e+00 3.24899628e-01 + -2.05434385e+00 5.68123738e-01 + 8.47876642e-01 -5.74224294e-01 + -6.91955602e-01 1.26009087e+00 + 2.56574498e+00 -1.15602581e+00 + 3.93306545e+00 -1.38398209e+00 + -2.73230251e+00 4.89062581e-01 + -1.04315474e+00 6.06335547e-01 + 1.23231431e+00 -4.46675065e-01 + -3.93035285e+00 1.43287651e+00 + -1.02132111e+00 9.58919791e-01 + -1.49425352e+00 1.06456165e+00 + -6.26485337e-01 1.03791402e+00 + -6.61772998e-01 2.63275425e-01 + -1.80940386e+00 5.70767403e-01 + 9.83720450e-01 -1.39449756e-01 + -2.24619662e+00 9.01044870e-01 + 8.94343014e-01 5.31038678e-02 + 1.95518199e-01 -2.81343295e-01 + -2.30533019e-01 -1.74478106e-01 + -2.01550361e+00 5.55958010e-01 + -4.36281469e+00 1.94374226e+00 + -5.18530457e+00 2.89278357e+00 + 2.67289101e+00 -2.98511449e-01 + -1.53566179e+00 -1.00588944e-01 + -6.09943217e-02 -1.56986047e-01 + -5.22146452e+00 1.66209208e+00 + -3.69777478e+00 2.26154873e+00 + 2.24607181e-01 -4.86934960e-01 + 2.49909450e+00 -1.03033370e+00 + -1.07841120e+00 8.22388054e-01 + -3.20697089e+00 1.09536143e+00 + 3.43524232e+00 -1.47289362e+00 + -5.65784134e-01 4.60365175e-01 + -1.76714734e+00 1.57752346e-01 + -7.77620365e-01 5.60153443e-01 + 6.34399352e-01 -5.22339836e-01 + 2.91011875e+00 -9.72623380e-01 + -1.19286824e+00 6.32370253e-01 + -2.18327609e-01 8.23953181e-01 + 3.42430842e-01 1.37098055e-01 + 1.28658034e+00 -9.11357320e-01 + 2.06914465e+00 -6.67556382e-01 + -6.69451020e-01 -6.38605102e-01 + -2.09312398e+00 1.16743634e+00 + -3.63778357e+00 1.91919157e+00 + 8.74685911e-01 -1.09931208e+00 + -3.91496791e+00 1.00808357e+00 + 1.29621330e+00 -8.32239802e-01 + 9.00222045e-01 -1.31159793e+00 + -1.12242062e+00 1.98517079e-01 + -3.71932852e-01 1.31667093e-01 + -2.23829610e+00 1.26328346e+00 + -2.08365062e+00 9.93385336e-01 + -1.91082720e+00 7.45866855e-01 + 4.38024917e+00 -2.05901118e+00 + -2.28872886e+00 6.85279335e-01 + 1.01274497e-01 -3.26227153e-01 + -5.04447572e-01 -3.18619513e-01 + 1.28537006e+00 -1.04573551e+00 + -7.83175212e-01 1.54791645e-01 + -3.89239175e+00 1.60017929e+00 + -8.87877111e-01 -1.04968005e-01 + 9.32215179e-01 -5.58691113e-01 + -6.44977127e-01 -2.23018375e-01 + 1.10141900e+00 -1.00666432e+00 + 2.92755687e-01 -1.45480350e-01 + 7.73580681e-01 -2.21150567e-01 + -1.40873709e+00 7.61548044e-01 + -8.89031805e-01 -3.48542923e-01 + 4.16844267e-01 -2.39914494e-01 + -4.64265832e-01 7.29581138e-01 + 1.99835179e+00 -7.70542813e-01 + 4.20523191e-02 -2.18783563e-01 + -6.32611758e-01 -3.09926115e-01 + 6.82912198e-02 -8.48327050e-01 + 1.92425229e+00 -1.37876951e+00 + 3.49461782e+00 -1.88354255e+00 + -3.25209026e+00 1.49809395e+00 + 6.59273182e-01 -2.37435654e-01 + -1.15517300e+00 8.46134387e-01 + 1.26756151e+00 -4.58988026e-01 + -3.99178418e+00 2.04153008e+00 + 7.05687841e-01 -6.83433306e-01 + -1.61997342e+00 8.16577004e-01 + -3.89750399e-01 4.29753250e-01 + -2.53026432e-01 4.92861432e-01 + -3.16788324e+00 4.44285524e-01 + -7.86248901e-01 1.12753716e+00 + -3.02351433e+00 1.28419015e+00 + -1.30131355e+00 1.71226678e+00 + -4.08843475e+00 1.62063214e+00 + -3.09209403e+00 1.19958520e+00 + 1.49102271e+00 -1.11834864e+00 + -3.18059348e+00 5.74587042e-01 + 2.06054867e+00 3.25797860e-03 + -3.50999200e+00 2.02412428e+00 + -8.26610023e-01 3.46528211e-01 + 2.00546034e+00 -4.07333110e-01 + -9.69941653e-01 4.80953753e-01 + 4.47925660e+00 -2.33127314e+00 + 2.03845790e+00 -9.90439915e-01 + -1.11349191e+00 4.31183918e-01 + -4.03628396e+00 1.68509679e+00 + -1.48177601e+00 7.74322088e-01 + 3.07369385e+00 -9.57465886e-01 + 2.39011286e+00 -6.44506921e-01 + 2.91561991e+00 -8.78627328e-01 + 1.10212733e+00 -4.21637388e-01 + 5.31985231e-01 -6.17445696e-01 + -6.82340929e-01 -2.93529716e-01 + 1.94290679e+00 -4.64268634e-01 + 1.92262116e+00 -7.93142835e-01 + 4.73762800e+00 -1.63654174e+00 + -3.17848641e+00 8.05791391e-01 + 4.08739432e+00 -1.80816807e+00 + -7.60648826e-01 1.24216138e-01 + -2.24716400e+00 7.90020937e-01 + 1.64284052e+00 -7.18784070e-01 + 1.04410012e-01 -7.11195880e-02 + 2.18268225e+00 -7.01767831e-01 + 2.06218013e+00 -8.70251746e-01 + -1.35266581e+00 7.08456358e-01 + -1.38157779e+00 5.14401086e-01 + -3.28326008e+00 1.20988399e+00 + 8.85358917e-01 -8.12213495e-01 + -2.34067500e+00 3.67657353e-01 + 3.96878127e+00 -1.66841450e+00 + 1.36518053e+00 -8.33436812e-01 + 5.25771988e-01 -5.06121987e-01 + -2.25948361e+00 1.30663765e+00 + -2.57662070e+00 6.32114628e-01 + -3.43134685e+00 2.38106008e+00 + 2.31571924e+00 -1.56566818e+00 + -2.95397202e+00 1.05661888e+00 + -1.35331242e+00 6.76383411e-01 + 1.40977132e+00 -1.17775938e+00 + 1.52561996e+00 -9.83147176e-01 + 2.26550832e+00 -2.10464123e-02 + 6.23371684e-01 -5.30768122e-01 + -4.42356624e-01 9.72226986e-01 + 2.31517901e+00 -1.08468105e+00 + 1.97236640e+00 -1.42016619e+00 + 3.18618687e+00 -1.45056343e+00 + -2.75880360e+00 5.40254980e-01 + -1.92916581e+00 1.45029864e-01 + 1.90022524e+00 -6.03805754e-01 + -1.05446211e+00 5.74361752e-01 + 1.45990390e+00 -9.28233993e-01 + 5.14960557e+00 -2.07564096e+00 + -7.53104842e-01 1.55876958e-01 + 8.09490983e-02 -8.58886384e-02 + -1.56894969e+00 4.53497227e-01 + 1.36944658e-01 5.60670875e-01 + -5.32635329e-01 4.40309945e-01 + 1.32507853e+00 -5.83670099e-01 + 1.20676031e+00 -8.02296831e-01 + -3.65023422e+00 1.17211368e+00 + 1.53393850e+00 -6.17771312e-01 + -3.99977129e+00 1.71415137e+00 + 5.70705058e-01 -4.60771539e-01 + -2.20608002e+00 1.07866596e+00 + -1.09040244e+00 6.77441076e-01 + -5.09886482e-01 -1.97282128e-01 + -1.58062785e+00 6.18333697e-01 + -1.53295020e+00 4.02168701e-01 + -5.18580598e-01 2.25767177e-01 + 1.59514316e+00 -2.54983617e-01 + -5.91938655e+00 2.68223782e+00 + 2.84200509e+00 -1.04685313e+00 + 1.31298664e+00 -1.16672614e+00 + -2.36660033e+00 1.81359460e+00 + 6.94163290e-02 3.76658816e-01 + 2.33973934e+00 -8.33173023e-01 + -8.24640389e-01 7.83717285e-01 + -1.02888281e+00 1.04680766e+00 + 1.34750745e+00 -5.89568160e-01 + -2.48761231e+00 7.44199284e-01 + -1.04501559e+00 4.72326911e-01 + -3.14610089e+00 1.89843692e+00 + 2.13003416e-01 5.76633620e-01 + -1.69239608e+00 5.66070021e-01 + 1.80491280e+00 -9.31701080e-01 + -6.94362572e-02 6.96026587e-01 + 1.36502578e+00 -6.85599000e-02 + -7.76764337e-01 3.64328661e-01 + -2.67322167e+00 6.80150021e-01 + 1.84338485e+00 -1.18487494e+00 + 2.88009231e+00 -1.25700411e+00 + 1.17114433e+00 -7.69727080e-01 + 2.11576167e+00 2.81502116e-01 + -1.51470088e+00 2.61553540e-01 + 1.18923669e-01 -1.17890202e-01 + 4.48359786e+00 -1.81427466e+00 + -1.27055948e+00 9.92388998e-01 + -8.00276606e-01 9.11326621e-02 + 7.51764024e-01 -1.03676498e-01 + 1.35769348e-01 -2.11470084e-01 + 2.50731332e+00 -1.12418270e+00 + -2.49752781e-01 7.81224033e-02 + -6.23037902e-01 3.16599691e-01 + -3.93772902e+00 1.37195391e+00 + 1.74256361e+00 -1.12363582e+00 + -1.49737281e+00 5.98828310e-01 + 7.75592115e-01 -4.64733802e-01 + -2.26027693e+00 1.36991118e+00 + -1.62849836e+00 7.36899107e-01 + 2.36850751e+00 -9.32126872e-01 + 5.86169745e+00 -2.49342512e+00 + -5.37092226e-01 1.23821274e+00 + 2.80535867e+00 -1.93363302e+00 + -1.77638106e+00 9.10050276e-01 + 3.02692018e+00 -1.60774676e+00 + 1.97833084e+00 -1.50636531e+00 + 9.09168906e-01 -8.83799359e-01 + 2.39769655e+00 -7.56977869e-01 + 1.47283981e+00 -1.06749890e+00 + 2.92060943e-01 -6.07040605e-01 + -2.09278201e+00 7.71858590e-01 + 7.10015905e-01 -5.42768432e-01 + -2.16826169e-01 1.56897896e-01 + 4.56288247e+00 -2.08912680e+00 + -6.63374020e-01 6.67325183e-01 + 1.80564442e+00 -9.76366134e-01 + 3.28720168e+00 -4.66575145e-01 + -1.60463695e-01 -2.58428153e-01 + 1.78590750e+00 -3.96427146e-01 + 2.75950306e+00 -1.82102856e+00 + -1.18234310e+00 6.28073320e-01 + 4.11415835e+00 -2.33551216e+00 + 1.38721004e+00 -2.77450622e-01 + -2.94903545e+00 1.74813352e+00 + 8.67290400e-01 -6.51667894e-01 + 2.70022274e+00 -8.11832480e-01 + -2.06766146e+00 8.24047249e-01 + 3.90717142e+00 -1.20155758e+00 + -2.95102809e+00 1.36667968e+00 + 6.08815147e+00 -2.60737974e+00 + 2.78576476e+00 -7.86628755e-01 + -3.26258407e+00 1.09302450e+00 + 1.59849422e+00 -1.09705202e+00 + -2.50600710e-01 1.63243175e-01 + -4.90477087e-01 -4.57729572e-01 + -1.24837181e+00 3.22157840e-01 + -2.46341049e+00 1.06517849e+00 + 9.62880751e-01 4.56962496e-01 + 3.99964487e-01 2.07472802e-01 + 6.36657705e-01 -3.46400942e-02 + 4.91231407e-02 -1.40289235e-02 + -4.66683524e-02 -3.72326100e-01 + -5.22049702e-01 -1.70440260e-01 + 5.27062938e-01 -2.32628395e-01 + -2.69440318e+00 1.18914874e+00 + 3.65087539e+00 -1.53427267e+00 + -1.16546364e-01 4.93245392e-02 + 7.55931384e-01 -3.02980139e-01 + 2.06338745e+00 -6.24841225e-01 + 1.31177908e-01 7.29338183e-01 + 1.48021784e+00 -6.39509896e-01 + -5.98656707e-01 2.84525503e-01 + -2.18611080e+00 1.79549812e+00 + -2.91673624e+00 2.15772237e-01 + -8.95591350e-01 7.68250538e-01 + 1.36139762e+00 -1.93845144e-01 + 5.45730414e+00 -2.28114404e+00 + 3.22747247e-01 9.33582332e-01 + -1.46384504e+00 1.12801186e-01 + 4.26728166e-01 -2.33481242e-01 + -1.41327270e+00 8.16103740e-01 + -2.53998067e-01 1.44906646e-01 + -1.32436467e+00 1.87556361e-01 + -3.77313086e+00 1.32896038e+00 + 3.77651731e+00 -1.76548043e+00 + -2.45297093e+00 1.32571926e+00 + -6.55900588e-01 3.56921462e-01 + 9.25558722e-01 -4.51988954e-01 + 1.20732231e+00 -3.02821614e-01 + 3.72660154e-01 -1.89365208e-01 + -1.77090939e+00 9.18087975e-01 + 3.01127567e-01 2.67965829e-01 + -1.76708900e+00 4.62069259e-01 + -2.71812099e+00 1.57233508e+00 + -5.35297633e-01 4.99231535e-01 + 1.50507631e+00 -9.85763646e-01 + 3.00424787e+00 -1.29837562e+00 + -4.99311105e-01 3.91086482e-01 + 1.30125207e+00 -1.26247924e-01 + 4.01699483e-01 -4.46909391e-01 + -1.33635257e+00 5.12068703e-01 + 1.39229757e+00 -9.10974858e-01 + -1.74229508e+00 1.49475978e+00 + -1.21489414e+00 4.04193753e-01 + -3.36537605e-01 -6.74335427e-01 + -2.79186828e-01 8.48314720e-01 + -2.03080140e+00 1.66599815e+00 + -3.53064281e-01 -7.68582906e-04 + -5.30305657e+00 2.91091546e+00 + -1.20049972e+00 8.26578358e-01 + 2.95906989e-01 2.40215920e-01 + -1.42955534e+00 4.63480310e-01 + -1.87856619e+00 8.21459385e-01 + -2.71124720e+00 1.80246843e+00 + -3.06933780e+00 1.22235760e+00 + 5.21935582e-01 -1.27298218e+00 + -1.34175797e+00 7.69018937e-01 + -1.81962785e+00 1.15528991e+00 + -3.99227550e-01 2.93821598e-01 + 1.22533179e+00 -4.73846323e-01 + -2.08068359e-01 -1.75039817e-01 + -2.03068526e+00 1.50370503e+00 + -3.27606113e+00 1.74906330e+00 + -4.37802587e-01 -2.26956048e-01 + -7.69774213e-02 -3.54922468e-01 + 6.47160749e-02 -2.07334721e-01 + -1.37791524e+00 4.43766709e-01 + 3.29846803e+00 -1.04060799e+00 + -3.63704046e+00 1.05800226e+00 + -1.26716116e+00 1.13077353e+00 + 1.98549075e+00 -1.31864807e+00 + 1.85159500e+00 -5.78629560e-01 + -1.55295206e+00 1.23655857e+00 + 6.76026255e-01 9.18824125e-02 + 1.23418960e+00 -4.68162027e-01 + 2.43186642e+00 -9.22422440e-01 + -3.18729701e+00 1.77582673e+00 + -4.02945613e+00 1.14303496e+00 + -1.92694576e-01 1.03301431e-01 + 1.89554730e+00 -4.60128096e-01 + -2.55626581e+00 1.16057084e+00 + 6.89144365e-01 -9.94982900e-01 + -4.44680606e+00 2.19751983e+00 + -3.15196193e+00 1.18762993e+00 + -1.17434977e+00 1.04534656e+00 + 8.58386984e-02 -1.03947487e+00 + 3.33354973e-01 5.54813610e-01 + -9.37631808e-01 3.33450150e-01 + -2.50232471e+00 5.39720635e-01 + 1.03611949e+00 -7.16304095e-02 + -2.05556816e-02 -3.28992265e-01 + -2.24176201e+00 1.13077506e+00 + 4.53583688e+00 -1.10710212e+00 + 4.77389762e-01 -8.99445512e-01 + -2.69075551e+00 6.83176866e-01 + -2.21779724e+00 1.16916849e+00 + -1.09669056e+00 2.10044765e-01 + -8.45367920e-01 -8.45951423e-02 + 4.37558941e-01 -6.95904256e-01 + 1.84884195e+00 -1.71205136e-01 + -8.36371957e-01 5.62862478e-01 + 1.27786531e+00 -1.33362147e+00 + 2.90684492e+00 -7.49892184e-01 + -3.38652716e+00 1.51180670e+00 + -1.30945978e+00 7.09261928e-01 + -7.50471924e-01 -5.24637889e-01 + 1.18580718e+00 -9.97943971e-04 + -7.55395645e+00 3.19273590e+00 + 1.72822535e+00 -1.20996962e+00 + 5.67374320e-01 6.19573416e-01 + -2.99163781e+00 1.79721534e+00 + 1.49862187e+00 -6.05631846e-02 + 1.79503506e+00 -4.90419706e-01 + 3.85626054e+00 -1.95396324e+00 + -9.39188410e-01 7.96498057e-01 + 2.91986664e+00 -1.29392724e+00 + -1.54265750e+00 6.40727933e-01 + 1.14919794e+00 1.20834257e-01 + 2.00936817e+00 -1.53728359e+00 + 3.72468420e+00 -1.38704612e+00 + -1.27794802e+00 3.48543179e-01 + 3.63294077e-01 5.70623314e-01 + 1.49381016e+00 -6.04500534e-01 + 2.98912256e+00 -1.72295726e+00 + -1.80833817e+00 2.94907625e-01 + -3.19669622e+00 1.31888700e+00 + 1.45889401e+00 -8.88448639e-01 + -2.80045388e+00 1.01207060e+00 + -4.78379567e+00 1.48646520e+00 + 2.25510003e+00 -7.13372461e-01 + -9.74441433e-02 -2.17766373e-01 + 2.64468496e-01 -3.60842698e-01 + -5.98821713e+00 3.20197892e+00 + 2.67030213e-01 -5.36386416e-01 + 2.24546960e+00 -8.13464649e-01 + -4.89171414e-01 3.86255031e-01 + -7.45713706e-01 6.29800380e-01 + -3.30460503e-01 3.85127284e-01 + -4.19588147e+00 1.52793198e+00 + 5.42078582e-01 -2.61642741e-02 + 4.24938513e-01 -5.72936751e-01 + 2.82717288e+00 -6.75355024e-01 + -1.44741788e+00 5.03578028e-01 + -1.65547573e+00 7.76444277e-01 + 2.20361170e+00 -1.40835680e+00 + -3.69540235e+00 2.32953767e+00 + -1.41909357e-01 2.28989778e-01 + 1.92838879e+00 -8.72525737e-01 + 1.40708100e+00 -6.81849638e-02 + 1.24988112e+00 -1.39470590e-01 + -2.39435855e+00 7.26587655e-01 + 7.03985028e-01 4.85403277e-02 + 4.05214529e+00 -9.16928318e-01 + 3.74198837e-01 -5.04192358e-01 + -8.43374127e-01 2.36064018e-01 + -3.32253349e-01 7.47840055e-01 + -6.03725210e+00 1.95173337e+00 + 4.60829865e+00 -1.51191309e+00 + -1.46247098e+00 1.11140916e+00 + -9.60111157e-01 -1.23189114e-01 + -7.49613187e-01 4.53614129e-01 + -5.77838219e-01 2.07366469e-02 + 8.07652950e-01 -5.16272662e-01 + -6.02556049e-01 5.05318649e-01 + -1.28712445e-01 2.57836512e-01 + -5.27662820e+00 2.11790737e+00 + 5.40819308e+00 -2.15366022e+00 + 9.37742513e-02 -1.60221751e-01 + 4.55902865e+00 -1.24646307e+00 + -9.06582589e-01 1.92928110e-01 + 2.99928996e+00 -8.04301218e-01 + -3.24317381e+00 1.80076061e+00 + 3.20421743e-01 8.76524679e-01 + -5.29606705e-01 -3.16717696e-01 + -1.77264560e+00 7.52686776e-01 + -1.51706824e+00 8.43755103e-01 + 1.52759111e+00 -7.86814243e-01 + 4.74845617e-01 4.21319700e-01 + 6.97829149e-01 -8.15664881e-01 + 3.09564973e+00 -1.06202469e+00 + 2.95320379e+00 -1.98963943e+00 + -4.23033224e+00 1.41013338e+00 + 1.48576206e+00 8.02908511e-02 + 4.52041627e+00 -2.04620399e+00 + 6.58403922e-01 -7.60781799e-01 + 2.10667543e-01 1.15241731e-01 + 1.77702583e+00 -8.10271859e-01 + 2.41277385e+00 -1.46972042e+00 + 1.50685525e+00 -1.99272545e-01 + 7.61665522e-01 -4.11276152e-01 + 1.18352312e+00 -9.59908608e-01 + -3.32031305e-01 8.07500132e-02 + 1.16813118e+00 -1.73095194e-01 + 1.18363346e+00 -5.41565052e-01 + 5.17702179e-01 -7.62442035e-01 + 4.57401006e-01 -1.45951115e-02 + 1.49377115e-01 2.99571605e-01 + 1.40399453e+00 -1.30160353e+00 + 5.26231567e-01 3.52783752e-01 + -1.91136514e+00 4.24228635e-01 + 1.74156701e+00 -9.92076776e-01 + -4.89323391e+00 2.32483507e+00 + 2.54011209e+00 -8.80366295e-01 + -5.56925706e-01 1.48842026e-01 + -2.35904668e+00 9.60474853e-01 + 1.42216971e+00 -4.67062761e-01 + -1.10809680e+00 7.68684300e-01 + 4.09674726e+00 -1.90795680e+00 + -2.23048923e+00 9.03812542e-01 + 6.57025763e-01 1.36514871e-01 + 2.10944145e+00 -9.78897838e-02 + 1.22552525e+00 -2.50303867e-01 + 2.84620103e-01 -5.30164020e-01 + -2.13562585e+00 1.03503056e+00 + 1.32414902e-01 -8.14190240e-03 + -5.82433561e-01 3.21020292e-01 + -5.06473247e-01 3.11530419e-01 + 1.57162465e+00 -1.20763919e+00 + -1.43155284e+00 -2.51203698e-02 + -1.47093713e+00 -1.39620999e-01 + -2.65765643e+00 1.06091403e+00 + 2.45992927e+00 -5.88815836e-01 + -1.28440162e+00 -1.99377398e-01 + 6.11257504e-01 -3.73577401e-01 + -3.46606103e-01 6.06081290e-01 + 3.76687505e+00 -8.80181424e-01 + -1.03725103e+00 1.45177517e+00 + 2.76659936e+00 -1.09361320e+00 + -3.61311296e+00 9.75032455e-01 + 3.22878655e+00 -9.69497365e-01 + 1.43560379e+00 -5.52524585e-01 + 2.94042153e+00 -1.79747037e+00 + 1.30739580e+00 2.47989248e-01 + -4.05056982e-01 1.22831715e+00 + -2.25827421e+00 2.30604626e-01 + 3.69262926e-01 4.32714650e-02 + -5.52064063e-01 6.07806340e-01 + 7.03325987e+00 -2.17956730e+00 + -2.37823835e-01 -8.28068639e-01 + -4.84279888e-01 5.67765194e-01 + -3.15863410e+00 1.02241617e+00 + -3.39561593e+00 1.36876374e+00 + -2.78482934e+00 6.81641104e-01 + -4.37604334e+00 2.23826340e+00 + -2.54049692e+00 8.22676745e-01 + 3.73264822e+00 -9.93498732e-01 + -3.49536064e+00 1.84771519e+00 + 9.81801604e-01 -5.21278776e-01 + 1.52996831e+00 -1.27386206e+00 + -9.23490293e-01 5.29099482e-01 + -2.76999461e+00 9.24831872e-01 + -3.30029834e-01 -2.49645555e-01 + -1.71156166e+00 5.44940854e-01 + -2.37009487e+00 5.83826982e-01 + -3.03216865e+00 1.04922722e+00 + -2.19539936e+00 1.37558730e+00 + 1.15350207e+00 -6.15318535e-01 + 4.62011792e+00 -2.46714517e+00 + 1.52627952e-02 -1.00618283e-01 + -1.10399342e+00 4.87413533e-01 + 3.55448194e+00 -9.10394190e-01 + -5.21890321e+00 2.44710745e+00 + 1.54289749e+00 -6.54269311e-01 + 2.67935674e+00 -9.92758863e-01 + 1.05801310e+00 2.60054285e-02 + 1.52509097e+00 -4.08768600e-01 + 3.27576917e+00 -1.28769406e+00 + 1.71008412e-01 -2.68739994e-01 + -9.83351344e-04 7.02495897e-02 + -7.60795056e-03 1.61968285e-01 + -1.80620472e+00 4.24934471e-01 + 2.32023297e-02 -2.57284559e-01 + 3.98219478e-01 -4.65361935e-01 + 6.63476988e-01 -3.29823196e-02 + 4.00154707e+00 -1.01792211e+00 + -1.50286870e+00 9.46875359e-01 + -2.22717585e+00 7.50636195e-01 + -3.47381508e-01 -6.51596975e-01 + 2.08076453e+00 -8.22800165e-01 + 2.05099963e+00 -4.00868250e-01 + 3.52576988e-02 -2.54418565e-01 + 1.57342042e+00 -7.62166492e-02 + -1.47019722e+00 3.40861172e-01 + -1.21156090e+00 3.21891246e-01 + 3.79729047e+00 -1.54350764e+00 + 1.26459678e-02 6.99203693e-01 + 1.53974177e-01 4.68643204e-01 + -1.73923561e-01 -1.26229768e-01 + 4.54644993e+00 -2.13951783e+00 + 1.46022547e-01 -4.57084165e-01 + 6.50048037e+00 -2.78872609e+00 + -1.51934912e+00 1.03216768e+00 + -3.06483575e+00 1.81101446e+00 + -2.38212125e+00 9.19559042e-01 + -1.81319611e+00 8.10545112e-01 + 1.70951294e+00 -6.10712680e-01 + 1.67974156e+00 -1.51241453e+00 + -5.94795113e+00 2.56893813e+00 + 3.62633110e-01 -7.46965304e-01 + -2.44042594e+00 8.52761797e-01 + 3.32412550e+00 -1.28439899e+00 + 4.74860766e+00 -1.72821964e+00 + 1.29072541e+00 -8.24872902e-01 + -1.69450702e+00 4.09600876e-01 + 1.29705411e+00 1.22300809e-01 + -2.63597613e+00 8.55612913e-01 + 9.28467301e-01 -2.63550114e-02 + 2.44670264e+00 -4.10123002e-01 + 1.06408206e+00 -5.03361942e-01 + 5.12384049e-02 -1.27116595e-02 + -1.06731272e+00 -1.76205029e-01 + -9.45454582e-01 3.74404917e-01 + 2.54343689e+00 -7.13810545e-01 + -2.54460335e+00 1.31590265e+00 + 1.89864233e+00 -3.98436339e-01 + -1.93990133e+00 6.01474630e-01 + -1.35938824e+00 4.00751788e-01 + 2.38567018e+00 -6.13904880e-01 + 2.18748050e-01 2.62631712e-01 + -2.01388788e+00 1.41474031e+00 + 2.74014581e+00 -1.27448105e+00 + -2.13828583e+00 1.13616144e+00 + 5.98730932e+00 -2.53430080e+00 + -1.72872795e+00 1.53702057e+00 + -2.53263962e+00 1.27342410e+00 + 1.34326968e+00 -1.99395088e-01 + 3.83352666e-01 -1.25683065e-01 + -2.35630657e+00 5.54116983e-01 + -1.94900838e+00 5.76270178e-01 + -1.36699108e+00 -3.40904824e-01 + -2.34727346e+00 -1.93054940e-02 + -3.82779777e+00 1.83025664e+00 + -4.31602080e+00 9.21605705e-01 + 5.54098133e-01 2.33991419e-01 + -4.53591188e+00 1.99833353e+00 + -3.92715909e+00 1.83231482e+00 + 3.91344440e-01 -1.11355111e-01 + 3.48576363e+00 -1.41379449e+00 + -1.42858690e+00 3.84532286e-01 + 1.79519859e+00 -9.23486448e-01 + 8.49691242e-01 -1.76551084e-01 + 1.53618138e+00 8.23835015e-02 + 5.91476520e-02 3.88296940e-02 + 1.44837346e+00 -7.24097604e-01 + -6.79008418e-01 4.04078097e-01 + 2.87555510e+00 -9.51825076e-01 + -1.12379101e+00 2.93457714e-01 + 1.45263980e+00 -6.01960544e-01 + -2.55741621e-01 9.26233518e-01 + 3.54570714e+00 -1.41521877e+00 + -1.61542388e+00 6.57844512e-01 + -3.22844269e-01 3.02823546e-01 + 1.03523913e+00 -6.92730711e-01 + 1.11084909e+00 -3.50823642e-01 + 3.41268693e+00 -1.90865862e+00 + 7.67062858e-01 -9.48792160e-01 + -5.49798016e+00 1.71139960e+00 + 1.14865798e+00 -6.12669150e-01 + -2.18256680e+00 7.78634462e-01 + 4.78857389e+00 -2.55555085e+00 + -1.85555569e+00 8.04311615e-01 + -4.22278799e+00 2.01162524e+00 + -1.56556149e+00 1.54353907e+00 + -3.11527864e+00 1.65973526e+00 + 2.66342611e+00 -1.20449402e+00 + 1.57635314e+00 -1.48716308e-01 + -6.35606865e-01 2.59701180e-01 + 1.02431976e+00 -6.76929904e-01 + 1.12973772e+00 1.49473892e-02 + -9.12758116e-01 2.21533933e-01 + -2.98014470e+00 1.71651189e+00 + 2.74016965e+00 -9.47893923e-01 + -3.47830591e+00 1.34941430e+00 + 1.74757562e+00 -3.72503752e-01 + 5.55820383e-01 -6.47992466e-01 + -1.19871928e+00 9.82429151e-01 + -2.53040133e+00 2.10671307e+00 + -1.94085605e+00 1.38938137e+00 diff --git a/data/mllib/sample_fpgrowth.txt b/data/mllib/sample_fpgrowth.txt new file mode 100644 index 000000000000..c451583e5131 --- /dev/null +++ b/data/mllib/sample_fpgrowth.txt @@ -0,0 +1,6 @@ +r z h k p +z y x w v u t s +s x o n r +x z y m t s q e +z +x z y r q t p diff --git a/data/mllib/sample_isotonic_regression_data.txt b/data/mllib/sample_isotonic_regression_data.txt new file mode 100644 index 000000000000..d257b509d4d3 --- /dev/null +++ b/data/mllib/sample_isotonic_regression_data.txt @@ -0,0 +1,100 @@ +0.24579296,0.01 +0.28505864,0.02 +0.31208567,0.03 +0.35900051,0.04 +0.35747068,0.05 +0.16675166,0.06 +0.17491076,0.07 +0.04181540,0.08 +0.04793473,0.09 +0.03926568,0.10 +0.12952575,0.11 +0.00000000,0.12 +0.01376849,0.13 +0.13105558,0.14 +0.08873024,0.15 +0.12595614,0.16 +0.15247323,0.17 +0.25956145,0.18 +0.20040796,0.19 +0.19581846,0.20 +0.15757267,0.21 +0.13717491,0.22 +0.19020908,0.23 +0.19581846,0.24 +0.20091790,0.25 +0.16879143,0.26 +0.18510964,0.27 +0.20040796,0.28 +0.29576747,0.29 +0.43396226,0.30 +0.53391127,0.31 +0.52116267,0.32 +0.48546660,0.33 +0.49209587,0.34 +0.54156043,0.35 +0.59765426,0.36 +0.56144824,0.37 +0.58592555,0.38 +0.52983172,0.39 +0.50178480,0.40 +0.52626211,0.41 +0.58286588,0.42 +0.64660887,0.43 +0.68077511,0.44 +0.74298827,0.45 +0.64864865,0.46 +0.67261601,0.47 +0.65782764,0.48 +0.69811321,0.49 +0.63029067,0.50 +0.61601224,0.51 +0.63233044,0.52 +0.65323814,0.53 +0.65323814,0.54 +0.67363590,0.55 +0.67006629,0.56 +0.51555329,0.57 +0.50892402,0.58 +0.33299337,0.59 +0.36206017,0.60 +0.43090260,0.61 +0.45996940,0.62 +0.56348802,0.63 +0.54920959,0.64 +0.48393677,0.65 +0.48495665,0.66 +0.46965834,0.67 +0.45181030,0.68 +0.45843957,0.69 +0.47118817,0.70 +0.51555329,0.71 +0.58031617,0.72 +0.55481897,0.73 +0.56297807,0.74 +0.56603774,0.75 +0.57929628,0.76 +0.64762876,0.77 +0.66241713,0.78 +0.69301377,0.79 +0.65119837,0.80 +0.68332483,0.81 +0.66598674,0.82 +0.73890872,0.83 +0.73992861,0.84 +0.84242733,0.85 +0.91330954,0.86 +0.88016318,0.87 +0.90719021,0.88 +0.93115757,0.89 +0.93115757,0.90 +0.91942886,0.91 +0.92911780,0.92 +0.95665477,0.93 +0.95002550,0.94 +0.96940337,0.95 +1.00000000,0.96 +0.89801122,0.97 +0.90311066,0.98 +0.90362060,0.99 +0.83477817,1.0 \ No newline at end of file diff --git a/data/mllib/sample_lda_data.txt b/data/mllib/sample_lda_data.txt new file mode 100644 index 000000000000..2e76702ca9d6 --- /dev/null +++ b/data/mllib/sample_lda_data.txt @@ -0,0 +1,12 @@ +1 2 6 0 2 3 1 1 0 0 3 +1 3 0 1 3 0 0 2 0 0 1 +1 4 1 0 0 4 9 0 1 2 0 +2 1 0 3 0 0 5 0 2 3 9 +3 1 1 9 3 0 2 0 0 1 3 +4 2 0 3 4 5 1 1 1 4 0 +2 1 0 3 0 0 5 0 2 2 9 +1 1 1 9 2 1 2 0 0 1 3 +4 4 0 3 4 2 1 3 0 0 0 +2 8 2 0 3 0 2 0 2 7 2 +1 1 1 9 0 2 2 0 0 3 3 +4 1 0 0 4 5 1 3 0 1 0 diff --git a/dev/.gitignore b/dev/.gitignore new file mode 100644 index 000000000000..4a6027429e0d --- /dev/null +++ b/dev/.gitignore @@ -0,0 +1 @@ +pep8*.py diff --git a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala index d888de929fdd..cc86ef45858c 100644 --- a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala +++ b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala @@ -36,8 +36,10 @@ object SparkSqlExample { val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ import sqlContext._ - val people = sc.makeRDD(1 to 100, 10).map(x => Person(s"Name$x", x)) + + val people = sc.makeRDD(1 to 100, 10).map(x => Person(s"Name$x", x)).toDF() people.registerTempTable("people") val teenagers = sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") val teenagerNames = teenagers.map(t => "Name: " + t(0)).collect() diff --git a/dev/change-version-to-2.10.sh b/dev/change-version-to-2.10.sh index 7473c20d28e0..c4adb1f96b7d 100755 --- a/dev/change-version-to-2.10.sh +++ b/dev/change-version-to-2.10.sh @@ -16,5 +16,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # -find . -name 'pom.xml' | grep -v target \ - | xargs -I {} sed -i -e 's|\(artifactId.*\)_2.11|\1_2.10|g' {} + +# Note that this will not necessarily work as intended with non-GNU sed (e.g. OS X) +BASEDIR=$(dirname $0)/.. +find $BASEDIR -name 'pom.xml' | grep -v target \ + | xargs -I {} sed -i -e 's/\(artifactId.*\)_2.11/\1_2.10/g' {} + +# Also update in parent POM +sed -i -e '0,/2.112.10 in parent POM +sed -i -e '0,/2.102.11/dev/null; then - curl --silent "${URL}" > "$JAR_DL" && mv "$JAR_DL" "$JAR" - elif hash wget 2>/dev/null; then + if [ $(command -v curl) ]; then + curl -L --silent "${URL}" > "$JAR_DL" && mv "$JAR_DL" "$JAR" + elif [ $(command -v wget) ]; then wget --quiet ${URL} -O "$JAR_DL" && mv "$JAR_DL" "$JAR" else printf "You do not have curl or wget installed, please install rat manually.\n" exit -1 fi - fi + fi - unzip -tq $JAR &> /dev/null - if [ $? -ne 0 ]; then - # We failed to download - printf "Our attempt to download rat locally to ${JAR} failed. Please install rat manually.\n" - exit -1 - fi - printf "Launching rat from ${JAR}\n" + unzip -tq "$JAR" &> /dev/null + if [ $? -ne 0 ]; then + # We failed to download + rm "$JAR" + printf "Our attempt to download rat locally to ${JAR} failed. Please install rat manually.\n" + exit -1 fi } @@ -71,6 +69,11 @@ mkdir -p "$FWDIR"/lib $java_cmd -jar "$rat_jar" -E "$FWDIR"/.rat-excludes -d "$FWDIR" > rat-results.txt +if [ $? -ne 0 ]; then + echo "RAT exited abnormally" + exit 1 +fi + ERRORS="$(cat rat-results.txt | grep -e "??")" if test ! -z "$ERRORS"; then diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index b1b8cb44e098..b5a67dd783b9 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -22,8 +22,9 @@ # Expects to be run in a totally empty directory. # # Options: -# --package-only only packages an existing release candidate -# +# --skip-create-release Assume the desired release tag already exists +# --skip-publish Do not publish to Maven central +# --skip-package Do not package and upload binary artifacts # Would be nice to add: # - Send output to stderr and have useful logging in stdout @@ -33,6 +34,9 @@ ASF_PASSWORD=${ASF_PASSWORD:-XXX} GPG_PASSPHRASE=${GPG_PASSPHRASE:-XXX} GIT_BRANCH=${GIT_BRANCH:-branch-1.0} RELEASE_VERSION=${RELEASE_VERSION:-1.2.0} +# Allows publishing under a different version identifier than +# was present in the actual release sources (e.g. rc-X) +PUBLISH_VERSION=${PUBLISH_VERSION:-$RELEASE_VERSION} NEXT_VERSION=${NEXT_VERSION:-1.2.1} RC_NAME=${RC_NAME:-rc2} @@ -51,7 +55,7 @@ set -e GIT_TAG=v$RELEASE_VERSION-$RC_NAME -if [[ ! "$@" =~ --package-only ]]; then +if [[ ! "$@" =~ --skip-create-release ]]; then echo "Creating release commit and publishing to Apache repository" # Artifact publishing git clone https://$ASF_USERNAME:$ASF_PASSWORD@git-wip-us.apache.org/repos/asf/spark.git \ @@ -87,12 +91,25 @@ if [[ ! "$@" =~ --package-only ]]; then git commit -a -m "Preparing development version $next_ver" git push origin $GIT_TAG git push origin HEAD:$GIT_BRANCH - git checkout -f $GIT_TAG + popd + rm -rf spark +fi + +if [[ ! "$@" =~ --skip-publish ]]; then + git clone https://$ASF_USERNAME:$ASF_PASSWORD@git-wip-us.apache.org/repos/asf/spark.git + pushd spark + git checkout --force $GIT_TAG + + # Substitute in case published version is different than released + old="^\( \{2,4\}\)${RELEASE_VERSION}<\/version>$" + new="\1${PUBLISH_VERSION}<\/version>" + find . -name pom.xml | grep -v dev | xargs -I {} sed -i \ + -e "s/${old}/${new}/" {} # Using Nexus API documented here: # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API echo "Creating Nexus staging repository" - repo_request="Apache Spark $GIT_TAG" + repo_request="Apache Spark $GIT_TAG (published as $PUBLISH_VERSION)" out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ -H "Content-Type:application/xml" -v \ $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) @@ -101,13 +118,13 @@ if [[ ! "$@" =~ --package-only ]]; then rm -rf $SPARK_REPO - mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ + build/mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ clean install ./dev/change-version-to-2.11.sh - - mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ + + build/mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ -Dscala-2.11 -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ clean install @@ -122,8 +139,14 @@ if [[ ! "$@" =~ --package-only ]]; then for file in $(find . -type f) do echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --output $file.asc --detach-sig --armour $file; - gpg --print-md MD5 $file > $file.md5; - gpg --print-md SHA1 $file > $file.sha1 + if [ $(command -v md5) ]; then + # Available on OS X; -q to keep only hash + md5 -q $file > $file.md5 + else + # Available on Linux; cut to keep only hash + md5sum $file | cut -f1 -d' ' > $file.md5 + fi + shasum -a 1 $file | cut -f1 -d' ' > $file.sha1 done nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id @@ -138,7 +161,7 @@ if [[ ! "$@" =~ --package-only ]]; then done echo "Closing nexus staging repository" - repo_request="$staged_repo_idApache Spark $GIT_TAG" + repo_request="$staged_repo_idApache Spark $GIT_TAG (published as $PUBLISH_VERSION)" out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ -H "Content-Type:application/xml" -v \ $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) @@ -149,88 +172,96 @@ if [[ ! "$@" =~ --package-only ]]; then rm -rf spark fi -# Source and binary tarballs -echo "Packaging release tarballs" -git clone https://git-wip-us.apache.org/repos/asf/spark.git -cd spark -git checkout --force $GIT_TAG -release_hash=`git rev-parse HEAD` - -rm .gitignore -rm -rf .git -cd .. - -cp -r spark spark-$RELEASE_VERSION -tar cvzf spark-$RELEASE_VERSION.tgz spark-$RELEASE_VERSION -echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --armour --output spark-$RELEASE_VERSION.tgz.asc \ - --detach-sig spark-$RELEASE_VERSION.tgz -echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md MD5 spark-$RELEASE_VERSION.tgz > \ - spark-$RELEASE_VERSION.tgz.md5 -echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md SHA512 spark-$RELEASE_VERSION.tgz > \ - spark-$RELEASE_VERSION.tgz.sha -rm -rf spark-$RELEASE_VERSION - -make_binary_release() { - NAME=$1 - FLAGS=$2 - cp -r spark spark-$RELEASE_VERSION-bin-$NAME - - cd spark-$RELEASE_VERSION-bin-$NAME - - # TODO There should probably be a flag to make-distribution to allow 2.11 support - if [[ $FLAGS == *scala-2.11* ]]; then - ./dev/change-version-to-2.11.sh - fi - - ./make-distribution.sh --name $NAME --tgz $FLAGS 2>&1 | tee ../binary-release-$NAME.log +if [[ ! "$@" =~ --skip-package ]]; then + # Source and binary tarballs + echo "Packaging release tarballs" + git clone https://git-wip-us.apache.org/repos/asf/spark.git + cd spark + git checkout --force $GIT_TAG + release_hash=`git rev-parse HEAD` + + rm .gitignore + rm -rf .git cd .. - cp spark-$RELEASE_VERSION-bin-$NAME/spark-$RELEASE_VERSION-bin-$NAME.tgz . - rm -rf spark-$RELEASE_VERSION-bin-$NAME - - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --armour \ - --output spark-$RELEASE_VERSION-bin-$NAME.tgz.asc \ - --detach-sig spark-$RELEASE_VERSION-bin-$NAME.tgz - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md \ - MD5 spark-$RELEASE_VERSION-bin-$NAME.tgz > \ - spark-$RELEASE_VERSION-bin-$NAME.tgz.md5 - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md \ - SHA512 spark-$RELEASE_VERSION-bin-$NAME.tgz > \ - spark-$RELEASE_VERSION-bin-$NAME.tgz.sha -} - - -make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" & -make_binary_release "hadoop1-scala2.11" "-Phive -Dscala-2.11" & -make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" & -make_binary_release "hadoop2.3" "-Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" & -make_binary_release "hadoop2.4" "-Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" & -make_binary_release "mapr3" "-Pmapr3 -Phive -Phive-thriftserver" & -make_binary_release "mapr4" "-Pmapr4 -Pyarn -Phive -Phive-thriftserver" & -make_binary_release "hadoop2.4-without-hive" "-Phadoop-2.4 -Pyarn" & -wait - -# Copy data -echo "Copying release tarballs" -rc_folder=spark-$RELEASE_VERSION-$RC_NAME -ssh $ASF_USERNAME@people.apache.org \ - mkdir /home/$ASF_USERNAME/public_html/$rc_folder -scp spark-* \ - $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_folder/ - -# Docs -cd spark -build/sbt clean -cd docs -# Compile docs with Java 7 to use nicer format -JAVA_HOME=$JAVA_7_HOME PRODUCTION=1 jekyll build -echo "Copying release documentation" -rc_docs_folder=${rc_folder}-docs -ssh $ASF_USERNAME@people.apache.org \ - mkdir /home/$ASF_USERNAME/public_html/$rc_docs_folder -rsync -r _site/* $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_docs_folder - -echo "Release $RELEASE_VERSION completed:" -echo "Git tag:\t $GIT_TAG" -echo "Release commit:\t $release_hash" -echo "Binary location:\t http://people.apache.org/~$ASF_USERNAME/$rc_folder" -echo "Doc location:\t http://people.apache.org/~$ASF_USERNAME/$rc_docs_folder" + + cp -r spark spark-$RELEASE_VERSION + tar cvzf spark-$RELEASE_VERSION.tgz spark-$RELEASE_VERSION + echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --armour --output spark-$RELEASE_VERSION.tgz.asc \ + --detach-sig spark-$RELEASE_VERSION.tgz + echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md MD5 spark-$RELEASE_VERSION.tgz > \ + spark-$RELEASE_VERSION.tgz.md5 + echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md SHA512 spark-$RELEASE_VERSION.tgz > \ + spark-$RELEASE_VERSION.tgz.sha + rm -rf spark-$RELEASE_VERSION + + # Updated for each binary build + make_binary_release() { + NAME=$1 + FLAGS=$2 + ZINC_PORT=$3 + cp -r spark spark-$RELEASE_VERSION-bin-$NAME + + cd spark-$RELEASE_VERSION-bin-$NAME + + # TODO There should probably be a flag to make-distribution to allow 2.11 support + if [[ $FLAGS == *scala-2.11* ]]; then + ./dev/change-version-to-2.11.sh + fi + + export ZINC_PORT=$ZINC_PORT + echo "Creating distribution: $NAME ($FLAGS)" + ./make-distribution.sh --name $NAME --tgz $FLAGS -DzincPort=$ZINC_PORT 2>&1 > \ + ../binary-release-$NAME.log + cd .. + cp spark-$RELEASE_VERSION-bin-$NAME/spark-$RELEASE_VERSION-bin-$NAME.tgz . + + echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --armour \ + --output spark-$RELEASE_VERSION-bin-$NAME.tgz.asc \ + --detach-sig spark-$RELEASE_VERSION-bin-$NAME.tgz + echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md \ + MD5 spark-$RELEASE_VERSION-bin-$NAME.tgz > \ + spark-$RELEASE_VERSION-bin-$NAME.tgz.md5 + echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md \ + SHA512 spark-$RELEASE_VERSION-bin-$NAME.tgz > \ + spark-$RELEASE_VERSION-bin-$NAME.tgz.sha + } + + # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds + # share the same Zinc server. + make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" "3030" & + make_binary_release "hadoop1-scala2.11" "-Phive -Dscala-2.11" "3031" & + make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & + make_binary_release "hadoop2.3" "-Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & + make_binary_release "hadoop2.4" "-Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & + make_binary_release "mapr3" "-Pmapr3 -Phive -Phive-thriftserver" "3035" & + make_binary_release "mapr4" "-Pmapr4 -Pyarn -Phive -Phive-thriftserver" "3036" & + make_binary_release "hadoop2.4-without-hive" "-Phadoop-2.4 -Pyarn" "3037" & + wait + rm -rf spark-$RELEASE_VERSION-bin-*/ + + # Copy data + echo "Copying release tarballs" + rc_folder=spark-$RELEASE_VERSION-$RC_NAME + ssh $ASF_USERNAME@people.apache.org \ + mkdir /home/$ASF_USERNAME/public_html/$rc_folder + scp spark-* \ + $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_folder/ + + # Docs + cd spark + sbt/sbt clean + cd docs + # Compile docs with Java 7 to use nicer format + JAVA_HOME="$JAVA_7_HOME" PRODUCTION=1 RELEASE_VERSION="$RELEASE_VERSION" jekyll build + echo "Copying release documentation" + rc_docs_folder=${rc_folder}-docs + ssh $ASF_USERNAME@people.apache.org \ + mkdir /home/$ASF_USERNAME/public_html/$rc_docs_folder + rsync -r _site/* $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_docs_folder + + echo "Release $RELEASE_VERSION completed:" + echo "Git tag:\t $GIT_TAG" + echo "Release commit:\t $release_hash" + echo "Binary location:\t http://people.apache.org/~$ASF_USERNAME/$rc_folder" + echo "Doc location:\t http://people.apache.org/~$ASF_USERNAME/$rc_docs_folder" +fi diff --git a/dev/create-release/known_translations b/dev/create-release/known_translations index b74e4ee8a330..0a599b5a6554 100644 --- a/dev/create-release/known_translations +++ b/dev/create-release/known_translations @@ -57,3 +57,37 @@ watermen - Yadong Qi witgo - Guoqiang Li xinyunh - Xinyun Huang zsxwing - Shixiong Zhu +Bilna - Bilna P +DoingDone9 - Doing Done +Earne - Ernest +FlytxtRnD - Meethu Mathew +GenTang - Gen TANG +JoshRosen - Josh Rosen +MechCoder - Manoj Kumar +OopsOutOfMemory - Sheng Li +Peishen-Jia - Peishen Jia +SaintBacchus - Huang Zhaowei +azagrebin - Andrey Zagrebin +bzz - Alexander Bezzubov +fjiang6 - Fan Jiang +gasparms - Gaspar Munoz +guowei2 - Guo Wei +hhbyyh - Yuhao Yang +hseagle - Peng Xu +javadba - Stephen Boesch +jbencook - Ben Cook +kul - Kuldeep +ligangty - Gang Li +marsishandsome - Liangliang Gu +medale - Markus Dale +nemccarthy - Nathan McCarthy +nxwhite-str - Nate Crosswhite +seayi - Xiaohua Yi +tianyi - Yi Tian +uncleGen - Uncle Gen +viper-kun - Xu Kun +x1- - Yuri Saito +zapletal-martin - Martin Zapletal +zuxqoj - Shekhar Bansal +mingyukim - Mingyu Kim +sigmoidanalytics - Mayur Rustagi diff --git a/dev/lint-python b/dev/lint-python index 772f856154ae..f50d149dc4d4 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -19,43 +19,54 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" -PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" +PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/" +PYTHON_LINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/python-lint-report.txt" cd "$SPARK_ROOT_DIR" +# compileall: https://docs.python.org/2/library/compileall.html +python -B -m compileall -q -l $PATHS_TO_CHECK > "$PYTHON_LINT_REPORT_PATH" +compile_status="${PIPESTATUS[0]}" + # Get pep8 at runtime so that we don't rely on it being installed on the build server. #+ See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 #+ TODOs: -#+ - Dynamically determine latest release version of pep8 and use that. -#+ - Download this from a more reliable source. (GitHub raw can be flaky, apparently. (?)) -PEP8_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pep8.py" -PEP8_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/jcrocholl/pep8/1.5.7/pep8.py" -PEP8_PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/" - -curl --silent -o "$PEP8_SCRIPT_PATH" "$PEP8_SCRIPT_REMOTE_PATH" -curl_status=$? - -if [ $curl_status -ne 0 ]; then - echo "Failed to download pep8.py from \"$PEP8_SCRIPT_REMOTE_PATH\"." - exit $curl_status -fi +#+ - Download pep8 from PyPI. It's more "official". +PEP8_VERSION="1.6.2" +PEP8_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pep8-$PEP8_VERSION.py" +PEP8_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/jcrocholl/pep8/$PEP8_VERSION/pep8.py" +if [ ! -e "$PEP8_SCRIPT_PATH" ]; then + curl --silent -o "$PEP8_SCRIPT_PATH" "$PEP8_SCRIPT_REMOTE_PATH" + curl_status="$?" + + if [ "$curl_status" -ne 0 ]; then + echo "Failed to download pep8.py from \"$PEP8_SCRIPT_REMOTE_PATH\"." + exit "$curl_status" + fi +fi # There is no need to write this output to a file #+ first, but we do so so that the check status can #+ be output before the report, like with the #+ scalastyle and RAT checks. -python "$PEP8_SCRIPT_PATH" $PEP8_PATHS_TO_CHECK > "$PEP8_REPORT_PATH" -pep8_status=${PIPESTATUS[0]} #$? +python "$PEP8_SCRIPT_PATH" --ignore=E402,E731,E241,W503,E226 $PATHS_TO_CHECK >> "$PYTHON_LINT_REPORT_PATH" +pep8_status="${PIPESTATUS[0]}" + +if [ "$compile_status" -eq 0 -a "$pep8_status" -eq 0 ]; then + lint_status=0 +else + lint_status=1 +fi -if [ $pep8_status -ne 0 ]; then - echo "PEP 8 checks failed." - cat "$PEP8_REPORT_PATH" +if [ "$lint_status" -ne 0 ]; then + echo "Python lint checks failed." + cat "$PYTHON_LINT_REPORT_PATH" else - echo "PEP 8 checks passed." + echo "Python lint checks passed." fi -rm "$PEP8_REPORT_PATH" -rm "$PEP8_SCRIPT_PATH" +# rm "$PEP8_SCRIPT_PATH" +rm "$PYTHON_LINT_REPORT_PATH" -exit $pep8_status +exit "$lint_status" diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index dfa924d2aa0b..b69cd15f99f6 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -55,8 +55,6 @@ # Prefix added to temporary branches BRANCH_PREFIX = "PR_TOOL" -os.chdir(SPARK_HOME) - def get_json(url): try: @@ -85,10 +83,6 @@ def continue_maybe(prompt): if result.lower() != "y": fail("Okay, exiting") - -original_head = run_cmd("git rev-parse HEAD")[:8] - - def clean_up(): print "Restoring head pointer to %s" % original_head run_cmd("git checkout %s" % original_head) @@ -101,7 +95,7 @@ def clean_up(): # merge the requested PR and return the merge hash -def merge_pr(pr_num, target_ref): +def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): pr_branch_name = "%s_MERGE_PR_%s" % (BRANCH_PREFIX, pr_num) target_branch_name = "%s_MERGE_PR_%s_%s" % (BRANCH_PREFIX, pr_num, target_ref.upper()) run_cmd("git fetch %s pull/%s/head:%s" % (PR_REMOTE_NAME, pr_num, pr_branch_name)) @@ -244,6 +238,8 @@ def resolve_jira_issue(merge_branches, comment, default_jira_id=""): versions = asf_jira.project_versions("SPARK") versions = sorted(versions, key=lambda x: x.name, reverse=True) versions = filter(lambda x: x.raw['released'] is False, versions) + # Consider only x.y.z versions + versions = filter(lambda x: re.match('\d+\.\d+\.\d+', x.name), versions) default_fix_versions = map(lambda x: fix_version_from_branch(x, versions).name, merge_branches) for v in default_fix_versions: @@ -272,7 +268,7 @@ def get_version_json(version_str): asf_jira.transition_issue( jira_id, resolve["id"], fixVersions=jira_fix_versions, comment=comment) - print "Succesfully resolved %s with fixVersions=%s!" % (jira_id, fix_versions) + print "Successfully resolved %s with fixVersions=%s!" % (jira_id, fix_versions) def resolve_jira_issues(title, merge_branches, comment): @@ -284,68 +280,155 @@ def resolve_jira_issues(title, merge_branches, comment): resolve_jira_issue(merge_branches, comment, jira_id) -branches = get_json("%s/branches" % GITHUB_API_BASE) -branch_names = filter(lambda x: x.startswith("branch-"), [x['name'] for x in branches]) -# Assumes branch names can be sorted lexicographically -latest_branch = sorted(branch_names, reverse=True)[0] - -pr_num = raw_input("Which pull request would you like to merge? (e.g. 34): ") -pr = get_json("%s/pulls/%s" % (GITHUB_API_BASE, pr_num)) -pr_events = get_json("%s/issues/%s/events" % (GITHUB_API_BASE, pr_num)) +def standardize_jira_ref(text): + """ + Standardize the [SPARK-XXXXX] [MODULE] prefix + Converts "[SPARK-XXX][mllib] Issue", "[MLLib] SPARK-XXX. Issue" or "SPARK XXX [MLLIB]: Issue" to "[SPARK-XXX] [MLLIB] Issue" + + >>> standardize_jira_ref("[SPARK-5821] [SQL] ParquetRelation2 CTAS should check if delete is successful") + '[SPARK-5821] [SQL] ParquetRelation2 CTAS should check if delete is successful' + >>> standardize_jira_ref("[SPARK-4123][Project Infra][WIP]: Show new dependencies added in pull requests") + '[SPARK-4123] [PROJECT INFRA] [WIP] Show new dependencies added in pull requests' + >>> standardize_jira_ref("[MLlib] Spark 5954: Top by key") + '[SPARK-5954] [MLLIB] Top by key' + >>> standardize_jira_ref("[SPARK-979] a LRU scheduler for load balancing in TaskSchedulerImpl") + '[SPARK-979] a LRU scheduler for load balancing in TaskSchedulerImpl' + >>> standardize_jira_ref("SPARK-1094 Support MiMa for reporting binary compatibility accross versions.") + '[SPARK-1094] Support MiMa for reporting binary compatibility accross versions.' + >>> standardize_jira_ref("[WIP] [SPARK-1146] Vagrant support for Spark") + '[SPARK-1146] [WIP] Vagrant support for Spark' + >>> standardize_jira_ref("SPARK-1032. If Yarn app fails before registering, app master stays aroun...") + '[SPARK-1032] If Yarn app fails before registering, app master stays aroun...' + >>> standardize_jira_ref("[SPARK-6250][SPARK-6146][SPARK-5911][SQL] Types are now reserved words in DDL parser.") + '[SPARK-6250] [SPARK-6146] [SPARK-5911] [SQL] Types are now reserved words in DDL parser.' + >>> standardize_jira_ref("Additional information for users building from source code") + 'Additional information for users building from source code' + """ + jira_refs = [] + components = [] + + # If the string is compliant, no need to process any further + if (re.search(r'^\[SPARK-[0-9]{3,6}\] (\[[A-Z0-9_\s,]+\] )+\S+', text)): + return text + + # Extract JIRA ref(s): + pattern = re.compile(r'(SPARK[-\s]*[0-9]{3,6})+', re.IGNORECASE) + for ref in pattern.findall(text): + # Add brackets, replace spaces with a dash, & convert to uppercase + jira_refs.append('[' + re.sub(r'\s+', '-', ref.upper()) + ']') + text = text.replace(ref, '') + + # Extract spark component(s): + # Look for alphanumeric chars, spaces, dashes, periods, and/or commas + pattern = re.compile(r'(\[[\w\s,-\.]+\])', re.IGNORECASE) + for component in pattern.findall(text): + components.append(component.upper()) + text = text.replace(component, '') + + # Cleanup any remaining symbols: + pattern = re.compile(r'^\W+(.*)', re.IGNORECASE) + if (pattern.search(text) is not None): + text = pattern.search(text).groups()[0] + + # Assemble full text (JIRA ref(s), module(s), remaining text) + clean_text = ' '.join(jira_refs).strip() + " " + ' '.join(components).strip() + " " + text.strip() + + # Replace multiple spaces with a single space, e.g. if no jira refs and/or components were included + clean_text = re.sub(r'\s+', ' ', clean_text.strip()) + + return clean_text + +def main(): + global original_head + + os.chdir(SPARK_HOME) + original_head = run_cmd("git rev-parse HEAD")[:8] + + branches = get_json("%s/branches" % GITHUB_API_BASE) + branch_names = filter(lambda x: x.startswith("branch-"), [x['name'] for x in branches]) + # Assumes branch names can be sorted lexicographically + latest_branch = sorted(branch_names, reverse=True)[0] + + pr_num = raw_input("Which pull request would you like to merge? (e.g. 34): ") + pr = get_json("%s/pulls/%s" % (GITHUB_API_BASE, pr_num)) + pr_events = get_json("%s/issues/%s/events" % (GITHUB_API_BASE, pr_num)) + + url = pr["url"] + + # Decide whether to use the modified title or not + modified_title = standardize_jira_ref(pr["title"]) + if modified_title != pr["title"]: + print "I've re-written the title as follows to match the standard format:" + print "Original: %s" % pr["title"] + print "Modified: %s" % modified_title + result = raw_input("Would you like to use the modified title? (y/n): ") + if result.lower() == "y": + title = modified_title + print "Using modified title:" + else: + title = pr["title"] + print "Using original title:" + print title + else: + title = pr["title"] -url = pr["url"] -title = pr["title"] -body = pr["body"] -target_ref = pr["base"]["ref"] -user_login = pr["user"]["login"] -base_ref = pr["head"]["ref"] -pr_repo_desc = "%s/%s" % (user_login, base_ref) + body = pr["body"] + target_ref = pr["base"]["ref"] + user_login = pr["user"]["login"] + base_ref = pr["head"]["ref"] + pr_repo_desc = "%s/%s" % (user_login, base_ref) -# Merged pull requests don't appear as merged in the GitHub API; -# Instead, they're closed by asfgit. -merge_commits = \ - [e for e in pr_events if e["actor"]["login"] == "asfgit" and e["event"] == "closed"] + # Merged pull requests don't appear as merged in the GitHub API; + # Instead, they're closed by asfgit. + merge_commits = \ + [e for e in pr_events if e["actor"]["login"] == "asfgit" and e["event"] == "closed"] -if merge_commits: - merge_hash = merge_commits[0]["commit_id"] - message = get_json("%s/commits/%s" % (GITHUB_API_BASE, merge_hash))["commit"]["message"] + if merge_commits: + merge_hash = merge_commits[0]["commit_id"] + message = get_json("%s/commits/%s" % (GITHUB_API_BASE, merge_hash))["commit"]["message"] - print "Pull request %s has already been merged, assuming you want to backport" % pr_num - commit_is_downloaded = run_cmd(['git', 'rev-parse', '--quiet', '--verify', + print "Pull request %s has already been merged, assuming you want to backport" % pr_num + commit_is_downloaded = run_cmd(['git', 'rev-parse', '--quiet', '--verify', "%s^{commit}" % merge_hash]).strip() != "" - if not commit_is_downloaded: - fail("Couldn't find any merge commit for #%s, you may need to update HEAD." % pr_num) + if not commit_is_downloaded: + fail("Couldn't find any merge commit for #%s, you may need to update HEAD." % pr_num) - print "Found commit %s:\n%s" % (merge_hash, message) - cherry_pick(pr_num, merge_hash, latest_branch) - sys.exit(0) + print "Found commit %s:\n%s" % (merge_hash, message) + cherry_pick(pr_num, merge_hash, latest_branch) + sys.exit(0) -if not bool(pr["mergeable"]): - msg = "Pull request %s is not mergeable in its current form.\n" % pr_num + \ - "Continue? (experts only!)" - continue_maybe(msg) + if not bool(pr["mergeable"]): + msg = "Pull request %s is not mergeable in its current form.\n" % pr_num + \ + "Continue? (experts only!)" + continue_maybe(msg) -print ("\n=== Pull Request #%s ===" % pr_num) -print ("title\t%s\nsource\t%s\ntarget\t%s\nurl\t%s" % ( - title, pr_repo_desc, target_ref, url)) -continue_maybe("Proceed with merging pull request #%s?" % pr_num) + print ("\n=== Pull Request #%s ===" % pr_num) + print ("title\t%s\nsource\t%s\ntarget\t%s\nurl\t%s" % ( + title, pr_repo_desc, target_ref, url)) + continue_maybe("Proceed with merging pull request #%s?" % pr_num) -merged_refs = [target_ref] + merged_refs = [target_ref] -merge_hash = merge_pr(pr_num, target_ref) + merge_hash = merge_pr(pr_num, target_ref, title, body, pr_repo_desc) -pick_prompt = "Would you like to pick %s into another branch?" % merge_hash -while raw_input("\n%s (y/n): " % pick_prompt).lower() == "y": - merged_refs = merged_refs + [cherry_pick(pr_num, merge_hash, latest_branch)] + pick_prompt = "Would you like to pick %s into another branch?" % merge_hash + while raw_input("\n%s (y/n): " % pick_prompt).lower() == "y": + merged_refs = merged_refs + [cherry_pick(pr_num, merge_hash, latest_branch)] -if JIRA_IMPORTED: - if JIRA_USERNAME and JIRA_PASSWORD: - continue_maybe("Would you like to update an associated JIRA?") - jira_comment = "Issue resolved by pull request %s\n[%s/%s]" % (pr_num, GITHUB_BASE, pr_num) - resolve_jira_issues(title, merged_refs, jira_comment) + if JIRA_IMPORTED: + if JIRA_USERNAME and JIRA_PASSWORD: + continue_maybe("Would you like to update an associated JIRA?") + jira_comment = "Issue resolved by pull request %s\n[%s/%s]" % (pr_num, GITHUB_BASE, pr_num) + resolve_jira_issues(title, merged_refs, jira_comment) + else: + print "JIRA_USERNAME and JIRA_PASSWORD not set" + print "Exiting without trying to close the associated JIRA." else: - print "JIRA_USERNAME and JIRA_PASSWORD not set" + print "Could not find jira-python library. Run 'sudo pip install jira-python' to install." print "Exiting without trying to close the associated JIRA." -else: - print "Could not find jira-python library. Run 'sudo pip install jira-python' to install." - print "Exiting without trying to close the associated JIRA." + +if __name__ == "__main__": + import doctest + doctest.testmod() + + main() diff --git a/dev/mima b/dev/mima index bed5cd042634..2952fa65d42f 100755 --- a/dev/mima +++ b/dev/mima @@ -27,16 +27,21 @@ cd "$FWDIR" echo -e "q\n" | build/sbt oldDeps/update rm -f .generated-mima* +generate_mima_ignore() { + SPARK_JAVA_OPTS="-XX:MaxPermSize=1g -Xmx2g" \ + ./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore +} + # Generate Mima Ignore is called twice, first with latest built jars # on the classpath and then again with previous version jars on the classpath. # Because of a bug in GenerateMIMAIgnore that when old jars are ahead on classpath # it did not process the new classes (which are in assembly jar). -./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore +generate_mima_ignore export SPARK_CLASSPATH="`find lib_managed \( -name '*spark*jar' -a -type f \) | tr "\\n" ":"`" echo "SPARK_CLASSPATH=$SPARK_CLASSPATH" -./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore +generate_mima_ignore echo -e "q\n" | build/sbt mima-report-binary-issues | grep -v -e "info.*Resolving" ret_val=$? diff --git a/dev/run-tests b/dev/run-tests index 2257a566bb1b..861d1671182c 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -36,7 +36,7 @@ function handle_error () { } -# Build against the right verison of Hadoop. +# Build against the right version of Hadoop. { if [ -n "$AMPLAB_JENKINS_BUILD_PROFILE" ]; then if [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop1.0" ]; then @@ -77,7 +77,7 @@ export SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Pkinesis-asl" fi } -# Only run Hive tests if there are sql changes. +# Only run Hive tests if there are SQL changes. # Partial solution for SPARK-1455. if [ -n "$AMPLAB_JENKINS" ]; then git fetch origin master:master @@ -141,31 +141,52 @@ echo "=========================================================================" CURRENT_BLOCK=$BLOCK_BUILD { + HIVE_BUILD_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver" + HIVE_12_BUILD_ARGS="$HIVE_BUILD_ARGS -Phive-0.12.0" - # NOTE: echo "q" is needed because sbt on encountering a build file with failure - # (either resolution or compilation) prompts the user for input either q, r, etc - # to quit or retry. This echo is there to make it not block. - # NOTE: Do not quote $BUILD_MVN_PROFILE_ARGS or else it will be interpreted as a - # single argument! - # QUESTION: Why doesn't 'yes "q"' work? - # QUESTION: Why doesn't 'grep -v -e "^\[info\] Resolving"' work? # First build with Hive 0.12.0 to ensure patches do not break the Hive 0.12.0 build - HIVE_12_BUILD_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver -Phive-0.12.0" echo "[info] Compile with Hive 0.12.0" - echo -e "q\n" \ - | build/sbt $HIVE_12_BUILD_ARGS clean hive/compile hive-thriftserver/compile \ - | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" + [ -d "lib_managed" ] && rm -rf lib_managed + echo "[info] Building Spark with these arguments: $HIVE_12_BUILD_ARGS" + + if [ "${AMPLAB_JENKINS_BUILD_TOOL}" == "maven" ]; then + build/mvn $HIVE_12_BUILD_ARGS clean package -DskipTests + else + # NOTE: echo "q" is needed because sbt on encountering a build file with failure + # (either resolution or compilation) prompts the user for input either q, r, etc + # to quit or retry. This echo is there to make it not block. + # NOTE: Do not quote $BUILD_MVN_PROFILE_ARGS or else it will be interpreted as a + # single argument! + # QUESTION: Why doesn't 'yes "q"' work? + # QUESTION: Why doesn't 'grep -v -e "^\[info\] Resolving"' work? + echo -e "q\n" \ + | build/sbt $HIVE_12_BUILD_ARGS clean hive/compile hive-thriftserver/compile \ + | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" + fi # Then build with default Hive version (0.13.1) because tests are based on this version echo "[info] Compile with Hive 0.13.1" - rm -rf lib_managed - echo "[info] Building Spark with these arguments: $SBT_MAVEN_PROFILES_ARGS"\ - " -Phive -Phive-thriftserver" - echo -e "q\n" \ - | build/sbt $SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver package assembly/assembly \ - | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" + [ -d "lib_managed" ] && rm -rf lib_managed + echo "[info] Building Spark with these arguments: $HIVE_BUILD_ARGS" + + if [ "${AMPLAB_JENKINS_BUILD_TOOL}" == "maven" ]; then + build/mvn $HIVE_BUILD_ARGS clean package -DskipTests + else + echo -e "q\n" \ + | build/sbt $HIVE_BUILD_ARGS package assembly/assembly streaming-kafka-assembly/assembly \ + | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" + fi } +echo "" +echo "=========================================================================" +echo "Detecting binary incompatibilities with MiMa" +echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_MIMA + +./dev/mima + echo "" echo "=========================================================================" echo "Running Spark unit tests" @@ -183,24 +204,28 @@ CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS if [ -n "$_SQL_TESTS_ONLY" ]; then # This must be an array of individual arguments. Otherwise, having one long string # will be interpreted as a single test, which doesn't work. - SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "mllib/test") + SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "hive-thriftserver/test" "mllib/test") else SBT_MAVEN_TEST_ARGS=("test") fi echo "[info] Running Spark tests with these arguments: $SBT_MAVEN_PROFILES_ARGS ${SBT_MAVEN_TEST_ARGS[@]}" - # NOTE: echo "q" is needed because sbt on encountering a build file with failure - # (either resolution or compilation) prompts the user for input either q, r, etc - # to quit or retry. This echo is there to make it not block. - # NOTE: Do not quote $SBT_MAVEN_PROFILES_ARGS or else it will be interpreted as a - # single argument! - # "${SBT_MAVEN_TEST_ARGS[@]}" is cool because it's an array. - # QUESTION: Why doesn't 'yes "q"' work? - # QUESTION: Why doesn't 'grep -v -e "^\[info\] Resolving"' work? - echo -e "q\n" \ - | build/sbt $SBT_MAVEN_PROFILES_ARGS "${SBT_MAVEN_TEST_ARGS[@]}" \ - | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" + if [ "${AMPLAB_JENKINS_BUILD_TOOL}" == "maven" ]; then + build/mvn test $SBT_MAVEN_PROFILES_ARGS --fail-at-end + else + # NOTE: echo "q" is needed because sbt on encountering a build file with failure + # (either resolution or compilation) prompts the user for input either q, r, etc + # to quit or retry. This echo is there to make it not block. + # NOTE: Do not quote $SBT_MAVEN_PROFILES_ARGS or else it will be interpreted as a + # single argument! + # "${SBT_MAVEN_TEST_ARGS[@]}" is cool because it's an array. + # QUESTION: Why doesn't 'yes "q"' work? + # QUESTION: Why doesn't 'grep -v -e "^\[info\] Resolving"' work? + echo -e "q\n" \ + | build/sbt $SBT_MAVEN_PROFILES_ARGS "${SBT_MAVEN_TEST_ARGS[@]}" \ + | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" + fi } echo "" @@ -210,13 +235,21 @@ echo "=========================================================================" CURRENT_BLOCK=$BLOCK_PYSPARK_UNIT_TESTS +# add path for python 3 in jenkins +export PATH="${PATH}:/home/anaonda/envs/py3k/bin" ./python/run-tests echo "" echo "=========================================================================" -echo "Detecting binary incompatibilities with MiMa" +echo "Running SparkR tests" echo "=========================================================================" -CURRENT_BLOCK=$BLOCK_MIMA +CURRENT_BLOCK=$BLOCK_SPARKR_UNIT_TESTS + +if [ $(command -v R) ]; then + ./R/install-dev.sh + ./R/run-tests.sh +else + echo "Ignoring SparkR tests as R was not found in PATH" +fi -./dev/mima diff --git a/dev/run-tests-codes.sh b/dev/run-tests-codes.sh index 1348e0609dda..154e01255b2e 100644 --- a/dev/run-tests-codes.sh +++ b/dev/run-tests-codes.sh @@ -22,6 +22,7 @@ readonly BLOCK_RAT=11 readonly BLOCK_SCALA_STYLE=12 readonly BLOCK_PYTHON_STYLE=13 readonly BLOCK_BUILD=14 -readonly BLOCK_SPARK_UNIT_TESTS=15 -readonly BLOCK_PYSPARK_UNIT_TESTS=16 -readonly BLOCK_MIMA=17 +readonly BLOCK_MIMA=15 +readonly BLOCK_SPARK_UNIT_TESTS=16 +readonly BLOCK_PYSPARK_UNIT_TESTS=17 +readonly BLOCK_SPARKR_UNIT_TESTS=18 diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 6a849e4f7720..030f2cdddb35 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -47,7 +47,23 @@ COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}" # GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :( SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}" -TESTS_TIMEOUT="120m" # format: http://linux.die.net/man/1/timeout +TESTS_TIMEOUT="150m" # format: http://linux.die.net/man/1/timeout + +# Array to capture all tests to run on the pull request. These tests are held under the +#+ dev/tests/ directory. +# +# To write a PR test: +#+ * the file must reside within the dev/tests directory +#+ * be an executable bash script +#+ * accept three arguments on the command line, the first being the Github PR long commit +#+ hash, the second the Github SHA1 hash, and the final the current PR hash +#+ * and, lastly, return string output to be included in the pr message output that will +#+ be posted to Github +PR_TESTS=( + "pr_merge_ability" + "pr_public_classes" + "pr_new_dependencies" +) function post_message () { local message=$1 @@ -131,61 +147,42 @@ function send_archived_logs () { fi } - -# We diff master...$ghprbActualCommit because that gets us changes introduced in the PR -#+ and not anything else added to master since the PR was branched. - -# check PR merge-ability and check for new public classes -{ - if [ "$sha1" == "$ghprbActualCommit" ]; then - merge_note=" * This patch **does not merge cleanly**." - else - merge_note=" * This patch merges cleanly." - fi - - source_files=$( - git diff master...$ghprbActualCommit --name-only `# diff patch against master from branch point` \ - | grep -v -e "\/test" `# ignore files in test directories` \ - | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \ - | tr "\n" " " - ) - new_public_classes=$( - git diff master...$ghprbActualCommit ${source_files} `# diff patch against master from branch point` \ - | grep "^\+" `# filter in only added lines` \ - | sed -r -e "s/^\+//g" `# remove the leading +` \ - | grep -e "trait " -e "class " `# filter in lines with these key words` \ - | grep -e "{" -e "(" `# filter in lines with these key words, too` \ - | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \ - | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \ - | sed -r -e "s/\{.*//g" `# remove from the { onwards` \ - | sed -r -e "s/\}//g" `# just in case, remove }; they mess the JSON` \ - | sed -r -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \ - | sed -r -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \ - | sed -r -e "s/^/ \* /g" `# prepend ' *' to start of line` \ - | sed -r -e "s/$/\\\n/g" `# append newline to end of line` \ - | tr -d "\n" `# remove actual LF characters` - ) - - if [ -z "$new_public_classes" ]; then - public_classes_note=" * This patch adds no public classes." - else - public_classes_note=" * This patch adds the following public classes _(experimental)_:" - public_classes_note="${public_classes_note}\n${new_public_classes}" - fi -} - # post start message { start_message="\ [Test build ${BUILD_DISPLAY_NAME} has started](${BUILD_URL}consoleFull) for \ PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." - start_message="${start_message}\n${merge_note}" - # start_message="${start_message}\n${public_classes_note}" - post_message "$start_message" } +# Environment variable to capture PR test output +pr_message="" +# Ensure we save off the current HEAD to revert to +current_pr_head="`git rev-parse HEAD`" + +echo "HEAD: `git rev-parse HEAD`" +echo "GHPRB: $ghprbActualCommit" +echo "SHA1: $sha1" + +# Run pull request tests +for t in "${PR_TESTS[@]}"; do + this_test="${FWDIR}/dev/tests/${t}.sh" + # Ensure the test can be found and is a file + if [ -f "${this_test}" ]; then + echo "Running test: $t" + this_mssg="$(bash "${this_test}" "${ghprbActualCommit}" "${sha1}" "${current_pr_head}")" + # Check if this is the merge test as we submit that note *before* and *after* + # the tests run + [ "$t" == "pr_merge_ability" ] && merge_note="${this_mssg}" + pr_message="${pr_message}\n${this_mssg}" + # Ensure, after each test, that we're back on the current PR + git checkout -f "${current_pr_head}" &>/dev/null + else + echo "Cannot find test ${this_test}." + fi +done + # run tests { timeout "${TESTS_TIMEOUT}" ./dev/run-tests @@ -211,12 +208,14 @@ function send_archived_logs () { failing_test="Python style tests" elif [ "$test_result" -eq "$BLOCK_BUILD" ]; then failing_test="to build" + elif [ "$test_result" -eq "$BLOCK_MIMA" ]; then + failing_test="MiMa tests" elif [ "$test_result" -eq "$BLOCK_SPARK_UNIT_TESTS" ]; then failing_test="Spark unit tests" elif [ "$test_result" -eq "$BLOCK_PYSPARK_UNIT_TESTS" ]; then failing_test="PySpark unit tests" - elif [ "$test_result" -eq "$BLOCK_MIMA" ]; then - failing_test="MiMa tests" + elif [ "$test_result" -eq "$BLOCK_SPARKR_UNIT_TESTS" ]; then + failing_test="SparkR unit tests" else failing_test="some tests" fi @@ -234,8 +233,7 @@ function send_archived_logs () { PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." result_message="${result_message}\n${test_result_note}" - result_message="${result_message}\n${merge_note}" - result_message="${result_message}\n${public_classes_note}" + result_message="${result_message}${pr_message}" post_message "$result_message" } diff --git a/dev/scalastyle b/dev/scalastyle index 86919227ed1a..4e03f89ed5d5 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -18,9 +18,10 @@ # echo -e "q\n" | build/sbt -Phive -Phive-thriftserver scalastyle > scalastyle.txt +echo -e "q\n" | build/sbt -Phive -Phive-thriftserver test:scalastyle >> scalastyle.txt # Check style with YARN built too -echo -e "q\n" | build/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 scalastyle \ - >> scalastyle.txt +echo -e "q\n" | build/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 scalastyle >> scalastyle.txt +echo -e "q\n" | build/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 test:scalastyle >> scalastyle.txt ERRORS=$(cat scalastyle.txt | awk '{if($1~/error/)print}') rm scalastyle.txt diff --git a/dev/tests/pr_merge_ability.sh b/dev/tests/pr_merge_ability.sh new file mode 100755 index 000000000000..d9a347fe24a8 --- /dev/null +++ b/dev/tests/pr_merge_ability.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# This script follows the base format for testing pull requests against +# another branch and returning results to be published. More details can be +# found at dev/run-tests-jenkins. +# +# Arg1: The Github Pull Request Actual Commit +#+ known as `ghprbActualCommit` in `run-tests-jenkins` +# Arg2: The SHA1 hash +#+ known as `sha1` in `run-tests-jenkins` +# + +ghprbActualCommit="$1" +sha1="$2" + +# check PR merge-ability +if [ "${sha1}" == "${ghprbActualCommit}" ]; then + echo " * This patch **does not merge cleanly**." +else + echo " * This patch merges cleanly." +fi diff --git a/dev/tests/pr_new_dependencies.sh b/dev/tests/pr_new_dependencies.sh new file mode 100755 index 000000000000..fdfb3c62aff5 --- /dev/null +++ b/dev/tests/pr_new_dependencies.sh @@ -0,0 +1,117 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# This script follows the base format for testing pull requests against +# another branch and returning results to be published. More details can be +# found at dev/run-tests-jenkins. +# +# Arg1: The Github Pull Request Actual Commit +#+ known as `ghprbActualCommit` in `run-tests-jenkins` +# Arg2: The SHA1 hash +#+ known as `sha1` in `run-tests-jenkins` +# Arg3: Current PR Commit Hash +#+ the PR hash for the current commit +# + +ghprbActualCommit="$1" +sha1="$2" +current_pr_head="$3" + +MVN_BIN="build/mvn" +CURR_CP_FILE="my-classpath.txt" +MASTER_CP_FILE="master-classpath.txt" + +# First switch over to the master branch +git checkout -f master +# Find and copy all pom.xml files into a *.gate file that we can check +# against through various `git` changes +find -name "pom.xml" -exec cp {} {}.gate \; +# Switch back to the current PR +git checkout -f "${current_pr_head}" + +# Check if any *.pom files from the current branch are different from the master +difference_q="" +for p in $(find -name "pom.xml"); do + [[ -f "${p}" && -f "${p}.gate" ]] && \ + difference_q="${difference_q}$(diff $p.gate $p)" +done + +# If no pom files were changed we can easily say no new dependencies were added +if [ -z "${difference_q}" ]; then + echo " * This patch does not change any dependencies." +else + # Else we need to manually build spark to determine what, if any, dependencies + # were added into the Spark assembly jar + ${MVN_BIN} clean package dependency:build-classpath -DskipTests 2>/dev/null | \ + sed -n -e '/Building Spark Project Assembly/,$p' | \ + grep --context=1 -m 2 "Dependencies classpath:" | \ + head -n 3 | \ + tail -n 1 | \ + tr ":" "\n" | \ + rev | \ + cut -d "/" -f 1 | \ + rev | \ + sort > ${CURR_CP_FILE} + + # Checkout the master branch to compare against + git checkout -f master + + ${MVN_BIN} clean package dependency:build-classpath -DskipTests 2>/dev/null | \ + sed -n -e '/Building Spark Project Assembly/,$p' | \ + grep --context=1 -m 2 "Dependencies classpath:" | \ + head -n 3 | \ + tail -n 1 | \ + tr ":" "\n" | \ + rev | \ + cut -d "/" -f 1 | \ + rev | \ + sort > ${MASTER_CP_FILE} + + DIFF_RESULTS="`diff ${CURR_CP_FILE} ${MASTER_CP_FILE}`" + + if [ -z "${DIFF_RESULTS}" ]; then + echo " * This patch does not change any dependencies." + else + # Pretty print the new dependencies + added_deps=$(echo "${DIFF_RESULTS}" | grep "<" | cut -d' ' -f2 | awk '{printf " * \`"$1"\`\\n"}') + removed_deps=$(echo "${DIFF_RESULTS}" | grep ">" | cut -d' ' -f2 | awk '{printf " * \`"$1"\`\\n"}') + added_deps_text=" * This patch **adds the following new dependencies:**\n${added_deps}" + removed_deps_text=" * This patch **removes the following dependencies:**\n${removed_deps}" + + # Construct the final returned message with proper + return_mssg="" + [ -n "${added_deps}" ] && return_mssg="${added_deps_text}" + if [ -n "${removed_deps}" ]; then + if [ -n "${return_mssg}" ]; then + return_mssg="${return_mssg}\n${removed_deps_text}" + else + return_mssg="${removed_deps_text}" + fi + fi + echo "${return_mssg}" + fi + + # Remove the files we've left over + [ -f "${CURR_CP_FILE}" ] && rm -f "${CURR_CP_FILE}" + [ -f "${MASTER_CP_FILE}" ] && rm -f "${MASTER_CP_FILE}" + + # Clean up our mess from the Maven builds just in case + ${MVN_BIN} clean &>/dev/null +fi diff --git a/dev/tests/pr_public_classes.sh b/dev/tests/pr_public_classes.sh new file mode 100755 index 000000000000..927295b88c96 --- /dev/null +++ b/dev/tests/pr_public_classes.sh @@ -0,0 +1,65 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# This script follows the base format for testing pull requests against +# another branch and returning results to be published. More details can be +# found at dev/run-tests-jenkins. +# +# Arg1: The Github Pull Request Actual Commit +#+ known as `ghprbActualCommit` in `run-tests-jenkins` +# Arg2: The SHA1 hash +#+ known as `sha1` in `run-tests-jenkins` +# + +# We diff master...$ghprbActualCommit because that gets us changes introduced in the PR +#+ and not anything else added to master since the PR was branched. + +ghprbActualCommit="$1" +sha1="$2" + +source_files=$( + git diff master...$ghprbActualCommit --name-only `# diff patch against master from branch point` \ + | grep -v -e "\/test" `# ignore files in test directories` \ + | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \ + | tr "\n" " " +) +new_public_classes=$( + git diff master...$ghprbActualCommit ${source_files} `# diff patch against master from branch point` \ + | grep "^\+" `# filter in only added lines` \ + | sed -r -e "s/^\+//g" `# remove the leading +` \ + | grep -e "trait " -e "class " `# filter in lines with these key words` \ + | grep -e "{" -e "(" `# filter in lines with these key words, too` \ + | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \ + | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \ + | sed -r -e "s/\{.*//g" `# remove from the { onwards` \ + | sed -r -e "s/\}//g" `# just in case, remove }; they mess the JSON` \ + | sed -r -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \ + | sed -r -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \ + | sed -r -e "s/^/ \* /g" `# prepend ' *' to start of line` \ + | sed -r -e "s/$/\\\n/g" `# append newline to end of line` \ + | tr -d "\n" `# remove actual LF characters` +) + +if [ -z "$new_public_classes" ]; then + echo " * This patch adds no public classes." +else + public_classes_note=" * This patch adds the following public classes _(experimental)_:" + echo "${public_classes_note}\n${new_public_classes}" +fi diff --git a/docs/README.md b/docs/README.md index 8a54724c4bea..5852f972a051 100644 --- a/docs/README.md +++ b/docs/README.md @@ -58,19 +58,25 @@ phase, use the following sytax: We use Sphinx to generate Python API docs, so you will need to install it by running `sudo pip install sphinx`. -## API Docs (Scaladoc and Sphinx) +## knitr, devtools -You can build just the Spark scaladoc by running `build/sbt doc` from the SPARK_PROJECT_ROOT directory. +SparkR documentation is written using `roxygen2` and we use `knitr`, `devtools` to generate +documentation. To install these packages you can run `install.packages(c("knitr", "devtools"))` from a +R console. + +## API Docs (Scaladoc, Sphinx, roxygen2) + +You can build just the Spark scaladoc by running `build/sbt unidoc` from the SPARK_PROJECT_ROOT directory. Similarly, you can build just the PySpark docs by running `make html` from the SPARK_PROJECT_ROOT/python/docs directory. Documentation is only generated for classes that are listed as -public in `__init__.py`. +public in `__init__.py`. The SparkR docs can be built by running SPARK_PROJECT_ROOT/R/create-docs.sh. When you run `jekyll` in the `docs` directory, it will also copy over the scaladoc for the various Spark subprojects into the `docs` directory (and then also into the `_site` directory). We use a -jekyll plugin to run `build/sbt doc` before building the site so if you haven't run it (recently) it +jekyll plugin to run `build/sbt unidoc` before building the site so if you haven't run it (recently) it may take some time as it generates all of the scaladoc. The jekyll plugin also generates the PySpark docs [Sphinx](http://sphinx-doc.org/). -NOTE: To skip the step of building and copying over the Scala and Python API docs, run `SKIP_API=1 +NOTE: To skip the step of building and copying over the Scala, Python, R API docs, run `SKIP_API=1 jekyll`. diff --git a/docs/_config.yml b/docs/_config.yml index e2db274e1f61..b22b627f0900 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -10,11 +10,12 @@ kramdown: include: - _static + - _modules # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 1.3.0-SNAPSHOT -SPARK_VERSION_SHORT: 1.3.0 +SPARK_VERSION: 1.4.0-SNAPSHOT +SPARK_VERSION_SHORT: 1.4.0 SCALA_BINARY_VERSION: "2.10" SCALA_VERSION: "2.10.4" MESOS_VERSION: 0.21.0 diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 8841f7675d35..b92c75f90b11 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -7,7 +7,9 @@ {{ page.title }} - Spark {{site.SPARK_VERSION_SHORT}} Documentation - + {% if page.description %} + + {% endif %} {% if page.redirect %} @@ -69,7 +71,7 @@
    • Spark Programming Guide
    • Spark Streaming
    • -
    • Spark SQL
    • +
    • DataFrames and SQL
    • MLlib (Machine Learning)
    • GraphX (Graph Processing)
    • Bagel (Pregel on Spark)
    • @@ -82,6 +84,7 @@
    • Scala
    • Java
    • Python
    • +
    • R
    • diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 3c626a0b7f54..0ea3f8eab461 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -78,5 +78,18 @@ puts "cp -r python/docs/_build/html/. docs/api/python" cp_r("python/docs/_build/html/.", "docs/api/python") - cd("..") + # Build SparkR API docs + puts "Moving to R directory and building roxygen docs." + cd("R") + puts `./create-docs.sh` + + puts "Moving back into home dir." + cd("../") + + puts "Making directory api/R" + mkdir_p "docs/api/R" + + puts "cp -r R/pkg/html/. docs/api/R" + cp_r("R/pkg/html/.", "docs/api/R") + end diff --git a/docs/bagel-programming-guide.md b/docs/bagel-programming-guide.md index 7e55131754a3..c2fe6b0e286c 100644 --- a/docs/bagel-programming-guide.md +++ b/docs/bagel-programming-guide.md @@ -1,6 +1,7 @@ --- layout: global -title: Bagel Programming Guide +displayTitle: Bagel Programming Guide +title: Bagel --- **Bagel will soon be superseded by [GraphX](graphx-programming-guide.html); we recommend that new users try GraphX instead.** diff --git a/docs/building-spark.md b/docs/building-spark.md index fb93017861ed..ea79c5bc276d 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -9,6 +9,10 @@ redirect_from: "building-with-maven.html" Building Spark using Maven requires Maven 3.0.4 or newer and Java 6+. +**Note:** Building Spark with Java 7 or later can create JAR files that may not be +readable with early versions of Java 6, due to the large number of files in the JAR +archive. Build with Java 6 if this is an issue for your deployment. + # Building with `build/mvn` Spark now comes packaged with a self-contained Maven installation to ease building and deployment of Spark from source located under the `build/` directory. This script will automatically download and setup all necessary build requirements ([Maven](https://maven.apache.org/), [Scala](http://www.scala-lang.org/), and [Zinc](https://github.com/typesafehub/zinc)) locally within the `build/` directory itself. It honors any `mvn` binary if present already, however, will pull down its own copy of Scala and Zinc regardless to ensure proper version requirements are met. `build/mvn` execution acts as a pass through to the `mvn` call allowing easy transition from previous build methods. As an example, one can build a version of Spark as follows: @@ -19,6 +23,18 @@ build/mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -DskipTests clean package Other build examples can be found below. +**Note:** When building on an encrypted filesystem (if your home directory is encrypted, for example), then the Spark build might fail with a "Filename too long" error. As a workaround, add the following in the configuration args of the `scala-maven-plugin` in the project `pom.xml`: + + -Xmax-classfile-name + 128 + +and in `project/SparkBuild.scala` add: + + scalacOptions in Compile ++= Seq("-Xmax-classfile-name", "128"), + +to the `sharedSettings` val. See also [this PR](https://github.com/apache/spark/pull/2883/files) if you are unsure of where to add these lines. + + # Setting up Maven's Memory Usage You'll need to configure Maven to use more memory than usual by setting `MAVEN_OPTS`. We recommend the following settings: @@ -111,9 +127,9 @@ To produce a Spark package compiled with Scala 2.11, use the `-Dscala-2.11` prop dev/change-version-to-2.11.sh mvn -Pyarn -Phadoop-2.4 -Dscala-2.11 -DskipTests clean package -Scala 2.11 support in Spark is experimental and does not support a few features. -Specifically, Spark's external Kafka library and JDBC component are not yet -supported in Scala 2.11 builds. +Scala 2.11 support in Spark does not support a few features due to dependencies +which are themselves not Scala 2.11 ready. Specifically, Spark's external +Kafka library and JDBC component are not yet supported in Scala 2.11 builds. # Spark Tests in Maven @@ -137,15 +153,18 @@ We use the scala-maven-plugin which supports incremental and continuous compilat should run continuous compilation (i.e. wait for changes). However, this has not been tested extensively. A couple of gotchas to note: + * it only scans the paths `src/main` and `src/test` (see [docs](http://scala-tools.org/mvnsites/maven-scala-plugin/usage_cc.html)), so it will only work from within certain submodules that have that structure. + * you'll typically need to run `mvn install` from the project root for compilation within specific submodules to work; this is because submodules that depend on other submodules do so via the `spark-parent` module). Thus, the full flow for running continuous-compilation of the `core` submodule may look more like: - ``` + +``` $ mvn install $ cd core $ mvn scala:cc @@ -156,14 +175,6 @@ Thus, the full flow for running continuous-compilation of the `core` submodule m For help in setting up IntelliJ IDEA or Eclipse for Spark development, and troubleshooting, refer to the [wiki page for IDE setup](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark#ContributingtoSpark-IDESetup). -# Building Spark Debian Packages - -The Maven build includes support for building a Debian package containing the assembly 'fat-jar', PySpark, and the necessary scripts and configuration files. This can be created by specifying the following: - - mvn -Pdeb -DskipTests clean package - -The debian package can then be found under assembly/target. We added the short commit hash to the file name so that we can distinguish individual packages built for SNAPSHOT versions. - # Running Java 8 Test Suites Running only Java 8 tests and nothing else. diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index 6a75d5c457f0..7079de546e2f 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -33,7 +33,11 @@ There are several useful things to note about this architecture: 2. Spark is agnostic to the underlying cluster manager. As long as it can acquire executor processes, and these communicate with each other, it is relatively easy to run it even on a cluster manager that also supports other applications (e.g. Mesos/YARN). -3. Because the driver schedules tasks on the cluster, it should be run close to the worker +3. The driver program must listen for and accept incoming connections from its executors throughout + its lifetime (e.g., see [spark.driver.port and spark.fileserver.port in the network config + section](configuration.html#networking)). As such, the driver program must be network + addressable from the worker nodes. +4. Because the driver schedules tasks on the cluster, it should be run close to the worker nodes, preferably on the same local area network. If you'd like to send requests to the cluster remotely, it's better to open an RPC to the driver and have it submit operations from nearby than to run a driver far away from the worker nodes. diff --git a/docs/configuration.md b/docs/configuration.md index efbab4085317..d587b91124cb 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1,6 +1,7 @@ --- layout: global -title: Spark Configuration +displayTitle: Spark Configuration +title: Configuration --- * This will become a table of contents (this text will be scraped). {:toc} @@ -34,9 +35,19 @@ val conf = new SparkConf() val sc = new SparkContext(conf) {% endhighlight %} -Note that we can have more than 1 thread in local mode, and in cases like spark streaming, we may actually -require one to prevent any sort of starvation issues. +Note that we can have more than 1 thread in local mode, and in cases like Spark Streaming, we may +actually require one to prevent any sort of starvation issues. +Properties that specify some time duration should be configured with a unit of time. +The following format is accepted: + + 25ms (milliseconds) + 5s (seconds) + 10m or 10min (minutes) + 3h (hours) + 5d (days) + 1y (years) + ## Dynamically Loading Spark Properties In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For instance, if you'd like to run the same application with different masters or different @@ -69,7 +80,9 @@ each line consists of a key and a value separated by whitespace. For example: Any values specified as flags or in the properties file will be passed on to the application and merged with those specified through SparkConf. Properties set directly on the SparkConf take highest precedence, then flags passed to `spark-submit` or `spark-shell`, then options -in the `spark-defaults.conf` file. +in the `spark-defaults.conf` file. A few configuration keys have been renamed since earlier +versions of Spark; in such cases, the older key names are still accepted, but take lower +precedence than any instance of the newer key. ## Viewing Spark Properties @@ -93,14 +106,6 @@ of the most common options to set are: The name of your application. This will appear in the UI and in log data.
      - - - - - @@ -108,23 +113,6 @@ of the most common options to set are: Number of cores to use for the driver process, only in cluster mode. - - - - - - - - - - - - - + + - - + + - + @@ -190,6 +175,14 @@ of the most common options to set are: Logs the effective SparkConf as INFO when a SparkContext is started. + + + + +
      Executor IDTotal Tasks Failed Tasks Succeeded TasksInputOutputShuffle ReadShuffle WriteShuffle Spill (Memory)Shuffle Spill (Disk) + Input Size / Records + + Output Size / Records + + + Shuffle Read Size / Records + + + Shuffle Write Size / Records + Shuffle Spill (Memory)Shuffle Spill (Disk)
      {v.failedTasks + v.succeededTasks} {v.failedTasks} {v.succeededTasks} - {Utils.bytesToString(v.inputBytes)} - {Utils.bytesToString(v.outputBytes)} - {Utils.bytesToString(v.shuffleRead)} - {Utils.bytesToString(v.shuffleWrite)} - {Utils.bytesToString(v.memoryBytesSpilled)} - {Utils.bytesToString(v.diskBytesSpilled)} + {s"${Utils.bytesToString(v.inputBytes)} / ${v.inputRecords}"} + + {s"${Utils.bytesToString(v.outputBytes)} / ${v.outputRecords}"} + + {s"${Utils.bytesToString(v.shuffleRead)} / ${v.shuffleReadRecords}"} + + {s"${Utils.bytesToString(v.shuffleWrite)} / ${v.shuffleWriteRecords}"} + + {Utils.bytesToString(v.memoryBytesSpilled)} + + {Utils.bytesToString(v.diskBytesSpilled)} +
      {acc.name}{acc.value}
      {acc.name}{acc.value}
      {UIUtils.formatDuration(millis.toLong)} @@ -258,30 +290,88 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val schedulerDelayQuantiles = schedulerDelayTitle +: getFormattedTimeQuantiles(schedulerDelays) - def getFormattedSizeQuantiles(data: Seq[Double]) = - Distribution(data).get.getQuantiles().map(d => {Utils.bytesToString(d.toLong)}{Utils.bytesToString(d.toLong)}{s"${Utils.bytesToString(d.toLong)} / ${recordDist.next().toLong}"}InputInput Size / RecordsOutputOutput Size / Records + + Shuffle Read Blocked Time + + + + Shuffle Read Size / Records + + Shuffle Read (Remote) + + Shuffle Remote Reads + + Shuffle WriteShuffle Write Size / Records
      - {inputReadable} + {s"$inputReadable / $inputRecords"} - {outputReadable} + {s"$outputReadable / $outputRecords"} + {shuffleReadBlockedTimeReadable} + - {shuffleReadReadable} + {s"$shuffleReadReadable / $shuffleReadRecords"} + + {shuffleReadRemoteReadable} - {shuffleWriteReadable} + {s"$shuffleWriteReadable / $shuffleWriteRecords"} {errorSummary}{details}
      spark.master(none) - The cluster manager to connect to. See the list of - allowed master URL's. -
      spark.driver.cores 1
      spark.driver.memory512m - Amount of memory to use for the driver process, i.e. where SparkContext is initialized. - (e.g. 512m, 2g). -
      spark.executor.memory512m - Amount of memory to use per executor process, in the same format as JVM memory strings - (e.g. 512m, 2g). -
      spark.driver.maxResultSize 1g @@ -137,38 +125,35 @@ of the most common options to set are:
      spark.serializerorg.apache.spark.serializer.
      JavaSerializer
      spark.driver.memory512m - Class to use for serializing objects that will be sent over the network or need to be cached - in serialized form. The default of Java serialization works with any Serializable Java object - but is quite slow, so we recommend using - org.apache.spark.serializer.KryoSerializer and configuring Kryo serialization - when speed is necessary. Can be any subclass of - - org.apache.spark.Serializer. + Amount of memory to use for the driver process, i.e. where SparkContext is initialized. + (e.g. 512m, 2g). + +
      Note: In client mode, this config must not be set through the SparkConf + directly in your application, because the driver JVM has already started at that point. + Instead, please set this through the --driver-memory command line option + or in your default properties file.
      spark.kryo.classesToRegister(none)spark.executor.memory512m - If you use Kryo serialization, give a comma-separated list of custom class names to register - with Kryo. - See the tuning guide for more details. + Amount of memory to use per executor process, in the same format as JVM memory strings + (e.g. 512m, 2g).
      spark.kryo.registratorspark.extraListeners (none) - If you use Kryo serialization, set this class to register your custom classes with Kryo. This - property is useful if you need to register your classes in a custom way, e.g. to specify a custom - field serializer. Otherwise spark.kryo.classesToRegister is simpler. It should be - set to a class that extends - - KryoRegistrator. - See the tuning guide for more details. + A comma-separated list of classes that implement SparkListener; when initializing + SparkContext, instances of these classes will be created and registered with Spark's listener + bus. If a class has a single-argument constructor that accepts a SparkConf, that constructor + will be called; otherwise, a zero-argument constructor will be called. If no valid constructor + can be found, the SparkContext creation will fail with an exception.
      spark.master(none) + The cluster manager to connect to. See the list of + allowed master URL's. +
      Apart from these, the following properties are also available, and may be useful in some situations: @@ -198,51 +191,84 @@ Apart from these, the following properties are also available, and may be useful - + + + + + + + + + + + + + + + + + + + - + - + - - + + @@ -256,30 +282,40 @@ Apart from these, the following properties are also available, and may be useful - + + + + + + - + - - + + @@ -290,6 +326,9 @@ Apart from these, the following properties are also available, and may be useful or it will be displayed before the driver exiting. It also can be dumped into disk by `sc.dump_profiles(path)`. If some of the profile results had been displayed maually, they will not be displayed automatically before driver exiting. + + By default the `pyspark.profiler.BasicProfiler` will be used, but this can be overridden by + passing a profiler class in as a parameter to the `SparkContext` constructor. @@ -302,6 +341,15 @@ Apart from these, the following properties are also available, and may be useful automatically. + + + + + @@ -312,40 +360,38 @@ Apart from these, the following properties are also available, and may be useful from JVM to Python worker for every task. +
      Property NameDefaultMeaning
      spark.executor.extraJavaOptionsspark.driver.extraClassPath (none) - A string of extra JVM options to pass to executors. For instance, GC settings or other - logging. Note that it is illegal to set Spark properties or heap size settings with this - option. Spark properties should be set using a SparkConf object or the - spark-defaults.conf file used with the spark-submit script. Heap size settings can be set - with spark.executor.memory. + Extra classpath entries to append to the classpath of the driver. + +
      Note: In client mode, this config must not be set through the SparkConf + directly in your application, because the driver JVM has already started at that point. + Instead, please set this through the --driver-class-path command line option or in + your default properties file.
      spark.driver.extraJavaOptions(none) + A string of extra JVM options to pass to the driver. For instance, GC settings or other logging. + +
      Note: In client mode, this config must not be set through the SparkConf + directly in your application, because the driver JVM has already started at that point. + Instead, please set this through the --driver-java-options command line option or in + your default properties file.
      spark.driver.extraLibraryPath(none) + Set a special library path to use when launching the driver JVM. + +
      Note: In client mode, this config must not be set through the SparkConf + directly in your application, because the driver JVM has already started at that point. + Instead, please set this through the --driver-library-path command line option or in + your default properties file.
      spark.driver.userClassPathFirstfalse + (Experimental) Whether to give user-added jars precedence over Spark's own jars when loading + classes in the the driver. This feature can be used to mitigate conflicts between Spark's + dependencies and user dependencies. It is currently an experimental feature. + + This is used in cluster mode only.
      spark.executor.extraClassPath (none) - Extra classpath entries to append to the classpath of executors. This exists primarily - for backwards-compatibility with older versions of Spark. Users typically should not need - to set this option. + Extra classpath entries to append to the classpath of executors. This exists primarily for + backwards-compatibility with older versions of Spark. Users typically should not need to set + this option.
      spark.executor.extraLibraryPathspark.executor.extraJavaOptions (none) - Set a special library path to use when launching executor JVM's. + A string of extra JVM options to pass to executors. For instance, GC settings or other logging. + Note that it is illegal to set Spark properties or heap size settings with this option. Spark + properties should be set using a SparkConf object or the spark-defaults.conf file used with the + spark-submit script. Heap size settings can be set with spark.executor.memory.
      spark.executor.logs.rolling.strategyspark.executor.extraLibraryPath (none) - Set the strategy of rolling of executor logs. By default it is disabled. It can - be set to "time" (time-based rolling) or "size" (size-based rolling). For "time", - use spark.executor.logs.rolling.time.interval to set the rolling interval. - For "size", use spark.executor.logs.rolling.size.maxBytes to set - the maximum file size for rolling. + Set a special library path to use when launching executor JVM's.
      spark.executor.logs.rolling.time.intervaldailyspark.executor.logs.rolling.maxRetainedFiles(none) - Set the time interval by which the executor logs will be rolled over. - Rolling is disabled by default. Valid values are `daily`, `hourly`, `minutely` or - any interval in seconds. See spark.executor.logs.rolling.maxRetainedFiles - for automatic cleaning of old logs. + Sets the number of latest rolling log files that are going to be retained by the system. + Older log files will be deleted. Disabled by default.
      spark.executor.logs.rolling.maxRetainedFilesspark.executor.logs.rolling.strategy (none) - Sets the number of latest rolling log files that are going to be retained by the system. - Older log files will be deleted. Disabled by default. + Set the strategy of rolling of executor logs. By default it is disabled. It can + be set to "time" (time-based rolling) or "size" (size-based rolling). For "time", + use spark.executor.logs.rolling.time.interval to set the rolling interval. + For "size", use spark.executor.logs.rolling.size.maxBytes to set + the maximum file size for rolling. +
      spark.executor.logs.rolling.time.intervaldaily + Set the time interval by which the executor logs will be rolled over. + Rolling is disabled by default. Valid values are `daily`, `hourly`, `minutely` or + any interval in seconds. See spark.executor.logs.rolling.maxRetainedFiles + for automatic cleaning of old logs.
      spark.files.userClassPathFirstspark.executor.userClassPathFirst false - (Experimental) Whether to give user-added jars precedence over Spark's own jars when - loading classes in Executors. This feature can be used to mitigate conflicts between - Spark's dependencies and user dependencies. It is currently an experimental feature. - (Currently, this setting does not work for YARN, see SPARK-2996 for more details). + (Experimental) Same functionality as spark.driver.userClassPathFirst, but + applied to executor instances.
      spark.python.worker.memory512mspark.executorEnv.[EnvironmentVariableName](none) - Amount of memory to use per python worker process during aggregation, in the same - format as JVM memory strings (e.g. 512m, 2g). If the memory - used during aggregation goes above this amount, it will spill the data into disks. + Add the environment variable specified by EnvironmentVariableName to the Executor + process. The user can specify multiple of these to set multiple environment variables.
      spark.python.worker.memory512m + Amount of memory to use per python worker process during aggregation, in the same + format as JVM memory strings (e.g. 512m, 2g). If the memory + used during aggregation goes above this amount, it will spill the data into disks. +
      spark.python.worker.reuse true
      + +#### Shuffle Behavior + + - - + + - - + + - - + + -
      Property NameDefaultMeaning
      spark.executorEnv.[EnvironmentVariableName](none)spark.reducer.maxMbInFlight48 - Add the environment variable specified by EnvironmentVariableName to the Executor - process. The user can specify multiple of these to set multiple environment variables. + Maximum size (in megabytes) of map outputs to fetch simultaneously from each reduce task. Since + each output requires us to create a buffer to receive it, this represents a fixed memory + overhead per reduce task, so keep it small unless you have a large amount of memory.
      spark.mesos.executor.homedriver side SPARK_HOMEspark.shuffle.blockTransferServicenetty - Set the directory in which Spark is installed on the executors in Mesos. By default, the - executors will simply use the driver's Spark home directory, which may not be visible to - them. Note that this is only relevant if a Spark binary package is not specified through - spark.executor.uri. + Implementation to use for transferring shuffle and cached blocks between executors. There + are two implementations available: netty and nio. Netty-based + block transfer is intended to be simpler but equally efficient and is the default option + starting in 1.2.
      spark.mesos.executor.memoryOverheadexecutor memory * 0.07, with minimum of 384spark.shuffle.compresstrue - This value is an additive for spark.executor.memory, specified in MiB, - which is used to calculate the total Mesos task memory. A value of 384 - implies a 384MiB overhead. Additionally, there is a hard-coded 7% minimum - overhead. The final overhead will be the larger of either - `spark.mesos.executor.memoryOverhead` or 7% of `spark.executor.memory`. + Whether to compress map output files. Generally a good idea. Compression will use + spark.io.compression.codec.
      - -#### Shuffle Behavior - - @@ -357,55 +403,46 @@ Apart from these, the following properties are also available, and may be useful - - + + - - + + - - + + - + - - - - - - - + + @@ -417,6 +454,17 @@ Apart from these, the following properties are also available, and may be useful the default option starting in 1.2. + + + + + @@ -426,13 +474,19 @@ Apart from these, the following properties are also available, and may be useful - - + + + + + + +
      Property NameDefaultMeaning
      spark.shuffle.consolidateFiles false
      spark.shuffle.spilltruespark.shuffle.file.buffer.kb32 - If set to "true", limits the amount of memory used during reduces by spilling data out to disk. - This spilling threshold is specified by spark.shuffle.memoryFraction. + Size of the in-memory buffer for each shuffle file output stream, in kilobytes. These buffers + reduce the number of disk seeks and system calls made in creating intermediate shuffle files.
      spark.shuffle.spill.compresstruespark.shuffle.io.maxRetries3 - Whether to compress data spilled during shuffles. Compression will use - spark.io.compression.codec. + (Netty only) Fetches that fail due to IO-related exceptions are automatically retried if this is + set to a non-zero value. This retry logic helps stabilize large shuffles in the face of long GC + pauses or transient network connectivity issues.
      spark.shuffle.memoryFraction0.2spark.shuffle.io.numConnectionsPerPeer1 - Fraction of Java heap to use for aggregation and cogroups during shuffles, if - spark.shuffle.spill is true. At any given time, the collective size of - all in-memory maps used for shuffles is bounded by this limit, beyond which the contents will - begin to spill to disk. If spills are often, consider increasing this value at the expense of - spark.storage.memoryFraction. + (Netty only) Connections between hosts are reused in order to reduce connection buildup for + large clusters. For clusters with many hard disks and few hosts, this may result in insufficient + concurrency to saturate all disks, and so users may consider increasing this value.
      spark.shuffle.compressspark.shuffle.io.preferDirectBufs true - Whether to compress map output files. Generally a good idea. Compression will use - spark.io.compression.codec. -
      spark.shuffle.file.buffer.kb32 - Size of the in-memory buffer for each shuffle file output stream, in kilobytes. These buffers - reduce the number of disk seeks and system calls made in creating intermediate shuffle files. + (Netty only) Off-heap buffers are used to reduce garbage collection during shuffle and cache + block transfer. For environments where off-heap memory is tightly limited, users may wish to + turn this off to force all allocations from Netty to be on-heap.
      spark.reducer.maxMbInFlight48spark.shuffle.io.retryWait5s - Maximum size (in megabytes) of map outputs to fetch simultaneously from each reduce task. Since - each output requires us to create a buffer to receive it, this represents a fixed memory - overhead per reduce task, so keep it small unless you have a large amount of memory. + (Netty only) How long to wait between retries of fetches. The maximum delay caused by retrying + is 15 seconds by default, calculated as maxRetries * retryWait.
      spark.shuffle.memoryFraction0.2 + Fraction of Java heap to use for aggregation and cogroups during shuffles, if + spark.shuffle.spill is true. At any given time, the collective size of + all in-memory maps used for shuffles is bounded by this limit, beyond which the contents will + begin to spill to disk. If spills are often, consider increasing this value at the expense of + spark.storage.memoryFraction. +
      spark.shuffle.sort.bypassMergeThreshold 200
      spark.shuffle.blockTransferServicenettyspark.shuffle.spilltrue - Implementation to use for transferring shuffle and cached blocks between executors. There - are two implementations available: netty and nio. Netty-based - block transfer is intended to be simpler but equally efficient and is the default option - starting in 1.2. + If set to "true", limits the amount of memory used during reduces by spilling data out to disk. + This spilling threshold is specified by spark.shuffle.memoryFraction. +
      spark.shuffle.spill.compresstrue + Whether to compress data spilled during shuffles. Compression will use + spark.io.compression.codec.
      @@ -441,26 +495,28 @@ Apart from these, the following properties are also available, and may be useful - - + + - - + + + Base directory in which Spark events are logged, if spark.eventLog.enabled is true. + Within this base directory, Spark creates a sub-directory for each application, and logs the + events specific to the application in this directory. Users may want to set this to + a unified location like an HDFS directory so history files can be read by the history server. + - - + + @@ -471,28 +527,26 @@ Apart from these, the following properties are also available, and may be useful - - + + - - + + - - + +
      Property NameDefaultMeaning
      spark.ui.port4040spark.eventLog.compressfalse - Port for your application's dashboard, which shows memory and workload data. + Whether to compress logged events, if spark.eventLog.enabled is true.
      spark.ui.retainedStages1000spark.eventLog.dirfile:///tmp/spark-events - How many stages the Spark UI and status APIs remember before garbage - collecting. -
      spark.ui.retainedJobs1000spark.eventLog.enabledfalse - How many jobs the Spark UI and status APIs remember before garbage - collecting. + Whether to log Spark events, useful for reconstructing the Web UI after the application has + finished.
      spark.eventLog.enabledfalsespark.ui.port4040 - Whether to log Spark events, useful for reconstructing the Web UI after the application has - finished. + Port for your application's dashboard, which shows memory and workload data.
      spark.eventLog.compressfalsespark.ui.retainedJobs1000 - Whether to compress logged events, if spark.eventLog.enabled is true. + How many jobs the Spark UI and status APIs remember before garbage + collecting.
      spark.eventLog.dirfile:///tmp/spark-eventsspark.ui.retainedStages1000 - Base directory in which Spark events are logged, if spark.eventLog.enabled is true. - Within this base directory, Spark creates a sub-directory for each application, and logs the - events specific to the application in this directory. Users may want to set this to - a unified location like an HDFS directory so history files can be read by the history server. + How many stages the Spark UI and status APIs remember before garbage + collecting.
      @@ -508,12 +562,10 @@ Apart from these, the following properties are also available, and may be useful - spark.rdd.compress - false + spark.closure.serializer + org.apache.spark.serializer.
      JavaSerializer - Whether to compress serialized RDD partitions (e.g. for - StorageLevel.MEMORY_ONLY_SER). Can save substantial space at the cost of some - extra CPU time. + Serializer class to use for closures. Currently only the Java serializer is supported. @@ -529,14 +581,6 @@ Apart from these, the following properties are also available, and may be useful and org.apache.spark.io.SnappyCompressionCodec. - - spark.io.compression.snappy.block.size - 32768 - - Block size (in bytes) used in Snappy compression, in the case when Snappy compression codec - is used. Lowering this block size will also lower shuffle memory usage when Snappy is used. - - spark.io.compression.lz4.block.size 32768 @@ -546,21 +590,20 @@ Apart from these, the following properties are also available, and may be useful - spark.closure.serializer - org.apache.spark.serializer.
      JavaSerializer + spark.io.compression.snappy.block.size + 32768 - Serializer class to use for closures. Currently only the Java serializer is supported. + Block size (in bytes) used in Snappy compression, in the case when Snappy compression codec + is used. Lowering this block size will also lower shuffle memory usage when Snappy is used. - spark.serializer.objectStreamReset - 100 + spark.kryo.classesToRegister + (none) - When serializing using org.apache.spark.serializer.JavaSerializer, the serializer caches - objects to prevent writing redundant data, however that stops garbage collection of those - objects. By calling 'reset' you flush that info from the serializer, and allow old - objects to be collected. To turn off this periodic reset set it to -1. - By default it will reset the serializer every 100 objects. + If you use Kryo serialization, give a comma-separated list of custom class names to register + with Kryo. + See the tuning guide for more details. @@ -585,12 +628,16 @@ Apart from these, the following properties are also available, and may be useful - spark.kryoserializer.buffer.mb - 0.064 + spark.kryo.registrator + (none) - Initial size of Kryo's serialization buffer, in megabytes. Note that there will be one buffer - per core on each worker. This buffer will grow up to - spark.kryoserializer.buffer.max.mb if needed. + If you use Kryo serialization, set this class to register your custom classes with Kryo. This + property is useful if you need to register your classes in a custom way, e.g. to specify a custom + field serializer. Otherwise spark.kryo.classesToRegister is simpler. It should be + set to a class that extends + + KryoRegistrator. + See the tuning guide for more details. @@ -602,11 +649,91 @@ Apart from these, the following properties are also available, and may be useful inside Kryo. + + spark.kryoserializer.buffer.mb + 0.064 + + Initial size of Kryo's serialization buffer, in megabytes. Note that there will be one buffer + per core on each worker. This buffer will grow up to + spark.kryoserializer.buffer.max.mb if needed. + + + + spark.rdd.compress + false + + Whether to compress serialized RDD partitions (e.g. for + StorageLevel.MEMORY_ONLY_SER). Can save substantial space at the cost of some + extra CPU time. + + + + spark.serializer + org.apache.spark.serializer.
      JavaSerializer + + Class to use for serializing objects that will be sent over the network or need to be cached + in serialized form. The default of Java serialization works with any Serializable Java object + but is quite slow, so we recommend using + org.apache.spark.serializer.KryoSerializer and configuring Kryo serialization + when speed is necessary. Can be any subclass of + + org.apache.spark.Serializer. + + + + spark.serializer.objectStreamReset + 100 + + When serializing using org.apache.spark.serializer.JavaSerializer, the serializer caches + objects to prevent writing redundant data, however that stops garbage collection of those + objects. By calling 'reset' you flush that info from the serializer, and allow old + objects to be collected. To turn off this periodic reset set it to -1. + By default it will reset the serializer every 100 objects. + + #### Execution Behavior + + + + + + + + + + + + + + + + + + + + - - + + + + + + + - - + + @@ -649,12 +787,23 @@ Apart from these, the following properties are also available, and may be useful - - - + + + + + + + + @@ -665,6 +814,15 @@ Apart from these, the following properties are also available, and may be useful increase it if you configure your own old generation size. + + + + + @@ -683,15 +841,6 @@ Apart from these, the following properties are also available, and may be useful directories on Tachyon file system. - - - - - @@ -699,211 +848,168 @@ Apart from these, the following properties are also available, and may be useful The URL of the underlying Tachyon file system in the TachyonStore. - - - - - - - - - - - - - - - - - - - -
      Property NameDefaultMeaning
      spark.broadcast.blockSize4096 + Size of each piece of a block in kilobytes for TorrentBroadcastFactory. + Too large a value decreases parallelism during broadcast (makes it slower); however, if it is + too small, BlockManager might take a performance hit. +
      spark.broadcast.factoryorg.apache.spark.broadcast.
      TorrentBroadcastFactory
      + Which broadcast implementation to use. +
      spark.cleaner.ttl(infinite) + Duration (seconds) of how long Spark will remember any metadata (stages generated, tasks + generated, etc.). Periodic cleanups will ensure that metadata older than this duration will be + forgotten. This is useful for running Spark for many hours / days (for example, running 24/7 in + case of Spark Streaming applications). Note that any RDD that persists in memory for more than + this duration will be cleared as well. +
      spark.executor.cores1 in YARN mode, all the available cores on the worker in standalone mode. + The number of cores to use on each executor. For YARN and standalone mode only. + + In standalone mode, setting this parameter allows an application to run multiple executors on + the same worker, provided that there are enough cores on that worker. Otherwise, only one + executor per application will run on each worker. +
      spark.default.parallelism @@ -625,19 +752,30 @@ Apart from these, the following properties are also available, and may be useful
      spark.broadcast.factoryorg.apache.spark.broadcast.
      TorrentBroadcastFactory
      spark.executor.heartbeatInterval10sInterval between each executor's heartbeats to the driver. Heartbeats let + the driver know that the executor is still alive and update it with metrics for in-progress + tasks.
      spark.files.fetchTimeout60s - Which broadcast implementation to use. + Communication timeout to use when fetching files added through SparkContext.addFile() from + the driver.
      spark.broadcast.blockSize4096spark.files.useFetchCachetrue - Size of each piece of a block in kilobytes for TorrentBroadcastFactory. - Too large a value decreases parallelism during broadcast (makes it slower); however, if it is - too small, BlockManager might take a performance hit. + If set to true (default), file fetching will use a local cache that is shared by executors + that belong to the same application, which can improve task launching performance when + running many executors on the same host. If set to false, these caching optimizations will + be disabled and all executors will fetch their own copies of files. This optimization may be + disabled in order to use Spark local directories that reside on NFS filesystems (see + SPARK-6313 for more details).
      spark.files.fetchTimeout60 - Communication timeout to use when fetching files added through SparkContext.addFile() from - the driver, in seconds. - spark.hadoop.cloneConffalseIf set to true, clones a new Hadoop Configuration object for each task. This + option should be enabled to work around Configuration thread-safety issues (see + SPARK-2546 for more details). + This is disabled by default in order to avoid unexpected performance regressions for jobs that + are not affected by these issues.
      spark.hadoop.validateOutputSpecstrueIf set to true, validates the output specification (e.g. checking if the output directory already exists) + used in saveAsHadoopFile and other variants. This can be disabled to silence exceptions due to pre-existing + output directories. We recommend that users do not disable this except if trying to achieve compatibility with + previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand. + This setting is ignored for jobs generated through Spark Streaming's StreamingContext, since + data may need to be rewritten to pre-existing output directories during checkpoint recovery.
      spark.storage.memoryFraction
      spark.storage.memoryMapThreshold2097152 + Size of a block, in bytes, above which Spark memory maps when reading a block from disk. + This prevents Spark from memory mapping very small blocks. In general, memory + mapping has high overhead for blocks close to or below the page size of the operating system. +
      spark.storage.unrollFraction 0.2
      spark.storage.memoryMapThreshold2097152 - Size of a block, in bytes, above which Spark memory maps when reading a block from disk. - This prevents Spark from memory mapping very small blocks. In general, memory - mapping has high overhead for blocks close to or below the page size of the operating system. -
      spark.tachyonStore.url tachyon://localhost:19998
      spark.cleaner.ttl(infinite) - Duration (seconds) of how long Spark will remember any metadata (stages generated, tasks - generated, etc.). Periodic cleanups will ensure that metadata older than this duration will be - forgotten. This is useful for running Spark for many hours / days (for example, running 24/7 in - case of Spark Streaming applications). Note that any RDD that persists in memory for more than - this duration will be cleared as well. -
      spark.hadoop.validateOutputSpecstrueIf set to true, validates the output specification (e.g. checking if the output directory already exists) - used in saveAsHadoopFile and other variants. This can be disabled to silence exceptions due to pre-existing - output directories. We recommend that users do not disable this except if trying to achieve compatibility with - previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand. - This setting is ignored for jobs generated through Spark Streaming's StreamingContext, since - data may need to be rewritten to pre-existing output directories during checkpoint recovery.
      spark.hadoop.cloneConffalseIf set to true, clones a new Hadoop Configuration object for each task. This - option should be enabled to work around Configuration thread-safety issues (see - SPARK-2546 for more details). - This is disabled by default in order to avoid unexpected performance regressions for jobs that - are not affected by these issues.
      spark.executor.heartbeatInterval10000Interval (milliseconds) between each executor's heartbeats to the driver. Heartbeats let - the driver know that the executor is still alive and update it with metrics for in-progress - tasks.
      #### Networking - - + + - - + + - - + + - - + + - - + + - - + + - + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + + Number of times to retry before an RPC task gives up. + An RPC task will run at most times of this number. - - + + - - + + - - + + + Duration for an RPC remote endpoint lookup operation to wait before timing out.
      Property NameDefaultMeaning
      spark.driver.host(local hostname)spark.akka.failure-detector.threshold300.0 - Hostname or IP address for the driver to listen on. - This is used for communicating with the executors and the standalone Master. + This is set to a larger value to disable failure detector that comes inbuilt akka. It can be + enabled again, if you plan to use this feature (Not recommended). This maps to akka's + `akka.remote.transport-failure-detector.threshold`. Tune this in combination of + `spark.akka.heartbeat.pauses` and `spark.akka.heartbeat.interval` if you need to.
      spark.driver.port(random)spark.akka.frameSize10 - Port for the driver to listen on. - This is used for communicating with the executors and the standalone Master. + Maximum message size to allow in "control plane" communication (for serialized tasks and task + results), in MB. Increase this if your tasks need to send back large results to the driver + (e.g. using collect() on a large dataset).
      spark.fileserver.port(random)spark.akka.heartbeat.interval1000s - Port for the driver's HTTP file server to listen on. + This is set to a larger value to disable the transport failure detector that comes built in to + Akka. It can be enabled again, if you plan to use this feature (Not recommended). A larger + interval value reduces network overhead and a smaller value ( ~ 1 s) might be more + informative for Akka's failure detector. Tune this in combination of `spark.akka.heartbeat.pauses` + if you need to. A likely positive use case for using failure detector would be: a sensistive + failure detector can help evict rogue executors quickly. However this is usually not the case + as GC pauses and network lags are expected in a real Spark cluster. Apart from that enabling + this leads to a lot of exchanges of heart beats between nodes leading to flooding the network + with those.
      spark.broadcast.port(random)spark.akka.heartbeat.pauses6000s - Port for the driver's HTTP broadcast server to listen on. - This is not relevant for torrent broadcast. + This is set to a larger value to disable the transport failure detector that comes built in to Akka. + It can be enabled again, if you plan to use this feature (Not recommended). Acceptable heart + beat pause for Akka. This can be used to control sensitivity to GC pauses. Tune + this along with `spark.akka.heartbeat.interval` if you need to.
      spark.replClassServer.port(random)spark.akka.threads4 - Port for the driver's HTTP class server to listen on. - This is only relevant for the Spark shell. + Number of actor threads to use for communication. Can be useful to increase on large clusters + when the driver has a lot of CPU cores.
      spark.blockManager.port(random)spark.akka.timeout100s - Port for all block managers to listen on. These exist on both the driver and the executors. + Communication timeout between Spark nodes.
      spark.executor.portspark.blockManager.port (random) - Port for the executor to listen on. This is used for communicating with the driver. + Port for all block managers to listen on. These exist on both the driver and the executors.
      spark.port.maxRetries16spark.broadcast.port(random) - Default maximum number of retries when binding to a port before giving up. + Port for the driver's HTTP broadcast server to listen on. + This is not relevant for torrent broadcast.
      spark.akka.frameSize10spark.driver.host(local hostname) - Maximum message size to allow in "control plane" communication (for serialized tasks and task - results), in MB. Increase this if your tasks need to send back large results to the driver - (e.g. using collect() on a large dataset). + Hostname or IP address for the driver to listen on. + This is used for communicating with the executors and the standalone Master.
      spark.akka.threads4spark.driver.port(random) - Number of actor threads to use for communication. Can be useful to increase on large clusters - when the driver has a lot of CPU cores. + Port for the driver to listen on. + This is used for communicating with the executors and the standalone Master.
      spark.akka.timeout100spark.executor.port(random) - Communication timeout between Spark nodes, in seconds. + Port for the executor to listen on. This is used for communicating with the driver.
      spark.network.timeout120spark.fileserver.port(random) - Default timeout for all network interactions, in seconds. This config will be used in - place of spark.core.connection.ack.wait.timeout, spark.akka.timeout, - spark.storage.blockManagerSlaveTimeoutMs or - spark.shuffle.io.connectionTimeout, if they are not configured. + Port for the driver's HTTP file server to listen on.
      spark.akka.heartbeat.pauses6000spark.network.timeout120s - This is set to a larger value to disable failure detector that comes inbuilt akka. It can be - enabled again, if you plan to use this feature (Not recommended). Acceptable heart beat pause - in seconds for akka. This can be used to control sensitivity to gc pauses. Tune this in - combination of `spark.akka.heartbeat.interval` and `spark.akka.failure-detector.threshold` - if you need to. + Default timeout for all network interactions. This config will be used in place of + spark.core.connection.ack.wait.timeout, spark.akka.timeout, + spark.storage.blockManagerSlaveTimeoutMs, + spark.shuffle.io.connectionTimeout, spark.rpc.askTimeout or + spark.rpc.lookupTimeout if they are not configured.
      spark.akka.failure-detector.threshold300.0spark.port.maxRetries16 - This is set to a larger value to disable failure detector that comes inbuilt akka. It can be - enabled again, if you plan to use this feature (Not recommended). This maps to akka's - `akka.remote.transport-failure-detector.threshold`. Tune this in combination of - `spark.akka.heartbeat.pauses` and `spark.akka.heartbeat.interval` if you need to. + Default maximum number of retries when binding to a port before giving up.
      spark.akka.heartbeat.interval1000spark.replClassServer.port(random) - This is set to a larger value to disable failure detector that comes inbuilt akka. It can be - enabled again, if you plan to use this feature (Not recommended). A larger interval value in - seconds reduces network overhead and a smaller value ( ~ 1 s) might be more informative for - akka's failure detector. Tune this in combination of `spark.akka.heartbeat.pauses` and - `spark.akka.failure-detector.threshold` if you need to. Only positive use case for using - failure detector can be, a sensistive failure detector can help evict rogue executors really - quick. However this is usually not the case as gc pauses and network lags are expected in a - real Spark cluster. Apart from that enabling this leads to a lot of exchanges of heart beats - between nodes leading to flooding the network with those. + Port for the driver's HTTP class server to listen on. + This is only relevant for the Spark shell.
      spark.shuffle.io.preferDirectBufstruespark.rpc.numRetries3 - (Netty only) Off-heap buffers are used to reduce garbage collection during shuffle and cache - block transfer. For environments where off-heap memory is tightly limited, users may wish to - turn this off to force all allocations from Netty to be on-heap.
      spark.shuffle.io.numConnectionsPerPeer1spark.rpc.retry.wait3s - (Netty only) Connections between hosts are reused in order to reduce connection buildup for - large clusters. For clusters with many hard disks and few hosts, this may result in insufficient - concurrency to saturate all disks, and so users may consider increasing this value. + Duration for an RPC ask operation to wait before retrying.
      spark.shuffle.io.maxRetries3spark.rpc.askTimeout120s - (Netty only) Fetches that fail due to IO-related exceptions are automatically retried if this is - set to a non-zero value. This retry logic helps stabilize large shuffles in the face of long GC - pauses or transient network connectivity issues. + Duration for an RPC ask operation to wait before timing out.
      spark.shuffle.io.retryWait5spark.rpc.lookupTimeout120s - (Netty only) Seconds to wait between retries of fetches. The maximum delay caused by retrying - is simply maxRetries * retryWait, by default 15 seconds.
      @@ -911,31 +1017,6 @@ Apart from these, the following properties are also available, and may be useful #### Scheduling - - - - - - - - - - - - - - - @@ -949,50 +1030,19 @@ Apart from these, the following properties are also available, and may be useful - - - - - - + - - - - - - - - - - - - - - - - + - + - + @@ -1024,16 +1074,15 @@ Apart from these, the following properties are also available, and may be useful - - + + - + - + - - + + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
      Property NameDefaultMeaning
      spark.task.cpus1 - Number of cores to allocate for each task. -
      spark.task.maxFailures4 - Number of individual task failures before giving up on the job. - Should be greater than or equal to 1. Number of allowed retries = this value - 1. -
      spark.scheduler.modeFIFO - The scheduling mode between - jobs submitted to the same SparkContext. Can be set to FAIR - to use fair sharing instead of queueing jobs one after another. Useful for - multi-user services. -
      spark.cores.max (not set)
      spark.mesos.coarsefalse - If set to "true", runs over Mesos clusters in - "coarse-grained" sharing mode, - where Spark acquires one long-lived Mesos task on each machine instead of one Mesos task per - Spark task. This gives lower-latency scheduling for short queries, but leaves resources in use - for the whole duration of the Spark job. -
      spark.speculationspark.localExecution.enabled false - If set to "true", performs speculative execution of tasks. This means if one or more tasks are - running slowly in a stage, they will be re-launched. -
      spark.speculation.interval100 - How often Spark will check for tasks to speculate, in milliseconds. -
      spark.speculation.quantile0.75 - Percentage of tasks which must be complete before speculation is enabled for a particular stage. -
      spark.speculation.multiplier1.5 - How many times slower a task is than the median to be considered for speculation. + Enables Spark to run certain jobs, such as first() or take() on the driver, without sending + tasks to the cluster. This can make certain jobs execute very quickly, but may require + shipping a whole partition of data to the driver.
      spark.locality.wait30003s - Number of milliseconds to wait to launch a data-local task before giving up and launching it + How long to wait to launch a data-local task before giving up and launching it on a less-local node. The same wait will be used to step through multiple locality levels (process-local, node-local, rack-local and then any). It is also possible to customize the waiting time for each level by setting spark.locality.wait.node, etc. @@ -1001,19 +1051,19 @@ Apart from these, the following properties are also available, and may be useful
      spark.locality.wait.processspark.locality.wait.node spark.locality.wait - Customize the locality wait for process locality. This affects tasks that attempt to access - cached data in a particular executor process. + Customize the locality wait for node locality. For example, you can set this to 0 to skip + node locality and search immediately for rack locality (if your cluster has rack information).
      spark.locality.wait.nodespark.locality.wait.process spark.locality.wait - Customize the locality wait for node locality. For example, you can set this to 0 to skip - node locality and search immediately for rack locality (if your cluster has rack information). + Customize the locality wait for process locality. This affects tasks that attempt to access + cached data in a particular executor process.
      spark.scheduler.revive.interval1000spark.scheduler.maxRegisteredResourcesWaitingTime30s - The interval length for the scheduler to revive the worker resource offers to run tasks - (in milliseconds). + Maximum amount of time to wait for resources to register before scheduling begins.
      spark.scheduler.minRegisteredResourcesRatio0.0 for Mesos and Standalone mode, 0.8 for YARN0.8 for YARN mode; 0.0 otherwise The minimum ratio of registered resources (registered resources / total expected resources) (resources are executors in yarn mode, CPU cores in standalone mode) @@ -1044,25 +1093,69 @@ Apart from these, the following properties are also available, and may be useful
      spark.scheduler.maxRegisteredResourcesWaitingTime30000spark.scheduler.modeFIFO - Maximum amount of time to wait for resources to register before scheduling begins - (in milliseconds). + The scheduling mode between + jobs submitted to the same SparkContext. Can be set to FAIR + to use fair sharing instead of queueing jobs one after another. Useful for + multi-user services.
      spark.localExecution.enabledspark.scheduler.revive.interval1s + The interval length for the scheduler to revive the worker resource offers to run tasks. +
      spark.speculation false - Enables Spark to run certain jobs, such as first() or take() on the driver, without sending - tasks to the cluster. This can make certain jobs execute very quickly, but may require - shipping a whole partition of data to the driver. + If set to "true", performs speculative execution of tasks. This means if one or more tasks are + running slowly in a stage, they will be re-launched. +
      spark.speculation.interval100ms + How often Spark will check for tasks to speculate. +
      spark.speculation.multiplier1.5 + How many times slower a task is than the median to be considered for speculation. +
      spark.speculation.quantile0.75 + Percentage of tasks which must be complete before speculation is enabled for a particular stage. +
      spark.task.cpus1 + Number of cores to allocate for each task. +
      spark.task.maxFailures4 + Number of individual task failures before giving up on the job. + Should be greater than or equal to 1. Number of allowed retries = this value - 1.
      -#### Dynamic allocation +#### Dynamic Allocation @@ -1074,32 +1167,49 @@ Apart from these, the following properties are also available, and may be useful available on YARN mode. For more detail, see the description here.

      - This requires the following configurations to be set: + This requires spark.shuffle.service.enabled to be set. + The following configurations are also relevant: spark.dynamicAllocation.minExecutors, spark.dynamicAllocation.maxExecutors, and - spark.shuffle.service.enabled + spark.dynamicAllocation.initialExecutors + + + + + + + - - + + + + + + - + @@ -1112,20 +1222,30 @@ Apart from these, the following properties are also available, and may be useful description. - - - - -
      Property NameDefaultMeaning
      spark.dynamicAllocation.executorIdleTimeout600s + If dynamic allocation is enabled and an executor has been idle for more than this duration, + the executor will be removed. For more detail, see this + description.
      spark.dynamicAllocation.initialExecutors spark.dynamicAllocation.minExecutors(none) - Lower bound for the number of executors if dynamic allocation is enabled (required). + Initial number of executors to run if dynamic allocation is enabled.
      spark.dynamicAllocation.maxExecutors(none)Integer.MAX_VALUE + Upper bound for the number of executors if dynamic allocation is enabled. +
      spark.dynamicAllocation.minExecutors0 - Upper bound for the number of executors if dynamic allocation is enabled (required). + Lower bound for the number of executors if dynamic allocation is enabled.
      spark.dynamicAllocation.schedulerBacklogTimeout605s If dynamic allocation is enabled and there have been pending tasks backlogged for more than - this duration (in seconds), new executors will be requested. For more detail, see this + this duration, new executors will be requested. For more detail, see this description.
      spark.dynamicAllocation.executorIdleTimeout600 - If dynamic allocation is enabled and an executor has been idle for more than this duration - (in seconds), the executor will be removed. For more detail, see this - description. -
      #### Security + + + + + + + + + + @@ -1142,21 +1262,29 @@ Apart from these, the following properties are also available, and may be useful not running on YARN and authentication is enabled. + + + + + - + - - + + @@ -1173,16 +1301,6 @@ Apart from these, the following properties are also available, and may be useful -Dspark.com.test.filter1.params='param1=foo,param2=testing' - - - - - @@ -1191,33 +1309,96 @@ Apart from these, the following properties are also available, and may be useful user that started the Spark job has view access. - - - - - - - - - -
      Property NameDefaultMeaning
      spark.acls.enablefalse + Whether Spark acls should are enabled. If enabled, this checks to see if the user has + access permissions to view or modify the job. Note this requires the user to be known, + so if the user comes across as null no checks are done. Filters can be used with the UI + to authenticate and set the user. +
      spark.admin.aclsEmpty + Comma separated list of users/administrators that have view and modify access to all Spark jobs. + This can be used if you run on a shared cluster and have a set of administrators or devs who + help debug when things work. +
      spark.authenticate false
      spark.core.connection.ack.wait.timeout60s + How long for the connection to wait for ack to occur before timing + out and giving up. To avoid unwilling timeout caused by long pause like GC, + you can set larger value. +
      spark.core.connection.auth.wait.timeout3030s - Number of seconds for the connection to wait for authentication to occur before timing + How long for the connection to wait for authentication to occur before timing out and giving up.
      spark.core.connection.ack.wait.timeout60spark.modify.aclsEmpty - Number of seconds for the connection to wait for ack to occur before timing - out and giving up. To avoid unwilling timeout caused by long pause like GC, - you can set larger value. + Comma separated list of users that have modify access to the Spark job. By default only the + user that started the Spark job has access to modify it (kill it for example).
      spark.acls.enablefalse - Whether Spark acls should are enabled. If enabled, this checks to see if the user has - access permissions to view or modify the job. Note this requires the user to be known, - so if the user comes across as null no checks are done. Filters can be used with the UI - to authenticate and set the user. -
      spark.ui.view.acls Empty
      spark.modify.aclsEmpty - Comma separated list of users that have modify access to the Spark job. By default only the - user that started the Spark job has access to modify it (kill it for example). -
      spark.admin.aclsEmpty - Comma separated list of users/administrators that have view and modify access to all Spark jobs. - This can be used if you run on a shared cluster and have a set of administrators or devs who - help debug when things work. -
      +#### Encryption + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
      Property NameDefaultMeaning
      spark.ssl.enabledfalse +

      Whether to enable SSL connections on all supported protocols.

      + +

      All the SSL settings like spark.ssl.xxx where xxx is a + particular configuration property, denote the global configuration for all the supported + protocols. In order to override the global configuration for the particular protocol, + the properties must be overwritten in the protocol-specific namespace.

      + +

      Use spark.ssl.YYY.XXX settings to overwrite the global configuration for + particular protocol denoted by YYY. Currently YYY can be + either akka for Akka based connections or fs for broadcast and + file server.

      +
      spark.ssl.enabledAlgorithmsEmpty + A comma separated list of ciphers. The specified ciphers must be supported by JVM. + The reference list of protocols one can find on + this + page. +
      spark.ssl.keyPasswordNone + A password to the private key in key-store. +
      spark.ssl.keyStoreNone + A path to a key-store file. The path can be absolute or relative to the directory where + the component is started in. +
      spark.ssl.keyStorePasswordNone + A password to the key-store. +
      spark.ssl.protocolNone + A protocol name. The protocol must be supported by JVM. The reference list of protocols + one can find on this + page. +
      spark.ssl.trustStoreNone + A path to a trust-store file. The path can be absolute or relative to the directory + where the component is started in. +
      spark.ssl.trustStorePasswordNone + A password to the trust-store. +
      + + #### Spark Streaming - + - + + + + + +
      Property NameDefaultMeaning
      spark.streaming.blockInterval200200ms - Interval (milliseconds) at which data received by Spark Streaming receivers is chunked + Interval at which data received by Spark Streaming receivers is chunked into blocks of data before storing them in Spark. Minimum recommended - 50 ms. See the performance tuning section in the Spark Streaming programing guide for more details. @@ -1225,9 +1406,9 @@ Apart from these, the following properties are also available, and may be useful
      spark.streaming.receiver.maxRateinfinitenot set - Maximum number records per second at which each receiver will receive data. + Maximum rate (number of records per second) at which each receiver will receive data. Effectively, each stream will consume at most this number of records per second. Setting this configuration to 0 or a negative number will put no limit on the rate. See the deployment guide @@ -1255,15 +1436,27 @@ Apart from these, the following properties are also available, and may be useful higher memory usage in Spark.
      spark.streaming.kafka.maxRatePerPartitionnot set + Maximum rate (number of records per second) at which data will be read from each Kafka + partition when using the new Kafka direct stream API. See the + Kafka Integration guide + for more details. +
      #### Cluster Managers Each cluster manager in Spark has additional configuration options. Configurations can be found on the pages for each mode: - * [YARN](running-on-yarn.html#configuration) - * [Mesos](running-on-mesos.html) - * [Standalone Mode](spark-standalone.html#cluster-launch-scripts) +##### [YARN](running-on-yarn.html#configuration) + +##### [Mesos](running-on-mesos.html#configuration) + +##### [Standalone Mode](spark-standalone.html#cluster-launch-scripts) # Environment Variables diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md index d50f445d7ecc..7f60f82b966f 100644 --- a/docs/ec2-scripts.md +++ b/docs/ec2-scripts.md @@ -5,7 +5,7 @@ title: Running Spark on EC2 The `spark-ec2` script, located in Spark's `ec2` directory, allows you to launch, manage and shut down Spark clusters on Amazon EC2. It automatically -sets up Spark, Shark and HDFS on the cluster for you. This guide describes +sets up Spark and HDFS on the cluster for you. This guide describes how to use `spark-ec2` to launch clusters, how to run jobs on them, and how to shut them down. It assumes you've already signed up for an EC2 account on the [Amazon Web Services site](http://aws.amazon.com/). @@ -52,7 +52,7 @@ identify machines belonging to each cluster in the Amazon EC2 Console. ```bash export AWS_SECRET_ACCESS_KEY=AaBbCcDdEeFGgHhIiJjKkLlMmNnOoPpQqRrSsTtU export AWS_ACCESS_KEY_ID=ABCDEFG1234567890123 -./spark-ec2 --key-pair=awskey --identity-file=awskey.pem --region=us-west-1 --zone=us-west-1a --spark-version=1.1.0 launch my-spark-cluster +./spark-ec2 --key-pair=awskey --identity-file=awskey.pem --region=us-west-1 --zone=us-west-1a launch my-spark-cluster ``` - After everything launches, check that the cluster scheduler is up and sees diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index e298c51f8a5b..3f10cb2dc3d2 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -1,6 +1,8 @@ --- layout: global -title: GraphX Programming Guide +displayTitle: GraphX Programming Guide +title: GraphX +description: GraphX graph processing library guide for Spark SPARK_VERSION_SHORT --- * This will become a table of contents (this text will be scraped). @@ -536,7 +538,7 @@ val joinedGraph = graph.joinVertices(uniqueCosts, ## Neighborhood Aggregation -A key step in may graph analytics tasks is aggregating information about the neighborhood of each +A key step in many graph analytics tasks is aggregating information about the neighborhood of each vertex. For example, we might want to know the number of followers each user has or the average age of the the followers of each user. Many iterative graph algorithms (e.g., PageRank, Shortest Path, and @@ -632,7 +634,7 @@ avgAgeOfOlderFollowers.collect.foreach(println(_)) ### Map Reduce Triplets Transition Guide (Legacy) -In earlier versions of GraphX we neighborhood aggregation was accomplished using the +In earlier versions of GraphX neighborhood aggregation was accomplished using the [`mapReduceTriplets`][Graph.mapReduceTriplets] operator: {% highlight scala %} @@ -661,7 +663,7 @@ val graph: Graph[Int, Float] = ... def msgFun(triplet: Triplet[Int, Float]): Iterator[(Int, String)] = { Iterator((triplet.dstId, "Hi")) } -def reduceFun(a: Int, b: Int): Int = a + b +def reduceFun(a: String, b: String): String = a + " " + b val result = graph.mapReduceTriplets[String](msgFun, reduceFun) {% endhighlight %} @@ -672,7 +674,7 @@ val graph: Graph[Int, Float] = ... def msgFun(triplet: EdgeContext[Int, Float, String]) { triplet.sendToDst("Hi") } -def reduceFun(a: Int, b: Int): Int = a + b +def reduceFun(a: String, b: String): String = a + " " + b val result = graph.aggregateMessages[String](msgFun, reduceFun) {% endhighlight %} @@ -680,8 +682,8 @@ val result = graph.aggregateMessages[String](msgFun, reduceFun) ### Computing Degree Information A common aggregation task is computing the degree of each vertex: the number of edges adjacent to -each vertex. In the context of directed graphs it often necessary to know the in-degree, out- -degree, and the total degree of each vertex. The [`GraphOps`][GraphOps] class contains a +each vertex. In the context of directed graphs it is often necessary to know the in-degree, +out-degree, and the total degree of each vertex. The [`GraphOps`][GraphOps] class contains a collection of operators to compute the degrees of each vertex. For example in the following we compute the max in, out, and total degrees: @@ -897,6 +899,8 @@ class VertexRDD[VD] extends RDD[(VertexID, VD)] { // Transform the values without changing the ids (preserves the internal index) def mapValues[VD2](map: VD => VD2): VertexRDD[VD2] def mapValues[VD2](map: (VertexId, VD) => VD2): VertexRDD[VD2] + // Show only vertices unique to this set based on their VertexId's + def minus(other: RDD[(VertexId, VD)]) // Remove vertices from this set that appear in the other set def diff(other: VertexRDD[VD]): VertexRDD[VD] // Join operators that take advantage of the internal indexing to accelerate joins (substantially) diff --git a/docs/img/cluster-overview.png b/docs/img/cluster-overview.png index 368274068e75..317554c5f2a5 100644 Binary files a/docs/img/cluster-overview.png and b/docs/img/cluster-overview.png differ diff --git a/docs/img/cluster-overview.pptx b/docs/img/cluster-overview.pptx index af3c462cd904..1b90d7ec5a7a 100644 Binary files a/docs/img/cluster-overview.pptx and b/docs/img/cluster-overview.pptx differ diff --git a/docs/index.md b/docs/index.md index 171d6ddad62f..b5b016e34795 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,6 +1,8 @@ --- layout: global -title: Spark Overview +displayTitle: Spark Overview +title: Overview +description: Apache Spark SPARK_VERSION_SHORT documentation homepage --- Apache Spark is a fast and general-purpose cluster computing system. @@ -72,7 +74,7 @@ options for deployment: in all supported languages (Scala, Java, Python) * Modules built on Spark: * [Spark Streaming](streaming-programming-guide.html): processing real-time data streams - * [Spark SQL](sql-programming-guide.html): support for structured data and relational queries + * [Spark SQL and DataFrames](sql-programming-guide.html): support for structured data and relational queries * [MLlib](mllib-guide.html): built-in machine learning library * [GraphX](graphx-programming-guide.html): Spark's new API for graph processing * [Bagel (Pregel on Spark)](bagel-programming-guide.html): older, simple graph processing model @@ -113,6 +115,8 @@ options for deployment: * [Spark Homepage](http://spark.apache.org) * [Spark Wiki](https://cwiki.apache.org/confluence/display/SPARK) +* [Spark Community](http://spark.apache.org/community.html) resources, including local meetups +* [StackOverflow tag `apache-spark`](http://stackoverflow.com/questions/tagged/apache-spark) * [Mailing Lists](http://spark.apache.org/mailing-lists.html): ask questions about Spark here * [AMP Camps](http://ampcamp.berkeley.edu/): a series of training camps at UC Berkeley that featured talks and exercises about Spark, Spark Streaming, Mesos, and more. [Videos](http://ampcamp.berkeley.edu/3/), @@ -121,11 +125,3 @@ options for deployment: * [Code Examples](http://spark.apache.org/examples.html): more are also available in the `examples` subfolder of Spark ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples), [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python)) - -# Community - -To get help using Spark or keep up with Spark development, sign up for the [user mailing list](http://spark.apache.org/mailing-lists.html). - -If you're in the San Francisco Bay Area, there's a regular [Spark meetup](http://www.meetup.com/spark-users/) every few weeks. Come by to meet the developers and other users. - -Finally, if you'd like to contribute code to Spark, read [how to contribute](contributing-to-spark.html). diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index a5425eb3557b..963e88a3e1d8 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -14,8 +14,7 @@ runs an independent set of executor processes. The cluster managers that Spark r facilities for [scheduling across applications](#scheduling-across-applications). Second, _within_ each Spark application, multiple "jobs" (Spark actions) may be running concurrently if they were submitted by different threads. This is common if your application is serving requests -over the network; for example, the [Shark](http://shark.cs.berkeley.edu) server works this way. Spark -includes a [fair scheduler](#scheduling-within-an-application) to schedule resources within each SparkContext. +over the network. Spark includes a [fair scheduler](#scheduling-within-an-application) to schedule resources within each SparkContext. # Scheduling Across Applications @@ -52,8 +51,7 @@ an application to gain back cores on one node when it has work to do. To use thi Note that none of the modes currently provide memory sharing across applications. If you would like to share data this way, we recommend running a single server application that can serve multiple requests by querying -the same RDDs. For example, the [Shark](http://shark.cs.berkeley.edu) JDBC server works this way for SQL -queries. In future releases, in-memory storage systems such as [Tachyon](http://tachyon-project.org) will +the same RDDs. In future releases, in-memory storage systems such as [Tachyon](http://tachyon-project.org) will provide another approach to share RDDs. ## Dynamic Resource Allocation @@ -77,11 +75,10 @@ scheduling while sharing cluster resources efficiently. ### Configuration and Setup All configurations used by this feature live under the `spark.dynamicAllocation.*` namespace. -To enable this feature, your application must set `spark.dynamicAllocation.enabled` to `true` and -provide lower and upper bounds for the number of executors through -`spark.dynamicAllocation.minExecutors` and `spark.dynamicAllocation.maxExecutors`. Other relevant -configurations are described on the [configurations page](configuration.html#dynamic-allocation) -and in the subsequent sections in detail. +To enable this feature, your application must set `spark.dynamicAllocation.enabled` to `true`. +Other relevant configurations are described on the +[configurations page](configuration.html#dynamic-allocation) and in the subsequent sections in +detail. Additionally, your application must use an external shuffle service. The purpose of the service is to preserve the shuffle files written by executors so the executors can be safely removed (more diff --git a/docs/ml-guide.md b/docs/ml-guide.md index be178d7689fd..771a07183e26 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -23,13 +23,13 @@ to `spark.ml`. Spark ML standardizes APIs for machine learning algorithms to make it easier to combine multiple algorithms into a single pipeline, or workflow. This section covers the key concepts introduced by the Spark ML API. -* **[ML Dataset](ml-guide.html#ml-dataset)**: Spark ML uses the [`SchemaRDD`](api/scala/index.html#org.apache.spark.sql.SchemaRDD) from Spark SQL as a dataset which can hold a variety of data types. +* **[ML Dataset](ml-guide.html#ml-dataset)**: Spark ML uses the [`DataFrame`](api/scala/index.html#org.apache.spark.sql.DataFrame) from Spark SQL as a dataset which can hold a variety of data types. E.g., a dataset could have different columns storing text, feature vectors, true labels, and predictions. -* **[`Transformer`](ml-guide.html#transformers)**: A `Transformer` is an algorithm which can transform one `SchemaRDD` into another `SchemaRDD`. +* **[`Transformer`](ml-guide.html#transformers)**: A `Transformer` is an algorithm which can transform one `DataFrame` into another `DataFrame`. E.g., an ML model is a `Transformer` which transforms an RDD with features into an RDD with predictions. -* **[`Estimator`](ml-guide.html#estimators)**: An `Estimator` is an algorithm which can be fit on a `SchemaRDD` to produce a `Transformer`. +* **[`Estimator`](ml-guide.html#estimators)**: An `Estimator` is an algorithm which can be fit on a `DataFrame` to produce a `Transformer`. E.g., a learning algorithm is an `Estimator` which trains on a dataset and produces a model. * **[`Pipeline`](ml-guide.html#pipeline)**: A `Pipeline` chains multiple `Transformer`s and `Estimator`s together to specify an ML workflow. @@ -39,20 +39,20 @@ E.g., a learning algorithm is an `Estimator` which trains on a dataset and produ ## ML Dataset Machine learning can be applied to a wide variety of data types, such as vectors, text, images, and structured data. -Spark ML adopts the [`SchemaRDD`](api/scala/index.html#org.apache.spark.sql.SchemaRDD) from Spark SQL in order to support a variety of data types under a unified Dataset concept. +Spark ML adopts the [`DataFrame`](api/scala/index.html#org.apache.spark.sql.DataFrame) from Spark SQL in order to support a variety of data types under a unified Dataset concept. -`SchemaRDD` supports many basic and structured types; see the [Spark SQL datatype reference](sql-programming-guide.html#spark-sql-datatype-reference) for a list of supported types. -In addition to the types listed in the Spark SQL guide, `SchemaRDD` can use ML [`Vector`](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) types. +`DataFrame` supports many basic and structured types; see the [Spark SQL datatype reference](sql-programming-guide.html#spark-sql-datatype-reference) for a list of supported types. +In addition to the types listed in the Spark SQL guide, `DataFrame` can use ML [`Vector`](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) types. -A `SchemaRDD` can be created either implicitly or explicitly from a regular `RDD`. See the code examples below and the [Spark SQL programming guide](sql-programming-guide.html) for examples. +A `DataFrame` can be created either implicitly or explicitly from a regular `RDD`. See the code examples below and the [Spark SQL programming guide](sql-programming-guide.html) for examples. -Columns in a `SchemaRDD` are named. The code examples below use names such as "text," "features," and "label." +Columns in a `DataFrame` are named. The code examples below use names such as "text," "features," and "label." ## ML Algorithms ### Transformers -A [`Transformer`](api/scala/index.html#org.apache.spark.ml.Transformer) is an abstraction which includes feature transformers and learned models. Technically, a `Transformer` implements a method `transform()` which converts one `SchemaRDD` into another, generally by appending one or more columns. +A [`Transformer`](api/scala/index.html#org.apache.spark.ml.Transformer) is an abstraction which includes feature transformers and learned models. Technically, a `Transformer` implements a method `transform()` which converts one `DataFrame` into another, generally by appending one or more columns. For example: * A feature transformer might take a dataset, read a column (e.g., text), convert it into a new column (e.g., feature vectors), append the new column to the dataset, and output the updated dataset. @@ -60,7 +60,7 @@ For example: ### Estimators -An [`Estimator`](api/scala/index.html#org.apache.spark.ml.Estimator) abstracts the concept of a learning algorithm or any algorithm which fits or trains on data. Technically, an `Estimator` implements a method `fit()` which accepts a `SchemaRDD` and produces a `Transformer`. +An [`Estimator`](api/scala/index.html#org.apache.spark.ml.Estimator) abstracts the concept of a learning algorithm or any algorithm which fits or trains on data. Technically, an `Estimator` implements a method `fit()` which accepts a `DataFrame` and produces a `Transformer`. For example, a learning algorithm such as `LogisticRegression` is an `Estimator`, and calling `fit()` trains a `LogisticRegressionModel`, which is a `Transformer`. ### Properties of ML Algorithms @@ -101,7 +101,7 @@ We illustrate this for the simple text document workflow. The figure below is f Above, the top row represents a `Pipeline` with three stages. The first two (`Tokenizer` and `HashingTF`) are `Transformer`s (blue), and the third (`LogisticRegression`) is an `Estimator` (red). -The bottom row represents data flowing through the pipeline, where cylinders indicate `SchemaRDD`s. +The bottom row represents data flowing through the pipeline, where cylinders indicate `DataFrame`s. The `Pipeline.fit()` method is called on the original dataset which has raw text documents and labels. The `Tokenizer.transform()` method splits the raw text documents into words, adding a new column with words into the dataset. The `HashingTF.transform()` method converts the words column into feature vectors, adding a new column with those vectors to the dataset. @@ -130,7 +130,7 @@ Each stage's `transform()` method updates the dataset and passes it to the next *DAG `Pipeline`s*: A `Pipeline`'s stages are specified as an ordered array. The examples given here are all for linear `Pipeline`s, i.e., `Pipeline`s in which each stage uses data produced by the previous stage. It is possible to create non-linear `Pipeline`s as long as the data flow graph forms a Directed Acyclic Graph (DAG). This graph is currently specified implicitly based on the input and output column names of each stage (generally specified as parameters). If the `Pipeline` forms a DAG, then the stages must be specified in topological order. -*Runtime checking*: Since `Pipeline`s can operate on datasets with varied types, they cannot use compile-time type checking. `Pipeline`s and `PipelineModel`s instead do runtime checking before actually running the `Pipeline`. This type checking is done using the dataset *schema*, a description of the data types of columns in the `SchemaRDD`. +*Runtime checking*: Since `Pipeline`s can operate on datasets with varied types, they cannot use compile-time type checking. `Pipeline`s and `PipelineModel`s instead do runtime checking before actually running the `Pipeline`. This type checking is done using the dataset *schema*, a description of the data types of columns in the `DataFrame`. ## Parameters @@ -171,12 +171,12 @@ import org.apache.spark.sql.{Row, SQLContext} val conf = new SparkConf().setAppName("SimpleParamsExample") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) -import sqlContext._ +import sqlContext.implicits._ // Prepare training data. // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes -// into SchemaRDDs, where it uses the case class metadata to infer the schema. -val training = sparkContext.parallelize(Seq( +// into DataFrames, where it uses the case class metadata to infer the schema. +val training = sc.parallelize(Seq( LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), @@ -192,7 +192,7 @@ lr.setMaxIter(10) .setRegParam(0.01) // Learn a LogisticRegression model. This uses the parameters stored in lr. -val model1 = lr.fit(training) +val model1 = lr.fit(training.toDF) // Since model1 is a Model (i.e., a Transformer produced by an Estimator), // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this @@ -203,33 +203,35 @@ println("Model 1 was fit using parameters: " + model1.fittingParamMap) // which supports several methods for specifying parameters. val paramMap = ParamMap(lr.maxIter -> 20) paramMap.put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. -paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.5) // Specify multiple Params. +paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. // One can also combine ParamMaps. -val paramMap2 = ParamMap(lr.scoreCol -> "probability") // Changes output column name. +val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name val paramMapCombined = paramMap ++ paramMap2 // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. -val model2 = lr.fit(training, paramMapCombined) +val model2 = lr.fit(training.toDF, paramMapCombined) println("Model 2 was fit using parameters: " + model2.fittingParamMap) -// Prepare test documents. -val test = sparkContext.parallelize(Seq( +// Prepare test data. +val test = sc.parallelize(Seq( LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) -// Make predictions on test documents using the Transformer.transform() method. +// Make predictions on test data using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. -// Note that model2.transform() outputs a 'probability' column instead of the usual 'score' -// column since we renamed the lr.scoreCol parameter previously. -model2.transform(test) - .select('features, 'label, 'probability, 'prediction) +// Note that model2.transform() outputs a 'myProbability' column instead of the usual +// 'probability' column since we renamed the lr.probabilityCol parameter previously. +model2.transform(test.toDF) + .select("features", "label", "myProbability", "prediction") .collect() - .foreach { case Row(features: Vector, label: Double, prob: Double, prediction: Double) => - println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction) + .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) => + println("($features, $label) -> prob=$prob, prediction=$prediction") } + +sc.stop() {% endhighlight %}
    @@ -244,23 +246,23 @@ import org.apache.spark.ml.param.ParamMap; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import org.apache.spark.sql.api.java.Row; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.Row; SparkConf conf = new SparkConf().setAppName("JavaSimpleParamsExample"); JavaSparkContext jsc = new JavaSparkContext(conf); -JavaSQLContext jsql = new JavaSQLContext(jsc); +SQLContext jsql = new SQLContext(jsc); // Prepare training data. -// We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes -// into SchemaRDDs, where it uses the case class metadata to infer the schema. +// We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans +// into DataFrames, where it uses the bean metadata to infer the schema. List localTraining = Lists.newArrayList( new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); -JavaSchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class); +DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class); // Create a LogisticRegression instance. This instance is an Estimator. LogisticRegression lr = new LogisticRegression(); @@ -281,13 +283,13 @@ System.out.println("Model 1 was fit using parameters: " + model1.fittingParamMap // We may alternatively specify parameters using a ParamMap. ParamMap paramMap = new ParamMap(); -paramMap.put(lr.maxIter(), 20); // Specify 1 Param. +paramMap.put(lr.maxIter().w(20)); // Specify 1 Param. paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter. -paramMap.put(lr.regParam(), 0.1); +paramMap.put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. // One can also combine ParamMaps. ParamMap paramMap2 = new ParamMap(); -paramMap2.put(lr.scoreCol(), "probability"); // Changes output column name. +paramMap2.put(lr.probabilityCol().w("myProbability")); // Change output column name ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); // Now learn a new model using the paramMapCombined parameters. @@ -300,19 +302,19 @@ List localTest = Lists.newArrayList( new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); -JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class); +DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. -// Note that model2.transform() outputs a 'probability' column instead of the usual 'score' -// column since we renamed the lr.scoreCol parameter previously. -model2.transform(test).registerAsTable("results"); -JavaSchemaRDD results = - jsql.sql("SELECT features, label, probability, prediction FROM results"); -for (Row r: results.collect()) { +// Note that model2.transform() outputs a 'myProbability' column instead of the usual +// 'probability' column since we renamed the lr.probabilityCol parameter previously. +DataFrame results = model2.transform(test); +for (Row r: results.select("features", "label", "myProbability", "prediction").collect()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) + ", prediction=" + r.get(3)); } + +jsc.stop(); {% endhighlight %}
    @@ -330,6 +332,7 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{Row, SQLContext} // Labeled and unlabeled instance types. @@ -337,14 +340,14 @@ import org.apache.spark.sql.{Row, SQLContext} case class LabeledDocument(id: Long, text: String, label: Double) case class Document(id: Long, text: String) -// Set up contexts. Import implicit conversions to SchemaRDD from sqlContext. +// Set up contexts. Import implicit conversions to DataFrame from sqlContext. val conf = new SparkConf().setAppName("SimpleTextClassificationPipeline") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) -import sqlContext._ +import sqlContext.implicits._ // Prepare training documents, which are labeled. -val training = sparkContext.parallelize(Seq( +val training = sc.parallelize(Seq( LabeledDocument(0L, "a b c d e spark", 1.0), LabeledDocument(1L, "b d", 0.0), LabeledDocument(2L, "spark f g h", 1.0), @@ -365,30 +368,32 @@ val pipeline = new Pipeline() .setStages(Array(tokenizer, hashingTF, lr)) // Fit the pipeline to training documents. -val model = pipeline.fit(training) +val model = pipeline.fit(training.toDF) // Prepare test documents, which are unlabeled. -val test = sparkContext.parallelize(Seq( +val test = sc.parallelize(Seq( Document(4L, "spark i j k"), Document(5L, "l m n"), Document(6L, "mapreduce spark"), Document(7L, "apache hadoop"))) // Make predictions on test documents. -model.transform(test) - .select('id, 'text, 'score, 'prediction) +model.transform(test.toDF) + .select("id", "text", "probability", "prediction") .collect() - .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => - println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println("($id, $text) --> prob=$prob, prediction=$prediction") } + +sc.stop() {% endhighlight %}
    {% highlight java %} -import java.io.Serializable; import java.util.List; import com.google.common.collect.Lists; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineModel; @@ -396,45 +401,44 @@ import org.apache.spark.ml.PipelineStage; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.HashingTF; import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import org.apache.spark.sql.api.java.Row; -import org.apache.spark.SparkConf; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; // Labeled and unlabeled instance types. // Spark SQL can infer schema from Java Beans. public class Document implements Serializable { - private Long id; + private long id; private String text; - public Document(Long id, String text) { + public Document(long id, String text) { this.id = id; this.text = text; } - public Long getId() { return this.id; } - public void setId(Long id) { this.id = id; } + public long getId() { return this.id; } + public void setId(long id) { this.id = id; } public String getText() { return this.text; } public void setText(String text) { this.text = text; } } public class LabeledDocument extends Document implements Serializable { - private Double label; + private double label; - public LabeledDocument(Long id, String text, Double label) { + public LabeledDocument(long id, String text, double label) { super(id, text); this.label = label; } - public Double getLabel() { return this.label; } - public void setLabel(Double label) { this.label = label; } + public double getLabel() { return this.label; } + public void setLabel(double label) { this.label = label; } } // Set up contexts. SparkConf conf = new SparkConf().setAppName("JavaSimpleTextClassificationPipeline"); JavaSparkContext jsc = new JavaSparkContext(conf); -JavaSQLContext jsql = new JavaSQLContext(jsc); +SQLContext jsql = new SQLContext(jsc); // Prepare training documents, which are labeled. List localTraining = Lists.newArrayList( @@ -442,8 +446,7 @@ List localTraining = Lists.newArrayList( new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), new LabeledDocument(3L, "hadoop mapreduce", 0.0)); -JavaSchemaRDD training = - jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); +DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -468,16 +471,62 @@ List localTest = Lists.newArrayList( new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); -JavaSchemaRDD test = - jsql.applySchema(jsc.parallelize(localTest), Document.class); +DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. -model.transform(test).registerAsTable("prediction"); -JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); -for (Row r: predictions.collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) +DataFrame predictions = model.transform(test); +for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + ", prediction=" + r.get(3)); } + +jsc.stop(); +{% endhighlight %} +
    + +
    +{% highlight python %} +from pyspark import SparkContext +from pyspark.ml import Pipeline +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.feature import HashingTF, Tokenizer +from pyspark.sql import Row, SQLContext + +sc = SparkContext(appName="SimpleTextClassificationPipeline") +sqlContext = SQLContext(sc) + +# Prepare training documents, which are labeled. +LabeledDocument = Row("id", "text", "label") +training = sc.parallelize([(0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0)]) \ + .map(lambda x: LabeledDocument(*x)).toDF() + +# Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. +tokenizer = Tokenizer(inputCol="text", outputCol="words") +hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") +lr = LogisticRegression(maxIter=10, regParam=0.01) +pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) + +# Fit the pipeline to training documents. +model = pipeline.fit(training) + +# Prepare test documents, which are unlabeled. +Document = Row("id", "text") +test = sc.parallelize([(4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop")]) \ + .map(lambda x: Document(*x)).toDF() + +# Make predictions on test documents and print columns of interest. +prediction = model.transform(test) +selected = prediction.select("id", "text", "prediction") +for row in selected.collect(): + print row + +sc.stop() {% endhighlight %}
    @@ -508,21 +557,26 @@ However, it is also a well-established method for choosing parameters which is m
    {% highlight scala %} import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.ml.feature.{HashingTF, Tokenizer} import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{Row, SQLContext} +// Labeled and unlabeled instance types. +// Spark SQL can infer schema from case classes. +case class LabeledDocument(id: Long, text: String, label: Double) +case class Document(id: Long, text: String) + val conf = new SparkConf().setAppName("CrossValidatorExample") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) -import sqlContext._ +import sqlContext.implicits._ // Prepare training documents, which are labeled. -val training = sparkContext.parallelize(Seq( +val training = sc.parallelize(Seq( LabeledDocument(0L, "a b c d e spark", 1.0), LabeledDocument(1L, "b d", 0.0), LabeledDocument(2L, "spark f g h", 1.0), @@ -565,24 +619,24 @@ crossval.setEstimatorParamMaps(paramGrid) crossval.setNumFolds(2) // Use 3+ in practice // Run cross-validation, and choose the best set of parameters. -val cvModel = crossval.fit(training) -// Get the best LogisticRegression model (with the best set of parameters from paramGrid). -val lrModel = cvModel.bestModel +val cvModel = crossval.fit(training.toDF) // Prepare test documents, which are unlabeled. -val test = sparkContext.parallelize(Seq( +val test = sc.parallelize(Seq( Document(4L, "spark i j k"), Document(5L, "l m n"), Document(6L, "mapreduce spark"), Document(7L, "apache hadoop"))) // Make predictions on test documents. cvModel uses the best model found (lrModel). -cvModel.transform(test) - .select('id, 'text, 'score, 'prediction) +cvModel.transform(test.toDF) + .select("id", "text", "probability", "prediction") .collect() - .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => - println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println(s"($id, $text) --> prob=$prob, prediction=$prediction") } + +sc.stop() {% endhighlight %}
    @@ -592,7 +646,6 @@ import java.util.List; import com.google.common.collect.Lists; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.Model; import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineStage; import org.apache.spark.ml.classification.LogisticRegression; @@ -603,13 +656,43 @@ import org.apache.spark.ml.param.ParamMap; import org.apache.spark.ml.tuning.CrossValidator; import org.apache.spark.ml.tuning.CrossValidatorModel; import org.apache.spark.ml.tuning.ParamGridBuilder; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import org.apache.spark.sql.api.java.Row; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; + +// Labeled and unlabeled instance types. +// Spark SQL can infer schema from Java Beans. +public class Document implements Serializable { + private long id; + private String text; + + public Document(long id, String text) { + this.id = id; + this.text = text; + } + + public long getId() { return this.id; } + public void setId(long id) { this.id = id; } + + public String getText() { return this.text; } + public void setText(String text) { this.text = text; } +} + +public class LabeledDocument extends Document implements Serializable { + private double label; + + public LabeledDocument(long id, String text, double label) { + super(id, text); + this.label = label; + } + + public double getLabel() { return this.label; } + public void setLabel(double label) { this.label = label; } +} SparkConf conf = new SparkConf().setAppName("JavaCrossValidatorExample"); JavaSparkContext jsc = new JavaSparkContext(conf); -JavaSQLContext jsql = new JavaSQLContext(jsc); +SQLContext jsql = new SQLContext(jsc); // Prepare training documents, which are labeled. List localTraining = Lists.newArrayList( @@ -625,8 +708,7 @@ List localTraining = Lists.newArrayList( new LabeledDocument(9L, "a e c l", 0.0), new LabeledDocument(10L, "spark compile", 1.0), new LabeledDocument(11L, "hadoop software", 0.0)); -JavaSchemaRDD training = - jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); +DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -660,8 +742,6 @@ crossval.setNumFolds(2); // Use 3+ in practice // Run cross-validation, and choose the best set of parameters. CrossValidatorModel cvModel = crossval.fit(training); -// Get the best LogisticRegression model (with the best set of parameters from paramGrid). -Model lrModel = cvModel.bestModel(); // Prepare test documents, which are unlabeled. List localTest = Lists.newArrayList( @@ -669,15 +749,16 @@ List localTest = Lists.newArrayList( new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); -JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class); +DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. cvModel uses the best model found (lrModel). -cvModel.transform(test).registerAsTable("prediction"); -JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); -for (Row r: predictions.collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) +DataFrame predictions = cvModel.transform(test); +for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + ", prediction=" + r.get(3)); } + +jsc.stop(); {% endhighlight %} @@ -686,6 +767,21 @@ for (Row r: predictions.collect()) { # Dependencies Spark ML currently depends on MLlib and has the same dependencies. -Please see the [MLlib Dependencies guide](mllib-guide.html#Dependencies) for more info. +Please see the [MLlib Dependencies guide](mllib-guide.html#dependencies) for more info. Spark ML also depends upon Spark SQL, but the relevant parts of Spark SQL do not bring additional dependencies. + +# Migration Guide + +## From 1.2 to 1.3 + +The main API changes are from Spark SQL. We list the most important changes here: + +* The old [SchemaRDD](http://spark.apache.org/docs/1.2.1/api/scala/index.html#org.apache.spark.sql.SchemaRDD) has been replaced with [DataFrame](api/scala/index.html#org.apache.spark.sql.DataFrame) with a somewhat modified API. All algorithms in Spark ML which used to use SchemaRDD now use DataFrame. +* In Spark 1.2, we used implicit conversions from `RDD`s of `LabeledPoint` into `SchemaRDD`s by calling `import sqlContext._` where `sqlContext` was an instance of `SQLContext`. These implicits have been moved, so we now call `import sqlContext.implicits._`. +* Java APIs for SQL have also changed accordingly. Please see the examples above and the [Spark SQL Programming Guide](sql-programming-guide.html) for details. + +Other changes were in `LogisticRegression`: + +* The `scoreCol` output column (with default value "score") was renamed to be `probabilityCol` (with default value "probability"). The type was originally `Double` (for the probability of class 1.0), but it is now `Vector` (for the probability of each class, to support multiclass classification in the future). +* In Spark 1.2, `LogisticRegressionModel` did not include an intercept. In Spark 1.3, it includes an intercept; however, it will always be 0.0 since it uses the default settings for [spark.mllib.LogisticRegressionWithLBFGS](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS). The option to use an intercept will be added in the future. diff --git a/docs/mllib-classification-regression.md b/docs/mllib-classification-regression.md index 719cc95767b0..8e91d62f4a90 100644 --- a/docs/mllib-classification-regression.md +++ b/docs/mllib-classification-regression.md @@ -17,13 +17,13 @@ the supported algorithms for each type of problem. - Binary Classificationlinear SVMs, logistic regression, decision trees, naive Bayes + Binary Classificationlinear SVMs, logistic regression, decision trees, random forests, gradient-boosted trees, naive Bayes - Multiclass Classificationdecision trees, naive Bayes + Multiclass Classificationdecision trees, random forests, naive Bayes - Regressionlinear least squares, Lasso, ridge regression, decision trees + Regressionlinear least squares, Lasso, ridge regression, decision trees, random forests, gradient-boosted trees, isotonic regression @@ -34,4 +34,8 @@ More details for these methods can be found here: * [binary classification (SVMs, logistic regression)](mllib-linear-methods.html#binary-classification) * [linear regression (least squares, Lasso, ridge)](mllib-linear-methods.html#linear-least-squares-lasso-and-ridge-regression) * [Decision trees](mllib-decision-tree.html) +* [Ensembles of decision trees](mllib-ensembles.html) + * [random forests](mllib-ensembles.html#random-forests) + * [gradient-boosted trees](mllib-ensembles.html#gradient-boosted-trees-gbts) * [Naive Bayes](mllib-naive-bayes.html) +* [Isotonic regression](mllib-isotonic-regression.html) diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index c696ae9c8e8c..f5aa15b7d9b7 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -4,25 +4,25 @@ title: Clustering - MLlib displayTitle: MLlib - Clustering --- -* Table of contents -{:toc} - - -## Clustering - Clustering is an unsupervised learning problem whereby we aim to group subsets of entities with one another based on some notion of similarity. Clustering is often used for exploratory analysis and/or as a component of a hierarchical supervised learning pipeline (in which distinct classifiers or regression -models are trained for each cluster). +models are trained for each cluster). + +MLlib supports the following models: -MLlib supports -[k-means](http://en.wikipedia.org/wiki/K-means_clustering) clustering, one of -the most commonly used clustering algorithms that clusters the data points into +* Table of contents +{:toc} + +## K-means + +[k-means](http://en.wikipedia.org/wiki/K-means_clustering) is one of the +most commonly used clustering algorithms that clusters the data points into a predefined number of clusters. The MLlib implementation includes a parallelized variant of the [k-means++](http://en.wikipedia.org/wiki/K-means%2B%2B) method called [kmeans||](http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf). -The implementation in MLlib has the following parameters: +The implementation in MLlib has the following parameters: * *k* is the number of desired clusters. * *maxIterations* is the maximum number of iterations to run. @@ -32,9 +32,9 @@ initialization via k-means\|\|. guaranteed to find a globally optimal solution, and when run multiple times on a given dataset, the algorithm returns the best clustering result). * *initializationSteps* determines the number of steps in the k-means\|\| algorithm. -* *epsilon* determines the distance threshold within which we consider k-means to have converged. +* *epsilon* determines the distance threshold within which we consider k-means to have converged. -### Examples +**Examples**
    @@ -148,41 +148,378 @@ print("Within Set Sum of Squared Error = " + str(WSSSE))
    -In order to run the above application, follow the instructions -provided in the [Self-Contained Applications](quick-start.html#self-contained-applications) -section of the Spark -Quick Start guide. Be sure to also include *spark-mllib* to your build file as -a dependency. +## Gaussian mixture + +A [Gaussian Mixture Model](http://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model) +represents a composite distribution whereby points are drawn from one of *k* Gaussian sub-distributions, +each with its own probability. The MLlib implementation uses the +[expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) + algorithm to induce the maximum-likelihood model given a set of samples. The implementation +has the following parameters: + +* *k* is the number of desired clusters. +* *convergenceTol* is the maximum change in log-likelihood at which we consider convergence achieved. +* *maxIterations* is the maximum number of iterations to perform without reaching convergence. +* *initialModel* is an optional starting point from which to start the EM algorithm. If this parameter is omitted, a random starting point will be constructed from the data. + +**Examples** + +
    +
    +In the following example after loading and parsing data, we use a +[GaussianMixture](api/scala/index.html#org.apache.spark.mllib.clustering.GaussianMixture) +object to cluster the data into two clusters. The number of desired clusters is passed +to the algorithm. We then output the parameters of the mixture model. + +{% highlight scala %} +import org.apache.spark.mllib.clustering.GaussianMixture +import org.apache.spark.mllib.clustering.GaussianMixtureModel +import org.apache.spark.mllib.linalg.Vectors + +// Load and parse the data +val data = sc.textFile("data/mllib/gmm_data.txt") +val parsedData = data.map(s => Vectors.dense(s.trim.split(' ').map(_.toDouble))).cache() + +// Cluster the data into two classes using GaussianMixture +val gmm = new GaussianMixture().setK(2).run(parsedData) + +// Save and load model +gmm.save(sc, "myGMMModel") +val sameModel = GaussianMixtureModel.load(sc, "myGMMModel") + +// output parameters of max-likelihood model +for (i <- 0 until gmm.k) { + println("weight=%f\nmu=%s\nsigma=\n%s\n" format + (gmm.weights(i), gmm.gaussians(i).mu, gmm.gaussians(i).sigma)) +} + +{% endhighlight %} +
    + +
    +All of MLlib's methods use Java-friendly types, so you can import and call them there the same +way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the +Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by +calling `.rdd()` on your `JavaRDD` object. A self-contained application example +that is equivalent to the provided example in Scala is given below: + +{% highlight java %} +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.clustering.GaussianMixture; +import org.apache.spark.mllib.clustering.GaussianMixtureModel; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.SparkConf; + +public class GaussianMixtureExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("GaussianMixture Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // Load and parse data + String path = "data/mllib/gmm_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public Vector call(String s) { + String[] sarray = s.trim().split(" "); + double[] values = new double[sarray.length]; + for (int i = 0; i < sarray.length; i++) + values[i] = Double.parseDouble(sarray[i]); + return Vectors.dense(values); + } + } + ); + parsedData.cache(); + + // Cluster the data into two classes using GaussianMixture + GaussianMixtureModel gmm = new GaussianMixture().setK(2).run(parsedData.rdd()); + + // Save and load GaussianMixtureModel + gmm.save(sc, "myGMMModel") + GaussianMixtureModel sameModel = GaussianMixtureModel.load(sc, "myGMMModel") + // Output the parameters of the mixture model + for(int j=0; j + +
    +In the following example after loading and parsing data, we use a +[GaussianMixture](api/python/pyspark.mllib.html#pyspark.mllib.clustering.GaussianMixture) +object to cluster the data into two clusters. The number of desired clusters is passed +to the algorithm. We then output the parameters of the mixture model. + +{% highlight python %} +from pyspark.mllib.clustering import GaussianMixture +from numpy import array + +# Load and parse the data +data = sc.textFile("data/mllib/gmm_data.txt") +parsedData = data.map(lambda line: array([float(x) for x in line.strip().split(' ')])) + +# Build the model (cluster the data) +gmm = GaussianMixture.train(parsedData, 2) + +# output parameters of model +for i in range(2): + print ("weight = ", gmm.weights[i], "mu = ", gmm.gaussians[i].mu, + "sigma = ", gmm.gaussians[i].sigma.toArray()) + +{% endhighlight %} +
    + +
    + +## Power iteration clustering (PIC) -## Streaming clustering +Power iteration clustering (PIC) is a scalable and efficient algorithm for clustering vertices of a +graph given pairwise similarties as edge properties, +described in [Lin and Cohen, Power Iteration Clustering](http://www.icml2010.org/papers/387.pdf). +It computes a pseudo-eigenvector of the normalized affinity matrix of the graph via +[power iteration](http://en.wikipedia.org/wiki/Power_iteration) and uses it to cluster vertices. +MLlib includes an implementation of PIC using GraphX as its backend. +It takes an `RDD` of `(srcId, dstId, similarity)` tuples and outputs a model with the clustering assignments. +The similarities must be nonnegative. +PIC assumes that the similarity measure is symmetric. +A pair `(srcId, dstId)` regardless of the ordering should appear at most once in the input data. +If a pair is missing from input, their similarity is treated as zero. +MLlib's PIC implementation takes the following (hyper-)parameters: -When data arrive in a stream, we may want to estimate clusters dynamically, -updating them as new data arrive. MLlib provides support for streaming k-means clustering, -with parameters to control the decay (or "forgetfulness") of the estimates. The algorithm -uses a generalization of the mini-batch k-means update rule. For each batch of data, we assign +* `k`: number of clusters +* `maxIterations`: maximum number of power iterations +* `initializationMode`: initialization model. This can be either "random", which is the default, + to use a random vector as vertex properties, or "degree" to use normalized sum similarities. + +**Examples** + +In the following, we show code snippets to demonstrate how to use PIC in MLlib. + +
    +
    + +[`PowerIterationClustering`](api/scala/index.html#org.apache.spark.mllib.clustering.PowerIterationClustering) +implements the PIC algorithm. +It takes an `RDD` of `(srcId: Long, dstId: Long, similarity: Double)` tuples representing the +affinity matrix. +Calling `PowerIterationClustering.run` returns a +[`PowerIterationClusteringModel`](api/scala/index.html#org.apache.spark.mllib.clustering.PowerIterationClusteringModel), +which contains the computed clustering assignments. + +{% highlight scala %} +import org.apache.spark.mllib.clustering.PowerIterationClustering +import org.apache.spark.mllib.linalg.Vectors + +val similarities: RDD[(Long, Long, Double)] = ... + +val pic = new PowerIteartionClustering() + .setK(3) + .setMaxIterations(20) +val model = pic.run(similarities) + +model.assignments.foreach { a => + println(s"${a.id} -> ${a.cluster}") +} +{% endhighlight %} + +A full example that produces the experiment described in the PIC paper can be found under +[`examples/`](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala). + +
    + +
    + +[`PowerIterationClustering`](api/java/org/apache/spark/mllib/clustering/PowerIterationClustering.html) +implements the PIC algorithm. +It takes an `JavaRDD` of `(srcId: Long, dstId: Long, similarity: Double)` tuples representing the +affinity matrix. +Calling `PowerIterationClustering.run` returns a +[`PowerIterationClusteringModel`](api/java/org/apache/spark/mllib/clustering/PowerIterationClusteringModel.html) +which contains the computed clustering assignments. + +{% highlight java %} +import scala.Tuple2; +import scala.Tuple3; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.clustering.PowerIterationClustering; +import org.apache.spark.mllib.clustering.PowerIterationClusteringModel; + +JavaRDD> similarities = ... + +PowerIterationClustering pic = new PowerIterationClustering() + .setK(2) + .setMaxIterations(10); +PowerIterationClusteringModel model = pic.run(similarities); + +for (PowerIterationClustering.Assignment a: model.assignments().toJavaRDD().collect()) { + System.out.println(a.id() + " -> " + a.cluster()); +} +{% endhighlight %} +
    + +
    + +## Latent Dirichlet allocation (LDA) + +[Latent Dirichlet allocation (LDA)](http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation) +is a topic model which infers topics from a collection of text documents. +LDA can be thought of as a clustering algorithm as follows: + +* Topics correspond to cluster centers, and documents correspond to examples (rows) in a dataset. +* Topics and documents both exist in a feature space, where feature vectors are vectors of word counts. +* Rather than estimating a clustering using a traditional distance, LDA uses a function based + on a statistical model of how text documents are generated. + +LDA takes in a collection of documents as vectors of word counts. +It learns clustering using [expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) +on the likelihood function. After fitting on the documents, LDA provides: + +* Topics: Inferred topics, each of which is a probability distribution over terms (words). +* Topic distributions for documents: For each document in the training set, LDA gives a probability distribution over topics. + +LDA takes the following parameters: + +* `k`: Number of topics (i.e., cluster centers) +* `maxIterations`: Limit on the number of iterations of EM used for learning +* `docConcentration`: Hyperparameter for prior over documents' distributions over topics. Currently must be > 1, where larger values encourage smoother inferred distributions. +* `topicConcentration`: Hyperparameter for prior over topics' distributions over terms (words). Currently must be > 1, where larger values encourage smoother inferred distributions. +* `checkpointInterval`: If using checkpointing (set in the Spark configuration), this parameter specifies the frequency with which checkpoints will be created. If `maxIterations` is large, using checkpointing can help reduce shuffle file sizes on disk and help with failure recovery. + +*Note*: LDA is a new feature with some missing functionality. In particular, it does not yet +support prediction on new documents, and it does not have a Python API. These will be added in the future. + +**Examples** + +In the following example, we load word count vectors representing a corpus of documents. +We then use [LDA](api/scala/index.html#org.apache.spark.mllib.clustering.LDA) +to infer three topics from the documents. The number of desired clusters is passed +to the algorithm. We then output the topics, represented as probability distributions over words. + +
    +
    + +{% highlight scala %} +import org.apache.spark.mllib.clustering.LDA +import org.apache.spark.mllib.linalg.Vectors + +// Load and parse the data +val data = sc.textFile("data/mllib/sample_lda_data.txt") +val parsedData = data.map(s => Vectors.dense(s.trim.split(' ').map(_.toDouble))) +// Index documents with unique IDs +val corpus = parsedData.zipWithIndex.map(_.swap).cache() + +// Cluster the documents into three topics using LDA +val ldaModel = new LDA().setK(3).run(corpus) + +// Output topics. Each is a distribution over words (matching word count vectors) +println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize + " words):") +val topics = ldaModel.topicsMatrix +for (topic <- Range(0, 3)) { + print("Topic " + topic + ":") + for (word <- Range(0, ldaModel.vocabSize)) { print(" " + topics(word, topic)); } + println() +} +{% endhighlight %} +
    + +
    +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.clustering.DistributedLDAModel; +import org.apache.spark.mllib.clustering.LDA; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.SparkConf; + +public class JavaLDAExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("LDA Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // Load and parse the data + String path = "data/mllib/sample_lda_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public Vector call(String s) { + String[] sarray = s.trim().split(" "); + double[] values = new double[sarray.length]; + for (int i = 0; i < sarray.length; i++) + values[i] = Double.parseDouble(sarray[i]); + return Vectors.dense(values); + } + } + ); + // Index documents with unique IDs + JavaPairRDD corpus = JavaPairRDD.fromJavaRDD(parsedData.zipWithIndex().map( + new Function, Tuple2>() { + public Tuple2 call(Tuple2 doc_id) { + return doc_id.swap(); + } + } + )); + corpus.cache(); + + // Cluster the documents into three topics using LDA + DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus); + + // Output topics. Each is a distribution over words (matching word count vectors) + System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize() + + " words):"); + Matrix topics = ldaModel.topicsMatrix(); + for (int topic = 0; topic < 3; topic++) { + System.out.print("Topic " + topic + ":"); + for (int word = 0; word < ldaModel.vocabSize(); word++) { + System.out.print(" " + topics.apply(word, topic)); + } + System.out.println(); + } + } +} +{% endhighlight %} +
    + +
    + +## Streaming k-means + +When data arrive in a stream, we may want to estimate clusters dynamically, +updating them as new data arrive. MLlib provides support for streaming k-means clustering, +with parameters to control the decay (or "forgetfulness") of the estimates. The algorithm +uses a generalization of the mini-batch k-means update rule. For each batch of data, we assign all points to their nearest cluster, compute new cluster centers, then update each cluster using: `\begin{equation} c_{t+1} = \frac{c_tn_t\alpha + x_tm_t}{n_t\alpha+m_t} \end{equation}` `\begin{equation} - n_{t+1} = n_t + m_t + n_{t+1} = n_t + m_t \end{equation}` -Where `$c_t$` is the previous center for the cluster, `$n_t$` is the number of points assigned -to the cluster thus far, `$x_t$` is the new cluster center from the current batch, and `$m_t$` -is the number of points added to the cluster in the current batch. The decay factor `$\alpha$` -can be used to ignore the past: with `$\alpha$=1` all data will be used from the beginning; -with `$\alpha$=0` only the most recent data will be used. This is analogous to an -exponentially-weighted moving average. +Where `$c_t$` is the previous center for the cluster, `$n_t$` is the number of points assigned +to the cluster thus far, `$x_t$` is the new cluster center from the current batch, and `$m_t$` +is the number of points added to the cluster in the current batch. The decay factor `$\alpha$` +can be used to ignore the past: with `$\alpha$=1` all data will be used from the beginning; +with `$\alpha$=0` only the most recent data will be used. This is analogous to an +exponentially-weighted moving average. -The decay can be specified using a `halfLife` parameter, which determines the +The decay can be specified using a `halfLife` parameter, which determines the correct decay factor `a` such that, for data acquired at time `t`, its contribution by time `t + halfLife` will have dropped to 0.5. The unit of time can be specified either as `batches` or `points` and the update rule will be adjusted accordingly. -### Examples +**Examples** This example shows how to estimate clusters on streaming data. @@ -200,9 +537,9 @@ import org.apache.spark.mllib.clustering.StreamingKMeans {% endhighlight %} -Then we make an input stream of vectors for training, as well as a stream of labeled data -points for testing. We assume a StreamingContext `ssc` has been created, see -[Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info. +Then we make an input stream of vectors for training, as well as a stream of labeled data +points for testing. We assume a StreamingContext `ssc` has been created, see +[Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info. {% highlight scala %} @@ -224,24 +561,24 @@ val model = new StreamingKMeans() {% endhighlight %} -Now register the streams for training and testing and start the job, printing +Now register the streams for training and testing and start the job, printing the predicted cluster assignments on new data points as they arrive. {% highlight scala %} model.trainOn(trainingData) -model.predictOnValues(testData).print() +model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() ssc.start() ssc.awaitTermination() - + {% endhighlight %} -As you add new text files with data the cluster centers will update. Each training +As you add new text files with data the cluster centers will update. Each training point should be formatted as `[x1, x2, x3]`, and each test data point -should be formatted as `(y, [x1, x2, x3])`, where `y` is some useful label or identifier -(e.g. a true category assignment). Anytime a text file is placed in `/training/data/dir` -the model will update. Anytime a text file is placed in `/testing/data/dir` +should be formatted as `(y, [x1, x2, x3])`, where `y` is some useful label or identifier +(e.g. a true category assignment). Anytime a text file is placed in `/training/data/dir` +the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. With new data, the cluster centers will change!
    diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 209496339229..76140282a2dd 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -66,6 +66,7 @@ recommendation model by measuring the Mean Squared Error of rating prediction. {% highlight scala %} import org.apache.spark.mllib.recommendation.ALS +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel import org.apache.spark.mllib.recommendation.Rating // Load and parse the data @@ -95,6 +96,10 @@ val MSE = ratesAndPreds.map { case ((user, product), (r1, r2)) => err * err }.mean() println("Mean Squared Error = " + MSE) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = MatrixFactorizationModel.load(sc, "myModelPath") {% endhighlight %} If the rating matrix is derived from another source of information (e.g., it is inferred from @@ -181,6 +186,10 @@ public class CollaborativeFiltering { } ).rdd()).mean(); System.out.println("Mean Squared Error = " + MSE); + + // Save and load model + model.save(sc.sc(), "myModelPath"); + MatrixFactorizationModel sameModel = MatrixFactorizationModel.load(sc.sc(), "myModelPath"); } } {% endhighlight %} @@ -192,12 +201,11 @@ We use the default ALS.train() method which assumes ratings are explicit. We eva recommendation by measuring the Mean Squared Error of rating prediction. {% highlight python %} -from pyspark.mllib.recommendation import ALS -from numpy import array +from pyspark.mllib.recommendation import ALS, MatrixFactorizationModel, Rating # Load and parse the data data = sc.textFile("data/mllib/als/test.data") -ratings = data.map(lambda line: array([float(x) for x in line.split(',')])) +ratings = data.map(lambda l: l.split(',')).map(lambda l: Rating(int(l[0]), int(l[1]), float(l[2]))) # Build the recommendation model using Alternating Least Squares rank = 10 @@ -205,11 +213,15 @@ numIterations = 20 model = ALS.train(ratings, rank, numIterations) # Evaluate the model on training data -testdata = ratings.map(lambda p: (int(p[0]), int(p[1]))) +testdata = ratings.map(lambda p: (p[0], p[1])) predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2])) ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions) -MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y)/ratesAndPreds.count() +MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y) / ratesAndPreds.count() print("Mean Squared Error = " + str(MSE)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = MatrixFactorizationModel.load(sc, "myModelPath") {% endhighlight %} If the rating matrix is derived from other source of information (i.e., it is inferred from other @@ -217,7 +229,7 @@ signals), you can use the trainImplicit method to get better results. {% highlight python %} # Build the recommendation model using Alternating Least Squares based on implicit ratings -model = ALS.trainImplicit(ratings, rank, numIterations, alpha = 0.01) +model = ALS.trainImplicit(ratings, rank, numIterations, alpha=0.01) {% endhighlight %}
    diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index 101dc2f8695f..4f2a2f71048f 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -78,13 +78,13 @@ MLlib recognizes the following types as dense vectors: and the following as sparse vectors: -* MLlib's [`SparseVector`](api/python/pyspark.mllib.linalg.SparseVector-class.html). +* MLlib's [`SparseVector`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.SparseVector). * SciPy's [`csc_matrix`](http://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csc_matrix.html#scipy.sparse.csc_matrix) with a single column We recommend using NumPy arrays over lists for efficiency, and using the factory methods implemented -in [`Vectors`](api/python/pyspark.mllib.linalg.Vectors-class.html) to create sparse vectors. +in [`Vectors`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Vector) to create sparse vectors. {% highlight python %} import numpy as np @@ -151,7 +151,7 @@ LabeledPoint neg = new LabeledPoint(1.0, Vectors.sparse(3, new int[] {0, 2}, new
    A labeled point is represented by -[`LabeledPoint`](api/python/pyspark.mllib.regression.LabeledPoint-class.html). +[`LabeledPoint`](api/python/pyspark.mllib.html#pyspark.mllib.regression.LabeledPoint). {% highlight python %} from pyspark.mllib.linalg import SparseVector @@ -211,7 +211,7 @@ JavaRDD examples =
    -[`MLUtils.loadLibSVMFile`](api/python/pyspark.mllib.util.MLUtils-class.html) reads training +[`MLUtils.loadLibSVMFile`](api/python/pyspark.mllib.html#pyspark.mllib.util.MLUtils) reads training examples stored in LIBSVM format. {% highlight python %} @@ -296,6 +296,70 @@ backed by an RDD of its entries. The underlying RDDs of a distributed matrix must be deterministic, because we cache the matrix size. In general the use of non-deterministic RDDs can lead to errors. +### BlockMatrix + +A `BlockMatrix` is a distributed matrix backed by an RDD of `MatrixBlock`s, where a `MatrixBlock` is +a tuple of `((Int, Int), Matrix)`, where the `(Int, Int)` is the index of the block, and `Matrix` is +the sub-matrix at the given index with size `rowsPerBlock` x `colsPerBlock`. +`BlockMatrix` supports methods such as `add` and `multiply` with another `BlockMatrix`. +`BlockMatrix` also has a helper function `validate` which can be used to check whether the +`BlockMatrix` is set up properly. + +
    +
    + +A [`BlockMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.BlockMatrix) can be +most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` by calling `toBlockMatrix`. +`toBlockMatrix` creates blocks of size 1024 x 1024 by default. +Users may change the block size by supplying the values through `toBlockMatrix(rowsPerBlock, colsPerBlock)`. + +{% highlight scala %} +import org.apache.spark.mllib.linalg.distributed.{BlockMatrix, CoordinateMatrix, MatrixEntry} + +val entries: RDD[MatrixEntry] = ... // an RDD of (i, j, v) matrix entries +// Create a CoordinateMatrix from an RDD[MatrixEntry]. +val coordMat: CoordinateMatrix = new CoordinateMatrix(entries) +// Transform the CoordinateMatrix to a BlockMatrix +val matA: BlockMatrix = coordMat.toBlockMatrix().cache() + +// Validate whether the BlockMatrix is set up properly. Throws an Exception when it is not valid. +// Nothing happens if it is valid. +matA.validate() + +// Calculate A^T A. +val ata = matA.transpose.multiply(matA) +{% endhighlight %} +
    + +
    + +A [`BlockMatrix`](api/java/org/apache/spark/mllib/linalg/distributed/BlockMatrix.html) can be +most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` by calling `toBlockMatrix`. +`toBlockMatrix` creates blocks of size 1024 x 1024 by default. +Users may change the block size by supplying the values through `toBlockMatrix(rowsPerBlock, colsPerBlock)`. + +{% highlight java %} +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.linalg.distributed.BlockMatrix; +import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix; +import org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix; + +JavaRDD entries = ... // a JavaRDD of (i, j, v) Matrix Entries +// Create a CoordinateMatrix from a JavaRDD. +CoordinateMatrix coordMat = new CoordinateMatrix(entries.rdd()); +// Transform the CoordinateMatrix to a BlockMatrix +BlockMatrix matA = coordMat.toBlockMatrix().cache(); + +// Validate whether the BlockMatrix is set up properly. Throws an Exception when it is not valid. +// Nothing happens if it is valid. +matA.validate(); + +// Calculate A^T A. +BlockMatrix ata = matA.transpose().multiply(matA); +{% endhighlight %} +
    +
    + ### RowMatrix A `RowMatrix` is a row-oriented distributed matrix without meaningful row indices, backed by an RDD diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index fc8e732251a3..c1d0f8a6b1cd 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -1,7 +1,7 @@ --- layout: global -title: Decision Tree - MLlib -displayTitle: MLlib - Decision Tree +title: Decision Trees - MLlib +displayTitle: MLlib - Decision Trees --- * Table of contents @@ -54,8 +54,8 @@ impurity measure for regression (variance). Variance Regression - $\frac{1}{N} \sum_{i=1}^{N} (x_i - \mu)^2$$y_i$ is label for an instance, - $N$ is the number of instances and $\mu$ is the mean given by $\frac{1}{N} \sum_{i=1}^N x_i$. + $\frac{1}{N} \sum_{i=1}^{N} (y_i - \mu)^2$$y_i$ is label for an instance, + $N$ is the number of instances and $\mu$ is the mean given by $\frac{1}{N} \sum_{i=1}^N y_i$. @@ -194,6 +194,7 @@ maximum tree depth of 5. The test error is calculated to measure the algorithm a
    {% highlight scala %} import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.util.MLUtils // Load and parse the data file. @@ -221,6 +222,10 @@ val labelAndPreds = testData.map { point => val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() println("Test Error = " + testErr) println("Learned classification tree model:\n" + model.toDebugString) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = DecisionTreeModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -279,13 +284,18 @@ Double testErr = }).count() / testData.count(); System.out.println("Test Error: " + testErr); System.out.println("Learned classification tree model:\n" + model.toDebugString()); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath"); {% endhighlight %}
    + {% highlight python %} from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.tree import DecisionTree +from pyspark.mllib.tree import DecisionTree, DecisionTreeModel from pyspark.mllib.util import MLUtils # Load and parse the data file into an RDD of LabeledPoint. @@ -305,6 +315,10 @@ testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(tes print('Test Error = ' + str(testErr)) print('Learned classification tree model:') print(model.toDebugString()) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = DecisionTreeModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -324,6 +338,7 @@ depth of 5. The Mean Squared Error (MSE) is computed at the end to evaluate
    {% highlight scala %} import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.util.MLUtils // Load and parse the data file. @@ -350,6 +365,10 @@ val labelsAndPredictions = testData.map { point => val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() println("Test Mean Squared Error = " + testMSE) println("Learned regression tree model:\n" + model.toDebugString) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = DecisionTreeModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -414,13 +433,18 @@ Double testMSE = }) / data.count(); System.out.println("Test Mean Squared Error: " + testMSE); System.out.println("Learned regression tree model:\n" + model.toDebugString()); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath"); {% endhighlight %}
    + {% highlight python %} from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.tree import DecisionTree +from pyspark.mllib.tree import DecisionTree, DecisionTreeModel from pyspark.mllib.util import MLUtils # Load and parse the data file into an RDD of LabeledPoint. @@ -440,6 +464,10 @@ testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / flo print('Test Mean Squared Error = ' + str(testMSE)) print('Learned regression tree model:') print(model.toDebugString()) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = DecisionTreeModel.load(sc, "myModelPath") {% endhighlight %}
    diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md index 23ede04b62d5..7521fb14a7bd 100644 --- a/docs/mllib-ensembles.md +++ b/docs/mllib-ensembles.md @@ -98,6 +98,7 @@ The test error is calculated to measure the algorithm accuracy.
    {% highlight scala %} import org.apache.spark.mllib.tree.RandomForest +import org.apache.spark.mllib.tree.model.RandomForestModel import org.apache.spark.mllib.util.MLUtils // Load and parse the data file. @@ -127,6 +128,10 @@ val labelAndPreds = testData.map { point => val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() println("Test Error = " + testErr) println("Learned classification forest model:\n" + model.toDebugString) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = RandomForestModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -188,12 +193,17 @@ Double testErr = }).count() / testData.count(); System.out.println("Test Error: " + testErr); System.out.println("Learned classification forest model:\n" + model.toDebugString()); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +RandomForestModel sameModel = RandomForestModel.load(sc.sc(), "myModelPath"); {% endhighlight %}
    + {% highlight python %} -from pyspark.mllib.tree import RandomForest +from pyspark.mllib.tree import RandomForest, RandomForestModel from pyspark.mllib.util import MLUtils # Load and parse the data file into an RDD of LabeledPoint. @@ -216,6 +226,10 @@ testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(tes print('Test Error = ' + str(testErr)) print('Learned classification forest model:') print(model.toDebugString()) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = RandomForestModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -235,6 +249,7 @@ The Mean Squared Error (MSE) is computed at the end to evaluate
    {% highlight scala %} import org.apache.spark.mllib.tree.RandomForest +import org.apache.spark.mllib.tree.model.RandomForestModel import org.apache.spark.mllib.util.MLUtils // Load and parse the data file. @@ -264,6 +279,10 @@ val labelsAndPredictions = testData.map { point => val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() println("Test Mean Squared Error = " + testMSE) println("Learned regression forest model:\n" + model.toDebugString) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = RandomForestModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -328,12 +347,17 @@ Double testMSE = }) / testData.count(); System.out.println("Test Mean Squared Error: " + testMSE); System.out.println("Learned regression forest model:\n" + model.toDebugString()); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +RandomForestModel sameModel = RandomForestModel.load(sc.sc(), "myModelPath"); {% endhighlight %}
    + {% highlight python %} -from pyspark.mllib.tree import RandomForest +from pyspark.mllib.tree import RandomForest, RandomForestModel from pyspark.mllib.util import MLUtils # Load and parse the data file into an RDD of LabeledPoint. @@ -356,6 +380,10 @@ testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / flo print('Test Mean Squared Error = ' + str(testMSE)) print('Learned regression forest model:') print(model.toDebugString()) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = RandomForestModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -427,10 +455,19 @@ We omit some decision tree parameters since those are covered in the [decision t * **`algo`**: The algorithm or task (classification vs. regression) is set using the tree [Strategy] parameter. +#### Validation while training -### Examples +Gradient boosting can overfit when trained with more trees. In order to prevent overfitting, it is useful to validate while +training. The method runWithValidation has been provided to make use of this option. It takes a pair of RDD's as arguments, the +first one being the training dataset and the second being the validation dataset. -GBTs currently have APIs in Scala and Java. Examples in both languages are shown below. +The training is stopped when the improvement in the validation error is not more than a certain tolerance +(supplied by the `validationTol` argument in `BoostingStrategy`). In practice, the validation error +decreases initially and later increases. There might be cases in which the validation error does not change monotonically, +and the user is advised to set a large enough negative tolerance and examine the validation curve using `evaluateEachIteration` +(which gives the error or loss per iteration) to tune the number of iterations. + +### Examples #### Classification @@ -446,6 +483,7 @@ The test error is calculated to measure the algorithm accuracy. {% highlight scala %} import org.apache.spark.mllib.tree.GradientBoostedTrees import org.apache.spark.mllib.tree.configuration.BoostingStrategy +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel import org.apache.spark.mllib.util.MLUtils // Load and parse the data file. @@ -458,7 +496,7 @@ val (trainingData, testData) = (splits(0), splits(1)) // The defaultParams for Classification use LogLoss by default. val boostingStrategy = BoostingStrategy.defaultParams("Classification") boostingStrategy.numIterations = 3 // Note: Use more iterations in practice. -boostingStrategy.treeStrategy.numClassesForClassification = 2 +boostingStrategy.treeStrategy.numClasses = 2 boostingStrategy.treeStrategy.maxDepth = 5 // Empty categoricalFeaturesInfo indicates all features are continuous. boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]() @@ -473,6 +511,10 @@ val labelAndPreds = testData.map { point => val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() println("Test Error = " + testErr) println("Learned classification GBT model:\n" + model.toDebugString) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = GradientBoostedTreesModel.load(sc, "myModelPath") {% endhighlight %} @@ -534,6 +576,41 @@ Double testErr = }).count() / testData.count(); System.out.println("Test Error: " + testErr); System.out.println("Learned classification GBT model:\n" + model.toDebugString()); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(sc.sc(), "myModelPath"); +{% endhighlight %} + + +
    + +{% highlight python %} +from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel +from pyspark.mllib.util import MLUtils + +# Load and parse the data file. +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") +# Split the data into training and test sets (30% held out for testing) +(trainingData, testData) = data.randomSplit([0.7, 0.3]) + +# Train a GradientBoostedTrees model. +# Notes: (a) Empty categoricalFeaturesInfo indicates all features are continuous. +# (b) Use more iterations in practice. +model = GradientBoostedTrees.trainClassifier(trainingData, + categoricalFeaturesInfo={}, numIterations=3) + +# Evaluate model on test instances and compute test error +predictions = model.predict(testData.map(lambda x: x.features)) +labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) +testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) +print('Test Error = ' + str(testErr)) +print('Learned classification GBT model:') +print(model.toDebugString()) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = GradientBoostedTreesModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -554,6 +631,7 @@ The Mean Squared Error (MSE) is computed at the end to evaluate {% highlight scala %} import org.apache.spark.mllib.tree.GradientBoostedTrees import org.apache.spark.mllib.tree.configuration.BoostingStrategy +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel import org.apache.spark.mllib.util.MLUtils // Load and parse the data file. @@ -580,6 +658,10 @@ val labelsAndPredictions = testData.map { point => val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() println("Test Mean Squared Error = " + testMSE) println("Learned regression GBT model:\n" + model.toDebugString) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = GradientBoostedTreesModel.load(sc, "myModelPath") {% endhighlight %} @@ -647,6 +729,41 @@ Double testMSE = }) / data.count(); System.out.println("Test Mean Squared Error: " + testMSE); System.out.println("Learned regression GBT model:\n" + model.toDebugString()); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(sc.sc(), "myModelPath"); +{% endhighlight %} + + +
    + +{% highlight python %} +from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel +from pyspark.mllib.util import MLUtils + +# Load and parse the data file. +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") +# Split the data into training and test sets (30% held out for testing) +(trainingData, testData) = data.randomSplit([0.7, 0.3]) + +# Train a GradientBoostedTrees model. +# Notes: (a) Empty categoricalFeaturesInfo indicates all features are continuous. +# (b) Use more iterations in practice. +model = GradientBoostedTrees.trainRegressor(trainingData, + categoricalFeaturesInfo={}, numIterations=3) + +# Evaluate model on test instances and compute test error +predictions = model.predict(testData.map(lambda x: x.features)) +labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) +testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(testData.count()) +print('Test Mean Squared Error = ' + str(testMSE)) +print('Learned regression GBT model:') +print(model.toDebugString()) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = GradientBoostedTreesModel.load(sc, "myModelPath") {% endhighlight %}
    diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 197bc77d506c..80842b27effd 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -240,11 +240,11 @@ following parameters in the constructor: * `withMean` False by default. Centers the data with mean before scaling. It will build a dense output, so this does not work on sparse input and will raise an exception. -* `withStd` True by default. Scales the data to unit variance. +* `withStd` True by default. Scales the data to unit standard deviation. We provide a [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.StandardScaler) method in `StandardScaler` which can take an input of `RDD[Vector]`, learn the summary statistics, and then -return a model which can transform the input dataset into unit variance and/or zero mean features +return a model which can transform the input dataset into unit standard deviation and/or zero mean features depending how we configure the `StandardScaler`. This model implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer) @@ -257,7 +257,7 @@ for that feature. ### Example The example below demonstrates how to load a dataset in libsvm format, and standardize the features -so that the new features have unit variance and/or zero mean. +so that the new features have unit standard deviation and/or zero mean.
    @@ -271,6 +271,8 @@ val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") val scaler1 = new StandardScaler().fit(data.map(x => x.features)) val scaler2 = new StandardScaler(withMean = true, withStd = true).fit(data.map(x => x.features)) +// scaler3 is an identical model to scaler2, and will produce identical transformations +val scaler3 = new StandardScalerModel(scaler2.std, scaler2.mean) // data1 will be unit variance. val data1 = data.map(x => (x.label, scaler1.transform(x.features))) @@ -294,6 +296,9 @@ features = data.map(lambda x: x.features) scaler1 = StandardScaler().fit(features) scaler2 = StandardScaler(withMean=True, withStd=True).fit(features) +# scaler3 is an identical model to scaler2, and will produce identical transformations +scaler3 = StandardScalerModel(scaler2.std, scaler2.mean) + # data1 will be unit variance. data1 = label.zip(scaler1.transform(features)) @@ -370,3 +375,105 @@ data2 = labels.zip(normalizer2.transform(features)) {% endhighlight %}
    + +## Feature selection +[Feature selection](http://en.wikipedia.org/wiki/Feature_selection) allows selecting the most relevant features for use in model construction. Feature selection reduces the size of the vector space and, in turn, the complexity of any subsequent operation with vectors. The number of features to select can be tuned using a held-out validation set. + +### ChiSqSelector +[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) stands for Chi-Squared feature selection. It operates on labeled data with categorical features. `ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, and then filters (selects) the top features which are most closely related to the label. + +#### Model Fitting + +[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) has the +following parameters in the constructor: + +* `numTopFeatures` number of top features that the selector will select (filter). + +We provide a [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) method in +`ChiSqSelector` which can take an input of `RDD[LabeledPoint]` with categorical features, learn the summary statistics, and then +return a `ChiSqSelectorModel` which can transform an input dataset into the reduced feature space. + +This model implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer) +which can apply the Chi-Squared feature selection on a `Vector` to produce a reduced `Vector` or on +an `RDD[Vector]` to produce a reduced `RDD[Vector]`. + +Note that the user can also construct a `ChiSqSelectorModel` by hand by providing an array of selected feature indices (which must be sorted in ascending order). + +#### Example + +The following example shows the basic use of ChiSqSelector. + +
    +
    +{% highlight scala %} +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils + +// Load some data in libsvm format +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") +// Discretize data in 16 equal bins since ChiSqSelector requires categorical features +val discretizedData = data.map { lp => + LabeledPoint(lp.label, Vectors.dense(lp.features.toArray.map { x => x / 16 } ) ) +} +// Create ChiSqSelector that will select 50 features +val selector = new ChiSqSelector(50) +// Create ChiSqSelector model (selecting features) +val transformer = selector.fit(discretizedData) +// Filter the top 50 features from each feature vector +val filteredData = discretizedData.map { lp => + LabeledPoint(lp.label, transformer.transform(lp.features)) +} +{% endhighlight %} +
    + +
    +{% highlight java %} +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.feature.ChiSqSelector; +import org.apache.spark.mllib.feature.ChiSqSelectorModel; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; + +SparkConf sparkConf = new SparkConf().setAppName("JavaChiSqSelector"); +JavaSparkContext sc = new JavaSparkContext(sparkConf); +JavaRDD points = MLUtils.loadLibSVMFile(sc.sc(), + "data/mllib/sample_libsvm_data.txt").toJavaRDD().cache(); + +// Discretize data in 16 equal bins since ChiSqSelector requires categorical features +JavaRDD discretizedData = points.map( + new Function() { + @Override + public LabeledPoint call(LabeledPoint lp) { + final double[] discretizedFeatures = new double[lp.features().size()]; + for (int i = 0; i < lp.features().size(); ++i) { + discretizedFeatures[i] = lp.features().apply(i) / 16; + } + return new LabeledPoint(lp.label(), Vectors.dense(discretizedFeatures)); + } + }); + +// Create ChiSqSelector that will select 50 features +ChiSqSelector selector = new ChiSqSelector(50); +// Create ChiSqSelector model (selecting features) +final ChiSqSelectorModel transformer = selector.fit(discretizedData.rdd()); +// Filter the top 50 features from each feature vector +JavaRDD filteredData = discretizedData.map( + new Function() { + @Override + public LabeledPoint call(LabeledPoint lp) { + return new LabeledPoint(lp.label(), transformer.transform(lp.features())); + } + } +); + +sc.stop(); +{% endhighlight %} +
    +
    + diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md new file mode 100644 index 000000000000..9fd9be0dd01b --- /dev/null +++ b/docs/mllib-frequent-pattern-mining.md @@ -0,0 +1,98 @@ +--- +layout: global +title: Frequent Pattern Mining - MLlib +displayTitle: MLlib - Frequent Pattern Mining +--- + +Mining frequent items, itemsets, subsequences, or other substructures is usually among the +first steps to analyze a large-scale dataset, which has been an active research topic in +data mining for years. +We refer users to Wikipedia's [association rule learning](http://en.wikipedia.org/wiki/Association_rule_learning) +for more information. +MLlib provides a parallel implementation of FP-growth, +a popular algorithm to mining frequent itemsets. + +## FP-growth + +The FP-growth algorithm is described in the paper +[Han et al., Mining frequent patterns without candidate generation](http://dx.doi.org/10.1145/335191.335372), +where "FP" stands for frequent pattern. +Given a dataset of transactions, the first step of FP-growth is to calculate item frequencies and identify frequent items. +Different from [Apriori-like](http://en.wikipedia.org/wiki/Apriori_algorithm) algorithms designed for the same purpose, +the second step of FP-growth uses a suffix tree (FP-tree) structure to encode transactions without generating candidate sets +explicitly, which are usually expensive to generate. +After the second step, the frequent itemsets can be extracted from the FP-tree. +In MLlib, we implemented a parallel version of FP-growth called PFP, +as described in [Li et al., PFP: Parallel FP-growth for query recommendation](http://dx.doi.org/10.1145/1454008.1454027). +PFP distributes the work of growing FP-trees based on the suffices of transactions, +and hence more scalable than a single-machine implementation. +We refer users to the papers for more details. + +MLlib's FP-growth implementation takes the following (hyper-)parameters: + +* `minSupport`: the minimum support for an itemset to be identified as frequent. + For example, if an item appears 3 out of 5 transactions, it has a support of 3/5=0.6. +* `numPartitions`: the number of partitions used to distribute the work. + +**Examples** + +
    +
    + +[`FPGrowth`](api/java/org/apache/spark/mllib/fpm/FPGrowth.html) implements the +FP-growth algorithm. +It take a `JavaRDD` of transactions, where each transaction is an `Iterable` of items of a generic type. +Calling `FPGrowth.run` with transactions returns an +[`FPGrowthModel`](api/java/org/apache/spark/mllib/fpm/FPGrowthModel.html) +that stores the frequent itemsets with their frequencies. + +{% highlight scala %} +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel} + +val transactions: RDD[Array[String]] = ... + +val fpg = new FPGrowth() + .setMinSupport(0.2) + .setNumPartitions(10) +val model = fpg.run(transactions) + +model.freqItemsets.collect().foreach { itemset => + println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) +} +{% endhighlight %} + +
    + +
    + +[`FPGrowth`](api/java/org/apache/spark/mllib/fpm/FPGrowth.html) implements the +FP-growth algorithm. +It take an `RDD` of transactions, where each transaction is an `Array` of items of a generic type. +Calling `FPGrowth.run` with transactions returns an +[`FPGrowthModel`](api/java/org/apache/spark/mllib/fpm/FPGrowthModel.html) +that stores the frequent itemsets with their frequencies. + +{% highlight java %} +import java.util.List; + +import com.google.common.base.Joiner; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.fpm.FPGrowth; +import org.apache.spark.mllib.fpm.FPGrowthModel; + +JavaRDD> transactions = ... + +FPGrowth fpg = new FPGrowth() + .setMinSupport(0.2) + .setNumPartitions(10); +FPGrowthModel model = fpg.run(transactions); + +for (FPGrowth.FreqItemset itemset: model.freqItemsets().toJavaRDD().collect()) { + System.out.println("[" + Joiner.on(",").join(s.javaItems()) + "], " + s.freq()); +} +{% endhighlight %} + +
    +
    diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 39c64d06926b..f8e879496c13 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -1,6 +1,8 @@ --- layout: global -title: Machine Learning Library (MLlib) Programming Guide +title: MLlib +displayTitle: Machine Learning Library (MLlib) Guide +description: MLlib machine learning library overview for Spark SPARK_VERSION_SHORT --- MLlib is Spark's scalable machine learning library consisting of common learning algorithms and utilities, @@ -19,14 +21,21 @@ filtering, dimensionality reduction, as well as underlying optimization primitiv * [naive Bayes](mllib-naive-bayes.html) * [decision trees](mllib-decision-tree.html) * [ensembles of trees](mllib-ensembles.html) (Random Forests and Gradient-Boosted Trees) + * [isotonic regression](mllib-isotonic-regression.html) * [Collaborative filtering](mllib-collaborative-filtering.html) * alternating least squares (ALS) * [Clustering](mllib-clustering.html) - * k-means + * [k-means](mllib-clustering.html#k-means) + * [Gaussian mixture](mllib-clustering.html#gaussian-mixture) + * [power iteration clustering (PIC)](mllib-clustering.html#power-iteration-clustering-pic) + * [latent Dirichlet allocation (LDA)](mllib-clustering.html#latent-dirichlet-allocation-lda) + * [streaming k-means](mllib-clustering.html#streaming-k-means) * [Dimensionality reduction](mllib-dimensionality-reduction.html) * singular value decomposition (SVD) * principal component analysis (PCA) * [Feature extraction and transformation](mllib-feature-extraction.html) +* [Frequent pattern mining](mllib-frequent-pattern-mining.html) + * FP-growth * [Optimization (developer)](mllib-optimization.html) * stochastic gradient descent * limited-memory BFGS (L-BFGS) @@ -37,7 +46,7 @@ and the migration guide below will explain all changes between releases. # spark.ml: high-level APIs for ML pipelines -Spark 1.2 includes a new package called `spark.ml`, which aims to provide a uniform set of +Spark 1.2 introduced a new package called `spark.ml`, which aims to provide a uniform set of high-level APIs that help users create and tune practical machine learning pipelines. It is currently an alpha component, and we would like to hear back from the community about how it fits real-world use cases and how it could be improved. @@ -52,149 +61,50 @@ See the **[spark.ml programming guide](ml-guide.html)** for more information on # Dependencies -MLlib uses the linear algebra package [Breeze](http://www.scalanlp.org/), -which depends on [netlib-java](https://github.com/fommil/netlib-java), -and [jblas](https://github.com/mikiobraun/jblas). -`netlib-java` and `jblas` depend on native Fortran routines. -You need to install the -[gfortran runtime library](https://github.com/mikiobraun/jblas/wiki/Missing-Libraries) -if it is not already present on your nodes. -MLlib will throw a linking error if it cannot detect these libraries automatically. -Due to license issues, we do not include `netlib-java`'s native libraries in MLlib's -dependency set under default settings. -If no native library is available at runtime, you will see a warning message. -To use native libraries from `netlib-java`, please build Spark with `-Pnetlib-lgpl` or -include `com.github.fommil.netlib:all:1.1.2` as a dependency of your project. -If you want to use optimized BLAS/LAPACK libraries such as -[OpenBLAS](http://www.openblas.net/), please link its shared libraries to -`/usr/lib/libblas.so.3` and `/usr/lib/liblapack.so.3`, respectively. -BLAS/LAPACK libraries on worker nodes should be built without multithreading. - -To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.4 or newer. +MLlib uses the linear algebra package +[Breeze](http://www.scalanlp.org/), which depends on +[netlib-java](https://github.com/fommil/netlib-java) for optimised +numerical processing. If natives are not available at runtime, you +will see a warning message and a pure JVM implementation will be used +instead. ---- - -# Migration Guide - -## From 1.1 to 1.2 - -The only API changes in MLlib v1.2 are in -[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), -which continues to be an experimental API in MLlib 1.2: - -1. *(Breaking change)* The Scala API for classification takes a named argument specifying the number -of classes. In MLlib v1.1, this argument was called `numClasses` in Python and -`numClassesForClassification` in Scala. In MLlib v1.2, the names are both set to `numClasses`. -This `numClasses` parameter is specified either via -[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy) -or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) -static `trainClassifier` and `trainRegressor` methods. - -2. *(Breaking change)* The API for -[`Node`](api/scala/index.html#org.apache.spark.mllib.tree.model.Node) has changed. -This should generally not affect user code, unless the user manually constructs decision trees -(instead of using the `trainClassifier` or `trainRegressor` methods). -The tree `Node` now includes more information, including the probability of the predicted label -(for classification). - -3. Printing methods' output has changed. The `toString` (Scala/Java) and `__repr__` (Python) methods used to print the full model; they now print a summary. For the full model, use `toDebugString`. - -Examples in the Spark distribution and examples in the -[Decision Trees Guide](mllib-decision-tree.html#examples) have been updated accordingly. - -## From 1.0 to 1.1 - -The only API changes in MLlib v1.1 are in -[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), -which continues to be an experimental API in MLlib 1.1: - -1. *(Breaking change)* The meaning of tree depth has been changed by 1 in order to match -the implementations of trees in -[scikit-learn](http://scikit-learn.org/stable/modules/classes.html#module-sklearn.tree) -and in [rpart](http://cran.r-project.org/web/packages/rpart/index.html). -In MLlib v1.0, a depth-1 tree had 1 leaf node, and a depth-2 tree had 1 root node and 2 leaf nodes. -In MLlib v1.1, a depth-0 tree has 1 leaf node, and a depth-1 tree has 1 root node and 2 leaf nodes. -This depth is specified by the `maxDepth` parameter in -[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy) -or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) -static `trainClassifier` and `trainRegressor` methods. - -2. *(Non-breaking change)* We recommend using the newly added `trainClassifier` and `trainRegressor` -methods to build a [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), -rather than using the old parameter class `Strategy`. These new training methods explicitly -separate classification and regression, and they replace specialized parameter types with -simple `String` types. +To learn more about the benefits and background of system optimised +natives, you may wish to watch Sam Halliday's ScalaX talk on +[High Performance Linear Algebra in Scala](http://fommil.github.io/scalax14/#/)). -Examples of the new, recommended `trainClassifier` and `trainRegressor` are given in the -[Decision Trees Guide](mllib-decision-tree.html#examples). +Due to licensing issues with runtime proprietary binaries, we do not +include `netlib-java`'s native proxies by default. To configure +`netlib-java` / Breeze to use system optimised binaries, include +`com.github.fommil.netlib:all:1.1.2` (or build Spark with +`-Pnetlib-lgpl`) as a dependency of your project and read the +[netlib-java](https://github.com/fommil/netlib-java) documentation for +your platform's additional installation instructions. -## From 0.9 to 1.0 +To use MLlib in Python, you will need [NumPy](http://www.numpy.org) +version 1.4 or newer. -In MLlib v1.0, we support both dense and sparse input in a unified way, which introduces a few -breaking changes. If your data is sparse, please store it in a sparse format instead of dense to -take advantage of sparsity in both storage and computation. Details are described below. - -
    -
    - -We used to represent a feature vector by `Array[Double]`, which is replaced by -[`Vector`](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) in v1.0. Algorithms that used -to accept `RDD[Array[Double]]` now take -`RDD[Vector]`. [`LabeledPoint`](api/scala/index.html#org.apache.spark.mllib.regression.LabeledPoint) -is now a wrapper of `(Double, Vector)` instead of `(Double, Array[Double])`. Converting -`Array[Double]` to `Vector` is straightforward: - -{% highlight scala %} -import org.apache.spark.mllib.linalg.{Vector, Vectors} - -val array: Array[Double] = ... // a double array -val vector: Vector = Vectors.dense(array) // a dense vector -{% endhighlight %} - -[`Vectors`](api/scala/index.html#org.apache.spark.mllib.linalg.Vectors$) provides factory methods to create sparse vectors. - -*Note*: Scala imports `scala.collection.immutable.Vector` by default, so you have to import `org.apache.spark.mllib.linalg.Vector` explicitly to use MLlib's `Vector`. - -
    - -
    - -We used to represent a feature vector by `double[]`, which is replaced by -[`Vector`](api/java/index.html?org/apache/spark/mllib/linalg/Vector.html) in v1.0. Algorithms that used -to accept `RDD` now take -`RDD`. [`LabeledPoint`](api/java/index.html?org/apache/spark/mllib/regression/LabeledPoint.html) -is now a wrapper of `(double, Vector)` instead of `(double, double[])`. Converting `double[]` to -`Vector` is straightforward: - -{% highlight java %} -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; +--- -double[] array = ... // a double array -Vector vector = Vectors.dense(array); // a dense vector -{% endhighlight %} +# Migration Guide -[`Vectors`](api/scala/index.html#org.apache.spark.mllib.linalg.Vectors$) provides factory methods to -create sparse vectors. +For the `spark.ml` package, please see the [spark.ml Migration Guide](ml-guide.html#migration-guide). -
    - -
    +## From 1.2 to 1.3 -We used to represent a labeled feature vector in a NumPy array, where the first entry corresponds to -the label and the rest are features. This representation is replaced by class -[`LabeledPoint`](api/python/pyspark.mllib.regression.LabeledPoint-class.html), which takes both -dense and sparse feature vectors. +In the `spark.mllib` package, there were several breaking changes. The first change (in `ALS`) is the only one in a component not marked as Alpha or Experimental. -{% highlight python %} -from pyspark.mllib.linalg import SparseVector -from pyspark.mllib.regression import LabeledPoint +* *(Breaking change)* In [`ALS`](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS), the extraneous method `solveLeastSquares` has been removed. The `DeveloperApi` method `analyzeBlocks` was also removed. +* *(Breaking change)* [`StandardScalerModel`](api/scala/index.html#org.apache.spark.mllib.feature.StandardScalerModel) remains an Alpha component. In it, the `variance` method has been replaced with the `std` method. To compute the column variance values returned by the original `variance` method, simply square the standard deviation values returned by `std`. +* *(Breaking change)* [`StreamingLinearRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD) remains an Experimental component. In it, there were two changes: + * The constructor taking arguments was removed in favor of a builder patten using the default constructor plus parameter setter methods. + * Variable `model` is no longer public. +* *(Breaking change)* [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) remains an Experimental component. In it and its associated classes, there were several changes: + * In `DecisionTree`, the deprecated class method `train` has been removed. (The object/static `train` methods remain.) + * In `Strategy`, the `checkpointDir` parameter has been removed. Checkpointing is still supported, but the checkpoint directory must be set before calling tree and tree ensemble training. +* `PythonMLlibAPI` (the interface between Scala/Java and Python for MLlib) was a public API but is now private, declared `private[python]`. This was never meant for external use. +* In linear regression (including Lasso and ridge regression), the squared loss is now divided by 2. + So in order to produce the same result as in 1.2, the regularization parameter needs to be divided by 2 and the step size needs to be multiplied by 2. -# Create a labeled point with a positive label and a dense feature vector. -pos = LabeledPoint(1.0, [1.0, 0.0, 3.0]) +## Previous Spark Versions -# Create a labeled point with a negative label and a sparse feature vector. -neg = LabeledPoint(0.0, SparseVector(3, [0, 2], [1.0, 3.0])) -{% endhighlight %} -
    -
    +Earlier migration guides are archived [on this page](mllib-migration-guides.html). diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md new file mode 100644 index 000000000000..b521c2f27cd6 --- /dev/null +++ b/docs/mllib-isotonic-regression.md @@ -0,0 +1,155 @@ +--- +layout: global +title: Isotonic regression - MLlib +displayTitle: MLlib - Regression +--- + +## Isotonic regression +[Isotonic regression](http://en.wikipedia.org/wiki/Isotonic_regression) +belongs to the family of regression algorithms. Formally isotonic regression is a problem where +given a finite set of real numbers `$Y = {y_1, y_2, ..., y_n}$` representing observed responses +and `$X = {x_1, x_2, ..., x_n}$` the unknown response values to be fitted +finding a function that minimises + +`\begin{equation} + f(x) = \sum_{i=1}^n w_i (y_i - x_i)^2 +\end{equation}` + +with respect to complete order subject to +`$x_1\le x_2\le ...\le x_n$` where `$w_i$` are positive weights. +The resulting function is called isotonic regression and it is unique. +It can be viewed as least squares problem under order restriction. +Essentially isotonic regression is a +[monotonic function](http://en.wikipedia.org/wiki/Monotonic_function) +best fitting the original data points. + +MLlib supports a +[pool adjacent violators algorithm](http://doi.org/10.1198/TECH.2010.10111) +which uses an approach to +[parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). +The training input is a RDD of tuples of three double values that represent +label, feature and weight in this order. Additionally IsotonicRegression algorithm has one +optional parameter called $isotonic$ defaulting to true. +This argument specifies if the isotonic regression is +isotonic (monotonically increasing) or antitonic (monotonically decreasing). + +Training returns an IsotonicRegressionModel that can be used to predict +labels for both known and unknown features. The result of isotonic regression +is treated as piecewise linear function. The rules for prediction therefore are: + +* If the prediction input exactly matches a training feature + then associated prediction is returned. In case there are multiple predictions with the same + feature then one of them is returned. Which one is undefined + (same as java.util.Arrays.binarySearch). +* If the prediction input is lower or higher than all training features + then prediction with lowest or highest feature is returned respectively. + In case there are multiple predictions with the same feature + then the lowest or highest is returned respectively. +* If the prediction input falls between two training features then prediction is treated + as piecewise linear function and interpolated value is calculated from the + predictions of the two closest features. In case there are multiple values + with the same feature then the same rules as in previous point are used. + +### Examples + +
    +
    +Data are read from a file where each line has a format label,feature +i.e. 4710.28,500.00. The data are split to training and testing set. +Model is created using the training set and a mean squared error is calculated from the predicted +labels and real labels in the test set. + +{% highlight scala %} +import org.apache.spark.mllib.regression.IsotonicRegression + +val data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt") + +// Create label, feature, weight tuples from input data with weight set to default value 1.0. +val parsedData = data.map { line => + val parts = line.split(',').map(_.toDouble) + (parts(0), parts(1), 1.0) +} + +// Split data into training (60%) and test (40%) sets. +val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) +val training = splits(0) +val test = splits(1) + +// Create isotonic regression model from training data. +// Isotonic parameter defaults to true so it is only shown for demonstration +val model = new IsotonicRegression().setIsotonic(true).run(training) + +// Create tuples of predicted and real labels. +val predictionAndLabel = test.map { point => + val predictedLabel = model.predict(point._2) + (predictedLabel, point._1) +} + +// Calculate mean squared error between predicted and real labels. +val meanSquaredError = predictionAndLabel.map{case(p, l) => math.pow((p - l), 2)}.mean() +println("Mean Squared Error = " + meanSquaredError) +{% endhighlight %} +
    + +
    +Data are read from a file where each line has a format label,feature +i.e. 4710.28,500.00. The data are split to training and testing set. +Model is created using the training set and a mean squared error is calculated from the predicted +labels and real labels in the test set. + +{% highlight java %} +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.IsotonicRegressionModel; +import scala.Tuple2; +import scala.Tuple3; + +JavaRDD data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt"); + +// Create label, feature, weight tuples from input data with weight set to default value 1.0. +JavaRDD> parsedData = data.map( + new Function>() { + public Tuple3 call(String line) { + String[] parts = line.split(","); + return new Tuple3<>(new Double(parts[0]), new Double(parts[1]), 1.0); + } + } +); + +// Split data into training (60%) and test (40%) sets. +JavaRDD>[] splits = parsedData.randomSplit(new double[] {0.6, 0.4}, 11L); +JavaRDD> training = splits[0]; +JavaRDD> test = splits[1]; + +// Create isotonic regression model from training data. +// Isotonic parameter defaults to true so it is only shown for demonstration +IsotonicRegressionModel model = new IsotonicRegression().setIsotonic(true).run(training); + +// Create tuples of predicted and real labels. +JavaPairRDD predictionAndLabel = test.mapToPair( + new PairFunction, Double, Double>() { + @Override public Tuple2 call(Tuple3 point) { + Double predictedLabel = model.predict(point._2()); + return new Tuple2(predictedLabel, point._1()); + } + } +); + +// Calculate mean squared error between predicted and real labels. +Double meanSquaredError = new JavaDoubleRDD(predictionAndLabel.map( + new Function, Object>() { + @Override public Object call(Tuple2 pl) { + return Math.pow(pl._1() - pl._2(), 2); + } + } +).rdd()).mean(); + +System.out.println("Mean Squared Error = " + meanSquaredError); +{% endhighlight %} +
    +
    diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 44b7f67c5773..2b2be4d9d027 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -17,7 +17,7 @@ displayTitle: MLlib - Linear Methods \newcommand{\av}{\mathbf{\alpha}} \newcommand{\bv}{\mathbf{b}} \newcommand{\N}{\mathbb{N}} -\newcommand{\id}{\mathbf{I}} +\newcommand{\id}{\mathbf{I}} \newcommand{\ind}{\mathbf{1}} \newcommand{\0}{\mathbf{0}} \newcommand{\unit}{\mathbf{e}} @@ -114,18 +114,26 @@ especially when the number of training examples is small. Under the hood, linear methods use convex optimization methods to optimize the objective functions. MLlib uses two methods, SGD and L-BFGS, described in the [optimization section](mllib-optimization.html). Currently, most algorithm APIs support Stochastic Gradient Descent (SGD), and a few support L-BFGS. Refer to [this optimization section](mllib-optimization.html#Choosing-an-Optimization-Method) for guidelines on choosing between optimization methods. -## Binary classification - -[Binary classification](http://en.wikipedia.org/wiki/Binary_classification) -aims to divide items into two categories: positive and negative. MLlib -supports two linear methods for binary classification: linear Support Vector -Machines (SVMs) and logistic regression. For both methods, MLlib supports -L1 and L2 regularized variants. The training data set is represented by an RDD -of [LabeledPoint](mllib-data-types.html) in MLlib. Note that, in the -mathematical formulation in this guide, a training label $y$ is denoted as -either $+1$ (positive) or $-1$ (negative), which is convenient for the -formulation. *However*, the negative label is represented by $0$ in MLlib -instead of $-1$, to be consistent with multiclass labeling. +## Classification + +[Classification](http://en.wikipedia.org/wiki/Statistical_classification) aims to divide items into +categories. +The most common classification type is +[binary classificaion](http://en.wikipedia.org/wiki/Binary_classification), where there are two +categories, usually named positive and negative. +If there are more than two categories, it is called +[multiclass classification](http://en.wikipedia.org/wiki/Multiclass_classification). +MLlib supports two linear methods for classification: linear Support Vector Machines (SVMs) +and logistic regression. +Linear SVMs supports only binary classification, while logistic regression supports both binary and +multiclass classification problems. +For both methods, MLlib supports L1 and L2 regularized variants. +The training data set is represented by an RDD of [LabeledPoint](mllib-data-types.html) in MLlib, +where labels are class indices starting from zero: $0, 1, 2, \ldots$. +Note that, in the mathematical formulation in this guide, a binary label $y$ is denoted as either +$+1$ (positive) or $-1$ (negative), which is convenient for the formulation. +*However*, the negative label is represented by $0$ in MLlib instead of $-1$, to be consistent with +multiclass labeling. ### Linear Support Vector Machines (SVMs) @@ -144,41 +152,7 @@ denoted by $\x$, the model makes predictions based on the value of $\wv^T \x$. By the default, if $\wv^T \x \geq 0$ then the outcome is positive, and negative otherwise. -### Logistic regression - -[Logistic regression](http://en.wikipedia.org/wiki/Logistic_regression) is widely used to predict a -binary response. -It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, with the loss -function in the formulation given by the logistic loss: -`\[ -L(\wv;\x,y) := \log(1+\exp( -y \wv^T \x)). -\]` - -The logistic regression algorithm outputs a logistic regression model. Given a -new data point, denoted by $\x$, the model makes predictions by -applying the logistic function -`\[ -\mathrm{f}(z) = \frac{1}{1 + e^{-z}} -\]` -where $z = \wv^T \x$. -By default, if $\mathrm{f}(\wv^T x) > 0.5$, the outcome is positive, or -negative otherwise, though unlike linear SVMs, the raw output of the logistic regression -model, $\mathrm{f}(z)$, has a probabilistic interpretation (i.e., the probability -that $\x$ is positive). - -### Evaluation metrics - -MLlib supports common evaluation metrics for binary classification (not available in PySpark). -This -includes precision, recall, [F-measure](http://en.wikipedia.org/wiki/F1_score), -[receiver operating characteristic (ROC)](http://en.wikipedia.org/wiki/Receiver_operating_characteristic), -precision-recall curve, and -[area under the curves (AUC)](http://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve). -AUC is commonly used to compare the performance of various models while -precision/recall/F-measure can help determine the appropriate threshold to use -for prediction purposes. - -### Examples +**Examples**
    @@ -190,7 +164,7 @@ error. {% highlight scala %} import org.apache.spark.SparkContext -import org.apache.spark.mllib.classification.SVMWithSGD +import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.Vectors @@ -211,7 +185,7 @@ val model = SVMWithSGD.train(training, numIterations) // Clear the default threshold. model.clearThreshold() -// Compute raw scores on the test set. +// Compute raw scores on the test set. val scoreAndLabels = test.map { point => val score = model.predict(point.features) (score, point.label) @@ -222,6 +196,10 @@ val metrics = new BinaryClassificationMetrics(scoreAndLabels) val auROC = metrics.areaUnderROC() println("Area under ROC = " + auROC) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = SVMModel.load(sc, "myModelPath") {% endhighlight %} The `SVMWithSGD.train()` method by default performs L2 regularization with the @@ -243,8 +221,6 @@ svmAlg.optimizer. val modelL1 = svmAlg.run(training) {% endhighlight %} -[`LogisticRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithSGD) can be used in a similar fashion as `SVMWithSGD`. -
    @@ -280,11 +256,11 @@ public class SVMClassifier { JavaRDD training = data.sample(false, 0.6, 11L); training.cache(); JavaRDD test = data.subtract(training); - + // Run training algorithm to build the model. int numIterations = 100; final SVMModel model = SVMWithSGD.train(training.rdd(), numIterations); - + // Clear the default threshold. model.clearThreshold(); @@ -297,13 +273,17 @@ public class SVMClassifier { } } ); - + // Get evaluation metrics. - BinaryClassificationMetrics metrics = + BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(JavaRDD.toRDD(scoreAndLabels)); double auROC = metrics.areaUnderROC(); - + System.out.println("Area under ROC = " + auROC); + + // Save and load model + model.save(sc.sc(), "myModelPath"); + SVMModel sameModel = SVMModel.load(sc.sc(), "myModelPath"); } } {% endhighlight %} @@ -338,6 +318,8 @@ a dependency. The following example shows how to load a sample dataset, build Logistic Regression model, and make predictions with the resulting model to compute the training error. +Note that the Python API does not yet support model save/load but will in the future. + {% highlight python %} from pyspark.mllib.classification import LogisticRegressionWithSGD from pyspark.mllib.regression import LabeledPoint @@ -362,7 +344,191 @@ print("Training Error = " + str(trainErr))
    -## Linear least squares, Lasso, and ridge regression +### Logistic regression + +[Logistic regression](http://en.wikipedia.org/wiki/Logistic_regression) is widely used to predict a +binary response. It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, +with the loss function in the formulation given by the logistic loss: +`\[ +L(\wv;\x,y) := \log(1+\exp( -y \wv^T \x)). +\]` + +For binary classification problems, the algorithm outputs a binary logistic regression model. +Given a new data point, denoted by $\x$, the model makes predictions by +applying the logistic function +`\[ +\mathrm{f}(z) = \frac{1}{1 + e^{-z}} +\]` +where $z = \wv^T \x$. +By default, if $\mathrm{f}(\wv^T x) > 0.5$, the outcome is positive, or +negative otherwise, though unlike linear SVMs, the raw output of the logistic regression +model, $\mathrm{f}(z)$, has a probabilistic interpretation (i.e., the probability +that $\x$ is positive). + +Binary logistic regression can be generalized into +[multinomial logistic regression](http://en.wikipedia.org/wiki/Multinomial_logistic_regression) to +train and predict multiclass classification problems. +For example, for $K$ possible outcomes, one of the outcomes can be chosen as a "pivot", and the +other $K - 1$ outcomes can be separately regressed against the pivot outcome. +In MLlib, the first class $0$ is chosen as the "pivot" class. +See Section 4.4 of +[The Elements of Statistical Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for +references. +Here is an +[detailed mathematical derivation](http://www.slideshare.net/dbtsai/2014-0620-mlor-36132297). + +For multiclass classification problems, the algorithm will output a multinomial logistic regression +model, which contains $K - 1$ binary logistic regression models regressed against the first class. +Given a new data points, $K - 1$ models will be run, and the class with largest probability will be +chosen as the predicted class. + +We implemented two algorithms to solve logistic regression: mini-batch gradient descent and L-BFGS. +We recommend L-BFGS over mini-batch gradient descent for faster convergence. + +**Examples** + +
    + +
    +The following code illustrates how to load a sample multiclass dataset, split it into train and +test, and use +[LogisticRegressionWithLBFGS](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS) +to fit a logistic regression model. +Then the model is evaluated against the test dataset and saved to disk. + +{% highlight scala %} +import org.apache.spark.SparkContext +import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, LogisticRegressionModel} +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLUtils + +// Load training data in LIBSVM format. +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + +// Split data into training (60%) and test (40%). +val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) +val training = splits(0).cache() +val test = splits(1) + +// Run training algorithm to build the model +val model = new LogisticRegressionWithLBFGS() + .setNumClasses(10) + .run(training) + +// Compute raw scores on the test set. +val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) +} + +// Get evaluation metrics. +val metrics = new MulticlassMetrics(predictionAndLabels) +val precision = metrics.precision +println("Precision = " + precision) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = LogisticRegressionModel.load(sc, "myModelPath") +{% endhighlight %} + +
    + +
    +The following code illustrates how to load a sample multiclass dataset, split it into train and +test, and use +[LogisticRegressionWithLBFGS](api/java/org/apache/spark/mllib/classification/LogisticRegressionWithLBFGS.html) +to fit a logistic regression model. +Then the model is evaluated against the test dataset and saved to disk. + +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.MulticlassMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class MultinomialLogisticRegressionExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("SVM Classifier Example"); + SparkContext sc = new SparkContext(conf); + String path = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(10) + .run(training.rdd()); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); + double precision = metrics.precision(); + System.out.println("Precision = " + precision); + + // Save and load model + model.save(sc, "myModelPath"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); + } +} +{% endhighlight %} +
    + +
    +The following example shows how to load a sample dataset, build Logistic Regression model, +and make predictions with the resulting model to compute the training error. + +Note that the Python API does not yet support multiclass classification and model save/load but +will in the future. + +{% highlight python %} +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.regression import LabeledPoint +from numpy import array + +# Load and parse the data +def parsePoint(line): + values = [float(x) for x in line.split(' ')] + return LabeledPoint(values[0], values[1:]) + +data = sc.textFile("data/mllib/sample_svm_data.txt") +parsedData = data.map(parsePoint) + +# Build the model +model = LogisticRegressionWithLBFGS.train(parsedData) + +# Evaluating the model on training data +labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) +trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) +print("Training Error = " + str(trainErr)) +{% endhighlight %} +
    +
    + +# Regression + +### Linear least squares, Lasso, and ridge regression Linear least squares is the most common formulation for regression problems. @@ -380,7 +546,7 @@ regularization; and [*Lasso*](http://en.wikipedia.org/wiki/Lasso_(statistics)) u regularization. For all of these models, the average loss or training error, $\frac{1}{n} \sum_{i=1}^n (\wv^T x_i - y_i)^2$, is known as the [mean squared error](http://en.wikipedia.org/wiki/Mean_squared_error). -### Examples +**Examples**
    @@ -391,8 +557,9 @@ values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). {% highlight scala %} -import org.apache.spark.mllib.regression.LinearRegressionWithSGD import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.LinearRegressionModel +import org.apache.spark.mllib.regression.LinearRegressionWithSGD import org.apache.spark.mllib.linalg.Vectors // Load and parse the data @@ -413,6 +580,10 @@ val valuesAndPreds = parsedData.map { point => } val MSE = valuesAndPreds.map{case(v, p) => math.pow((v - p), 2)}.mean() println("training Mean Squared Error = " + MSE) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = LinearRegressionModel.load(sc, "myModelPath") {% endhighlight %} [`RidgeRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.RidgeRegressionWithSGD) @@ -483,6 +654,10 @@ public class LinearRegression { } ).rdd()).mean(); System.out.println("training Mean Squared Error = " + MSE); + + // Save and load model + model.save(sc.sc(), "myModelPath"); + LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath"); } } {% endhighlight %} @@ -494,6 +669,8 @@ The example then uses LinearRegressionWithSGD to build a simple linear model to values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). +Note that the Python API does not yet support model save/load but will in the future. + {% highlight python %} from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD from numpy import array @@ -523,7 +700,7 @@ section of the Spark quick-start guide. Be sure to also include *spark-mllib* to your build file as a dependency. -## Streaming linear regression +###Streaming linear regression When data arrive in a streaming fashion, it is useful to fit regression models online, updating the parameters of the model as new data arrives. MLlib currently supports @@ -531,7 +708,7 @@ streaming linear regression using ordinary least squares. The fitting is similar to that performed offline, except fitting occurs on each batch of data, so that the model continually updates to reflect the data from the stream. -### Examples +**Examples** The following example demonstrates how to load training and testing data from two different input streams of text files, parse the streams as labeled points, fit a linear regression model @@ -598,7 +775,7 @@ will get better!
    -## Implementation (developer) +# Implementation (developer) Behind the scene, MLlib implements a simple distributed version of stochastic gradient descent (SGD), building on the underlying gradient descent primitive (as described in the MLlib - Old Migration Guides +description: MLlib migration guides from before Spark SPARK_VERSION_SHORT +--- + +The migration guide for the current Spark version is kept on the [MLlib Programming Guide main page](mllib-guide.html#migration-guide). + +## From 1.1 to 1.2 + +The only API changes in MLlib v1.2 are in +[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), +which continues to be an experimental API in MLlib 1.2: + +1. *(Breaking change)* The Scala API for classification takes a named argument specifying the number +of classes. In MLlib v1.1, this argument was called `numClasses` in Python and +`numClassesForClassification` in Scala. In MLlib v1.2, the names are both set to `numClasses`. +This `numClasses` parameter is specified either via +[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy) +or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) +static `trainClassifier` and `trainRegressor` methods. + +2. *(Breaking change)* The API for +[`Node`](api/scala/index.html#org.apache.spark.mllib.tree.model.Node) has changed. +This should generally not affect user code, unless the user manually constructs decision trees +(instead of using the `trainClassifier` or `trainRegressor` methods). +The tree `Node` now includes more information, including the probability of the predicted label +(for classification). + +3. Printing methods' output has changed. The `toString` (Scala/Java) and `__repr__` (Python) methods used to print the full model; they now print a summary. For the full model, use `toDebugString`. + +Examples in the Spark distribution and examples in the +[Decision Trees Guide](mllib-decision-tree.html#examples) have been updated accordingly. + +## From 1.0 to 1.1 + +The only API changes in MLlib v1.1 are in +[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), +which continues to be an experimental API in MLlib 1.1: + +1. *(Breaking change)* The meaning of tree depth has been changed by 1 in order to match +the implementations of trees in +[scikit-learn](http://scikit-learn.org/stable/modules/classes.html#module-sklearn.tree) +and in [rpart](http://cran.r-project.org/web/packages/rpart/index.html). +In MLlib v1.0, a depth-1 tree had 1 leaf node, and a depth-2 tree had 1 root node and 2 leaf nodes. +In MLlib v1.1, a depth-0 tree has 1 leaf node, and a depth-1 tree has 1 root node and 2 leaf nodes. +This depth is specified by the `maxDepth` parameter in +[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy) +or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) +static `trainClassifier` and `trainRegressor` methods. + +2. *(Non-breaking change)* We recommend using the newly added `trainClassifier` and `trainRegressor` +methods to build a [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), +rather than using the old parameter class `Strategy`. These new training methods explicitly +separate classification and regression, and they replace specialized parameter types with +simple `String` types. + +Examples of the new, recommended `trainClassifier` and `trainRegressor` are given in the +[Decision Trees Guide](mllib-decision-tree.html#examples). + +## From 0.9 to 1.0 + +In MLlib v1.0, we support both dense and sparse input in a unified way, which introduces a few +breaking changes. If your data is sparse, please store it in a sparse format instead of dense to +take advantage of sparsity in both storage and computation. Details are described below. + diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index d5b044d94fdd..9780ea52c499 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -13,12 +13,15 @@ compute the conditional probability distribution of label given an observation and use it for prediction. MLlib supports [multinomial naive -Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes), -which is typically used for [document -classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). +Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes) +and [Bernoulli naive Bayes] (http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). +These models are typically used for [document classification] +(http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). Within that context, each observation is a document and each -feature represents a term whose value is the frequency of the term. -Feature values must be nonnegative to represent term frequencies. +feature represents a term whose value is the frequency of the term (in multinomial naive Bayes) or +a zero or one indicating whether the term was found in the document (in Bernoulli naive Bayes). +Feature values must be nonnegative. The model type is selected with an optional parameter +"Multinomial" or "Bernoulli" with "Multinomial" as the default. [Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by setting the parameter $\lambda$ (default to $1.0$). For document classification, the input feature vectors are usually sparse, and sparse vectors should be supplied as input to take advantage of @@ -32,12 +35,12 @@ sparsity. Since the training data is only used once, it is not necessary to cach [NaiveBayes](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayes$) implements multinomial naive Bayes. It takes an RDD of [LabeledPoint](api/scala/index.html#org.apache.spark.mllib.regression.LabeledPoint) and an optional -smoothing parameter `lambda` as input, and output a +smoothing parameter `lambda` as input, an optional model type parameter (default is Multinomial), and outputs a [NaiveBayesModel](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayesModel), which can be used for evaluation and prediction. {% highlight scala %} -import org.apache.spark.mllib.classification.NaiveBayes +import org.apache.spark.mllib.classification.{NaiveBayes, NaiveBayesModel} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint @@ -51,10 +54,14 @@ val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) val training = splits(0) val test = splits(1) -val model = NaiveBayes.train(training, lambda = 1.0) +val model = NaiveBayes.train(training, lambda = 1.0, model = "Multinomial") val predictionAndLabel = test.map(p => (model.predict(p.features), p.label)) val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count() + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = NaiveBayesModel.load(sc, "myModelPath") {% endhighlight %} @@ -93,34 +100,46 @@ double accuracy = predictionAndLabel.filter(new Function, return pl._1().equals(pl._2()); } }).count() / (double) test.count(); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +NaiveBayesModel sameModel = NaiveBayesModel.load(sc.sc(), "myModelPath"); {% endhighlight %}
    -[NaiveBayes](api/python/pyspark.mllib.classification.NaiveBayes-class.html) implements multinomial +[NaiveBayes](api/python/pyspark.mllib.html#pyspark.mllib.classification.NaiveBayes) implements multinomial naive Bayes. It takes an RDD of -[LabeledPoint](api/python/pyspark.mllib.regression.LabeledPoint-class.html) and an optionally +[LabeledPoint](api/python/pyspark.mllib.html#pyspark.mllib.regression.LabeledPoint) and an optionally smoothing parameter `lambda` as input, and output a -[NaiveBayesModel](api/python/pyspark.mllib.classification.NaiveBayesModel-class.html), which can be +[NaiveBayesModel](api/python/pyspark.mllib.html#pyspark.mllib.classification.NaiveBayesModel), which can be used for evaluation and prediction. - +Note that the Python API does not yet support model save/load but will in the future. + {% highlight python %} -from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.classification import NaiveBayes +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.regression import LabeledPoint + +def parseLine(line): + parts = line.split(',') + label = float(parts[0]) + features = Vectors.dense([float(x) for x in parts[1].split(' ')]) + return LabeledPoint(label, features) + +data = sc.textFile('data/mllib/sample_naive_bayes_data.txt').map(parseLine) -# an RDD of LabeledPoint -data = sc.parallelize([ - LabeledPoint(0.0, [0.0, 0.0]) - ... # more labeled points -]) +# Split data aproximately into training (60%) and test (40%) +training, test = data.randomSplit([0.6, 0.4], seed = 0) # Train a naive Bayes model. -model = NaiveBayes.train(data, 1.0) +model = NaiveBayes.train(training, 1.0) -# Make prediction. -prediction = model.predict([0.0, 0.0]) +# Make prediction and test accuracy. +predictionAndLabel = test.map(lambda p : (model.predict(p.features), p.label)) +accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count() {% endhighlight %}
    diff --git a/docs/mllib-optimization.md b/docs/mllib-optimization.md index 4d101afca2c9..6cabc1610a15 100644 --- a/docs/mllib-optimization.md +++ b/docs/mllib-optimization.md @@ -203,6 +203,10 @@ regularization, as well as L2 regularizer. recommended. * `maxNumIterations` is the maximal number of iterations that L-BFGS can be run. * `regParam` is the regularization parameter when using regularization. +* `convergenceTol` controls how much relative change is still allowed when L-BFGS +is considered to converge. This must be nonnegative. Lower values are less tolerant and +therefore generally cause more iterations to be run. This value looks at both average +improvement and the norm of gradient inside [Breeze LBFGS](https://github.com/scalanlp/breeze/blob/master/math/src/main/scala/breeze/optimize/LBFGS.scala). The `return` is a tuple containing two elements. The first element is a column matrix containing weights for every feature, and the second element is an array containing diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index ca8c29218f52..887eae7f4f07 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -81,8 +81,8 @@ System.out.println(summary.numNonzeros()); // number of nonzeros in each column
    -[`colStats()`](api/python/pyspark.mllib.stat.Statistics-class.html#colStats) returns an instance of -[`MultivariateStatisticalSummary`](api/python/pyspark.mllib.stat.MultivariateStatisticalSummary-class.html), +[`colStats()`](api/python/pyspark.mllib.html#pyspark.mllib.stat.Statistics.colStats) returns an instance of +[`MultivariateStatisticalSummary`](api/python/pyspark.mllib.html#pyspark.mllib.stat.MultivariateStatisticalSummary), which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the total count. @@ -169,7 +169,7 @@ Matrix correlMatrix = Statistics.corr(data.rdd(), "pearson");
    -[`Statistics`](api/python/pyspark.mllib.stat.Statistics-class.html) provides methods to +[`Statistics`](api/python/pyspark.mllib.html#pyspark.mllib.stat.Statistics) provides methods to calculate correlations between series. Depending on the type of input, two `RDD[Double]`s or an `RDD[Vector]`, the output will be a `Double` or the correlation `Matrix` respectively. @@ -258,7 +258,7 @@ JavaPairRDD exactSample = data.sampleByKeyExact(false, fractions); {% endhighlight %}
    -[`sampleByKey()`](api/python/pyspark.rdd.RDD-class.html#sampleByKey) allows users to +[`sampleByKey()`](api/python/pyspark.html#pyspark.RDD.sampleByKey) allows users to sample approximately $\lceil f_k \cdot n_k \rceil \, \forall k \in K$ items, where $f_k$ is the desired fraction for key $k$, $n_k$ is the number of key-value pairs for key $k$, and $K$ is the set of keys. @@ -476,7 +476,7 @@ JavaDoubleRDD v = u.map(
    -[`RandomRDDs`](api/python/pyspark.mllib.random.RandomRDDs-class.html) provides factory +[`RandomRDDs`](api/python/pyspark.mllib.html#pyspark.mllib.random.RandomRDDs) provides factory methods to generate random double RDDs or vector RDDs. The following example generates a random double RDD, whose values follows the standard normal distribution `N(0, 1)`, and then map it to `N(1, 4)`. diff --git a/docs/monitoring.md b/docs/monitoring.md index f32cdef240d3..8a85928d6d44 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -1,6 +1,7 @@ --- layout: global title: Monitoring and Instrumentation +description: Monitoring, metrics, and instrumentation guide for Spark SPARK_VERSION_SHORT --- There are several ways to monitor Spark applications: web UIs, metrics, and external instrumentation. @@ -85,10 +86,10 @@ follows: - spark.history.fs.updateInterval - 10 + spark.history.fs.update.interval + 10s - The period, in seconds, at which information displayed by this history server is updated. + The period at which information displayed by this history server is updated. Each update checks for any changes made to the event logs in persisted storage. @@ -144,11 +145,35 @@ follows: If disabled, no access control checks are made. + + spark.history.fs.cleaner.enabled + false + + Specifies whether the History Server should periodically clean up event logs from storage. + + + + spark.history.fs.cleaner.interval + 1d + + How often the job history cleaner checks for files to delete. + Files are only deleted if they are older than spark.history.fs.cleaner.maxAge. + + + + spark.history.fs.cleaner.maxAge + 7d + + Job history files older than this will be deleted when the history cleaner runs. + + Note that in all of these UIs, the tables are sortable by clicking their headers, making it easy to identify slow tasks, data skew, etc. +Note that the history server only displays completed Spark jobs. One way to signal the completion of a Spark job is to stop the Spark Context explicitly (`sc.stop()`), or in Python using the `with SparkContext() as sc:` to handle the Spark Context setup and tear down, and still show the job history on the UI. + # Metrics Spark has a configurable metrics system based on the @@ -175,6 +200,7 @@ Each instance can report to zero or more _sinks_. Sinks are contained in the * `JmxSink`: Registers metrics for viewing in a JMX console. * `MetricsServlet`: Adds a servlet within the existing Spark UI to serve metrics data as JSON data. * `GraphiteSink`: Sends metrics to a Graphite node. +* `Slf4jSink`: Sends metrics to slf4j as log entries. Spark also supports a Ganglia sink which is not included in the default build due to licensing restrictions: diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 2443fc29b470..27816515c5de 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -1,6 +1,7 @@ --- layout: global title: Spark Programming Guide +description: Spark SPARK_VERSION_SHORT programming guide in Java, Scala and Python --- * This will become a table of contents (this text will be scraped). @@ -141,8 +142,8 @@ JavaSparkContext sc = new JavaSparkContext(conf);
    -The first thing a Spark program must do is to create a [SparkContext](api/python/pyspark.context.SparkContext-class.html) object, which tells Spark -how to access a cluster. To create a `SparkContext` you first need to build a [SparkConf](api/python/pyspark.conf.SparkConf-class.html) object +The first thing a Spark program must do is to create a [SparkContext](api/python/pyspark.html#pyspark.SparkContext) object, which tells Spark +how to access a cluster. To create a `SparkContext` you first need to build a [SparkConf](api/python/pyspark.html#pyspark.SparkConf) object that contains information about your application. {% highlight python %} @@ -172,8 +173,11 @@ in-process. In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `--master` argument, and you can add JARs to the classpath -by passing a comma-separated list to the `--jars` argument. -For example, to run `bin/spark-shell` on exactly four cores, use: +by passing a comma-separated list to the `--jars` argument. You can also add dependencies +(e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates +to the `--packages` argument. Any additional repositories where dependencies might exist (e.g. SonaType) +can be passed to the `--repositories` argument. For example, to run `bin/spark-shell` on exactly +four cores, use: {% highlight bash %} $ ./bin/spark-shell --master local[4] @@ -185,6 +189,12 @@ Or, to also add `code.jar` to its classpath, use: $ ./bin/spark-shell --master local[4] --jars code.jar {% endhighlight %} +To include a dependency using maven coordinates: + +{% highlight bash %} +$ ./bin/spark-shell --master local[4] --packages "org.example:example:0.1" +{% endhighlight %} + For a complete list of options, run `spark-shell --help`. Behind the scenes, `spark-shell` invokes the more general [`spark-submit` script](submitting-applications.html). @@ -195,7 +205,11 @@ For a complete list of options, run `spark-shell --help`. Behind the scenes, In the PySpark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `--master` argument, and you can add Python .zip, .egg or .py files -to the runtime path by passing a comma-separated list to `--py-files`. +to the runtime path by passing a comma-separated list to `--py-files`. You can also add dependencies +(e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates +to the `--packages` argument. Any additional repositories where dependencies might exist (e.g. SonaType) +can be passed to the `--repositories` argument. Any python dependencies a Spark Package has (listed in +the requirements.txt of that package) must be manually installed using pip when necessary. For example, to run `bin/pyspark` on exactly four cores, use: {% highlight bash %} @@ -223,9 +237,13 @@ You can customize the `ipython` command by setting `PYSPARK_DRIVER_PYTHON_OPTS`. the [IPython Notebook](http://ipython.org/notebook.html) with PyLab plot support: {% highlight bash %} -$ PYSPARK_DRIVER_PYTHON=ipython PYSPARK_DRIVER_PYTHON_OPTS="notebook --pylab inline" ./bin/pyspark +$ PYSPARK_DRIVER_PYTHON=ipython PYSPARK_DRIVER_PYTHON_OPTS="notebook" ./bin/pyspark {% endhighlight %} +After the IPython Notebook server is launched, you can create a new "Python 2" notebook from +the "Files" tab. Inside the notebook, you can input the command `%pylab inline` as part of +your notebook before you start to try Spark from the IPython notebook. +
    @@ -321,7 +339,7 @@ Apart from text files, Spark's Scala API also supports several other data format * For [SequenceFiles](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/mapred/SequenceFileInputFormat.html), use SparkContext's `sequenceFile[K, V]` method where `K` and `V` are the types of key and values in the file. These should be subclasses of Hadoop's [Writable](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/Writable.html) interface, like [IntWritable](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/IntWritable.html) and [Text](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/Text.html). In addition, Spark allows you to specify native types for a few common Writables; for example, `sequenceFile[Int, String]` will automatically read IntWritables and Texts. -* For other Hadoop InputFormats, you can use the `SparkContext.hadoopRDD` method, which takes an arbitrary `JobConf` and input format class, key class and value class. Set these the same way you would for a Hadoop job with your input source. You can also use `SparkContext.newHadoopRDD` for InputFormats based on the "new" MapReduce API (`org.apache.hadoop.mapreduce`). +* For other Hadoop InputFormats, you can use the `SparkContext.hadoopRDD` method, which takes an arbitrary `JobConf` and input format class, key class and value class. Set these the same way you would for a Hadoop job with your input source. You can also use `SparkContext.newAPIHadoopRDD` for InputFormats based on the "new" MapReduce API (`org.apache.hadoop.mapreduce`). * `RDD.saveAsObjectFile` and `SparkContext.objectFile` support saving an RDD in a simple format consisting of serialized Java objects. While this is not as efficient as specialized formats like Avro, it offers an easy way to save any RDD. @@ -353,7 +371,7 @@ Apart from text files, Spark's Java API also supports several other data formats * For [SequenceFiles](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/mapred/SequenceFileInputFormat.html), use SparkContext's `sequenceFile[K, V]` method where `K` and `V` are the types of key and values in the file. These should be subclasses of Hadoop's [Writable](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/Writable.html) interface, like [IntWritable](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/IntWritable.html) and [Text](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/Text.html). -* For other Hadoop InputFormats, you can use the `JavaSparkContext.hadoopRDD` method, which takes an arbitrary `JobConf` and input format class, key class and value class. Set these the same way you would for a Hadoop job with your input source. You can also use `JavaSparkContext.newHadoopRDD` for InputFormats based on the "new" MapReduce API (`org.apache.hadoop.mapreduce`). +* For other Hadoop InputFormats, you can use the `JavaSparkContext.hadoopRDD` method, which takes an arbitrary `JobConf` and input format class, key class and value class. Set these the same way you would for a Hadoop job with your input source. You can also use `JavaSparkContext.newAPIHadoopRDD` for InputFormats based on the "new" MapReduce API (`org.apache.hadoop.mapreduce`). * `JavaRDD.saveAsObjectFile` and `JavaSparkContext.objectFile` support saving an RDD in a simple format consisting of serialized Java objects. While this is not as efficient as specialized formats like Avro, it offers an easy way to save any RDD. @@ -711,7 +729,7 @@ class MyClass(object): def __init__(self): self.field = "Hello" def doStuff(self, rdd): - return rdd.map(lambda s: self.field + x) + return rdd.map(lambda s: self.field + s) {% endhighlight %} To avoid this issue, the simplest way is to copy `field` into a local variable instead @@ -720,13 +738,76 @@ of accessing it externally: {% highlight python %} def doStuff(self, rdd): field = self.field - return rdd.map(lambda s: field + x) + return rdd.map(lambda s: field + s) +{% endhighlight %} + + + + + +### Understanding closures +One of the harder things about Spark is understanding the scope and life cycle of variables and methods when executing code across a cluster. RDD operations that modify variables outside of their scope can be a frequent source of confusion. In the example below we'll look at code that uses `foreach()` to increment a counter, but similar issues can occur for other operations as well. + +#### Example + +Consider the naive RDD element sum below, which behaves completely differently depending on whether execution is happening within the same JVM. A common example of this is when running Spark in `local` mode (`--master = local[n]`) versus deploying a Spark application to a cluster (e.g. via spark-submit to YARN): + +
    + +
    +{% highlight scala %} +var counter = 0 +var rdd = sc.parallelize(data) + +// Wrong: Don't do this!! +rdd.foreach(x => counter += x) + +println("Counter value: " + counter) +{% endhighlight %} +
    + +
    +{% highlight java %} +int counter = 0; +JavaRDD rdd = sc.parallelize(data); + +// Wrong: Don't do this!! +rdd.foreach(x -> counter += x); + +println("Counter value: " + counter); {% endhighlight %} +
    + +
    +{% highlight python %} +counter = 0 +rdd = sc.parallelize(data) + +# Wrong: Don't do this!! +rdd.foreach(lambda x: counter += x) + +print("Counter value: " + counter) +{% endhighlight %}
    +#### Local vs. cluster modes + +The primary challenge is that the behavior of the above code is undefined. In local mode with a single JVM, the above code will sum the values within the RDD and store it in **counter**. This is because both the RDD and the variable **counter** are in the same memory space on the driver node. + +However, in `cluster` mode, what happens is more complicated, and the above may not work as intended. To execute jobs, Spark breaks up the processing of RDD operations into tasks - each of which is operated on by an executor. Prior to execution, Spark computes the **closure**. The closure is those variables and methods which must be visible for the executor to perform its computations on the RDD (in this case `foreach()`). This closure is serialized and sent to each executor. In `local` mode, there is only the one executors so everything shares the same closure. In other modes however, this is not the case and the executors running on seperate worker nodes each have their own copy of the closure. + +What is happening here is that the variables within the closure sent to each executor are now copies and thus, when **counter** is referenced within the `foreach` function, it's no longer the **counter** on the driver node. There is still a **counter** in the memory of the driver node but this is no longer visible to the executors! The executors only sees the copy from the serialized closure. Thus, the final value of **counter** will still be zero since all operations on **counter** were referencing the value within the serialized closure. + +To ensure well-defined behavior in these sorts of scenarios one should use an [`Accumulator`](#AccumLink). Accumulators in Spark are used specifically to provide a mechanism for safely updating a variable when execution is split up across worker nodes in a cluster. The Accumulators section of this guide discusses these in more detail. + +In general, closures - constructs like loops or locally defined methods, should not be used to mutate some global state. Spark does not define or guarantee the behavior of mutations to objects referenced from outside of closures. Some code that does this may work in local mode, but that's just by accident and such code will not behave as expected in distributed mode. Use an Accumulator instead if some global aggregation is needed. + +#### Printing elements of an RDD +Another common idiom is attempting to print out the elements of an RDD using `rdd.foreach(println)` or `rdd.map(println)`. On a single machine, this will generate the expected output and print all the RDD's elements. However, in `cluster` mode, the output to `stdout` being called by the executors is now writing to the executor's `stdout` instead, not the one on the driver, so `stdout` on the driver won't show these! To print all elements on the driver, one can use the `collect()` method to first bring the RDD to the driver node thus: `rdd.collect().foreach(println)`. This can cause the driver to run out of memory, though, because `collect()` fetches the entire RDD to a single machine; if you only need to print a few elements of the RDD, a safer approach is to use the `take()`: `rdd.take(100).foreach(println)`. + ### Working with Key-Value Pairs
    @@ -835,7 +916,7 @@ The following table lists some of the common transformations supported by Spark. RDD API doc ([Scala](api/scala/index.html#org.apache.spark.rdd.RDD), [Java](api/java/index.html?org/apache/spark/api/java/JavaRDD.html), - [Python](api/python/pyspark.rdd.RDD-class.html)) + [Python](api/python/pyspark.html#pyspark.RDD)) and pair RDD functions doc ([Scala](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions), [Java](api/java/index.html?org/apache/spark/api/java/JavaPairRDD.html)) @@ -856,7 +937,7 @@ for details. Similar to map, but each input item can be mapped to 0 or more output items (so func should return a Seq rather than a single item). - mapPartitions(func) + mapPartitions(func) Similar to map, but runs separately on each partition (block) of the RDD, so func must be of type Iterator<T> => Iterator<U> when running on an RDD of type T. @@ -883,10 +964,10 @@ for details. Return a new dataset that contains the distinct elements of the source dataset. - groupByKey([numTasks]) + groupByKey([numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, Iterable<V>) pairs.
    Note: If you are grouping in order to perform an aggregation (such as a sum or - average) over each key, using reduceByKey or combineByKey will yield much better + average) over each key, using reduceByKey or aggregateByKey will yield much better performance.
    Note: By default, the level of parallelism in the output depends on the number of partitions of the parent RDD. @@ -894,25 +975,25 @@ for details. - reduceByKey(func, [numTasks]) + reduceByKey(func, [numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, V) pairs where the values for each key are aggregated using the given reduce function func, which must be of type (V,V) => V. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument. - aggregateByKey(zeroValue)(seqOp, combOp, [numTasks]) + aggregateByKey(zeroValue)(seqOp, combOp, [numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, U) pairs where the values for each key are aggregated using the given combine functions and a neutral "zero" value. Allows an aggregated value type that is different than the input value type, while avoiding unnecessary allocations. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument. - sortByKey([ascending], [numTasks]) + sortByKey([ascending], [numTasks]) When called on a dataset of (K, V) pairs where K implements Ordered, returns a dataset of (K, V) pairs sorted by keys in ascending or descending order, as specified in the boolean ascending argument. - join(otherDataset, [numTasks]) + join(otherDataset, [numTasks]) When called on datasets of type (K, V) and (K, W), returns a dataset of (K, (V, W)) pairs with all pairs of elements for each key. Outer joins are supported through leftOuterJoin, rightOuterJoin, and fullOuterJoin. - cogroup(otherDataset, [numTasks]) + cogroup(otherDataset, [numTasks]) When called on datasets of type (K, V) and (K, W), returns a dataset of (K, (Iterable<V>, Iterable<W>)) tuples. This operation is also called groupWith. @@ -925,17 +1006,17 @@ for details. process's stdin and lines output to its stdout are returned as an RDD of strings. - coalesce(numPartitions) + coalesce(numPartitions) Decrease the number of partitions in the RDD to numPartitions. Useful for running operations more efficiently after filtering down a large dataset. repartition(numPartitions) Reshuffle the data in the RDD randomly to create either more or fewer partitions and balance it across them. - This always shuffles all data over the network. + This always shuffles all data over the network. - repartitionAndSortWithinPartitions(partitioner) + repartitionAndSortWithinPartitions(partitioner) Repartition the RDD according to the given partitioner and, within each resulting partition, sort records by their keys. This is more efficient than calling repartition and then sorting within each partition because it can push the sorting down into the shuffle machinery. @@ -948,7 +1029,7 @@ The following table lists some of the common actions supported by Spark. Refer t RDD API doc ([Scala](api/scala/index.html#org.apache.spark.rdd.RDD), [Java](api/java/index.html?org/apache/spark/api/java/JavaRDD.html), - [Python](api/python/pyspark.rdd.RDD-class.html)) + [Python](api/python/pyspark.html#pyspark.RDD)) and pair RDD functions doc ([Scala](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions), [Java](api/java/index.html?org/apache/spark/api/java/JavaPairRDD.html)) @@ -974,7 +1055,7 @@ for details. take(n) - Return an array with the first n elements of the dataset. Note that this is currently not executed in parallel. Instead, the driver program computes all the elements. + Return an array with the first n elements of the dataset. takeSample(withReplacement, num, [seed]) @@ -999,15 +1080,77 @@ for details. SparkContext.objectFile(). - countByKey() + countByKey() Only available on RDDs of type (K, V). Returns a hashmap of (K, Int) pairs with the count of each key. foreach(func) - Run a function func on each element of the dataset. This is usually done for side effects such as updating an accumulator variable (see below) or interacting with external storage systems. + Run a function func on each element of the dataset. This is usually done for side effects such as updating an Accumulator or interacting with external storage systems. +
    Note: modifying variables other than Accumulators outside of the foreach() may result in undefined behavior. See Understanding closures for more details. +### Shuffle operations + +Certain operations within Spark trigger an event known as the shuffle. The shuffle is Spark's +mechanism for re-distributing data so that it's grouped differently across partitions. This typically +involves copying data across executors and machines, making the shuffle a complex and +costly operation. + +#### Background + +To understand what happens during the shuffle we can consider the example of the +[`reduceByKey`](#ReduceByLink) operation. The `reduceByKey` operation generates a new RDD where all +values for a single key are combined into a tuple - the key and the result of executing a reduce +function against all values associated with that key. The challenge is that not all values for a +single key necessarily reside on the same partition, or even the same machine, but they must be +co-located to compute the result. + +In Spark, data is generally not distributed across partitions to be in the necessary place for a +specific operation. During computations, a single task will operate on a single partition - thus, to +organize all the data for a single `reduceByKey` reduce task to execute, Spark needs to perform an +all-to-all operation. It must read from all partitions to find all the values for all keys, +and then bring together values across partitions to compute the final result for each key - +this is called the **shuffle**. + +Although the set of elements in each partition of newly shuffled data will be deterministic, and so +is the ordering of partitions themselves, the ordering of these elements is not. If one desires predictably +ordered data following shuffle then it's possible to use: + +* `mapPartitions` to sort each partition using, for example, `.sorted` +* `repartitionAndSortWithinPartitions` to efficiently sort partitions while simultaneously repartitioning +* `sortBy` to make a globally ordered RDD + +Operations which can cause a shuffle include **repartition** operations like +[`repartition`](#RepartitionLink), and [`coalesce`](#CoalesceLink), **'ByKey** operations +(except for counting) like [`groupByKey`](#GroupByLink) and [`reduceByKey`](#ReduceByLink), and +**join** operations like [`cogroup`](#CogroupLink) and [`join`](#JoinLink). + +#### Performance Impact +The **Shuffle** is an expensive operation since it involves disk I/O, data serialization, and +network I/O. To organize data for the shuffle, Spark generates sets of tasks - *map* tasks to +organize the data, and a set of *reduce* tasks to aggregate it. This nomenclature comes from +MapReduce and does not directly relate to Spark's `map` and `reduce` operations. + +Internally, results from individual map tasks are kept in memory until they can't fit. Then, these +are sorted based on the target partition and written to a single file. On the reduce side, tasks +read the relevant sorted blocks. + +Certain shuffle operations can consume significant amounts of heap memory since they employ +in-memory data structures to organize records before or after transferring them. Specifically, +`reduceByKey` and `aggregateByKey` create these structures on the map side and `'ByKey` operations +generate these on the reduce side. When data does not fit in memory Spark will spill these tables +to disk, incurring the additional overhead of disk I/O and increased garbage collection. + +Shuffle also generates a large number of intermediate files on disk. As of Spark 1.3, these files +are not cleaned up from Spark's temporary storage until Spark is stopped, which means that +long-running Spark jobs may consume available disk space. This is done so the shuffle doesn't need +to be re-computed if the lineage is re-computed. The temporary storage directory is specified by the +`spark.local.dir` configuration parameter when configuring the Spark context. + +Shuffle behavior can be tuned by adjusting a variety of configuration parameters. See the +'Shuffle Behavior' section within the [Spark Configuration Guide](configuration.html). + ## RDD Persistence One of the most important capabilities in Spark is *persisting* (or *caching*) a dataset in memory @@ -1027,7 +1170,7 @@ replicate it across nodes, or store it off-heap in [Tachyon](http://tachyon-proj These levels are set by passing a `StorageLevel` object ([Scala](api/scala/index.html#org.apache.spark.storage.StorageLevel), [Java](api/java/index.html?org/apache/spark/storage/StorageLevel.html), -[Python](api/python/pyspark.storagelevel.StorageLevel-class.html)) +[Python](api/python/pyspark.html#pyspark.StorageLevel)) to `persist()`. The `cache()` method is a shorthand for using the default storage level, which is `StorageLevel.MEMORY_ONLY` (store deserialized objects in memory). The full set of storage levels is: @@ -1129,6 +1272,12 @@ than shipping a copy of it with tasks. They can be used, for example, to give ev large input dataset in an efficient manner. Spark also attempts to distribute broadcast variables using efficient broadcast algorithms to reduce communication cost. +Spark actions are executed through a set of stages, separated by distributed "shuffle" operations. +Spark automatically broadcasts the common data needed by tasks within each stage. The data +broadcasted this way is cached in serialized form and deserialized before running each task. This +means that explicitly creating broadcast variables is only useful when tasks across multiple stages +need the same data or when caching the data in deserialized form is important. + Broadcast variables are created from a variable `v` by calling `SparkContext.broadcast(v)`. The broadcast variable is a wrapper around `v`, and its value can be accessed by calling the `value` method. The code below shows this: @@ -1177,7 +1326,7 @@ run on the cluster so that `v` is not shipped to the nodes more than once. In ad `v` should not be modified after it is broadcast in order to ensure that all nodes get the same value of the broadcast variable (e.g. if the variable is shipped to a new node later). -## Accumulators +## Accumulators Accumulators are variables that are only "added" to through an associative operation and can therefore be efficiently supported in parallel. They can be used to implement counters (as in @@ -1290,7 +1439,7 @@ scala> accum.value {% endhighlight %} While this code used the built-in support for accumulators of type Int, programmers can also -create their own types by subclassing [AccumulatorParam](api/python/pyspark.accumulators.AccumulatorParam-class.html). +create their own types by subclassing [AccumulatorParam](api/python/pyspark.html#pyspark.AccumulatorParam). The AccumulatorParam interface has two methods: `zero` for providing a "zero value" for your data type, and `addInPlace` for adding two values together. For example, supposing we had a `Vector` class representing mathematical vectors, we could write: @@ -1322,25 +1471,28 @@ Accumulators do not change the lazy evaluation model of Spark. If they are being
    {% highlight scala %} -val acc = sc.accumulator(0) -data.map(x => acc += x; f(x)) -// Here, acc is still 0 because no actions have cause the `map` to be computed. +val accum = sc.accumulator(0) +data.map { x => accum += x; f(x) } +// Here, accum is still 0 because no actions have caused the `map` to be computed. {% endhighlight %}
    {% highlight java %} Accumulator accum = sc.accumulator(0); -data.map(x -> accum.add(x); f(x);); -// Here, accum is still 0 because no actions have cause the `map` to be computed. +data.map(x -> { accum.add(x); return f(x); }); +// Here, accum is still 0 because no actions have caused the `map` to be computed. {% endhighlight %}
    {% highlight python %} accum = sc.accumulator(0) -data.map(lambda x => acc.add(x); f(x)) -# Here, acc is still 0 because no actions have cause the `map` to be computed. +def g(x): + accum.add(x) + return f(x) +data.map(g) +# Here, accum is still 0 because no actions have caused the `map` to be computed. {% endhighlight %}
    @@ -1352,6 +1504,11 @@ The [application submission guide](submitting-applications.html) describes how t In short, once you package your application into a JAR (for Java/Scala) or a set of `.py` or `.zip` files (for Python), the `bin/spark-submit` script lets you submit it to any supported cluster manager. +# Launching Spark jobs from Java / Scala + +The [org.apache.spark.launcher](api/java/index.html?org/apache/spark/launcher/package-summary.html) +package provides classes for launching Spark jobs as child processes using a simple Java API. + # Unit Testing Spark is friendly to unit testing with any popular unit test framework. diff --git a/docs/quick-start.md b/docs/quick-start.md index bf643bb70e15..81143da865cf 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -1,6 +1,7 @@ --- layout: global title: Quick Start +description: Quick start tutorial for Spark SPARK_VERSION_SHORT --- * This will become a table of contents (this text will be scraped). diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 78358499fd01..594bf78b6771 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -110,7 +110,7 @@ cluster, or `mesos://zk://host:2181` for a multi-master Mesos cluster using ZooK The driver also needs some configuration in `spark-env.sh` to interact properly with Mesos: 1. In `spark-env.sh` set some environment variables: - * `export MESOS_NATIVE_LIBRARY=`. This path is typically + * `export MESOS_NATIVE_JAVA_LIBRARY=`. This path is typically `/lib/libmesos.so` where the prefix is `/usr/local` by default. See Mesos installation instructions above. On Mac OS X, the library is called `libmesos.dylib` instead of `libmesos.so`. @@ -167,9 +167,6 @@ acquire. By default, it will acquire *all* cores in the cluster (that get offere only makes sense if you run just one application at a time. You can cap the maximum number of cores using `conf.set("spark.cores.max", "10")` (for example). -# Known issues -- When using the "fine-grained" mode, make sure that your executors always leave 32 MB free on the slaves. Otherwise it can happen that your Spark job does not proceed anymore. Currently, Apache Mesos only offers resources if there are at least 32 MB memory allocatable. But as Spark allocates memory only for the executor and cpu only for tasks, it can happen on high slave memory usage that no new tasks will be started anymore. More details can be found in [MESOS-1688](https://issues.apache.org/jira/browse/MESOS-1688). Alternatively use the "coarse-gained" mode, which is not affected by this issue. - # Running Alongside Hadoop You can run Spark and Mesos alongside your existing Hadoop cluster by just launching them as a @@ -197,7 +194,11 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.coarse false - Set the run mode for Spark on Mesos. For more information about the run mode, refer to #Mesos Run Mode section above. + If set to "true", runs over Mesos clusters in + "coarse-grained" sharing mode, + where Spark acquires one long-lived Mesos task on each machine instead of one Mesos task per + Spark task. This gives lower-latency scheduling for short queries, but leaves resources in use + for the whole duration of the Spark job. @@ -209,21 +210,33 @@ See the [configuration page](configuration.html) for information on Spark config Note that total amount of cores the executor will request in total will not exceed the spark.cores.max setting. + + spark.mesos.mesosExecutor.cores + 1.0 + + (Fine-grained mode only) Number of cores to give each Mesos executor. This does not + include the cores used to run the Spark tasks. In other words, even if no Spark task + is being run, each Mesos executor will occupy the number of cores configured here. + The value can be a floating point number. + + spark.mesos.executor.home - SPARK_HOME + driver side SPARK_HOME - The location where the mesos executor will look for Spark binaries to execute, and uses the SPARK_HOME setting on default. - This variable is only used when no spark.executor.uri is provided, and assumes Spark is installed on the specified location - on each slave. + Set the directory in which Spark is installed on the executors in Mesos. By default, the + executors will simply use the driver's Spark home directory, which may not be visible to + them. Note that this is only relevant if a Spark binary package is not specified through + spark.executor.uri. spark.mesos.executor.memoryOverhead - 384 + executor memory * 0.10, with minimum of 384 - The amount of memory that Mesos executor will request for the task to account for the overhead of running the executor itself. - The final total amount of memory allocated is the maximum value between executor memory plus memoryOverhead, and overhead fraction (1.07) plus the executor memory. + The amount of additional memory, specified in MB, to be allocated per executor. By default, + the overhead will be larger of either 384 or 10% of `spark.executor.memory`. If it's set, + the final overhead will be this value. diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 68ab127bcf08..0968fc5ad632 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -48,9 +48,9 @@ Most of the configs are the same for Spark on YARN as for other deployment modes spark.yarn.am.waitTime - 100000 + 100s - In yarn-cluster mode, time in milliseconds for the application master to wait for the + In yarn-cluster mode, time for the application master to wait for the SparkContext to be initialized. In yarn-client mode, time for the application master to wait for the driver to connect to it. @@ -87,7 +87,8 @@ Most of the configs are the same for Spark on YARN as for other deployment modes spark.yarn.historyServer.address (none) - The address of the Spark history server (i.e. host.com:18080). The address should not contain a scheme (http://). Defaults to not being set since the history server is an optional service. This address is given to the YARN ResourceManager when the Spark application finishes to link the application from the ResourceManager UI to the Spark history server UI. + The address of the Spark history server (i.e. host.com:18080). The address should not contain a scheme (http://). Defaults to not being set since the history server is an optional service. This address is given to the YARN ResourceManager when the Spark application finishes to link the application from the ResourceManager UI to the Spark history server UI. + For this property, YARN properties can be used as variables, and these are substituted by Spark at runtime. For eg, if the Spark history server runs on the same node as the YARN ResourceManager, it can be set to `${hadoopconf-yarn.resourcemanager.hostname}:18080`. @@ -104,9 +105,16 @@ Most of the configs are the same for Spark on YARN as for other deployment modes Comma-separated list of files to be placed in the working directory of each executor. + + spark.executor.instances + 2 + + The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. + + spark.yarn.executor.memoryOverhead - executorMemory * 0.07, with minimum of 384 + executorMemory * 0.10, with minimum of 384 The amount of off heap memory (in megabytes) to be allocated per executor. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the executor size (typically 6-10%). @@ -189,12 +197,25 @@ Most of the configs are the same for Spark on YARN as for other deployment modes It should be no larger than the global number of max attempts in the YARN configuration. + + spark.yarn.submit.waitAppCompletion + true + + In YARN cluster mode, controls whether the client waits to exit until the application completes. + If set to true, the client process will stay alive reporting the application's status. + Otherwise, the client process will exit after submission. + + # Launching Spark on YARN Ensure that `HADOOP_CONF_DIR` or `YARN_CONF_DIR` points to the directory which contains the (client side) configuration files for the Hadoop cluster. -These configs are used to write to the dfs and connect to the YARN ResourceManager. +These configs are used to write to the dfs and connect to the YARN ResourceManager. The +configuration contained in this directory will be distributed to the YARN cluster so that all +containers used by the application use the same configuration. If the configuration references +Java system properties or environment variables not managed by YARN, they should also be set in the +Spark application's configuration (driver, executors, and the AM when running in client mode). There are two deploy modes that can be used to launch Spark applications on YARN. In yarn-cluster mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In yarn-client mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. @@ -267,6 +288,6 @@ If you need a reference to the proper location to put log files in the YARN so t # Important notes - Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. -- The local directories used by Spark executors will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. +- In `yarn-cluster` mode, the local directories used by the Spark executors and the Spark driver will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. In `yarn-client` mode, the Spark executors will use the local directories configured for YARN while the Spark driver will use those defined in `spark.local.dir`. This is because the Spark driver does not run on the YARN cluster in `yarn-client` mode, only the Spark executors do. - The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. - The `--jars` option allows the `SparkContext.addJar` function to work if you are using it with local files and running in `yarn-cluster` mode. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files. diff --git a/docs/security.md b/docs/security.md index 1e206a139fb7..c034ba12ff1f 100644 --- a/docs/security.md +++ b/docs/security.md @@ -1,6 +1,7 @@ --- layout: global -title: Spark Security +displayTitle: Spark Security +title: Security --- Spark currently supports authentication via a shared secret. Authentication can be configured to be on via the `spark.authenticate` configuration parameter. This parameter controls whether the Spark communication protocols do authentication using the shared secret. This authentication is a basic handshake to make sure both sides have the same shared secret and are allowed to communicate. If the shared secret is not identical they will not be allowed to communicate. The shared secret is created as follows: @@ -20,6 +21,30 @@ Spark allows for a set of administrators to be specified in the acls who always If your applications are using event logging, the directory where the event logs go (`spark.eventLog.dir`) should be manually created and have the proper permissions set on it. If you want those log files secured, the permissions should be set to `drwxrwxrwxt` for that directory. The owner of the directory should be the super user who is running the history server and the group permissions should be restricted to super user group. This will allow all users to write to the directory but will prevent unprivileged users from removing or renaming a file unless they own the file or directory. The event log files will be created by Spark with permissions such that only the user and group have read and write access. +## Encryption + +Spark supports SSL for Akka and HTTP (for broadcast and file server) protocols. However SSL is not supported yet for WebUI and block transfer service. + +Connection encryption (SSL) configuration is organized hierarchically. The user can configure the default SSL settings which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the protocols without disabling the ability to configure each one individually. The common SSL settings are at `spark.ssl` namespace in Spark configuration, while Akka SSL configuration is at `spark.ssl.akka` and HTTP for broadcast and file server SSL configuration is at `spark.ssl.fs`. The full breakdown can be found on the [configuration page](configuration.html). + +SSL must be configured on each node and configured for each component involved in communication using the particular protocol. + +### YARN mode +The key-store can be prepared on the client side and then distributed and used by the executors as the part of the application. It is possible because the user is able to deploy files before the application is started in YARN by using `spark.yarn.dist.files` or `spark.yarn.dist.archives` configuration settings. The responsibility for encryption of transferring these files is on YARN side and has nothing to do with Spark. + +### Standalone mode +The user needs to provide key-stores and configuration options for master and workers. They have to be set by attaching appropriate Java system properties in `SPARK_MASTER_OPTS` and in `SPARK_WORKER_OPTS` environment variables, or just in `SPARK_DAEMON_JAVA_OPTS`. In this mode, the user may allow the executors to use the SSL settings inherited from the worker which spawned that executor. It can be accomplished by setting `spark.ssl.useNodeLocalConf` to `true`. If that parameter is set, the settings provided by user on the client side, are not used by the executors. + +### Preparing the key-stores +Key-stores can be generated by `keytool` program. The reference documentation for this tool is +[here](https://docs.oracle.com/javase/7/docs/technotes/tools/solaris/keytool.html). The most basic +steps to configure the key-stores and the trust-store for the standalone deployment mode is as +follows: +* Generate a keys pair for each node +* Export the public key of the key pair to a file on each node +* Import all exported public keys into a single trust-store +* Distribute the trust-store over the nodes + ## Configuring Ports for Network Security Spark makes heavy use of the network, and some environments have strict requirements for using tight diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 5c6084fb4625..0eed9adacf12 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -24,7 +24,7 @@ the master's web UI, which is [http://localhost:8080](http://localhost:8080) by Similarly, you can start one or more workers and connect them to the master via: - ./bin/spark-class org.apache.spark.deploy.worker.Worker spark://IP:PORT + ./sbin/start-slave.sh Once you have started a worker, look at the master's web UI ([http://localhost:8080](http://localhost:8080) by default). You should see the new node listed there, along with its number of CPUs and memory (minus one gigabyte left for the OS). @@ -81,6 +81,7 @@ Once you've set up this file, you can launch or stop your cluster with the follo - `sbin/start-master.sh` - Starts a master instance on the machine the script is executed on. - `sbin/start-slaves.sh` - Starts a slave instance on each machine specified in the `conf/slaves` file. +- `sbin/start-slave.sh` - Starts a slave instance on the machine the script is executed on. - `sbin/start-all.sh` - Starts both a master and a number of slaves as described above. - `sbin/stop-master.sh` - Stops the master that was started via the `bin/start-master.sh` script. - `sbin/stop-slaves.sh` - Stops all slave instances on the machines specified in the `conf/slaves` file. @@ -222,8 +223,7 @@ SPARK_WORKER_OPTS supports the following system properties: false Enable periodic cleanup of worker / application directories. Note that this only affects standalone - mode, as YARN works differently. Applications directories are cleaned up regardless of whether - the application is still running. + mode, as YARN works differently. Only the directories of stopped applications are cleaned up. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index be8c5c2c1522..b8233ae06fdf 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1,6 +1,7 @@ --- layout: global -title: Spark SQL Programming Guide +displayTitle: Spark SQL and DataFrame Guide +title: Spark SQL and DataFrames --- * This will become a table of contents (this text will be scraped). @@ -8,168 +9,348 @@ title: Spark SQL Programming Guide # Overview +Spark SQL is a Spark module for structured data processing. It provides a programming abstraction called DataFrames and can also act as distributed SQL query engine. + + +# DataFrames + +A DataFrame is a distributed collection of data organized into named columns. It is conceptually equivalent to a table in a relational database or a data frame in R/Python, but with richer optimizations under the hood. DataFrames can be constructed from a wide array of sources such as: structured data files, tables in Hive, external databases, or existing RDDs. + +The DataFrame API is available in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), [Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), and [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame). + +All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell` or the `pyspark` shell. + + +## Starting Point: `SQLContext` +
    -Spark SQL allows relational queries expressed in SQL, HiveQL, or Scala to be executed using -Spark. At the core of this component is a new type of RDD, -[SchemaRDD](api/scala/index.html#org.apache.spark.sql.SchemaRDD). SchemaRDDs are composed of -[Row](api/scala/index.html#org.apache.spark.sql.package@Row:org.apache.spark.sql.catalyst.expressions.Row.type) objects, along with -a schema that describes the data types of each column in the row. A SchemaRDD is similar to a table -in a traditional relational database. A SchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io) -file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). +The entry point into all functionality in Spark SQL is the +[`SQLContext`](api/scala/index.html#org.apache.spark.sql.`SQLContext`) class, or one of its +descendants. To create a basic `SQLContext`, all you need is a SparkContext. + +{% highlight scala %} +val sc: SparkContext // An existing SparkContext. +val sqlContext = new org.apache.spark.sql.SQLContext(sc) -All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell`. +// this is used to implicitly convert an RDD to a DataFrame. +import sqlContext.implicits._ +{% endhighlight %}
    -
    -Spark SQL allows relational queries expressed in SQL or HiveQL to be executed using -Spark. At the core of this component is a new type of RDD, -[JavaSchemaRDD](api/scala/index.html#org.apache.spark.sql.api.java.JavaSchemaRDD). JavaSchemaRDDs are composed of -[Row](api/scala/index.html#org.apache.spark.sql.api.java.Row) objects, along with -a schema that describes the data types of each column in the row. A JavaSchemaRDD is similar to a table -in a traditional relational database. A JavaSchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io) -file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). +
    + +The entry point into all functionality in Spark SQL is the +[`SQLContext`](api/java/index.html#org.apache.spark.sql.SQLContext) class, or one of its +descendants. To create a basic `SQLContext`, all you need is a SparkContext. + +{% highlight java %} +JavaSparkContext sc = ...; // An existing JavaSparkContext. +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); +{% endhighlight %} +
    -Spark SQL allows relational queries expressed in SQL or HiveQL to be executed using -Spark. At the core of this component is a new type of RDD, -[SchemaRDD](api/python/pyspark.sql.SchemaRDD-class.html). SchemaRDDs are composed of -[Row](api/python/pyspark.sql.Row-class.html) objects, along with -a schema that describes the data types of each column in the row. A SchemaRDD is similar to a table -in a traditional relational database. A SchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io) -file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). +The entry point into all relational functionality in Spark is the +[`SQLContext`](api/python/pyspark.sql.html#pyspark.sql.SQLContext) class, or one +of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. + +{% highlight python %} +from pyspark.sql import SQLContext +sqlContext = SQLContext(sc) +{% endhighlight %} -All of the examples on this page use sample data included in the Spark distribution and can be run in the `pyspark` shell.
    -**Spark SQL is currently an alpha component. While we will minimize API changes, some APIs may change in future releases.** +In addition to the basic `SQLContext`, you can also create a `HiveContext`, which provides a +superset of the functionality provided by the basic `SQLContext`. Additional features include +the ability to write queries using the more complete HiveQL parser, access to Hive UDFs, and the +ability to read data from Hive tables. To use a `HiveContext`, you do not need to have an +existing Hive setup, and all of the data sources available to a `SQLContext` are still available. +`HiveContext` is only packaged separately to avoid including all of Hive's dependencies in the default +Spark build. If these dependencies are not a problem for your application then using `HiveContext` +is recommended for the 1.3 release of Spark. Future releases will focus on bringing `SQLContext` up +to feature parity with a `HiveContext`. + +The specific variant of SQL that is used to parse queries can also be selected using the +`spark.sql.dialect` option. This parameter can be changed using either the `setConf` method on +a `SQLContext` or by using a `SET key=value` command in SQL. For a `SQLContext`, the only dialect +available is "sql" which uses a simple SQL parser provided by Spark SQL. In a `HiveContext`, the +default is "hiveql", though "sql" is also available. Since the HiveQL parser is much more complete, +this is recommended for most use cases. + + +## Creating DataFrames -*************************************************************************************************** +With a `SQLContext`, applications can create `DataFrame`s from an existing `RDD`, from a Hive table, or from data sources. -# Getting Started +As an example, the following creates a `DataFrame` based on the content of a JSON file:
    - -The entry point into all relational functionality in Spark is the -[SQLContext](api/scala/index.html#org.apache.spark.sql.SQLContext) class, or one of its -descendants. To create a basic SQLContext, all you need is a SparkContext. - {% highlight scala %} val sc: SparkContext // An existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) -// createSchemaRDD is used to implicitly convert an RDD to a SchemaRDD. -import sqlContext.createSchemaRDD -{% endhighlight %} +val df = sqlContext.jsonFile("examples/src/main/resources/people.json") -In addition to the basic SQLContext, you can also create a HiveContext, which provides a -superset of the functionality provided by the basic SQLContext. Additional features include -the ability to write queries using the more complete HiveQL parser, access to HiveUDFs, and the -ability to read data from Hive tables. To use a HiveContext, you do not need to have an -existing Hive setup, and all of the data sources available to a SQLContext are still available. -HiveContext is only packaged separately to avoid including all of Hive's dependencies in the default -Spark build. If these dependencies are not a problem for your application then using HiveContext -is recommended for the 1.2 release of Spark. Future releases will focus on bringing SQLContext up to -feature parity with a HiveContext. +// Displays the content of the DataFrame to stdout +df.show() +{% endhighlight %}
    - -The entry point into all relational functionality in Spark is the -[JavaSQLContext](api/scala/index.html#org.apache.spark.sql.api.java.JavaSQLContext) class, or one -of its descendants. To create a basic JavaSQLContext, all you need is a JavaSparkContext. - {% highlight java %} JavaSparkContext sc = ...; // An existing JavaSparkContext. -JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); -{% endhighlight %} +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); -In addition to the basic SQLContext, you can also create a HiveContext, which provides a strict -super set of the functionality provided by the basic SQLContext. Additional features include -the ability to write queries using the more complete HiveQL parser, access to HiveUDFs, and the -ability to read data from Hive tables. To use a HiveContext, you do not need to have an -existing Hive setup, and all of the data sources available to a SQLContext are still available. -HiveContext is only packaged separately to avoid including all of Hive's dependencies in the default -Spark build. If these dependencies are not a problem for your application then using HiveContext -is recommended for the 1.2 release of Spark. Future releases will focus on bringing SQLContext up to -feature parity with a HiveContext. +DataFrame df = sqlContext.jsonFile("examples/src/main/resources/people.json"); + +// Displays the content of the DataFrame to stdout +df.show(); +{% endhighlight %}
    +{% highlight python %} +from pyspark.sql import SQLContext +sqlContext = SQLContext(sc) -The entry point into all relational functionality in Spark is the -[SQLContext](api/python/pyspark.sql.SQLContext-class.html) class, or one -of its decedents. To create a basic SQLContext, all you need is a SparkContext. +df = sqlContext.jsonFile("examples/src/main/resources/people.json") +# Displays the content of the DataFrame to stdout +df.show() +{% endhighlight %} + +
    +
    + + +## DataFrame Operations + +DataFrames provide a domain-specific language for structured data manipulation in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), [Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), and [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame). + +Here we include some basic examples of structured data processing using DataFrames: + + +
    +
    +{% highlight scala %} +val sc: SparkContext // An existing SparkContext. +val sqlContext = new org.apache.spark.sql.SQLContext(sc) + +// Create the DataFrame +val df = sqlContext.jsonFile("examples/src/main/resources/people.json") + +// Show the content of the DataFrame +df.show() +// age name +// null Michael +// 30 Andy +// 19 Justin + +// Print the schema in a tree format +df.printSchema() +// root +// |-- age: long (nullable = true) +// |-- name: string (nullable = true) + +// Select only the "name" column +df.select("name").show() +// name +// Michael +// Andy +// Justin + +// Select everybody, but increment the age by 1 +df.select(df("name"), df("age") + 1).show() +// name (age + 1) +// Michael null +// Andy 31 +// Justin 20 + +// Select people older than 21 +df.filter(df("age") > 21).show() +// age name +// 30 Andy + +// Count people by age +df.groupBy("age").count().show() +// age count +// null 1 +// 19 1 +// 30 1 +{% endhighlight %} + +
    + +
    +{% highlight java %} +JavaSparkContext sc // An existing SparkContext. +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc) + +// Create the DataFrame +DataFrame df = sqlContext.jsonFile("examples/src/main/resources/people.json"); + +// Show the content of the DataFrame +df.show(); +// age name +// null Michael +// 30 Andy +// 19 Justin + +// Print the schema in a tree format +df.printSchema(); +// root +// |-- age: long (nullable = true) +// |-- name: string (nullable = true) + +// Select only the "name" column +df.select("name").show(); +// name +// Michael +// Andy +// Justin + +// Select everybody, but increment the age by 1 +df.select(df.col("name"), df.col("age").plus(1)).show(); +// name (age + 1) +// Michael null +// Andy 31 +// Justin 20 + +// Select people older than 21 +df.filter(df.col("age").gt(21)).show(); +// age name +// 30 Andy + +// Count people by age +df.groupBy("age").count().show(); +// age count +// null 1 +// 19 1 +// 30 1 +{% endhighlight %} + +
    + +
    {% highlight python %} from pyspark.sql import SQLContext sqlContext = SQLContext(sc) -{% endhighlight %} -In addition to the basic SQLContext, you can also create a HiveContext, which provides a strict -super set of the functionality provided by the basic SQLContext. Additional features include -the ability to write queries using the more complete HiveQL parser, access to HiveUDFs, and the -ability to read data from Hive tables. To use a HiveContext, you do not need to have an -existing Hive setup, and all of the data sources available to a SQLContext are still available. -HiveContext is only packaged separately to avoid including all of Hive's dependencies in the default -Spark build. If these dependencies are not a problem for your application then using HiveContext -is recommended for the 1.2 release of Spark. Future releases will focus on bringing SQLContext up to -feature parity with a HiveContext. +# Create the DataFrame +df = sqlContext.jsonFile("examples/src/main/resources/people.json") + +# Show the content of the DataFrame +df.show() +## age name +## null Michael +## 30 Andy +## 19 Justin + +# Print the schema in a tree format +df.printSchema() +## root +## |-- age: long (nullable = true) +## |-- name: string (nullable = true) + +# Select only the "name" column +df.select("name").show() +## name +## Michael +## Andy +## Justin + +# Select everybody, but increment the age by 1 +df.select(df.name, df.age + 1).show() +## name (age + 1) +## Michael null +## Andy 31 +## Justin 20 + +# Select people older than 21 +df.filter(df.age > 21).show() +## age name +## 30 Andy + +# Count people by age +df.groupBy("age").count().show() +## age count +## null 1 +## 19 1 +## 30 1 +{% endhighlight %} + +
    + +## Running SQL Queries Programmatically + +The `sql` function on a `SQLContext` enables applications to run SQL queries programmatically and returns the result as a `DataFrame`. + +
    +
    +{% highlight scala %} +val sqlContext = ... // An existing SQLContext +val df = sqlContext.sql("SELECT * FROM table") +{% endhighlight %}
    -The specific variant of SQL that is used to parse queries can also be selected using the -`spark.sql.dialect` option. This parameter can be changed using either the `setConf` method on -a SQLContext or by using a `SET key=value` command in SQL. For a SQLContext, the only dialect -available is "sql" which uses a simple SQL parser provided by Spark SQL. In a HiveContext, the -default is "hiveql", though "sql" is also available. Since the HiveQL parser is much more complete, - this is recommended for most use cases. +
    +{% highlight java %} +SQLContext sqlContext = ... // An existing SQLContext +DataFrame df = sqlContext.sql("SELECT * FROM table") +{% endhighlight %} +
    -# Data Sources +
    +{% highlight python %} +from pyspark.sql import SQLContext +sqlContext = SQLContext(sc) +df = sqlContext.sql("SELECT * FROM table") +{% endhighlight %} +
    +
    -Spark SQL supports operating on a variety of data sources through the `SchemaRDD` interface. -A SchemaRDD can be operated on as normal RDDs and can also be registered as a temporary table. -Registering a SchemaRDD as a table allows you to run SQL queries over its data. This section -describes the various methods for loading data into a SchemaRDD. -## RDDs +## Interoperating with RDDs -Spark SQL supports two different methods for converting existing RDDs into SchemaRDDs. The first +Spark SQL supports two different methods for converting existing RDDs into DataFrames. The first method uses reflection to infer the schema of an RDD that contains specific types of objects. This reflection based approach leads to more concise code and works well when you already know the schema while writing your Spark application. -The second method for creating SchemaRDDs is through a programmatic interface that allows you to +The second method for creating DataFrames is through a programmatic interface that allows you to construct a schema and then apply it to an existing RDD. While this method is more verbose, it allows -you to construct SchemaRDDs when the columns and their types are not known until runtime. +you to construct DataFrames when the columns and their types are not known until runtime. ### Inferring the Schema Using Reflection
    -The Scala interaface for Spark SQL supports automatically converting an RDD containing case classes -to a SchemaRDD. The case class +The Scala interface for Spark SQL supports automatically converting an RDD containing case classes +to a DataFrame. The case class defines the schema of the table. The names of the arguments to the case class are read using reflection and become the names of the columns. Case classes can also be nested or contain complex -types such as Sequences or Arrays. This RDD can be implicitly converted to a SchemaRDD and then be +types such as Sequences or Arrays. This RDD can be implicitly converted to a DataFrame and then be registered as a table. Tables can be used in subsequent SQL statements. {% highlight scala %} // sc is an existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) -// createSchemaRDD is used to implicitly convert an RDD to a SchemaRDD. -import sqlContext.createSchemaRDD +// this is used to implicitly convert an RDD to a DataFrame. +import sqlContext.implicits._ // Define the schema using a case class. // Note: Case classes in Scala 2.10 can support only up to 22 fields. To work around this limit, @@ -177,13 +358,13 @@ import sqlContext.createSchemaRDD case class Person(name: String, age: Int) // Create an RDD of Person objects and register it as a table. -val people = sc.textFile("examples/src/main/resources/people.txt").map(_.split(",")).map(p => Person(p(0), p(1).trim.toInt)) +val people = sc.textFile("examples/src/main/resources/people.txt").map(_.split(",")).map(p => Person(p(0), p(1).trim.toInt)).toDF() people.registerTempTable("people") // SQL statements can be run by using the sql methods provided by sqlContext. val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") -// The results of SQL queries are SchemaRDDs and support all the normal RDD operations. +// The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. teenagers.map(t => "Name: " + t(0)).collect().foreach(println) {% endhighlight %} @@ -193,7 +374,7 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println)
    Spark SQL supports automatically converting an RDD of [JavaBeans](http://stackoverflow.com/questions/3295496/what-is-a-javabean-exactly) -into a Schema RDD. The BeanInfo, obtained using reflection, defines the schema of the table. +into a DataFrame. The BeanInfo, obtained using reflection, defines the schema of the table. Currently, Spark SQL does not support JavaBeans that contain nested or contain complex types such as Lists or Arrays. You can create a JavaBean by creating a class that implements Serializable and has getters and setters for all of its fields. @@ -224,12 +405,12 @@ public static class Person implements Serializable { {% endhighlight %} -A schema can be applied to an existing RDD by calling `applySchema` and providing the Class object +A schema can be applied to an existing RDD by calling `createDataFrame` and providing the Class object for the JavaBean. {% highlight java %} // sc is an existing JavaSparkContext. -JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); // Load a text file and convert each line to a JavaBean. JavaRDD people = sc.textFile("examples/src/main/resources/people.txt").map( @@ -246,15 +427,15 @@ JavaRDD people = sc.textFile("examples/src/main/resources/people.txt").m }); // Apply a schema to an RDD of JavaBeans and register it as a table. -JavaSchemaRDD schemaPeople = sqlContext.applySchema(people, Person.class); +DataFrame schemaPeople = sqlContext.createDataFrame(people, Person.class); schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. -JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") +DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") -// The results of SQL queries are SchemaRDDs and support all the normal RDD operations. +// The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. -List teenagerNames = teenagers.map(new Function() { +List teenagerNames = teenagers.javaRDD().map(new Function() { public String call(Row row) { return "Name: " + row.getString(0); } @@ -266,7 +447,7 @@ List teenagerNames = teenagers.map(new Function() {
    -Spark SQL can convert an RDD of Row objects to a SchemaRDD, inferring the datatypes. Rows are constructed by passing a list of +Spark SQL can convert an RDD of Row objects to a DataFrame, inferring the datatypes. Rows are constructed by passing a list of key/value pairs as kwargs to the Row class. The keys of this list define the column names of the table, and the types are inferred by looking at the first row. Since we currently only look at the first row, it is important that there is no missing data in the first row of the RDD. In future versions we @@ -283,11 +464,11 @@ lines = sc.textFile("examples/src/main/resources/people.txt") parts = lines.map(lambda l: l.split(",")) people = parts.map(lambda p: Row(name=p[0], age=int(p[1]))) -# Infer the schema, and register the SchemaRDD as a table. +# Infer the schema, and register the DataFrame as a table. schemaPeople = sqlContext.inferSchema(people) schemaPeople.registerTempTable("people") -# SQL can be run over SchemaRDDs that have been registered as a table. +# SQL can be run over DataFrames that have been registered as a table. teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") # The results of SQL queries are RDDs and support all the normal RDD operations. @@ -309,12 +490,12 @@ for teenName in teenNames.collect(): When case classes cannot be defined ahead of time (for example, the structure of records is encoded in a string, or a text dataset will be parsed and fields will be projected differently for different users), -a `SchemaRDD` can be created programmatically with three steps. +a `DataFrame` can be created programmatically with three steps. 1. Create an RDD of `Row`s from the original RDD; 2. Create the schema represented by a `StructType` matching the structure of `Row`s in the RDD created in Step 1. -3. Apply the schema to the RDD of `Row`s via `applySchema` method provided +3. Apply the schema to the RDD of `Row`s via `createDataFrame` method provided by `SQLContext`. For example: @@ -328,8 +509,11 @@ val people = sc.textFile("examples/src/main/resources/people.txt") // The schema is encoded in a string val schemaString = "name age" -// Import Spark SQL data types and Row. -import org.apache.spark.sql._ +// Import Row. +import org.apache.spark.sql.Row; + +// Import Spark SQL data types +import org.apache.spark.sql.types.{StructType,StructField,StringType}; // Generate the schema based on the string of schema val schema = @@ -340,15 +524,15 @@ val schema = val rowRDD = people.map(_.split(",")).map(p => Row(p(0), p(1).trim)) // Apply the schema to the RDD. -val peopleSchemaRDD = sqlContext.applySchema(rowRDD, schema) +val peopleDataFrame = sqlContext.createDataFrame(rowRDD, schema) -// Register the SchemaRDD as a table. -peopleSchemaRDD.registerTempTable("people") +// Register the DataFrames as a table. +peopleDataFrame.registerTempTable("people") // SQL statements can be run by using the sql methods provided by sqlContext. val results = sqlContext.sql("SELECT name FROM people") -// The results of SQL queries are SchemaRDDs and support all the normal RDD operations. +// The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. results.map(t => "Name: " + t(0)).collect().foreach(println) {% endhighlight %} @@ -361,26 +545,29 @@ results.map(t => "Name: " + t(0)).collect().foreach(println) When JavaBean classes cannot be defined ahead of time (for example, the structure of records is encoded in a string, or a text dataset will be parsed and fields will be projected differently for different users), -a `SchemaRDD` can be created programmatically with three steps. +a `DataFrame` can be created programmatically with three steps. 1. Create an RDD of `Row`s from the original RDD; 2. Create the schema represented by a `StructType` matching the structure of `Row`s in the RDD created in Step 1. -3. Apply the schema to the RDD of `Row`s via `applySchema` method provided -by `JavaSQLContext`. +3. Apply the schema to the RDD of `Row`s via `createDataFrame` method provided +by `SQLContext`. For example: {% highlight java %} -// Import factory methods provided by DataType. -import org.apache.spark.sql.api.java.DataType +import org.apache.spark.api.java.function.Function; +// Import factory methods provided by DataTypes. +import org.apache.spark.sql.types.DataTypes; // Import StructType and StructField -import org.apache.spark.sql.api.java.StructType -import org.apache.spark.sql.api.java.StructField +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.StructField; // Import Row. -import org.apache.spark.sql.api.java.Row +import org.apache.spark.sql.Row; +// Import RowFactory. +import org.apache.spark.sql.RowFactory; // sc is an existing JavaSparkContext. -JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); // Load a text file and convert each line to a JavaBean. JavaRDD people = sc.textFile("examples/src/main/resources/people.txt"); @@ -391,31 +578,31 @@ String schemaString = "name age"; // Generate the schema based on the string of schema List fields = new ArrayList(); for (String fieldName: schemaString.split(" ")) { - fields.add(DataType.createStructField(fieldName, DataType.StringType, true)); + fields.add(DataTypes.createStructField(fieldName, DataTypes.StringType, true)); } -StructType schema = DataType.createStructType(fields); +StructType schema = DataTypes.createStructType(fields); // Convert records of the RDD (people) to Rows. JavaRDD rowRDD = people.map( new Function() { public Row call(String record) throws Exception { String[] fields = record.split(","); - return Row.create(fields[0], fields[1].trim()); + return RowFactory.create(fields[0], fields[1].trim()); } }); // Apply the schema to the RDD. -JavaSchemaRDD peopleSchemaRDD = sqlContext.applySchema(rowRDD, schema); +DataFrame peopleDataFrame = sqlContext.createDataFrame(rowRDD, schema); -// Register the SchemaRDD as a table. -peopleSchemaRDD.registerTempTable("people"); +// Register the DataFrame as a table. +peopleDataFrame.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. -JavaSchemaRDD results = sqlContext.sql("SELECT name FROM people"); +DataFrame results = sqlContext.sql("SELECT name FROM people"); -// The results of SQL queries are SchemaRDDs and support all the normal RDD operations. +// The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. -List names = results.map(new Function() { +List names = results.javaRDD().map(new Function() { public String call(Row row) { return "Name: " + row.getString(0); } @@ -430,17 +617,18 @@ List names = results.map(new Function() { When a dictionary of kwargs cannot be defined ahead of time (for example, the structure of records is encoded in a string, or a text dataset will be parsed and fields will be projected differently for different users), -a `SchemaRDD` can be created programmatically with three steps. +a `DataFrame` can be created programmatically with three steps. 1. Create an RDD of tuples or lists from the original RDD; 2. Create the schema represented by a `StructType` matching the structure of tuples or lists in the RDD created in the step 1. -3. Apply the schema to the RDD via `applySchema` method provided by `SQLContext`. +3. Apply the schema to the RDD via `createDataFrame` method provided by `SQLContext`. For example: {% highlight python %} # Import SQLContext and data types -from pyspark.sql import * +from pyspark.sql import SQLContext +from pyspark.sql.types import * # sc is an existing SparkContext. sqlContext = SQLContext(sc) @@ -457,12 +645,12 @@ fields = [StructField(field_name, StringType(), True) for field_name in schemaSt schema = StructType(fields) # Apply the schema to the RDD. -schemaPeople = sqlContext.applySchema(people, schema) +schemaPeople = sqlContext.createDataFrame(people, schema) -# Register the SchemaRDD as a table. +# Register the DataFrame as a table. schemaPeople.registerTempTable("people") -# SQL can be run over SchemaRDDs that have been registered as a table. +# SQL can be run over DataFrames that have been registered as a table. results = sqlContext.sql("SELECT name FROM people") # The results of SQL queries are RDDs and support all the normal RDD operations. @@ -471,11 +659,157 @@ for name in names.collect(): print name {% endhighlight %} +
    + +
    + + +# Data Sources + +Spark SQL supports operating on a variety of data sources through the `DataFrame` interface. +A DataFrame can be operated on as normal RDDs and can also be registered as a temporary table. +Registering a DataFrame as a table allows you to run SQL queries over its data. This section +describes the general methods for loading and saving data using the Spark Data Sources and then +goes into specific options that are available for the built-in data sources. + +## Generic Load/Save Functions + +In the simplest form, the default data source (`parquet` unless otherwise configured by +`spark.sql.sources.default`) will be used for all operations. + +
    +
    + +{% highlight scala %} +val df = sqlContext.load("examples/src/main/resources/users.parquet") +df.select("name", "favorite_color").save("namesAndFavColors.parquet") +{% endhighlight %} + +
    + +
    + +{% highlight java %} + +DataFrame df = sqlContext.load("examples/src/main/resources/users.parquet"); +df.select("name", "favorite_color").save("namesAndFavColors.parquet"); + +{% endhighlight %} + +
    + +
    + +{% highlight python %} + +df = sqlContext.load("examples/src/main/resources/users.parquet") +df.select("name", "favorite_color").save("namesAndFavColors.parquet") + +{% endhighlight %} + +
    +
    + +### Manually Specifying Options + +You can also manually specify the data source that will be used along with any extra options +that you would like to pass to the data source. Data sources are specified by their fully qualified +name (i.e., `org.apache.spark.sql.parquet`), but for built-in sources you can also use the shorted +name (`json`, `parquet`, `jdbc`). DataFrames of any type can be converted into other types +using this syntax. + +
    +
    + +{% highlight scala %} +val df = sqlContext.load("examples/src/main/resources/people.json", "json") +df.select("name", "age").save("namesAndAges.parquet", "parquet") +{% endhighlight %}
    +
    + +{% highlight java %} + +DataFrame df = sqlContext.load("examples/src/main/resources/people.json", "json"); +df.select("name", "age").save("namesAndAges.parquet", "parquet"); + +{% endhighlight %} +
    +
    + +{% highlight python %} + +df = sqlContext.load("examples/src/main/resources/people.json", "json") +df.select("name", "age").save("namesAndAges.parquet", "parquet") + +{% endhighlight %} + +
    +
    + +### Save Modes + +Save operations can optionally take a `SaveMode`, that specifies how to handle existing data if +present. It is important to realize that these save modes do not utilize any locking and are not +atomic. Thus, it is not safe to have multiple writers attempting to write to the same location. +Additionally, when performing a `Overwrite`, the data will be deleted before writing out the +new data. + + + + + + + + + + + + + + + + + + + + + + + +
    Scala/JavaPythonMeaning
    SaveMode.ErrorIfExists (default)"error" (default) + When saving a DataFrame to a data source, if data already exists, + an exception is expected to be thrown. +
    SaveMode.Append"append" + When saving a DataFrame to a data source, if data/table already exists, + contents of the DataFrame are expected to be appended to existing data. +
    SaveMode.Overwrite"overwrite" + Overwrite mode means that when saving a DataFrame to a data source, + if data/table already exists, existing data is expected to be overwritten by the contents of + the DataFrame. +
    SaveMode.Ignore"ignore" + Ignore mode means that when saving a DataFrame to a data source, if data already exists, + the save operation is expected to not save the contents of the DataFrame and to not + change the existing data. This is similar to a `CREATE TABLE IF NOT EXISTS` in SQL. +
    + +### Saving to Persistent Tables + +When working with a `HiveContext`, `DataFrames` can also be saved as persistent tables using the +`saveAsTable` command. Unlike the `registerTempTable` command, `saveAsTable` will materialize the +contents of the dataframe and create a pointer to the data in the HiveMetastore. Persistent tables +will still exist even after your Spark program has restarted, as long as you maintain your connection +to the same metastore. A DataFrame for a persistent table can be created by calling the `table` +method on a `SQLContext` with the name of the table. + +By default `saveAsTable` will create a "managed table", meaning that the location of the data will +be controlled by the metastore. Managed tables will also have their data deleted automatically +when a table is dropped. + ## Parquet Files [Parquet](http://parquet.io) is a columnar format that is supported by many other data processing systems. @@ -492,16 +826,16 @@ Using the data from the above example: {% highlight scala %} // sqlContext from the previous example is used in this example. -// createSchemaRDD is used to implicitly convert an RDD to a SchemaRDD. -import sqlContext.createSchemaRDD +// This is used to implicitly convert an RDD to a DataFrame. +import sqlContext.implicits._ val people: RDD[Person] = ... // An RDD of case class objects, from the previous example. -// The RDD is implicitly converted to a SchemaRDD by createSchemaRDD, allowing it to be stored using Parquet. +// The RDD is implicitly converted to a DataFrame by implicits, allowing it to be stored using Parquet. people.saveAsParquetFile("people.parquet") // Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. -// The result of loading a Parquet file is also a SchemaRDD. +// The result of loading a Parquet file is also a DataFrame. val parquetFile = sqlContext.parquetFile("people.parquet") //Parquet files can also be registered as tables and then used in SQL statements. @@ -517,19 +851,19 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println) {% highlight java %} // sqlContext from the previous example is used in this example. -JavaSchemaRDD schemaPeople = ... // The JavaSchemaRDD from the previous example. +DataFrame schemaPeople = ... // The DataFrame from the previous example. -// JavaSchemaRDDs can be saved as Parquet files, maintaining the schema information. +// DataFrames can be saved as Parquet files, maintaining the schema information. schemaPeople.saveAsParquetFile("people.parquet"); // Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. -// The result of loading a parquet file is also a JavaSchemaRDD. -JavaSchemaRDD parquetFile = sqlContext.parquetFile("people.parquet"); +// The result of loading a parquet file is also a DataFrame. +DataFrame parquetFile = sqlContext.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); -JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); -List teenagerNames = teenagers.map(new Function() { +DataFrame teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); +List teenagerNames = teenagers.javaRDD().map(new Function() { public String call(Row row) { return "Name: " + row.getString(0); } @@ -543,13 +877,13 @@ List teenagerNames = teenagers.map(new Function() { {% highlight python %} # sqlContext from the previous example is used in this example. -schemaPeople # The SchemaRDD from the previous example. +schemaPeople # The DataFrame from the previous example. -# SchemaRDDs can be saved as Parquet files, maintaining the schema information. +# DataFrames can be saved as Parquet files, maintaining the schema information. schemaPeople.saveAsParquetFile("people.parquet") # Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. -# The result of loading a parquet file is also a SchemaRDD. +# The result of loading a parquet file is also a DataFrame. parquetFile = sqlContext.parquetFile("people.parquet") # Parquet files can also be registered as tables and then used in SQL statements. @@ -562,11 +896,150 @@ for teenName in teenNames.collect():
    +
    + +{% highlight sql %} + +CREATE TEMPORARY TABLE parquetTable +USING org.apache.spark.sql.parquet +OPTIONS ( + path "examples/src/main/resources/people.parquet" +) + +SELECT * FROM parquetTable + +{% endhighlight %} + +
    + +
    + +### Partition discovery + +Table partitioning is a common optimization approach used in systems like Hive. In a partitioned +table, data are usually stored in different directories, with partitioning column values encoded in +the path of each partition directory. The Parquet data source is now able to discover and infer +partitioning information automatically. For exmaple, we can store all our previously used +population data into a partitioned table using the following directory structure, with two extra +columns, `gender` and `country` as partitioning columns: + +{% highlight text %} + +path +└── to + └── table + ├── gender=male + │   ├── ... + │   │ + │   ├── country=US + │   │   └── data.parquet + │   ├── country=CN + │   │   └── data.parquet + │   └── ... + └── gender=female +    ├── ... +    │ +    ├── country=US +    │   └── data.parquet +    ├── country=CN +    │   └── data.parquet +    └── ... + +{% endhighlight %} + +By passing `path/to/table` to either `SQLContext.parquetFile` or `SQLContext.load`, Spark SQL will +automatically extract the partitioning information from the paths. Now the schema of the returned +DataFrame becomes: + +{% highlight text %} + +root +|-- name: string (nullable = true) +|-- age: long (nullable = true) +|-- gender: string (nullable = true) +|-- country: string (nullable = true) + +{% endhighlight %} + +Notice that the data types of the partitioning columns are automatically inferred. Currently, +numeric data types and string type are supported. + +### Schema merging + +Like ProtocolBuffer, Avro, and Thrift, Parquet also supports schema evolution. Users can start with +a simple schema, and gradually add more columns to the schema as needed. In this way, users may end +up with multiple Parquet files with different but mutually compatible schemas. The Parquet data +source is now able to automatically detect this case and merge schemas of all these files. + +
    + +
    + +{% highlight scala %} +// sqlContext from the previous example is used in this example. +// This is used to implicitly convert an RDD to a DataFrame. +import sqlContext.implicits._ + +// Create a simple DataFrame, stored into a partition directory +val df1 = sparkContext.makeRDD(1 to 5).map(i => (i, i * 2)).toDF("single", "double") +df1.saveAsParquetFile("data/test_table/key=1") + +// Create another DataFrame in a new partition directory, +// adding a new column and dropping an existing column +val df2 = sparkContext.makeRDD(6 to 10).map(i => (i, i * 3)).toDF("single", "triple") +df2.saveAsParquetFile("data/test_table/key=2") + +// Read the partitioned table +val df3 = sqlContext.parquetFile("data/test_table") +df3.printSchema() + +// The final schema consists of all 3 columns in the Parquet files together +// with the partiioning column appeared in the partition directory paths. +// root +// |-- single: int (nullable = true) +// |-- double: int (nullable = true) +// |-- triple: int (nullable = true) +// |-- key : int (nullable = true) +{% endhighlight %} + +
    + +
    + +{% highlight python %} +# sqlContext from the previous example is used in this example. + +# Create a simple DataFrame, stored into a partition directory +df1 = sqlContext.createDataFrame(sc.parallelize(range(1, 6))\ + .map(lambda i: Row(single=i, double=i * 2))) +df1.save("data/test_table/key=1", "parquet") + +# Create another DataFrame in a new partition directory, +# adding a new column and dropping an existing column +df2 = sqlContext.createDataFrame(sc.parallelize(range(6, 11)) + .map(lambda i: Row(single=i, triple=i * 3))) +df2.save("data/test_table/key=2", "parquet") + +# Read the partitioned table +df3 = sqlContext.parquetFile("data/test_table") +df3.printSchema() + +# The final schema consists of all 3 columns in the Parquet files together +# with the partiioning column appeared in the partition directory paths. +# root +# |-- single: int (nullable = true) +# |-- double: int (nullable = true) +# |-- triple: int (nullable = true) +# |-- key : int (nullable = true) +{% endhighlight %} + +
    +
    ### Configuration -Configuration of Parquet can be done using the `setConf` method on SQLContext or by running +Configuration of Parquet can be done using the `setConf` method on `SQLContext` or by running `SET key=value` commands using SQL. @@ -580,6 +1053,15 @@ Configuration of Parquet can be done using the `setConf` method on SQLContext or flag tells Spark SQL to interpret binary data as a string to provide compatibility with these systems. + + + + + @@ -619,8 +1101,8 @@ Configuration of Parquet can be done using the `setConf` method on SQLContext or
    -Spark SQL can automatically infer the schema of a JSON dataset and load it as a SchemaRDD. -This conversion can be done using one of two methods in a SQLContext: +Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. +This conversion can be done using one of two methods in a `SQLContext`: * `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. * `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. @@ -636,7 +1118,7 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc) // A JSON dataset is pointed to by path. // The path can be either a single text file or a directory storing text files. val path = "examples/src/main/resources/people.json" -// Create a SchemaRDD from the file(s) pointed to by path +// Create a DataFrame from the file(s) pointed to by path val people = sqlContext.jsonFile(path) // The inferred schema can be visualized using the printSchema() method. @@ -645,13 +1127,13 @@ people.printSchema() // |-- age: integer (nullable = true) // |-- name: string (nullable = true) -// Register this SchemaRDD as a table. +// Register this DataFrame as a table. people.registerTempTable("people") // SQL statements can be run by using the sql methods provided by sqlContext. val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") -// Alternatively, a SchemaRDD can be created for a JSON dataset represented by +// Alternatively, a DataFrame can be created for a JSON dataset represented by // an RDD[String] storing one JSON object per string. val anotherPeopleRDD = sc.parallelize( """{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}""" :: Nil) @@ -661,8 +1143,8 @@ val anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD)
    -Spark SQL can automatically infer the schema of a JSON dataset and load it as a JavaSchemaRDD. -This conversion can be done using one of two methods in a JavaSQLContext : +Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. +This conversion can be done using one of two methods in a `SQLContext` : * `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. * `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. @@ -673,13 +1155,13 @@ a regular multi-line JSON file will most often fail. {% highlight java %} // sc is an existing JavaSparkContext. -JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); // A JSON dataset is pointed to by path. // The path can be either a single text file or a directory storing text files. String path = "examples/src/main/resources/people.json"; -// Create a JavaSchemaRDD from the file(s) pointed to by path -JavaSchemaRDD people = sqlContext.jsonFile(path); +// Create a DataFrame from the file(s) pointed to by path +DataFrame people = sqlContext.jsonFile(path); // The inferred schema can be visualized using the printSchema() method. people.printSchema(); @@ -687,24 +1169,24 @@ people.printSchema(); // |-- age: integer (nullable = true) // |-- name: string (nullable = true) -// Register this JavaSchemaRDD as a table. +// Register this DataFrame as a table. people.registerTempTable("people"); // SQL statements can be run by using the sql methods provided by sqlContext. -JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); +DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); -// Alternatively, a JavaSchemaRDD can be created for a JSON dataset represented by +// Alternatively, a DataFrame can be created for a JSON dataset represented by // an RDD[String] storing one JSON object per string. List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = sc.parallelize(jsonData); -JavaSchemaRDD anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD); +DataFrame anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD); {% endhighlight %}
    -Spark SQL can automatically infer the schema of a JSON dataset and load it as a SchemaRDD. -This conversion can be done using one of two methods in a SQLContext: +Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. +This conversion can be done using one of two methods in a `SQLContext`: * `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. * `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. @@ -721,7 +1203,7 @@ sqlContext = SQLContext(sc) # A JSON dataset is pointed to by path. # The path can be either a single text file or a directory storing text files. path = "examples/src/main/resources/people.json" -# Create a SchemaRDD from the file(s) pointed to by path +# Create a DataFrame from the file(s) pointed to by path people = sqlContext.jsonFile(path) # The inferred schema can be visualized using the printSchema() method. @@ -730,13 +1212,13 @@ people.printSchema() # |-- age: integer (nullable = true) # |-- name: string (nullable = true) -# Register this SchemaRDD as a table. +# Register this DataFrame as a table. people.registerTempTable("people") -# SQL statements can be run by using the sql methods provided by sqlContext. +# SQL statements can be run by using the sql methods provided by `sqlContext`. teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") -# Alternatively, a SchemaRDD can be created for a JSON dataset represented by +# Alternatively, a DataFrame can be created for a JSON dataset represented by # an RDD[String] storing one JSON object per string. anotherPeopleRDD = sc.parallelize([ '{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}']) @@ -744,6 +1226,22 @@ anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD) {% endhighlight %}
    +
    + +{% highlight sql %} + +CREATE TEMPORARY TABLE jsonTable +USING org.apache.spark.sql.json +OPTIONS ( + path "examples/src/main/resources/people.json" +) + +SELECT * FROM jsonTable + +{% endhighlight %} + +
    +
    ## Hive Tables @@ -763,7 +1261,7 @@ Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and adds support for finding tables in the MetaStore and writing queries using HiveQL. Users who do -not have an existing Hive deployment can still create a HiveContext. When not configured by the +not have an existing Hive deployment can still create a `HiveContext`. When not configured by the hive-site.xml, the context automatically creates `metastore_db` and `warehouse` in the current directory. @@ -782,14 +1280,14 @@ sqlContext.sql("FROM src SELECT key, value").collect().foreach(println)
    -When working with Hive one must construct a `JavaHiveContext`, which inherits from `JavaSQLContext`, and +When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and adds support for finding tables in the MetaStore and writing queries using HiveQL. In addition to -the `sql` method a `JavaHiveContext` also provides an `hql` methods, which allows queries to be +the `sql` method a `HiveContext` also provides an `hql` methods, which allows queries to be expressed in HiveQL. {% highlight java %} // sc is an existing JavaSparkContext. -JavaHiveContext sqlContext = new org.apache.spark.sql.hive.api.java.HiveContext(sc); +HiveContext sqlContext = new org.apache.spark.sql.hive.HiveContext(sc); sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); @@ -824,6 +1322,124 @@ results = sqlContext.sql("FROM src SELECT key, value").collect()
    +## JDBC To Other Databases + +Spark SQL also includes a data source that can read data from other databases using JDBC. This +functionality should be preferred over using [JdbcRDD](api/scala/index.html#org.apache.spark.rdd.JdbcRDD). +This is because the results are returned +as a DataFrame and they can easily be processed in Spark SQL or joined with other data sources. +The JDBC data source is also easier to use from Java or Python as it does not require the user to +provide a ClassTag. +(Note that this is different than the Spark SQL JDBC server, which allows other applications to +run queries using Spark SQL). + +To get started you will need to include the JDBC driver for you particular database on the +spark classpath. For example, to connect to postgres from the Spark Shell you would run the +following command: + +{% highlight bash %} +SPARK_CLASSPATH=postgresql-9.3-1102-jdbc41.jar bin/spark-shell +{% endhighlight %} + +Tables from the remote database can be loaded as a DataFrame or Spark SQL Temporary table using +the Data Sources API. The following options are supported: + +
    spark.sql.parquet.int96AsTimestamptrue + Some Parquet-producing systems, in particular Impala, store Timestamp into INT96. Spark would also + store Timestamp as INT96 because we need to avoid precision lost of the nanoseconds field. This + flag tells Spark SQL to interpret INT96 data as a timestamp to provide compatibility with these systems. +
    spark.sql.parquet.cacheMetadata true
    + + + + + + + + + + + + + + + + + + +
    Property NameMeaning
    url + The JDBC URL to connect to. +
    dbtable + The JDBC table that should be read. Note that anything that is valid in a `FROM` clause of + a SQL query can be used. For example, instead of a full table you could also use a + subquery in parentheses. +
    driver + The class name of the JDBC driver needed to connect to this URL. This class will be loaded + on the master and workers before running an JDBC commands to allow the driver to + register itself with the JDBC subsystem. +
    partitionColumn, lowerBound, upperBound, numPartitions + These options must all be specified if any of them is specified. They describe how to + partition the table when reading in parallel from multiple workers. + partitionColumn must be a numeric column from the table in question. Notice + that lowerBound and upperBound are just used to decide the + partition stride, not for filtering the rows in table. So all rows in the table will be + partitioned and returned. +
    + +
    + +
    + +{% highlight scala %} +val jdbcDF = sqlContext.load("jdbc", Map( + "url" -> "jdbc:postgresql:dbserver", + "dbtable" -> "schema.tablename")) +{% endhighlight %} + +
    + +
    + +{% highlight java %} + +Map options = new HashMap(); +options.put("url", "jdbc:postgresql:dbserver"); +options.put("dbtable", "schema.tablename"); + +DataFrame jdbcDF = sqlContext.load("jdbc", options) +{% endhighlight %} + + +
    + +
    + +{% highlight python %} + +df = sqlContext.load(source="jdbc", url="jdbc:postgresql:dbserver", dbtable="schema.tablename") + +{% endhighlight %} + +
    + +
    + +{% highlight sql %} + +CREATE TEMPORARY TABLE jdbcTable +USING org.apache.spark.sql.jdbc +OPTIONS ( + url "jdbc:postgresql:dbserver", + dbtable "schema.tablename" +) + +{% endhighlight %} + +
    +
    + +## Troubleshooting + + * The JDBC driver class must be visible to the primordial class loader on the client session and on all executors. This is because Java's DriverManager class does a security check that results in it ignoring all drivers not visible to the primordial class loader when one goes to open a connection. One convenient way to do this is to modify compute_classpath.sh on all worker nodes to include your driver JARs. + * Some databases, such as H2, convert all names to upper case. You'll need to use upper case to refer to those names in Spark SQL. + + # Performance Tuning For some workloads it is possible to improve performance by either caching data in memory, or by @@ -831,11 +1447,11 @@ turning on some experimental options. ## Caching Data In Memory -Spark SQL can cache tables using an in-memory columnar format by calling `sqlContext.cacheTable("tableName")` or `schemaRDD.cache()`. +Spark SQL can cache tables using an in-memory columnar format by calling `sqlContext.cacheTable("tableName")` or `dataFrame.cache()`. Then Spark SQL will scan only required columns and will automatically tune compression to minimize memory usage and GC pressure. You can call `sqlContext.uncacheTable("tableName")` to remove the table from memory. -Configuration of in-memory caching can be done using the `setConf` method on SQLContext or by running +Configuration of in-memory caching can be done using the `setConf` method on `SQLContext` or by running `SET key=value` commands using SQL. @@ -894,15 +1510,14 @@ that these options will be deprecated in future release as more optimizations ar
    -# Other SQL Interfaces +# Distributed SQL Engine -Spark SQL also supports interfaces for running SQL queries directly without the need to write any -code. +Spark SQL can also act as a distributed query engine using its JDBC/ODBC or command-line interface. In this mode, end-users or applications can interact with Spark SQL directly to run SQL queries, without the need to write any code. ## Running the Thrift JDBC/ODBC server The Thrift JDBC/ODBC server implemented here corresponds to the [`HiveServer2`](https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2) -in Hive 0.12. You can test the JDBC server with the beeline script that comes with either Spark or Hive 0.12. +in Hive 0.13. You can test the JDBC server with the beeline script that comes with either Spark or Hive 0.13. To start the JDBC/ODBC server, run the following in the Spark directory: @@ -947,10 +1562,10 @@ Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. You may also use the beeline script that comes with Hive. -Thrift JDBC server also supports sending thrift RPC messages over HTTP transport. -Use the following setting to enable HTTP mode as system property or in `hive-site.xml` file in `conf/`: +Thrift JDBC server also supports sending thrift RPC messages over HTTP transport. +Use the following setting to enable HTTP mode as system property or in `hive-site.xml` file in `conf/`: - hive.server2.transport.mode - Set this to value: http + hive.server2.transport.mode - Set this to value: http hive.server2.thrift.http.port - HTTP port number fo listen on; default is 10001 hive.server2.http.endpoint - HTTP endpoint; default is cliservice @@ -972,7 +1587,88 @@ Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. You may run `./bin/spark-sql --help` for a complete list of all available options. -# Compatibility with Other Systems +# Migration Guide + +## Upgrading from Spark SQL 1.0-1.2 to 1.3 + +In Spark 1.3 we removed the "Alpha" label from Spark SQL and as part of this did a cleanup of the +available APIs. From Spark 1.3 onwards, Spark SQL will provide binary compatibility with other +releases in the 1.X series. This compatibility guarantee excludes APIs that are explicitly marked +as unstable (i.e., DeveloperAPI or Experimental). + +#### Rename of SchemaRDD to DataFrame + +The largest change that users will notice when upgrading to Spark SQL 1.3 is that `SchemaRDD` has +been renamed to `DataFrame`. This is primarily because DataFrames no longer inherit from RDD +directly, but instead provide most of the functionality that RDDs provide though their own +implementation. DataFrames can still be converted to RDDs by calling the `.rdd` method. + +In Scala there is a type alias from `SchemaRDD` to `DataFrame` to provide source compatibility for +some use cases. It is still recommended that users update their code to use `DataFrame` instead. +Java and Python users will need to update their code. + +#### Unification of the Java and Scala APIs + +Prior to Spark 1.3 there were separate Java compatible classes (`JavaSQLContext` and `JavaSchemaRDD`) +that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users +of either language should use `SQLContext` and `DataFrame`. In general theses classes try to +use types that are usable from both languages (i.e. `Array` instead of language specific collections). +In some cases where no common type exists (e.g., for passing in closures or Maps) function overloading +is used instead. + +Additionally the Java specific types API has been removed. Users of both Scala and Java should +use the classes present in `org.apache.spark.sql.types` to describe schema programmatically. + + +#### Isolation of Implicit Conversions and Removal of dsl Package (Scala-only) + +Many of the code examples prior to Spark 1.3 started with `import sqlContext._`, which brought +all of the functions from sqlContext into scope. In Spark 1.3 we have isolated the implicit +conversions for converting `RDD`s into `DataFrame`s into an object inside of the `SQLContext`. +Users should now write `import sqlContext.implicits._`. + +Additionally, the implicit conversions now only augment RDDs that are composed of `Product`s (i.e., +case classes or tuples) with a method `toDF`, instead of applying automatically. + +When using function inside of the DSL (now replaced with the `DataFrame` API) users used to import +`org.apache.spark.sql.catalyst.dsl`. Instead the public dataframe functions API should be used: +`import org.apache.spark.sql.functions._`. + +#### Removal of the type aliases in org.apache.spark.sql for DataType (Scala-only) + +Spark 1.3 removes the type aliases that were present in the base sql package for `DataType`. Users +should instead import the classes in `org.apache.spark.sql.types` + +#### UDF Registration Moved to `sqlContext.udf` (Java & Scala) + +Functions that are used to register UDFs, either for use in the DataFrame DSL or SQL, have been +moved into the udf object in `SQLContext`. + +
    +
    +{% highlight java %} + +sqlContext.udf.register("strLen", (s: String) => s.length()) + +{% endhighlight %} +
    + +
    +{% highlight java %} + +sqlContext.udf().register("strLen", (String s) -> { s.length(); }); + +{% endhighlight %} +
    + +
    + +Python UDF registration is unchanged. + +#### Python DataTypes No Longer Singletons + +When using DataTypes in Python you will need to construct them (i.e. `StringType()`) instead of +referencing a singleton. ## Migration Guide for Shark User @@ -1092,15 +1788,11 @@ in Hive deployments. * Tables with buckets: bucket is the hash partitioning within a Hive table partition. Spark SQL doesn't support buckets yet. + **Esoteric Hive Features** -* Tables with partitions using different input formats: In Spark SQL, all table partitions need to - have the same input format. -* Non-equi outer join: For the uncommon use case of using outer joins with non-equi join conditions - (e.g. condition "`key < 10`"), Spark SQL will output wrong result for the `NULL` tuple. -* `UNION` type and `DATE` type +* `UNION` type * Unique join -* Single query multi insert * Column statistics collecting: Spark SQL does not piggyback scans to collect column statistics at the moment and only supports populating the sizeInBytes field of the hive metastore. @@ -1116,9 +1808,6 @@ less important due to Spark SQL's in-memory computational model. Others are slot releases of Spark SQL. * Block level bitmap indexes and virtual columns (used to build indexes) -* Automatically convert a join to map join: For joining a large table with multiple small tables, - Hive automatically converts the join into a map join. We are adding this auto conversion in the - next release. * Automatically determine the number of reducers for joins and groupbys: Currently in Spark SQL, you need to control the degree of parallelism post-shuffle using "`SET spark.sql.shuffle.partitions=[num_tasks];`". * Meta-data only query: For queries that can be answered by using only meta data, Spark SQL still @@ -1129,33 +1818,10 @@ releases of Spark SQL. Hive can optionally merge the small files into fewer large files to avoid overflowing the HDFS metadata. Spark SQL does not support that. -# Writing Language-Integrated Relational Queries -**Language-Integrated queries are experimental and currently only supported in Scala.** +# Data Types -Spark SQL also supports a domain specific language for writing queries. Once again, -using the data from the above examples: - -{% highlight scala %} -// sc is an existing SparkContext. -val sqlContext = new org.apache.spark.sql.SQLContext(sc) -// Importing the SQL context gives access to all the public SQL functions and implicit conversions. -import sqlContext._ -val people: RDD[Person] = ... // An RDD of case class objects, from the first example. - -// The following is the same as 'SELECT name FROM people WHERE age >= 10 AND age <= 19' -val teenagers = people.where('age >= 10).where('age <= 19).select('name) -teenagers.map(t => "Name: " + t(0)).collect().foreach(println) -{% endhighlight %} - -The DSL uses Scala symbols to represent columns in the underlying table, which are identifiers -prefixed with a tick (`'`). Implicit conversions turn these symbols into expressions that are -evaluated by the SQL execution engine. A full list of the functions supported can be found in the -[ScalaDoc](api/scala/index.html#org.apache.spark.sql.SchemaRDD). - - - -# Spark SQL DataType Reference +Spark SQL and DataFrames support the following data types: * Numeric types - `ByteType`: Represents 1-byte signed integer numbers. @@ -1198,10 +1864,10 @@ evaluated by the SQL execution engine. A full list of the functions supported c
    -All data types of Spark SQL are located in the package `org.apache.spark.sql`. +All data types of Spark SQL are located in the package `org.apache.spark.sql.types`. You can access them by doing {% highlight scala %} -import org.apache.spark.sql._ +import org.apache.spark.sql.types._ {% endhighlight %} @@ -1253,7 +1919,7 @@ import org.apache.spark.sql._ - + @@ -1447,7 +2113,7 @@ please use factory methods provided in - +
    DecimalType scala.math.BigDecimal java.math.BigDecimal DecimalType
    StructType org.apache.spark.sql.api.java.Row org.apache.spark.sql.Row DataTypes.createStructType(fields)
    Note: fields is a List or an array of StructFields. @@ -1468,10 +2134,10 @@ please use factory methods provided in
    -All data types of Spark SQL are located in the package of `pyspark.sql`. +All data types of Spark SQL are located in the package of `pyspark.sql.types`. You can access them by doing {% highlight python %} -from pyspark.sql import * +from pyspark.sql.types import * {% endhighlight %} diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index ac01dd3d8019..c8ab146bcae0 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -5,6 +5,8 @@ title: Spark Streaming + Flume Integration Guide [Apache Flume](https://flume.apache.org/) is a distributed, reliable, and available service for efficiently collecting, aggregating, and moving large amounts of log data. Here we explain how to configure Flume and Spark Streaming to receive data from Flume. There are two approaches to this. +Python API Flume is not yet available in the Python API. + ## Approach 1: Flume-style Push-based Approach Flume is designed to push data between Flume agents. In this approach, Spark Streaming essentially sets up a receiver that acts an Avro agent for Flume, to which Flume can push the data. Here are the configuration steps. @@ -64,7 +66,7 @@ configuring Flume agents. 3. **Deploying:** Package `spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). -## Approach 2 (Experimental): Pull-based Approach using a Custom Sink +## Approach 2: Pull-based Approach using a Custom Sink Instead of Flume pushing data directly to Spark Streaming, this approach runs a custom Flume sink that allows the following. - Flume pushes data into the sink, and the data stays buffered. diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 0e38fe2144e9..64714f0b799f 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -2,58 +2,155 @@ layout: global title: Spark Streaming + Kafka Integration Guide --- -[Apache Kafka](http://kafka.apache.org/) is publish-subscribe messaging rethought as a distributed, partitioned, replicated commit log service. Here we explain how to configure Spark Streaming to receive data from Kafka. +[Apache Kafka](http://kafka.apache.org/) is publish-subscribe messaging rethought as a distributed, partitioned, replicated commit log service. Here we explain how to configure Spark Streaming to receive data from Kafka. There are two approaches to this - the old approach using Receivers and Kafka's high-level API, and a new experimental approach (introduced in Spark 1.3) without using Receivers. They have different programming models, performance characteristics, and semantics guarantees, so read on for more details. -1. **Linking:** In your SBT/Maven project definition, link your streaming application against the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). +## Approach 1: Receiver-based Approach +This approach uses a Receiver to receive the data. The Received is implemented using the Kafka high-level consumer API. As with all receivers, the data received from Kafka through a Receiver is stored in Spark executors, and then jobs launched by Spark Streaming processes the data. + +However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming. To ensure zero data loss, enable the Write Ahead Logs (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs. + +Next, we discuss how to use this approach in your streaming application. + +1. **Linking:** For Scala/Java applications using SBT/Maven project definitions, link your streaming application with the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). groupId = org.apache.spark artifactId = spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}} version = {{site.SPARK_VERSION_SHORT}} -2. **Programming:** In the streaming application code, import `KafkaUtils` and create input DStream as follows. + For Python applications, you will have to add this above library and its dependencies when deploying your application. See the *Deploying* subsection below. + +2. **Programming:** In the streaming application code, import `KafkaUtils` and create an input DStream as follows.
    import org.apache.spark.streaming.kafka._ - val kafkaStream = KafkaUtils.createStream( - streamingContext, [zookeeperQuorum], [group id of the consumer], [per-topic number of Kafka partitions to consume]) + val kafkaStream = KafkaUtils.createStream(streamingContext, + [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]) - See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) + You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala).
    import org.apache.spark.streaming.kafka.*; - JavaPairReceiverInputDStream kafkaStream = KafkaUtils.createStream( - streamingContext, [zookeeperQuorum], [group id of the consumer], [per-topic number of Kafka partitions to consume]); + JavaPairReceiverInputDStream kafkaStream = + KafkaUtils.createStream(streamingContext, + [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]); + + You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java). - See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java). +
    +
    + from pyspark.streaming.kafka import KafkaUtils + + kafkaStream = KafkaUtils.createStream(streamingContext, \ + [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]) + + By default, the Python API will decode Kafka data as UTF8 encoded strings. You can specify your custom decoding function to decode the byte arrays in Kafka records to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/kafka_wordcount.py).
    - *Points to remember:* + **Points to remember:** - Topic partitions in Kafka does not correlate to partitions of RDDs generated in Spark Streaming. So increasing the number of topic-specific partitions in the `KafkaUtils.createStream()` only increases the number of threads using which topics that are consumed within a single receiver. It does not increase the parallelism of Spark in processing the data. Refer to the main document for more information on that. - Multiple Kafka input DStreams can be created with different groups and topics for parallel receiving of data using multiple receivers. -3. **Deploying:** Package `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). - -Note that the Kafka receiver used by default is an -[*unreliable* receiver](streaming-programming-guide.html#receiver-reliability) section in the -programming guide). In Spark 1.2, we have added an experimental *reliable* Kafka receiver that -provides stronger -[fault-tolerance guarantees](streaming-programming-guide.html#fault-tolerance-semantics) of zero -data loss on failures. This receiver is automatically used when the write ahead log -(also introduced in Spark 1.2) is enabled -(see [Deployment](#deploying-applications.html) section in the programming guide). This -may reduce the receiving throughput of individual Kafka receivers compared to the unreliable -receivers, but this can be corrected by running -[more receivers in parallel](streaming-programming-guide.html#level-of-parallelism-in-data-receiving) -to increase aggregate throughput. Additionally, it is recommended that the replication of the -received data within Spark be disabled when the write ahead log is enabled as the log is already stored -in a replicated storage system. This can be done by setting the storage level for the input -stream to `StorageLevel.MEMORY_AND_DISK_SER` (that is, use + - If you have enabled Write Ahead Logs with a replicated file system like HDFS, the received data is already being replicated in the log. Hence, the storage level in storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER` (that is, use `KafkaUtils.createStream(..., StorageLevel.MEMORY_AND_DISK_SER)`). + +3. **Deploying:** As with any Spark applications, `spark-submit` is used to launch your application. However, the details are slightly different for Scala/Java applications and Python applications. + + For Scala and Java applications, if you are using SBT or Maven for project management, then package `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). + + For Python applications which lack SBT/Maven project management, `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies can be directly added to `spark-submit` using `--packages` (see [Application Submission Guide](submitting-applications.html)). That is, + + ./bin/spark-submit --packages org.apache.spark:spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} ... + + Alternatively, you can also download the JAR of the Maven artifact `spark-streaming-kafka-assembly` from the + [Maven repository](http://search.maven.org/#search|ga|1|a%3A%22spark-streaming-kafka-assembly_2.10%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`. + +## Approach 2: Direct Approach (No Receivers) +This is a new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this is an experimental feature in Spark 1.3 and is only available in the Scala and Java API. + +This approach has the following advantages over the received-based approach (i.e. Approach 1). + +- *Simplified Parallelism:* No need to create multiple input Kafka streams and union-ing them. With `directStream`, Spark Streaming will create as many RDD partitions as there is Kafka partitions to consume, which will all read data from Kafka in parallel. So there is one-to-one mapping between Kafka and RDD partitions, which is easier to understand and tune. + +- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write Ahead Log. This second approach eliminate the problem as there is no receiver, and hence no need for Write Ahead Logs. + +- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper and offsets tracked only by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. + +Note that one disadvantage of this approach is that it does not update offsets in Zookeeper, hence Zookeeper-based Kafka monitoring tools will not show progress. However, you can access the offsets processed by this approach in each batch and update Zookeeper yourself (see below). + +Next, we discuss how to use this approach in your streaming application. + +1. **Linking:** This approach is supported only in Scala/Java application. Link your SBT/Maven project with the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). + + groupId = org.apache.spark + artifactId = spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}} + version = {{site.SPARK_VERSION_SHORT}} + +2. **Programming:** In the streaming application code, import `KafkaUtils` and create an input DStream as follows. + +
    +
    + import org.apache.spark.streaming.kafka._ + + val directKafkaStream = KafkaUtils.createDirectStream[ + [key class], [value class], [key decoder class], [value decoder class] ]( + streamingContext, [map of Kafka parameters], [set of topics to consume]) + + See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala). +
    +
    + import org.apache.spark.streaming.kafka.*; + + JavaPairReceiverInputDStream directKafkaStream = + KafkaUtils.createDirectStream(streamingContext, + [key class], [value class], [key decoder class], [value decoder class], + [map of Kafka parameters], [set of topics to consume]); + + See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java). + +
    +
    + + In the Kafka parameters, you must specify either `metadata.broker.list` or `bootstrap.servers`. + By default, it will start consuming from the latest offset of each Kafka partition. If you set configuration `auto.offset.reset` in Kafka parameters to `smallest`, then it will start consuming from the smallest offset. + + You can also start consuming from any arbitrary offset using other variations of `KafkaUtils.createDirectStream`. Furthermore, if you want to access the Kafka offsets consumed in each batch, you can do the following. + +
    +
    + directKafkaStream.foreachRDD { rdd => + val offsetRanges = rdd.asInstanceOf[HasOffsetRanges] + // offsetRanges.length = # of Kafka partitions being consumed + ... + } +
    +
    + directKafkaStream.foreachRDD( + new Function, Void>() { + @Override + public Void call(JavaPairRDD rdd) throws IOException { + OffsetRange[] offsetRanges = ((HasOffsetRanges)rdd).offsetRanges + // offsetRanges.length = # of Kafka partitions being consumed + ... + return null; + } + } + ); +
    +
    + + You can use this to update Zookeeper yourself if you want Zookeeper-based Kafka monitoring tools to show progress of the streaming application. + + Another thing to note is that since this approach does not use Receivers, the standard receiver-related (that is, [configurations](configuration.html) of the form `spark.streaming.receiver.*` ) will not apply to the input DStreams created by this approach (will apply to other input DStreams though). Instead, use the [configurations](configuration.html) `spark.streaming.kafka.*`. An important one is `spark.streaming.kafka.maxRatePerPartition` which is the maximum rate at which each Kafka partition will be read by this direct API. + +3. **Deploying:** Similar to the first approach, you can package `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR and the launch the application using `spark-submit`. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. \ No newline at end of file diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index e37a2bb37b9a..2f2fea53168a 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1,6 +1,8 @@ --- layout: global -title: Spark Streaming Programming Guide +displayTitle: Spark Streaming Programming Guide +title: Spark Streaming +description: Spark Streaming programming guide and tutorial for Spark SPARK_VERSION_SHORT --- * This will become a table of contents (this text will be scraped). @@ -187,15 +189,15 @@ Next, we want to count these words. {% highlight java %} // Count each word in each batch -JavaPairDStream pairs = words.map( +JavaPairDStream pairs = words.mapToPair( new PairFunction() { - @Override public Tuple2 call(String s) throws Exception { + @Override public Tuple2 call(String s) { return new Tuple2(s, 1); } }); JavaPairDStream wordCounts = pairs.reduceByKey( new Function2() { - @Override public Integer call(Integer i1, Integer i2) throws Exception { + @Override public Integer call(Integer i1, Integer i2) { return i1 + i2; } }); @@ -430,7 +432,7 @@ some of the common ones are as follows.
    For an up-to-date list, please refer to the -[Apache repository](http://search.maven.org/#search%7Cga%7C1%7Cg%3A%22org.apache.spark%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) +[Maven repository](http://search.maven.org/#search%7Cga%7C1%7Cg%3A%22org.apache.spark%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) for the full list of supported sources and artifacts. *** @@ -660,8 +662,7 @@ methods for creating DStreams from files and Akka actors as input sources. For simple text files, there is an easier method `streamingContext.textFileStream(dataDirectory)`. And file streams do not require running a receiver, hence does not require allocating cores. - Python API As of Spark 1.2, - `fileStream` is not available in the Python API, only `textFileStream` is available. + Python API `fileStream` is not available in the Python API, only `textFileStream` is available. - **Streams based on Custom Actors:** DStreams can be created with data streams received through Akka actors by using `streamingContext.actorStream(actorProps, actor-name)`. See the [Custom Receiver @@ -680,8 +681,9 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea ### Advanced Sources {:.no_toc} -Python API As of Spark 1.2, -these sources are not available in the Python API. + +Python API As of Spark 1.3, +out of these sources, *only* Kafka is available in the Python API. We will add more advanced sources in the Python API in future. This category of sources require interfacing with external non-Spark libraries, some of them with complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts @@ -702,7 +704,7 @@ create a DStream using data from Twitter's stream of tweets, you have to do the {% highlight scala %} import org.apache.spark.streaming.twitter._ -TwitterUtils.createStream(ssc) +TwitterUtils.createStream(ssc, None) {% endhighlight %}
    @@ -721,6 +723,12 @@ and it in the classpath. Some of these advanced sources are as follows. +- **Kafka:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kafka 0.8.1.1. See the [Kafka Integration Guide](streaming-kafka-integration.html) for more details. + +- **Flume:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Flume 1.4.0. See the [Flume Integration Guide](streaming-flume-integration.html) for more details. + +- **Kinesis:** See the [Kinesis Integration Guide](streaming-kinesis-integration.html) for more details. + - **Twitter:** Spark Streaming's TwitterUtils uses Twitter4j 3.0.3 to get the public stream of tweets using [Twitter's Streaming API](https://dev.twitter.com/docs/streaming-apis). Authentication information can be provided by any of the [methods](http://twitter4j.org/en/configuration.html) supported by @@ -730,17 +738,10 @@ Some of these advanced sources are as follows. ([TwitterPopularTags]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala) and [TwitterAlgebirdCMS]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala)). -- **Flume:** Spark Streaming {{site.SPARK_VERSION_SHORT}} can received data from Flume 1.4.0. See the [Flume Integration Guide](streaming-flume-integration.html) for more details. - -- **Kafka:** Spark Streaming {{site.SPARK_VERSION_SHORT}} can receive data from Kafka 0.8.0. See the [Kafka Integration Guide](streaming-kafka-integration.html) for more details. - -- **Kinesis:** See the [Kinesis Integration Guide](streaming-kinesis-integration.html) for more details. - ### Custom Sources {:.no_toc} -Python API As of Spark 1.2, -these sources are not available in the Python API. +Python API This is not yet supported in Python. Input DStreams can also be created out of custom data sources. All you have to do is implement an user-defined **receiver** (see next section to understand what that is) that can receive data from @@ -844,7 +845,7 @@ Some of the common ones are as follows.
    -The last two transformations are worth highlighting again. +A few of these transformations are worth discussing in more detail. #### UpdateStateByKey Operation {:.no_toc} @@ -876,6 +877,12 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi val runningCounts = pairs.updateStateByKey[Int](updateFunction _) {% endhighlight %} +The update function will be called for each word, with `newValues` having a sequence of 1's (from +the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete +Scala code, take a look at the example +[StatefulNetworkWordCount.scala]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache +/spark/examples/streaming/StatefulNetworkWordCount.scala). +
    @@ -897,6 +904,12 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi JavaPairDStream runningCounts = pairs.updateStateByKey(updateFunction); {% endhighlight %} +The update function will be called for each word, with `newValues` having a sequence of 1's (from +the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete +Java code, take a look at the example +[JavaStatefulNetworkWordCount.java]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming +/JavaStatefulNetworkWordCount.java). +
    @@ -914,14 +927,14 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi runningCounts = pairs.updateStateByKey(updateFunction) {% endhighlight %} -
    -
    - The update function will be called for each word, with `newValues` having a sequence of 1's (from the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete -Scala code, take a look at the example +Python code, take a look at the example [stateful_network_wordcount.py]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/stateful_network_wordcount.py). +
    +
    + Note that using `updateStateByKey` requires the checkpoint directory to be configured, which is discussed in detail in the [checkpointing](#checkpointing) section. @@ -983,7 +996,7 @@ In fact, you can also use [machine learning](mllib-guide.html) and #### Window Operations {:.no_toc} -Finally, Spark Streaming also provides *windowed computations*, which allow you to apply +Spark Streaming also provides *windowed computations*, which allow you to apply transformations over a sliding window of data. This following figure illustrates this sliding window. @@ -1027,7 +1040,7 @@ val windowedWordCounts = pairs.reduceByKeyAndWindow((a:Int,b:Int) => (a + b), Se {% highlight java %} // Reduce function adding two integers, defined separately for clarity Function2 reduceFunc = new Function2() { - @Override public Integer call(Integer i1, Integer i2) throws Exception { + @Override public Integer call(Integer i1, Integer i2) { return i1 + i2; } }; @@ -1106,6 +1119,100 @@ said two parameters - windowLength and slideInterval. +#### Join Operations +{:.no_toc} +Finally, its worth highlighting how easily you can perform different kinds of joins in Spark Streaming. + + +##### Stream-stream joins +{:.no_toc} +Streams can be very easily joined with other streams. + +
    +
    +{% highlight scala %} +val stream1: DStream[String, String] = ... +val stream2: DStream[String, String] = ... +val joinedStream = stream1.join(stream2) +{% endhighlight %} +
    +
    +{% highlight java %} +JavaPairDStream stream1 = ... +JavaPairDStream stream2 = ... +JavaPairDStream joinedStream = stream1.join(stream2); +{% endhighlight %} +
    +
    +{% highlight python %} +stream1 = ... +stream2 = ... +joinedStream = stream1.join(stream2) +{% endhighlight %} +
    +
    +Here, in each batch interval, the RDD generated by `stream1` will be joined with the RDD generated by `stream2`. You can also do `leftOuterJoin`, `rightOuterJoin`, `fullOuterJoin`. Furthermore, it is often very useful to do joins over windows of the streams. That is pretty easy as well. + +
    +
    +{% highlight scala %} +val windowedStream1 = stream1.window(Seconds(20)) +val windowedStream2 = stream2.window(Minutes(1)) +val joinedStream = windowedStream1.join(windowedStream2) +{% endhighlight %} +
    +
    +{% highlight java %} +JavaPairDStream windowedStream1 = stream1.window(Durations.seconds(20)); +JavaPairDStream windowedStream2 = stream2.window(Durations.minutes(1)); +JavaPairDStream joinedStream = windowedStream1.join(windowedStream2); +{% endhighlight %} +
    +
    +{% highlight python %} +windowedStream1 = stream1.window(20) +windowedStream2 = stream2.window(60) +joinedStream = windowedStream1.join(windowedStream2) +{% endhighlight %} +
    +
    + +##### Stream-dataset joins +{:.no_toc} +This has already been shown earlier while explain `DStream.transform` operation. Here is yet another example of joining a windowed stream with a dataset. + +
    +
    +{% highlight scala %} +val dataset: RDD[String, String] = ... +val windowedStream = stream.window(Seconds(20))... +val joinedStream = windowedStream.transform { rdd => rdd.join(dataset) } +{% endhighlight %} +
    +
    +{% highlight java %} +JavaPairRDD dataset = ... +JavaPairDStream windowedStream = stream.window(Durations.seconds(20)); +JavaPairDStream joinedStream = windowedStream.transform( + new Function>, JavaRDD>>() { + @Override + public JavaRDD> call(JavaRDD> rdd) { + return rdd.join(dataset); + } + } +); +{% endhighlight %} +
    +
    +{% highlight python %} +dataset = ... # some RDD +windowedStream = stream.window(20) +joinedStream = windowedStream.transform(lambda rdd: rdd.join(dataset)) +{% endhighlight %} +
    +
    + +In fact, you can also dynamically change the dataset you want to join against. The function provided to `transform` is evaluated every batch interval and therefore will use the current dataset that `dataset` reference points to. The complete list of DStream transformations is available in the API documentation. For the Scala API, see [DStream](api/scala/index.html#org.apache.spark.streaming.dstream.DStream) @@ -1313,6 +1420,178 @@ Note that the connections in the pool should be lazily created on demand and tim *** +## DataFrame and SQL Operations +You can easily use [DataFrames and SQL](sql-programming-guide.html) operations on streaming data. You have to create a SQLContext using the SparkContext that the StreamingContext is using. Furthermore this has to done such that it can be restarted on driver failures. This is done by creating a lazily instantiated singleton instance of SQLContext. This is shown in the following example. It modifies the earlier [word count example](#a-quick-example) to generate word counts using DataFrames and SQL. Each RDD is converted to a DataFrame, registered as a temporary table and then queried using SQL. + +
    +
    +{% highlight scala %} + +/** Lazily instantiated singleton instance of SQLContext */ +object SQLContextSingleton { + @transient private var instance: SQLContext = null + + // Instantiate SQLContext on demand + def getInstance(sparkContext: SparkContext): SQLContext = synchronized { + if (instance == null) { + instance = new SQLContext(sparkContext) + } + instance + } +} + +... + +/** Case class for converting RDD to DataFrame */ +case class Row(word: String) + +... + +/** DataFrame operations inside your streaming program */ + +val words: DStream[String] = ... + +words.foreachRDD { rdd => + + // Get the singleton instance of SQLContext + val sqlContext = SQLContextSingleton.getInstance(rdd.sparkContext) + import sqlContext.implicits._ + + // Convert RDD[String] to RDD[case class] to DataFrame + val wordsDataFrame = rdd.map(w => Row(w)).toDF() + + // Register as table + wordsDataFrame.registerTempTable("words") + + // Do word count on DataFrame using SQL and print it + val wordCountsDataFrame = + sqlContext.sql("select word, count(*) as total from words group by word") + wordCountsDataFrame.show() +} + +{% endhighlight %} + +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala). +
    +
    +{% highlight java %} + +/** Lazily instantiated singleton instance of SQLContext */ +class JavaSQLContextSingleton { + static private transient SQLContext instance = null; + static public SQLContext getInstance(SparkContext sparkContext) { + if (instance == null) { + instance = new SQLContext(sparkContext); + } + return instance; + } +} + +... + +/** Java Bean class for converting RDD to DataFrame */ +public class JavaRow implements java.io.Serializable { + private String word; + + public String getWord() { + return word; + } + + public void setWord(String word) { + this.word = word; + } +} + +... + +/** DataFrame operations inside your streaming program */ + +JavaDStream words = ... + +words.foreachRDD( + new Function2, Time, Void>() { + @Override + public Void call(JavaRDD rdd, Time time) { + SQLContext sqlContext = JavaSQLContextSingleton.getInstance(rdd.context()); + + // Convert RDD[String] to RDD[case class] to DataFrame + JavaRDD rowRDD = rdd.map(new Function() { + public JavaRow call(String word) { + JavaRow record = new JavaRow(); + record.setWord(word); + return record; + } + }); + DataFrame wordsDataFrame = sqlContext.createDataFrame(rowRDD, JavaRow.class); + + // Register as table + wordsDataFrame.registerTempTable("words"); + + // Do word count on table using SQL and print it + DataFrame wordCountsDataFrame = + sqlContext.sql("select word, count(*) as total from words group by word"); + wordCountsDataFrame.show(); + return null; + } + } +); +{% endhighlight %} + +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java). +
    +
    +{% highlight python %} + +# Lazily instantiated global instance of SQLContext +def getSqlContextInstance(sparkContext): + if ('sqlContextSingletonInstance' not in globals()): + globals()['sqlContextSingletonInstance'] = SQLContext(sparkContext) + return globals()['sqlContextSingletonInstance'] + +... + +# DataFrame operations inside your streaming program + +words = ... # DStream of strings + +def process(time, rdd): + print "========= %s =========" % str(time) + try: + # Get the singleton instance of SQLContext + sqlContext = getSqlContextInstance(rdd.context) + + # Convert RDD[String] to RDD[Row] to DataFrame + rowRdd = rdd.map(lambda w: Row(word=w)) + wordsDataFrame = sqlContext.createDataFrame(rowRdd) + + # Register as table + wordsDataFrame.registerTempTable("words") + + # Do word count on table using SQL and print it + wordCountsDataFrame = sqlContext.sql("select word, count(*) as total from words group by word") + wordCountsDataFrame.show() + except: + pass + +words.foreachRDD(process) +{% endhighlight %} + +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/sql_network_wordcount.py). + +
    +
    + +You can also run SQL queries on tables defined on streaming data from a different thread (that is, asynchronous to the running StreamingContext). Just make sure that you set the StreamingContext to remember sufficient amount of streaming data such that query can run. Otherwise the StreamingContext, which is unaware of the any asynchronous SQL queries, will delete off old streaming data before the query can complete. For example, if you want to query the last batch, but your query can take 5 minutes to run, then call `streamingContext.remember(Minutes(5))` (in Scala, or equivalent in other languages). + +See the [DataFrames and SQL](sql-programming-guide.html) guide to learn more about DataFrames. + +*** + +## MLlib Operations +You can also easily use machine learning algorithms provided by [MLlib](mllib-guide.html). First of all, there are streaming machine learning algorithms (e.g. [Streaming Linear Regression](mllib-linear-methods.html#streaming-linear-regression), [Streaming KMeans](mllib-clustering.html#streaming-k-means), etc.) which can simultaneously learn from the streaming data as well as apply the model on the streaming data. Beyond these, for a much larger class of machine learning algorithms, you can learn a learning model offline (i.e. using historical data) and then apply the model online on streaming data. See the [MLlib](mllib-guide.html) guide for more details. + +*** + ## Caching / Persistence Similar to RDDs, DStreams also allow developers to persist the stream's data in memory. That is, using `persist()` method on a DStream will automatically persist every RDD of that DStream in @@ -1566,9 +1845,8 @@ To run a Spark Streaming applications, you need to have the following. + *Mesos* - [Marathon](https://github.com/mesosphere/marathon) has been used to achieve this with Mesos. - -- *[Experimental in Spark 1.2] Configuring write ahead logs* - In Spark 1.2, - we have introduced a new experimental feature of write ahead logs for achieving strong +- *[Since Spark 1.2] Configuring write ahead logs* - Since Spark 1.2, + we have introduced _write ahead logs_ for achieving strong fault-tolerance guarantees. If enabled, all the data received from a receiver gets written into a write ahead log in the configuration checkpoint directory. This prevents data loss on driver recovery, thus ensuring zero data loss (discussed in detail in the @@ -1654,7 +1932,7 @@ improve the performance of you application. At a high level, you need to conside 2. Setting the right batch size such that the batches of data can be processed as fast as they are received (that is, data processing keeps up with the data ingestion). -## Reducing the Processing Time of each Batch +## Reducing the Batch Processing Times There are a number of optimizations that can be done in Spark to minimize the processing time of each batch. These have been discussed in detail in [Tuning Guide](tuning.html). This section highlights some of the most important ones. @@ -1726,16 +2004,15 @@ documentation), or set the `spark.default.parallelism` ### Data Serialization {:.no_toc} -The overhead of data serialization can be significant, especially when sub-second batch sizes are - to be achieved. There are two aspects to it. +The overheads of data serialization can be reduce by tuning the serialization formats. In case of streaming, there are two types of data that are being serialized. -* **Serialization of RDD data in Spark**: Please refer to the detailed discussion on data - serialization in the [Tuning Guide](tuning.html). However, note that unlike Spark, by default - RDDs are persisted as serialized byte arrays to minimize pauses related to GC. +* **Input data**: By default, the input data received through Receivers is stored in the executors' memory with [StorageLevel.MEMORY_AND_DISK_SER_2](api/scala/index.html#org.apache.spark.storage.StorageLevel$). That is, the data is serialized into bytes to reduce GC overheads, and replicated for tolerating executor failures. Also, the data is kept first in memory, and spilled over to disk only if the memory is unsufficient to hold all the input data necessary for the streaming computation. This serialization obviously has overheads -- the receiver must deserialize the received data and re-serialize it using Spark's serialization format. -* **Serialization of input data**: To ingest external data into Spark, data received as bytes - (say, from the network) needs to deserialized from bytes and re-serialized into Spark's - serialization format. Hence, the deserialization overhead of input data may be a bottleneck. +* **Persisted RDDs generated by Streaming Operations**: RDDs generated by streaming computations may be persisted in memory. For example, window operation persist data in memory as they would be processed multiple times. However, unlike Spark, by default RDDs are persisted with [StorageLevel.MEMORY_ONLY_SER](api/scala/index.html#org.apache.spark.storage.StorageLevel$) (i.e. serialized) to minimize GC overheads. + +In both cases, using Kryo serialization can reduce both CPU and memory overheads. See the [Spark Tuning Guide](tuning.html#data-serialization)) for more details. Consider registering custom classes, and disabling object reference tracking for Kryo (see Kryo-related configurations in the [Configuration Guide](configuration.html#compression-and-serialization)). + +In specific cases where the amount of data that needs to be retained for the streaming application is not large, it may be feasible to persist data (both types) as deserialized objects without incurring excessive GC overheads. For example, if you are using batch intervals of few seconds and no window operations, then you can try disabling serialization in persisted data by explicitly setting the storage level accordingly. This would reduce the CPU overheads due to serialization, potentially improving performance without too much GC overheads. ### Task Launching Overheads {:.no_toc} @@ -1755,7 +2032,7 @@ thus allowing sub-second batch size to be viable. *** -## Setting the Right Batch Size +## Setting the Right Batch Interval For a Spark Streaming application running on a cluster to be stable, the system should be able to process data as fast as it is being received. In other words, batches of data should be processed as fast as they are being generated. Whether this is true for an application can be found by @@ -1787,40 +2064,40 @@ temporary data rate increases maybe fine as long as the delay reduces back to a ## Memory Tuning Tuning the memory usage and GC behavior of Spark applications have been discussed in great detail -in the [Tuning Guide](tuning.html). It is recommended that you read that. In this section, -we highlight a few customizations that are strongly recommended to minimize GC related pauses -in Spark Streaming applications and achieving more consistent batch processing times. - -* **Default persistence level of DStreams**: Unlike RDDs, the default persistence level of DStreams -serializes the data in memory (that is, -[StorageLevel.MEMORY_ONLY_SER](api/scala/index.html#org.apache.spark.storage.StorageLevel$) for -DStream compared to -[StorageLevel.MEMORY_ONLY](api/scala/index.html#org.apache.spark.storage.StorageLevel$) for RDDs). -Even though keeping the data serialized incurs higher serialization/deserialization overheads, -it significantly reduces GC pauses. - -* **Clearing persistent RDDs**: By default, all persistent RDDs generated by Spark Streaming will - be cleared from memory based on Spark's built-in policy (LRU). If `spark.cleaner.ttl` is set, - then persistent RDDs that are older than that value are periodically cleared. As mentioned - [earlier](#operation), this needs to be careful set based on operations used in the Spark - Streaming program. However, a smarter unpersisting of RDDs can be enabled by setting the - [configuration property](configuration.html#spark-properties) `spark.streaming.unpersist` to - `true`. This makes the system to figure out which RDDs are not necessary to be kept around and - unpersists them. This is likely to reduce - the RDD memory usage of Spark, potentially improving GC behavior as well. - -* **Concurrent garbage collector**: Using the concurrent mark-and-sweep GC further -minimizes the variability of GC pauses. Even though concurrent GC is known to reduce the +in the [Tuning Guide](tuning.html#memory-tuning). It is strongly recommended that you read that. In this section, we discuss a few tuning parameters specifically in the context of Spark Streaming applications. + +The amount of cluster memory required by a Spark Streaming application depends heavily on the type of transformations used. For example, if you want to use a window operation on last 10 minutes of data, then your cluster should have sufficient memory to hold 10 minutes of worth of data in memory. Or if you want to use `updateStateByKey` with a large number of keys, then the necessary memory will be high. On the contrary, if you want to do a simple map-filter-store operation, then necessary memory will be low. + +In general, since the data received through receivers are stored with StorageLevel.MEMORY_AND_DISK_SER_2, the data that does not fit in memory will spill over to the disk. This may reduce the performance of the streaming application, and hence it is advised to provide sufficient memory as required by your streaming application. Its best to try and see the memory usage on a small scale and estimate accordingly. + +Another aspect of memory tuning is garbage collection. For a streaming application that require low latency, it is undesirable to have large pauses caused by JVM Garbage Collection. + +There are a few parameters that can help you tune the memory usage and GC overheads. + +* **Persistence Level of DStreams**: As mentioned earlier in the [Data Serialization](#data-serialization) section, the input data and RDDs are by default persisted as serialized bytes. This reduces both, the memory usage and GC overheads, compared to deserialized persistence. Enabling Kryo serialization further reduces serialized sizes and memory usage. Further reduction in memory usage can be achieved with compression (see the Spark configuration `spark.rdd.compress`), at the cost of CPU time. + +* **Clearing old data**: By default, all input data and persisted RDDs generated by DStream transformations are automatically cleared. Spark Streaming decides when to clear the data based on the transformations that are used. For example, if you are using window operation of 10 minutes, then Spark Streaming will keep around last 10 minutes of data, and actively throw away older data. +Data can be retained for longer duration (e.g. interactively querying older data) by setting `streamingContext.remember`. + +* **CMS Garbage Collector**: Use of the concurrent mark-and-sweep GC is strongly recommended for keeping GC-related pauses consistently low. Even though concurrent GC is known to reduce the overall processing throughput of the system, its use is still recommended to achieve more -consistent batch processing times. +consistent batch processing times. Make sure you set the CMS GC on both the driver (using `--driver-java-options` in `spark-submit`) and the executors (using [Spark configuration](configuration.html#runtime-environment) `spark.executor.extraJavaOptions`). + +* **Other tips**: To further reduce GC overheads, here are some more tips to try. + - Use Tachyon for off-heap storage of persisted RDDs. See more detail in the [Spark Programming Guide](programming-guide.html#rdd-persistence). + - Use more executors with smaller heap sizes. This will reduce the GC pressure within each JVM heap. + *************************************************************************************************** *************************************************************************************************** # Fault-tolerance Semantics In this section, we will discuss the behavior of Spark Streaming applications in the event -of node failures. To understand this, let us remember the basic fault-tolerance semantics of -Spark's RDDs. +of failures. + +## Background +{:.no_toc} +To understand the semantics provided by Spark Streaming, let us remember the basic fault-tolerance semantics of Spark's RDDs. 1. An RDD is an immutable, deterministically re-computable, distributed dataset. Each RDD remembers the lineage of deterministic operations that were used on a fault-tolerant input @@ -1854,13 +2131,43 @@ Furthermore, there are two kinds of failures that we should be concerned about: With this basic knowledge, let us understand the fault-tolerance semantics of Spark Streaming. -## Semantics with files as input source +## Definitions +{:.no_toc} +The semantics of streaming systems are often captured in terms of how many times each record can be processed by the system. There are three types of guarantees that a system can provide under all possible operating conditions (despite failures, etc.) + +1. *At most once*: Each record will be either processed once or not processed at all. +2. *At least once*: Each record will be processed one or more times. This is stronger than *at-most once* as it ensure that no data will be lost. But there may be duplicates. +3. *Exactly once*: Each record will be processed exactly once - no data will be lost and no data will be processed multiple times. This is obviously the strongest guarantee of the three. + +## Basic Semantics +{:.no_toc} +In any stream processing system, broadly speaking, there are three steps in processing the data. + +1. *Receiving the data*: The data is received from sources using Receivers or otherwise. + +1. *Transforming the data*: The data received data is transformed using DStream and RDD transformations. + +1. *Pushing out the data*: The final transformed data is pushed out to external systems like file systems, databases, dashboards, etc. + +If a streaming application has to achieve end-to-end exactly-once guarantees, then each step has to provide exactly-once guarantee. That is, each record must be received exactly once, transformed exactly once, and pushed to downstream systems exactly once. Let's understand the semantics of these steps in the context of Spark Streaming. + +1. *Receiving the data*: Different input sources provided different guarantees. This is discussed in detail in the next subsection. + +1. *Transforming the data*: All data that has been received will be processed _exactly once_, thanks to the guarantees that RDDs provide. Even if there are failures, as long as the received input data is accessible, the final transformed RDDs will always have the same contents. + +1. *Pushing out the data*: Output operations by default ensure _at-least once_ semantics because it depends on the type of output operation (idempotent, or not) and the semantics of the downstream system (supports transactions or not). But users can implement their own transaction mechanisms to achieve _exactly-once_ semantics. This is discussed in more details later in the section. + +## Semantics of Received Data +{:.no_toc} +Different input sources provide different guarantees, ranging from _at-least once_ to _exactly once_. Read for more details. + +### With Files {:.no_toc} If all of the input data is already present in a fault-tolerant files system like HDFS, Spark Streaming can always recover from any failure and process all the data. This gives *exactly-once* semantics, that all the data will be processed exactly once no matter what fails. -## Semantics with input sources based on receivers +### With Receiver-based Sources {:.no_toc} For input sources based on receivers, the fault-tolerance semantics depend on both the failure scenario and the type of receiver. @@ -1879,10 +2186,9 @@ receivers, data received but not replicated can get lost. If the driver node fai then besides these losses, all the past data that was received and replicated in memory will be lost. This will affect the results of the stateful transformations. -To avoid this loss of past received data, Spark 1.2 introduces an experimental feature of _write +To avoid this loss of past received data, Spark 1.2 introduced _write ahead logs_ which saves the received data to fault-tolerant storage. With the [write ahead logs -enabled](#deploying-applications) and reliable receivers, there is zero data loss and -exactly-once semantics. +enabled](#deploying-applications) and reliable receivers, there is zero data loss. In terms of semantics, it provides at-least once guarantee. The following table summarizes the semantics under failures: @@ -1894,23 +2200,30 @@ The following table summarizes the semantics under failures: - Spark 1.1 or earlier, or
    - Spark 1.2 without write ahead log + Spark 1.1 or earlier, OR
    + Spark 1.2 or later without write ahead logs Buffered data lost with unreliable receivers
    - Zero data loss with reliable receivers and files
    + Zero data loss with reliable receivers
    + At-least once semantics Buffered data lost with unreliable receivers
    Past data lost with all receivers
    - Zero data loss with files - + Undefined semantics + - Spark 1.2 with write ahead log - Zero data loss with reliable receivers and files - Zero data loss with reliable receivers and files + Spark 1.2 or later with write ahead logs + + Zero data loss with reliable receivers
    + At-least once semantics + + + Zero data loss with reliable receivers and files
    + At-least once semantics + @@ -1919,17 +2232,24 @@ The following table summarizes the semantics under failures: +### With Kafka Direct API +{:.no_toc} +In Spark 1.3, we have introduced a new Kafka Direct API, which can ensure that all the Kafka data is received by Spark Streaming exactly once. Along with this, if you implement exactly-once output operation, you can achieve end-to-end exactly-once guarantees. This approach (experimental as of Spark 1.3) is further discussed in the [Kafka Integration Guide](streaming-kafka-integration.html). + ## Semantics of output operations {:.no_toc} -Since all data is modeled as RDDs with their lineage of deterministic operations, any recomputation - always leads to the same result. As a result, all DStream transformations are guaranteed to have - _exactly-once_ semantics. That is, the final transformed result will be same even if there were - was a worker node failure. However, output operations (like `foreachRDD`) have _at-least once_ - semantics, that is, the transformed data may get written to an external entity more than once in - the event of a worker failure. While this is acceptable for saving to HDFS using the - `saveAs***Files` operations (as the file will simply get over-written by the same data), - additional transactions-like mechanisms may be necessary to achieve exactly-once semantics - for output operations. +Output operations (like `foreachRDD`) have _at-least once_ semantics, that is, +the transformed data may get written to an external entity more than once in +the event of a worker failure. While this is acceptable for saving to file systems using the +`saveAs***Files` operations (as the file will simply get overwritten with the same data), +additional effort may be necessary to achieve exactly-once semantics. There are two approaches. + +- *Idempotent updates*: Multiple attempts always write the same data. For example, `saveAs***Files` always writes the same data to the generated files. + +- *Transactional updates*: All updates are made transactionally so that updates are made exactly once atomically. One way to do this would be the following. + + - Use the batch time (available in `foreachRDD`) and the partition index of the transformed RDD to create an identifier. This identifier uniquely identifies a blob data in the streaming application. + - Update external system with this blob transactionally (that is, exactly once, atomically) using the identifier. That is, if the identifier is not already committed, commit the partition data and the identifier atomically. Else if this was already committed, skip the update. *************************************************************************************************** @@ -1987,7 +2307,11 @@ package and renamed for better clarity. *************************************************************************************************** # Where to Go from Here - +* Additional guides + - [Kafka Integration Guide](streaming-kafka-integration.html) + - [Flume Integration Guide](streaming-flume-integration.html) + - [Kinesis Integration Guide](streaming-kinesis-integration.html) + - [Custom Receiver Guide](streaming-custom-receivers.html) * API documentation - Scala docs * [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) and @@ -2009,8 +2333,8 @@ package and renamed for better clarity. [ZeroMQUtils](api/java/index.html?org/apache/spark/streaming/zeromq/ZeroMQUtils.html), and [MQTTUtils](api/java/index.html?org/apache/spark/streaming/mqtt/MQTTUtils.html) - Python docs - * [StreamingContext](api/python/pyspark.streaming.html#pyspark.streaming.StreamingContext) - * [DStream](api/python/pyspark.streaming.html#pyspark.streaming.DStream) + * [StreamingContext](api/python/pyspark.streaming.html#pyspark.streaming.StreamingContext) and [DStream](api/python/pyspark.streaming.html#pyspark.streaming.DStream) + * [KafkaUtils](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils) * More examples in [Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming) and [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming) diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 3bd1deaccfaf..3ecbf2308cd4 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -58,8 +58,8 @@ for applications that involve the REPL (e.g. Spark shell). Alternatively, if your application is submitted from a machine far from the worker machines (e.g. locally on your laptop), it is common to use `cluster` mode to minimize network latency between -the drivers and the executors. Note that `cluster` mode is currently not supported for standalone -clusters, Mesos clusters, or Python applications. +the drivers and the executors. Note that `cluster` mode is currently not supported for +Mesos clusters or Python applications. For Python applications, simply pass a `.py` file in the place of `` instead of a JAR, and add Python `.zip`, `.egg` or `.py` files to the search path with `--py-files`. @@ -133,10 +133,10 @@ The master URL passed to Spark can be in one of the following formats: Or, for a Mesos cluster using ZooKeeper, use mesos://zk://.... yarn-client Connect to a YARN cluster in -client mode. The cluster location will be found based on the HADOOP_CONF_DIR variable. +client mode. The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable. yarn-cluster Connect to a YARN cluster in -cluster mode. The cluster location will be found based on HADOOP_CONF_DIR. +cluster mode. The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable. @@ -174,6 +174,11 @@ This can use up a significant amount of space over time and will need to be clea is handled automatically, and with Spark standalone, automatic cleanup can be configured with the `spark.worker.cleanup.appDataTtl` property. +Users may also include any other dependencies by supplying a comma-delimited list of maven coordinates +with `--packages`. All transitive dependencies will be handled when using this command. Additional +repositories (or resolvers in SBT) can be added in a comma-delimited fashion with the flag `--repositories`. +These commands can be used with `pyspark`, `spark-shell`, and `spark-submit` to include Spark Packages. + For Python, the equivalent `--py-files` option can be used to distribute `.egg`, `.zip` and `.py` libraries to executors. diff --git a/docs/tuning.md b/docs/tuning.md index efaac9d3d405..cbd227868b24 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -1,6 +1,8 @@ --- layout: global -title: Tuning Spark +displayTitle: Tuning Spark +title: Tuning +description: Tuning and performance optimization guide for Spark SPARK_VERSION_SHORT --- * This will become a table of contents (this text will be scraped). diff --git a/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh b/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh index 740c267fd986..4f3e8da809f7 100644 --- a/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh +++ b/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh @@ -25,10 +25,10 @@ export MAPRED_LOCAL_DIRS="{{mapred_local_dirs}}" export SPARK_LOCAL_DIRS="{{spark_local_dirs}}" export MODULES="{{modules}}" export SPARK_VERSION="{{spark_version}}" -export SHARK_VERSION="{{shark_version}}" +export TACHYON_VERSION="{{tachyon_version}}" export HADOOP_MAJOR_VERSION="{{hadoop_major_version}}" export SWAP_MB="{{swap}}" export SPARK_WORKER_INSTANCES="{{spark_worker_instances}}" export SPARK_MASTER_OPTS="{{spark_master_opts}}" export AWS_ACCESS_KEY_ID="{{aws_access_key_id}}" -export AWS_SECRET_ACCESS_KEY="{{aws_secret_access_key}}" \ No newline at end of file +export AWS_SECRET_ACCESS_KEY="{{aws_secret_access_key}}" diff --git a/ec2/spark-ec2 b/ec2/spark-ec2 index 3abd3f396f60..26e7d2265569 100755 --- a/ec2/spark-ec2 +++ b/ec2/spark-ec2 @@ -20,6 +20,6 @@ # Preserve the user's CWD so that relative paths are passed correctly to #+ the underlying Python script. -SPARK_EC2_DIR="$(dirname $0)" +SPARK_EC2_DIR="$(dirname "$0")" python -Wdefault "${SPARK_EC2_DIR}/spark_ec2.py" "$@" diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index abab209a05ba..87c081827971 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -19,26 +19,38 @@ # limitations under the License. # -from __future__ import with_statement +from __future__ import with_statement, print_function import hashlib +import itertools import logging import os +import os.path import pipes import random import shutil import string +from stat import S_IRUSR import subprocess import sys import tarfile import tempfile +import textwrap import time -import urllib2 import warnings from datetime import datetime from optparse import OptionParser from sys import stderr +if sys.version < "3": + from urllib2 import urlopen, Request, HTTPError +else: + from urllib.request import urlopen, Request + from urllib.error import HTTPError + +SPARK_EC2_VERSION = "1.2.1" +SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) + VALID_SPARK_VERSIONS = set([ "0.7.3", "0.8.0", @@ -52,45 +64,81 @@ "1.1.0", "1.1.1", "1.2.0", + "1.2.1", ]) -DEFAULT_SPARK_VERSION = "1.2.0" +SPARK_TACHYON_MAP = { + "1.0.0": "0.4.1", + "1.0.1": "0.4.1", + "1.0.2": "0.4.1", + "1.1.0": "0.5.0", + "1.1.1": "0.5.0", + "1.2.0": "0.5.0", + "1.2.1": "0.5.0", +} + +DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION DEFAULT_SPARK_GITHUB_REPO = "https://github.com/apache/spark" -SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) -MESOS_SPARK_EC2_BRANCH = "branch-1.3" - -# A URL prefix from which to fetch AMI information -AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/{b}/ami-list".format(b=MESOS_SPARK_EC2_BRANCH) - - -def setup_boto(): - # Download Boto if it's not already present in the SPARK_EC2_DIR/lib folder: - version = "boto-2.34.0" - md5 = "5556223d2d0cc4d06dd4829e671dcecd" - url = "https://pypi.python.org/packages/source/b/boto/%s.tar.gz" % version - lib_dir = os.path.join(SPARK_EC2_DIR, "lib") - if not os.path.exists(lib_dir): - os.mkdir(lib_dir) - boto_lib_dir = os.path.join(lib_dir, version) - if not os.path.isdir(boto_lib_dir): - tgz_file_path = os.path.join(lib_dir, "%s.tar.gz" % version) - print "Downloading Boto from PyPi" - download_stream = urllib2.urlopen(url) - with open(tgz_file_path, "wb") as tgz_file: - tgz_file.write(download_stream.read()) - with open(tgz_file_path) as tar: - if hashlib.md5(tar.read()).hexdigest() != md5: - print >> stderr, "ERROR: Got wrong md5sum for Boto" - sys.exit(1) - tar = tarfile.open(tgz_file_path) - tar.extractall(path=lib_dir) - tar.close() - os.remove(tgz_file_path) - print "Finished downloading Boto" - sys.path.insert(0, boto_lib_dir) +# Default location to get the spark-ec2 scripts (and ami-list) from +DEFAULT_SPARK_EC2_GITHUB_REPO = "https://github.com/mesos/spark-ec2" +DEFAULT_SPARK_EC2_BRANCH = "branch-1.3" + + +def setup_external_libs(libs): + """ + Download external libraries from PyPI to SPARK_EC2_DIR/lib/ and prepend them to our PATH. + """ + PYPI_URL_PREFIX = "https://pypi.python.org/packages/source" + SPARK_EC2_LIB_DIR = os.path.join(SPARK_EC2_DIR, "lib") + + if not os.path.exists(SPARK_EC2_LIB_DIR): + print("Downloading external libraries that spark-ec2 needs from PyPI to {path}...".format( + path=SPARK_EC2_LIB_DIR + )) + print("This should be a one-time operation.") + os.mkdir(SPARK_EC2_LIB_DIR) + + for lib in libs: + versioned_lib_name = "{n}-{v}".format(n=lib["name"], v=lib["version"]) + lib_dir = os.path.join(SPARK_EC2_LIB_DIR, versioned_lib_name) + + if not os.path.isdir(lib_dir): + tgz_file_path = os.path.join(SPARK_EC2_LIB_DIR, versioned_lib_name + ".tar.gz") + print(" - Downloading {lib}...".format(lib=lib["name"])) + download_stream = urlopen( + "{prefix}/{first_letter}/{lib_name}/{lib_name}-{lib_version}.tar.gz".format( + prefix=PYPI_URL_PREFIX, + first_letter=lib["name"][:1], + lib_name=lib["name"], + lib_version=lib["version"] + ) + ) + with open(tgz_file_path, "wb") as tgz_file: + tgz_file.write(download_stream.read()) + with open(tgz_file_path) as tar: + if hashlib.md5(tar.read()).hexdigest() != lib["md5"]: + print("ERROR: Got wrong md5sum for {lib}.".format(lib=lib["name"]), file=stderr) + sys.exit(1) + tar = tarfile.open(tgz_file_path) + tar.extractall(path=SPARK_EC2_LIB_DIR) + tar.close() + os.remove(tgz_file_path) + print(" - Finished downloading {lib}.".format(lib=lib["name"])) + sys.path.insert(1, lib_dir) + + +# Only PyPI libraries are supported. +external_libs = [ + { + "name": "boto", + "version": "2.34.0", + "md5": "5556223d2d0cc4d06dd4829e671dcecd" + } +] + +setup_external_libs(external_libs) -setup_boto() import boto from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType, EBSBlockDeviceType from boto import ec2 @@ -103,12 +151,11 @@ class UsageError(Exception): # Configure and parse our command-line arguments def parse_args(): parser = OptionParser( - usage="spark-ec2 [options] " - + "\n\n can be: launch, destroy, login, stop, start, get-master, reboot-slaves", - add_help_option=False) - parser.add_option( - "-h", "--help", action="help", - help="Show this help message and exit") + prog="spark-ec2", + version="%prog {v}".format(v=SPARK_EC2_VERSION), + usage="%prog [options] \n\n" + + " can be: launch, destroy, login, stop, start, get-master, reboot-slaves") + parser.add_option( "-s", "--slaves", type="int", default=1, help="Number of slaves to launch (default: %default)") @@ -130,13 +177,15 @@ def parse_args(): help="Master instance type (leave empty for same as instance-type)") parser.add_option( "-r", "--region", default="us-east-1", - help="EC2 region zone to launch instances in") + help="EC2 region used to launch instances in, or to find them in (default: %default)") parser.add_option( "-z", "--zone", default="", help="Availability zone to launch instances in, or 'all' to spread " + "slaves across multiple (an additional $0.01/Gb for bandwidth" + "between zones applies) (default: a single zone chosen at random)") - parser.add_option("-a", "--ami", help="Amazon Machine Image ID to use") + parser.add_option( + "-a", "--ami", + help="Amazon Machine Image ID to use") parser.add_option( "-v", "--spark-version", default=DEFAULT_SPARK_VERSION, help="Version of Spark to use: 'X.Y.Z' or a specific git hash (default: %default)") @@ -144,6 +193,23 @@ def parse_args(): "--spark-git-repo", default=DEFAULT_SPARK_GITHUB_REPO, help="Github repo from which to checkout supplied commit hash (default: %default)") + parser.add_option( + "--spark-ec2-git-repo", + default=DEFAULT_SPARK_EC2_GITHUB_REPO, + help="Github repo from which to checkout spark-ec2 (default: %default)") + parser.add_option( + "--spark-ec2-git-branch", + default=DEFAULT_SPARK_EC2_BRANCH, + help="Github repo branch of spark-ec2 to use (default: %default)") + parser.add_option( + "--deploy-root-dir", + default=None, + help="A directory to copy into / on the first master. " + + "Must be absolute. Note that a trailing slash is handled as per rsync: " + + "If you omit it, the last directory of the --deploy-root-dir path will be created " + + "in / before copying its contents. If you append the trailing slash, " + + "the directory is not created and its contents are copied directly into /. " + + "(default: %default).") parser.add_option( "--hadoop-major-version", default="1", help="Major version of Hadoop (default: %default)") @@ -168,10 +234,11 @@ def parse_args(): "Only possible on EBS-backed AMIs. " + "EBS volumes are only attached if --ebs-vol-size > 0." + "Only support up to 8 EBS volumes.") - parser.add_option("--placement-group", type="string", default=None, - help="Which placement group to try and launch " + - "instances into. Assumes placement group is already " + - "created.") + parser.add_option( + "--placement-group", type="string", default=None, + help="Which placement group to try and launch " + + "instances into. Assumes placement group is already " + + "created.") parser.add_option( "--swap", metavar="SWAP", type="int", default=1024, help="Swap space to set up per node, in MB (default: %default)") @@ -204,7 +271,7 @@ def parse_args(): "(e.g -Dspark.worker.timeout=180)") parser.add_option( "--user-data", type="string", default="", - help="Path to a user-data file (most AMI's interpret this as an initialization script)") + help="Path to a user-data file (most AMIs interpret this as an initialization script)") parser.add_option( "--authorized-address", type="string", default="0.0.0.0/0", help="Address to authorize on created security groups (default: %default)") @@ -215,9 +282,15 @@ def parse_args(): "--copy-aws-credentials", action="store_true", default=False, help="Add AWS credentials to hadoop configuration to allow Spark to access S3") parser.add_option( - "--subnet-id", default=None, help="VPC subnet to launch instances in") + "--subnet-id", default=None, + help="VPC subnet to launch instances in") parser.add_option( - "--vpc-id", default=None, help="VPC to launch instances in") + "--vpc-id", default=None, + help="VPC to launch instances in") + parser.add_option( + "--private-ips", action="store_true", default=False, + help="Use private IPs for instances rather than public if VPC/subnet " + + "requires that.") (opts, args) = parser.parse_args() if len(args) != 2: @@ -231,12 +304,12 @@ def parse_args(): if home_dir is None or not os.path.isfile(home_dir + '/.boto'): if not os.path.isfile('/etc/boto.cfg'): if os.getenv('AWS_ACCESS_KEY_ID') is None: - print >> stderr, ("ERROR: The environment variable AWS_ACCESS_KEY_ID " + - "must be set") + print("ERROR: The environment variable AWS_ACCESS_KEY_ID must be set", + file=stderr) sys.exit(1) if os.getenv('AWS_SECRET_ACCESS_KEY') is None: - print >> stderr, ("ERROR: The environment variable AWS_SECRET_ACCESS_KEY " + - "must be set") + print("ERROR: The environment variable AWS_SECRET_ACCESS_KEY must be set", + file=stderr) sys.exit(1) return (opts, action, cluster_name) @@ -248,7 +321,7 @@ def get_or_make_group(conn, name, vpc_id): if len(group) > 0: return group[0] else: - print "Creating security group " + name + print("Creating security group " + name) return conn.create_security_group(name, "Spark EC2 group", vpc_id) @@ -256,86 +329,90 @@ def get_validate_spark_version(version, repo): if "." in version: version = version.replace("v", "") if version not in VALID_SPARK_VERSIONS: - print >> stderr, "Don't know about Spark version: {v}".format(v=version) + print("Don't know about Spark version: {v}".format(v=version), file=stderr) sys.exit(1) return version else: github_commit_url = "{repo}/commit/{commit_hash}".format(repo=repo, commit_hash=version) - request = urllib2.Request(github_commit_url) + request = Request(github_commit_url) request.get_method = lambda: 'HEAD' try: - response = urllib2.urlopen(request) - except urllib2.HTTPError, e: - print >> stderr, "Couldn't validate Spark commit: {url}".format(url=github_commit_url) - print >> stderr, "Received HTTP response code of {code}.".format(code=e.code) + response = urlopen(request) + except HTTPError as e: + print("Couldn't validate Spark commit: {url}".format(url=github_commit_url), + file=stderr) + print("Received HTTP response code of {code}.".format(code=e.code), file=stderr) sys.exit(1) return version -# Check whether a given EC2 instance object is in a state we consider active, -# i.e. not terminating or terminated. We count both stopping and stopped as -# active since we can restart stopped clusters. -def is_active(instance): - return (instance.state in ['pending', 'running', 'stopping', 'stopped']) - - -# Attempt to resolve an appropriate AMI given the architecture and region of the request. # Source: http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/ # Last Updated: 2014-06-20 # For easy maintainability, please keep this manually-inputted dictionary sorted by key. +EC2_INSTANCE_TYPES = { + "c1.medium": "pvm", + "c1.xlarge": "pvm", + "c3.2xlarge": "pvm", + "c3.4xlarge": "pvm", + "c3.8xlarge": "pvm", + "c3.large": "pvm", + "c3.xlarge": "pvm", + "cc1.4xlarge": "hvm", + "cc2.8xlarge": "hvm", + "cg1.4xlarge": "hvm", + "cr1.8xlarge": "hvm", + "hi1.4xlarge": "pvm", + "hs1.8xlarge": "pvm", + "i2.2xlarge": "hvm", + "i2.4xlarge": "hvm", + "i2.8xlarge": "hvm", + "i2.xlarge": "hvm", + "m1.large": "pvm", + "m1.medium": "pvm", + "m1.small": "pvm", + "m1.xlarge": "pvm", + "m2.2xlarge": "pvm", + "m2.4xlarge": "pvm", + "m2.xlarge": "pvm", + "m3.2xlarge": "hvm", + "m3.large": "hvm", + "m3.medium": "hvm", + "m3.xlarge": "hvm", + "r3.2xlarge": "hvm", + "r3.4xlarge": "hvm", + "r3.8xlarge": "hvm", + "r3.large": "hvm", + "r3.xlarge": "hvm", + "t1.micro": "pvm", + "t2.medium": "hvm", + "t2.micro": "hvm", + "t2.small": "hvm", +} + + +def get_tachyon_version(spark_version): + return SPARK_TACHYON_MAP.get(spark_version, "") + + +# Attempt to resolve an appropriate AMI given the architecture and region of the request. def get_spark_ami(opts): - instance_types = { - "c1.medium": "pvm", - "c1.xlarge": "pvm", - "c3.2xlarge": "pvm", - "c3.4xlarge": "pvm", - "c3.8xlarge": "pvm", - "c3.large": "pvm", - "c3.xlarge": "pvm", - "cc1.4xlarge": "hvm", - "cc2.8xlarge": "hvm", - "cg1.4xlarge": "hvm", - "cr1.8xlarge": "hvm", - "hi1.4xlarge": "pvm", - "hs1.8xlarge": "pvm", - "i2.2xlarge": "hvm", - "i2.4xlarge": "hvm", - "i2.8xlarge": "hvm", - "i2.xlarge": "hvm", - "m1.large": "pvm", - "m1.medium": "pvm", - "m1.small": "pvm", - "m1.xlarge": "pvm", - "m2.2xlarge": "pvm", - "m2.4xlarge": "pvm", - "m2.xlarge": "pvm", - "m3.2xlarge": "hvm", - "m3.large": "hvm", - "m3.medium": "hvm", - "m3.xlarge": "hvm", - "r3.2xlarge": "hvm", - "r3.4xlarge": "hvm", - "r3.8xlarge": "hvm", - "r3.large": "hvm", - "r3.xlarge": "hvm", - "t1.micro": "pvm", - "t2.medium": "hvm", - "t2.micro": "hvm", - "t2.small": "hvm", - } - if opts.instance_type in instance_types: - instance_type = instance_types[opts.instance_type] + if opts.instance_type in EC2_INSTANCE_TYPES: + instance_type = EC2_INSTANCE_TYPES[opts.instance_type] else: instance_type = "pvm" - print >> stderr,\ - "Don't recognize %s, assuming type is pvm" % opts.instance_type + print("Don't recognize %s, assuming type is pvm" % opts.instance_type, file=stderr) - ami_path = "%s/%s/%s" % (AMI_PREFIX, opts.region, instance_type) + # URL prefix from which to fetch AMI information + ami_prefix = "{r}/{b}/ami-list".format( + r=opts.spark_ec2_git_repo.replace("https://github.com", "https://raw.github.com", 1), + b=opts.spark_ec2_git_branch) + + ami_path = "%s/%s/%s" % (ami_prefix, opts.region, instance_type) try: - ami = urllib2.urlopen(ami_path).read().strip() - print "Spark AMI: " + ami + ami = urlopen(ami_path).read().strip() + print("Spark AMI: " + ami) except: - print >> stderr, "Could not resolve AMI at: " + ami_path + print("Could not resolve AMI at: " + ami_path, file=stderr) sys.exit(1) return ami @@ -347,10 +424,11 @@ def get_spark_ami(opts): # Fails if there already instances running in the cluster's groups. def launch_cluster(conn, opts, cluster_name): if opts.identity_file is None: - print >> stderr, "ERROR: Must provide an identity file (-i) for ssh connections." + print("ERROR: Must provide an identity file (-i) for ssh connections.", file=stderr) sys.exit(1) + if opts.key_pair is None: - print >> stderr, "ERROR: Must provide a key pair name (-k) to use on instances." + print("ERROR: Must provide a key pair name (-k) to use on instances.", file=stderr) sys.exit(1) user_data_content = None @@ -358,7 +436,7 @@ def launch_cluster(conn, opts, cluster_name): with open(opts.user_data) as user_data_file: user_data_content = user_data_file.read() - print "Setting up security groups..." + print("Setting up security groups...") master_group = get_or_make_group(conn, cluster_name + "-master", opts.vpc_id) slave_group = get_or_make_group(conn, cluster_name + "-slaves", opts.vpc_id) authorized_address = opts.authorized_address @@ -387,6 +465,13 @@ def launch_cluster(conn, opts, cluster_name): master_group.authorize('tcp', 50070, 50070, authorized_address) master_group.authorize('tcp', 60070, 60070, authorized_address) master_group.authorize('tcp', 4040, 4045, authorized_address) + # HDFS NFS gateway requires 111,2049,4242 for tcp & udp + master_group.authorize('tcp', 111, 111, authorized_address) + master_group.authorize('udp', 111, 111, authorized_address) + master_group.authorize('tcp', 2049, 2049, authorized_address) + master_group.authorize('udp', 2049, 2049, authorized_address) + master_group.authorize('tcp', 4242, 4242, authorized_address) + master_group.authorize('udp', 4242, 4242, authorized_address) if opts.ganglia: master_group.authorize('tcp', 5080, 5080, authorized_address) if slave_group.rules == []: # Group was just now created @@ -417,8 +502,8 @@ def launch_cluster(conn, opts, cluster_name): existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name, die_on_error=False) if existing_slaves or (existing_masters and not opts.use_existing_master): - print >> stderr, ("ERROR: There are already instances running in " + - "group %s or %s" % (master_group.name, slave_group.name)) + print("ERROR: There are already instances running in group %s or %s" % + (master_group.name, slave_group.name), file=stderr) sys.exit(1) # Figure out Spark AMI @@ -431,12 +516,12 @@ def launch_cluster(conn, opts, cluster_name): additional_group_ids = [sg.id for sg in conn.get_all_security_groups() if opts.additional_security_group in (sg.name, sg.id)] - print "Launching instances..." + print("Launching instances...") try: image = conn.get_all_images(image_ids=[opts.ami])[0] except: - print >> stderr, "Could not find AMI " + opts.ami + print("Could not find AMI " + opts.ami, file=stderr) sys.exit(1) # Create block device mapping so that we can add EBS volumes if asked to. @@ -462,8 +547,8 @@ def launch_cluster(conn, opts, cluster_name): # Launch slaves if opts.spot_price is not None: # Launch spot instances with the requested price - print ("Requesting %d slaves as spot instances with price $%.3f" % - (opts.slaves, opts.spot_price)) + print("Requesting %d slaves as spot instances with price $%.3f" % + (opts.slaves, opts.spot_price)) zones = get_zones(conn, opts) num_zones = len(zones) i = 0 @@ -486,7 +571,7 @@ def launch_cluster(conn, opts, cluster_name): my_req_ids += [req.id for req in slave_reqs] i += 1 - print "Waiting for spot instances to be granted..." + print("Waiting for spot instances to be granted...") try: while True: time.sleep(10) @@ -499,24 +584,24 @@ def launch_cluster(conn, opts, cluster_name): if i in id_to_req and id_to_req[i].state == "active": active_instance_ids.append(id_to_req[i].instance_id) if len(active_instance_ids) == opts.slaves: - print "All %d slaves granted" % opts.slaves + print("All %d slaves granted" % opts.slaves) reservations = conn.get_all_reservations(active_instance_ids) slave_nodes = [] for r in reservations: slave_nodes += r.instances break else: - print "%d of %d slaves granted, waiting longer" % ( - len(active_instance_ids), opts.slaves) + print("%d of %d slaves granted, waiting longer" % ( + len(active_instance_ids), opts.slaves)) except: - print "Canceling spot instance requests" + print("Canceling spot instance requests") conn.cancel_spot_instance_requests(my_req_ids) # Log a warning if any of these requests actually launched instances: (master_nodes, slave_nodes) = get_existing_cluster( conn, opts, cluster_name, die_on_error=False) running = len(master_nodes) + len(slave_nodes) if running: - print >> stderr, ("WARNING: %d instances are still running" % running) + print(("WARNING: %d instances are still running" % running), file=stderr) sys.exit(0) else: # Launch non-spot instances @@ -538,13 +623,16 @@ def launch_cluster(conn, opts, cluster_name): placement_group=opts.placement_group, user_data=user_data_content) slave_nodes += slave_res.instances - print "Launched %d slaves in %s, regid = %s" % (num_slaves_this_zone, - zone, slave_res.id) + print("Launched {s} slave{plural_s} in {z}, regid = {r}".format( + s=num_slaves_this_zone, + plural_s=('' if num_slaves_this_zone == 1 else 's'), + z=zone, + r=slave_res.id)) i += 1 # Launch or resume masters if existing_masters: - print "Starting master..." + print("Starting master...") for inst in existing_masters: if inst.state not in ["shutting-down", "terminated"]: inst.start() @@ -567,8 +655,11 @@ def launch_cluster(conn, opts, cluster_name): user_data=user_data_content) master_nodes = master_res.instances - print "Launched master in %s, regid = %s" % (zone, master_res.id) + print("Launched master in %s, regid = %s" % (zone, master_res.id)) + # This wait time corresponds to SPARK-4983 + print("Waiting for AWS to propagate instance metadata...") + time.sleep(5) # Give the instances descriptive names for master in master_nodes: master.add_tag( @@ -583,43 +674,50 @@ def launch_cluster(conn, opts, cluster_name): return (master_nodes, slave_nodes) -# Get the EC2 instances in an existing cluster if available. -# Returns a tuple of lists of EC2 instance objects for the masters and slaves +def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): + """ + Get the EC2 instances in an existing cluster if available. + Returns a tuple of lists of EC2 instance objects for the masters and slaves. + """ + print("Searching for existing cluster {c} in region {r}...".format( + c=cluster_name, r=opts.region)) + def get_instances(group_names): + """ + Get all non-terminated instances that belong to any of the provided security groups. -def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): - print "Searching for existing cluster " + cluster_name + "..." - reservations = conn.get_all_reservations() - master_nodes = [] - slave_nodes = [] - for res in reservations: - active = [i for i in res.instances if is_active(i)] - for inst in active: - group_names = [g.name for g in inst.groups] - if (cluster_name + "-master") in group_names: - master_nodes.append(inst) - elif (cluster_name + "-slaves") in group_names: - slave_nodes.append(inst) - if any((master_nodes, slave_nodes)): - print "Found %d master(s), %d slaves" % (len(master_nodes), len(slave_nodes)) - if master_nodes != [] or not die_on_error: - return (master_nodes, slave_nodes) - else: - if master_nodes == [] and slave_nodes != []: - print >> sys.stderr, "ERROR: Could not find master in group " + cluster_name + "-master" - else: - print >> sys.stderr, "ERROR: Could not find any existing cluster" + EC2 reservation filters and instance states are documented here: + http://docs.aws.amazon.com/cli/latest/reference/ec2/describe-instances.html#options + """ + reservations = conn.get_all_reservations( + filters={"instance.group-name": group_names}) + instances = itertools.chain.from_iterable(r.instances for r in reservations) + return [i for i in instances if i.state not in ["shutting-down", "terminated"]] + + master_instances = get_instances([cluster_name + "-master"]) + slave_instances = get_instances([cluster_name + "-slaves"]) + + if any((master_instances, slave_instances)): + print("Found {m} master{plural_m}, {s} slave{plural_s}.".format( + m=len(master_instances), + plural_m=('' if len(master_instances) == 1 else 's'), + s=len(slave_instances), + plural_s=('' if len(slave_instances) == 1 else 's'))) + + if not master_instances and die_on_error: + print("ERROR: Could not find a master for cluster {c} in region {r}.".format( + c=cluster_name, r=opts.region), file=sys.stderr) sys.exit(1) + return (master_instances, slave_instances) + # Deploy configuration files and run setup scripts on a newly launched # or started EC2 cluster. - - def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): - master = master_nodes[0].public_dns_name + master = get_dns_name(master_nodes[0], opts.private_ips) if deploy_ssh_key: - print "Generating cluster's SSH key on master..." + print("Generating cluster's SSH key on master...") key_setup = """ [ -f ~/.ssh/id_rsa ] || (ssh-keygen -q -t rsa -N '' -f ~/.ssh/id_rsa && @@ -627,10 +725,11 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): """ ssh(master, opts, key_setup) dot_ssh_tar = ssh_read(master, opts, ['tar', 'c', '.ssh']) - print "Transferring cluster's SSH key to slaves..." + print("Transferring cluster's SSH key to slaves...") for slave in slave_nodes: - print slave.public_dns_name - ssh_write(slave.public_dns_name, opts, ['tar', 'x'], dot_ssh_tar) + slave_address = get_dns_name(slave, opts.private_ips) + print(slave_address) + ssh_write(slave_address, opts, ['tar', 'x'], dot_ssh_tar) modules = ['spark', 'ephemeral-hdfs', 'persistent-hdfs', 'mapreduce', 'spark-standalone', 'tachyon'] @@ -643,15 +742,18 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): # NOTE: We should clone the repository before running deploy_files to # prevent ec2-variables.sh from being overwritten + print("Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format( + r=opts.spark_ec2_git_repo, b=opts.spark_ec2_git_branch)) ssh( host=master, opts=opts, command="rm -rf spark-ec2" + " && " - + "git clone https://github.com/mesos/spark-ec2.git -b {b}".format(b=MESOS_SPARK_EC2_BRANCH) + + "git clone {r} -b {b} spark-ec2".format(r=opts.spark_ec2_git_repo, + b=opts.spark_ec2_git_branch) ) - print "Deploying files to master..." + print("Deploying files to master...") deploy_files( conn=conn, root_dir=SPARK_EC2_DIR + "/" + "deploy.generic", @@ -661,35 +763,54 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): modules=modules ) - print "Running setup on master..." + if opts.deploy_root_dir is not None: + print("Deploying {s} to master...".format(s=opts.deploy_root_dir)) + deploy_user_files( + root_dir=opts.deploy_root_dir, + opts=opts, + master_nodes=master_nodes + ) + + print("Running setup on master...") setup_spark_cluster(master, opts) - print "Done!" + print("Done!") def setup_spark_cluster(master, opts): ssh(master, opts, "chmod u+x spark-ec2/setup.sh") ssh(master, opts, "spark-ec2/setup.sh") - print "Spark standalone cluster started at http://%s:8080" % master + print("Spark standalone cluster started at http://%s:8080" % master) if opts.ganglia: - print "Ganglia started at http://%s:5080/ganglia" % master + print("Ganglia started at http://%s:5080/ganglia" % master) -def is_ssh_available(host, opts): +def is_ssh_available(host, opts, print_ssh_output=True): """ Check if SSH is available on a host. """ - try: - with open(os.devnull, 'w') as devnull: - ret = subprocess.check_call( - ssh_command(opts) + ['-t', '-t', '-o', 'ConnectTimeout=3', - '%s@%s' % (opts.user, host), stringify_command('true')], - stdout=devnull, - stderr=devnull - ) - return ret == 0 - except subprocess.CalledProcessError as e: - return False + s = subprocess.Popen( + ssh_command(opts) + ['-t', '-t', '-o', 'ConnectTimeout=3', + '%s@%s' % (opts.user, host), stringify_command('true')], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT # we pipe stderr through stdout to preserve output order + ) + cmd_output = s.communicate()[0] # [1] is stderr, which we redirected to stdout + + if s.returncode != 0 and print_ssh_output: + # extra leading newline is for spacing in wait_for_cluster_state() + print(textwrap.dedent("""\n + Warning: SSH connection error. (This could be temporary.) + Host: {h} + SSH return code: {r} + SSH output: {o} + """).format( + h=host, + r=s.returncode, + o=cmd_output.strip() + )) + + return s.returncode == 0 def is_cluster_ssh_available(cluster_instances, opts): @@ -697,7 +818,8 @@ def is_cluster_ssh_available(cluster_instances, opts): Check if SSH is available on all the instances in a cluster. """ for i in cluster_instances: - if not is_ssh_available(host=i.ip_address, opts=opts): + dns_name = get_dns_name(i, opts.private_ips) + if not is_ssh_available(host=dns_name, opts=opts): return False else: return True @@ -747,10 +869,10 @@ def wait_for_cluster_state(conn, opts, cluster_instances, cluster_state): sys.stdout.write("\n") end_time = datetime.now() - print "Cluster is now in '{s}' state. Waited {t} seconds.".format( + print("Cluster is now in '{s}' state. Waited {t} seconds.".format( s=cluster_state, t=(end_time - start_time).seconds - ) + )) # Get number of local disks available for a given EC2 instance type. @@ -798,8 +920,8 @@ def get_num_disks(instance_type): if instance_type in disks_by_instance: return disks_by_instance[instance_type] else: - print >> stderr, ("WARNING: Don't know number of disks on instance type %s; assuming 1" - % instance_type) + print("WARNING: Don't know number of disks on instance type %s; assuming 1" + % instance_type, file=stderr) return 1 @@ -811,7 +933,7 @@ def get_num_disks(instance_type): # # root_dir should be an absolute path to the directory with the files we want to deploy. def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): - active_master = master_nodes[0].public_dns_name + active_master = get_dns_name(master_nodes[0], opts.private_ips) num_disks = get_num_disks(opts.instance_type) hdfs_data_dirs = "/mnt/ephemeral-hdfs/data" @@ -828,14 +950,20 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): if "." in opts.spark_version: # Pre-built Spark deploy spark_v = get_validate_spark_version(opts.spark_version, opts.spark_git_repo) + tachyon_v = get_tachyon_version(spark_v) else: # Spark-only custom deploy spark_v = "%s|%s" % (opts.spark_git_repo, opts.spark_version) + tachyon_v = "" + print("Deploying Spark via git hash; Tachyon won't be set up") + modules = filter(lambda x: x != "tachyon", modules) + master_addresses = [get_dns_name(i, opts.private_ips) for i in master_nodes] + slave_addresses = [get_dns_name(i, opts.private_ips) for i in slave_nodes] template_vars = { - "master_list": '\n'.join([i.public_dns_name for i in master_nodes]), + "master_list": '\n'.join(master_addresses), "active_master": active_master, - "slave_list": '\n'.join([i.public_dns_name for i in slave_nodes]), + "slave_list": '\n'.join(slave_addresses), "cluster_url": cluster_url, "hdfs_data_dirs": hdfs_data_dirs, "mapred_local_dirs": mapred_local_dirs, @@ -843,6 +971,7 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): "swap": str(opts.swap), "modules": '\n'.join(modules), "spark_version": spark_v, + "tachyon_version": tachyon_v, "hadoop_major_version": opts.hadoop_major_version, "spark_worker_instances": "%d" % opts.worker_instances, "spark_master_opts": opts.master_opts @@ -887,6 +1016,23 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): shutil.rmtree(tmp_dir) +# Deploy a given local directory to a cluster, WITHOUT parameter substitution. +# Note that unlike deploy_files, this works for binary files. +# Also, it is up to the user to add (or not) the trailing slash in root_dir. +# Files are only deployed to the first master instance in the cluster. +# +# root_dir should be an absolute path. +def deploy_user_files(root_dir, opts, master_nodes): + active_master = get_dns_name(master_nodes[0], opts.private_ips) + command = [ + 'rsync', '-rv', + '-e', stringify_command(ssh_command(opts)), + "%s" % root_dir, + "%s@%s:/" % (opts.user, active_master) + ] + subprocess.check_call(command) + + def stringify_command(parts): if isinstance(parts, str): return parts @@ -896,6 +1042,7 @@ def stringify_command(parts): def ssh_args(opts): parts = ['-o', 'StrictHostKeyChecking=no'] + parts += ['-o', 'UserKnownHostsFile=/dev/null'] if opts.identity_file is not None: parts += ['-i', opts.identity_file] return parts @@ -924,8 +1071,8 @@ def ssh(host, opts, command): "--key-pair parameters and try again.".format(host)) else: raise e - print >> stderr, \ - "Error executing remote command, retrying after 30 seconds: {0}".format(e) + print("Error executing remote command, retrying after 30 seconds: {0}".format(e), + file=stderr) time.sleep(30) tries = tries + 1 @@ -964,8 +1111,8 @@ def ssh_write(host, opts, command, arguments): elif tries > 5: raise RuntimeError("ssh_write failed with error %s" % proc.returncode) else: - print >> stderr, \ - "Error {0} while executing remote command, retrying after 30 seconds".format(status) + print("Error {0} while executing remote command, retrying after 30 seconds". + format(status), file=stderr) time.sleep(30) tries = tries + 1 @@ -987,6 +1134,20 @@ def get_partition(total, num_partitions, current_partitions): return num_slaves_this_zone +# Gets the IP address, taking into account the --private-ips flag +def get_ip_address(instance, private_ips=False): + ip = instance.ip_address if not private_ips else \ + instance.private_ip_address + return ip + + +# Gets the DNS name, taking into account the --private-ips flag +def get_dns_name(instance, private_ips=False): + dns = instance.public_dns_name if not private_ips else \ + instance.private_ip_address + return dns + + def real_main(): (opts, action, cluster_name) = parse_args() @@ -1003,14 +1164,67 @@ def real_main(): DeprecationWarning ) + if opts.identity_file is not None: + if not os.path.exists(opts.identity_file): + print("ERROR: The identity file '{f}' doesn't exist.".format(f=opts.identity_file), + file=stderr) + sys.exit(1) + + file_mode = os.stat(opts.identity_file).st_mode + if not (file_mode & S_IRUSR) or not oct(file_mode)[-2:] == '00': + print("ERROR: The identity file must be accessible only by you.", file=stderr) + print('You can fix this with: chmod 400 "{f}"'.format(f=opts.identity_file), + file=stderr) + sys.exit(1) + + if opts.instance_type not in EC2_INSTANCE_TYPES: + print("Warning: Unrecognized EC2 instance type for instance-type: {t}".format( + t=opts.instance_type), file=stderr) + + if opts.master_instance_type != "": + if opts.master_instance_type not in EC2_INSTANCE_TYPES: + print("Warning: Unrecognized EC2 instance type for master-instance-type: {t}".format( + t=opts.master_instance_type), file=stderr) + # Since we try instance types even if we can't resolve them, we check if they resolve first + # and, if they do, see if they resolve to the same virtualization type. + if opts.instance_type in EC2_INSTANCE_TYPES and \ + opts.master_instance_type in EC2_INSTANCE_TYPES: + if EC2_INSTANCE_TYPES[opts.instance_type] != \ + EC2_INSTANCE_TYPES[opts.master_instance_type]: + print("Error: spark-ec2 currently does not support having a master and slaves " + "with different AMI virtualization types.", file=stderr) + print("master instance virtualization type: {t}".format( + t=EC2_INSTANCE_TYPES[opts.master_instance_type]), file=stderr) + print("slave instance virtualization type: {t}".format( + t=EC2_INSTANCE_TYPES[opts.instance_type]), file=stderr) + sys.exit(1) + if opts.ebs_vol_num > 8: - print >> stderr, "ebs-vol-num cannot be greater than 8" + print("ebs-vol-num cannot be greater than 8", file=stderr) + sys.exit(1) + + # Prevent breaking ami_prefix (/, .git and startswith checks) + # Prevent forks with non spark-ec2 names for now. + if opts.spark_ec2_git_repo.endswith("/") or \ + opts.spark_ec2_git_repo.endswith(".git") or \ + not opts.spark_ec2_git_repo.startswith("https://github.com") or \ + not opts.spark_ec2_git_repo.endswith("spark-ec2"): + print("spark-ec2-git-repo must be a github repo and it must not have a trailing / or .git. " + "Furthermore, we currently only support forks named spark-ec2.", file=stderr) + sys.exit(1) + + if not (opts.deploy_root_dir is None or + (os.path.isabs(opts.deploy_root_dir) and + os.path.isdir(opts.deploy_root_dir) and + os.path.exists(opts.deploy_root_dir))): + print("--deploy-root-dir must be an absolute path to a directory that exists " + "on the local file system", file=stderr) sys.exit(1) try: conn = ec2.connect_to_region(opts.region) except Exception as e: - print >> stderr, (e) + print((e), file=stderr) sys.exit(1) # Select an AZ at random if it was not specified. @@ -1019,7 +1233,7 @@ def real_main(): if action == "launch": if opts.slaves <= 0: - print >> sys.stderr, "ERROR: You have to start at least 1 slave" + print("ERROR: You have to start at least 1 slave", file=sys.stderr) sys.exit(1) if opts.resume: (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) @@ -1034,26 +1248,27 @@ def real_main(): setup_cluster(conn, master_nodes, slave_nodes, opts, True) elif action == "destroy": - print "Are you sure you want to destroy the cluster %s?" % cluster_name - print "The following instances will be terminated:" (master_nodes, slave_nodes) = get_existing_cluster( conn, opts, cluster_name, die_on_error=False) - for inst in master_nodes + slave_nodes: - print "> %s" % inst.public_dns_name - msg = "ALL DATA ON ALL NODES WILL BE LOST!!\nDestroy cluster %s (y/N): " % cluster_name + if any(master_nodes + slave_nodes): + print("The following instances will be terminated:") + for inst in master_nodes + slave_nodes: + print("> %s" % get_dns_name(inst, opts.private_ips)) + print("ALL DATA ON ALL NODES WILL BE LOST!!") + + msg = "Are you sure you want to destroy the cluster {c}? (y/N) ".format(c=cluster_name) response = raw_input(msg) if response == "y": - print "Terminating master..." + print("Terminating master...") for inst in master_nodes: inst.terminate() - print "Terminating slaves..." + print("Terminating slaves...") for inst in slave_nodes: inst.terminate() # Delete security groups as well if opts.delete_groups: - print "Deleting security groups (this will take some time)..." group_names = [cluster_name + "-master", cluster_name + "-slaves"] wait_for_cluster_state( conn=conn, @@ -1061,15 +1276,16 @@ def real_main(): cluster_instances=(master_nodes + slave_nodes), cluster_state='terminated' ) + print("Deleting security groups (this will take some time)...") attempt = 1 while attempt <= 3: - print "Attempt %d" % attempt + print("Attempt %d" % attempt) groups = [g for g in conn.get_all_security_groups() if g.name in group_names] success = True # Delete individual rules in all groups before deleting groups to # remove dependencies between them for group in groups: - print "Deleting rules in security group " + group.name + print("Deleting rules in security group " + group.name) for rule in group.rules: for grant in rule.grants: success &= group.revoke(ip_protocol=rule.ip_protocol, @@ -1082,11 +1298,12 @@ def real_main(): time.sleep(30) # Yes, it does have to be this long :-( for group in groups: try: - conn.delete_security_group(group.name) - print "Deleted security group " + group.name + # It is needed to use group_id to make it work with VPC + conn.delete_security_group(group_id=group.id) + print("Deleted security group %s" % group.name) except boto.exception.EC2ResponseError: success = False - print "Failed to delete security group " + group.name + print("Failed to delete security group %s" % group.name) # Unfortunately, group.revoke() returns True even if a rule was not # deleted, so this needs to be rerun if something fails @@ -1096,18 +1313,21 @@ def real_main(): attempt += 1 if not success: - print "Failed to delete all security groups after 3 tries." - print "Try re-running in a few minutes." + print("Failed to delete all security groups after 3 tries.") + print("Try re-running in a few minutes.") elif action == "login": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - master = master_nodes[0].public_dns_name - print "Logging into master " + master + "..." - proxy_opt = [] - if opts.proxy_port is not None: - proxy_opt = ['-D', opts.proxy_port] - subprocess.check_call( - ssh_command(opts) + proxy_opt + ['-t', '-t', "%s@%s" % (opts.user, master)]) + if not master_nodes[0].public_dns_name and not opts.private_ips: + print("Master has no public DNS name. Maybe you meant to specify --private-ips?") + else: + master = get_dns_name(master_nodes[0], opts.private_ips) + print("Logging into master " + master + "...") + proxy_opt = [] + if opts.proxy_port is not None: + proxy_opt = ['-D', opts.proxy_port] + subprocess.check_call( + ssh_command(opts) + proxy_opt + ['-t', '-t', "%s@%s" % (opts.user, master)]) elif action == "reboot-slaves": response = raw_input( @@ -1117,15 +1337,18 @@ def real_main(): if response == "y": (master_nodes, slave_nodes) = get_existing_cluster( conn, opts, cluster_name, die_on_error=False) - print "Rebooting slaves..." + print("Rebooting slaves...") for inst in slave_nodes: if inst.state not in ["shutting-down", "terminated"]: - print "Rebooting " + inst.id + print("Rebooting " + inst.id) inst.reboot() elif action == "get-master": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - print master_nodes[0].public_dns_name + if not master_nodes[0].public_dns_name and not opts.private_ips: + print("Master has no public DNS name. Maybe you meant to specify --private-ips?") + else: + print(get_dns_name(master_nodes[0], opts.private_ips)) elif action == "stop": response = raw_input( @@ -1138,11 +1361,11 @@ def real_main(): if response == "y": (master_nodes, slave_nodes) = get_existing_cluster( conn, opts, cluster_name, die_on_error=False) - print "Stopping master..." + print("Stopping master...") for inst in master_nodes: if inst.state not in ["shutting-down", "terminated"]: inst.stop() - print "Stopping slaves..." + print("Stopping slaves...") for inst in slave_nodes: if inst.state not in ["shutting-down", "terminated"]: if inst.spot_instance_request_id: @@ -1152,11 +1375,11 @@ def real_main(): elif action == "start": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - print "Starting slaves..." + print("Starting slaves...") for inst in slave_nodes: if inst.state not in ["shutting-down", "terminated"]: inst.start() - print "Starting master..." + print("Starting master...") for inst in master_nodes: if inst.state not in ["shutting-down", "terminated"]: inst.start() @@ -1166,18 +1389,29 @@ def real_main(): cluster_instances=(master_nodes + slave_nodes), cluster_state='ssh-ready' ) + + # Determine types of running instances + existing_master_type = master_nodes[0].instance_type + existing_slave_type = slave_nodes[0].instance_type + # Setting opts.master_instance_type to the empty string indicates we + # have the same instance type for the master and the slaves + if existing_master_type == existing_slave_type: + existing_master_type = "" + opts.master_instance_type = existing_master_type + opts.instance_type = existing_slave_type + setup_cluster(conn, master_nodes, slave_nodes, opts, False) else: - print >> stderr, "Invalid action: %s" % action + print("Invalid action: %s" % action, file=stderr) sys.exit(1) def main(): try: real_main() - except UsageError, e: - print >> stderr, "\nError:\n", e + except UsageError as e: + print("\nError:\n", e, file=stderr) sys.exit(1) diff --git a/examples/pom.xml b/examples/pom.xml index 4b92147725f6..df1717403b67 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT ../pom.xml @@ -35,12 +35,6 @@ http://spark.apache.org/ - - - com.google.guava - guava - compile - org.apache.spark spark-core_${scala.binary.version} @@ -96,6 +90,12 @@ org.apache.spark spark-streaming-zeromq_${scala.binary.version} ${project.version} + + + org.spark-project.protobuf + protobuf-java + + org.apache.hbase @@ -240,6 +240,7 @@ org.apache.commons commons-math3 + provided com.twitter @@ -268,6 +269,22 @@ com.ning compress-lzf + + commons-cli + commons-cli + + + commons-codec + commons-codec + + + commons-lang + commons-lang + + + commons-logging + commons-logging + io.netty netty @@ -276,10 +293,22 @@ jline jline + + net.jpountz.lz4 + lz4 + org.apache.cassandra.deps avro + + org.apache.commons + commons-math3 + + + org.apache.thrift + libthrift + @@ -287,6 +316,17 @@ scopt_${scala.binary.version} 3.2.0 + + + + org.scala-lang + scala-library + provided + + @@ -310,69 +350,34 @@ org.apache.maven.plugins maven-shade-plugin - - - package - - shade - - - false - ${project.build.directory}/scala-${scala.binary.version}/spark-examples-${project.version}-hadoop${hadoop.version}.jar - - - *:* - - - - - com.google.guava:guava - - - ** - - - - *:* - - META-INF/*.SF - META-INF/*.DSA - META-INF/*.RSA - - - - - - com.google - org.spark-project.guava - - com.google.common.** - - - com.google.common.base.Optional** - - - - org.apache.commons.math3 - org.spark-project.commons.math3 - - - - - - reference.conf - - - log4j.properties - - - - - + + false + ${project.build.directory}/scala-${scala.binary.version}/spark-examples-${project.version}-hadoop${hadoop.version}.jar + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + reference.conf + + + log4j.properties + + + @@ -385,11 +390,6 @@ spark-streaming-kinesis-asl_${scala.binary.version} ${project.version} - - org.apache.httpcomponents - httpclient - ${commons.httpclient.version} - diff --git a/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java new file mode 100644 index 000000000000..bab9f2478e77 --- /dev/null +++ b/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Arrays; +import java.util.regex.Pattern; + +import scala.Tuple2; + +import com.google.common.collect.Lists; +import kafka.serializer.StringDecoder; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.*; +import org.apache.spark.streaming.api.java.*; +import org.apache.spark.streaming.kafka.KafkaUtils; +import org.apache.spark.streaming.Durations; + +/** + * Consumes messages from one or more topics in Kafka and does wordcount. + * Usage: DirectKafkaWordCount + * is a list of one or more Kafka brokers + * is a list of one or more kafka topics to consume from + * + * Example: + * $ bin/run-example streaming.KafkaWordCount broker1-host:port,broker2-host:port topic1,topic2 + */ + +public final class JavaDirectKafkaWordCount { + private static final Pattern SPACE = Pattern.compile(" "); + + public static void main(String[] args) { + if (args.length < 2) { + System.err.println("Usage: DirectKafkaWordCount \n" + + " is a list of one or more Kafka brokers\n" + + " is a list of one or more kafka topics to consume from\n\n"); + System.exit(1); + } + + StreamingExamples.setStreamingLogLevels(); + + String brokers = args[0]; + String topics = args[1]; + + // Create context with 2 second batch interval + SparkConf sparkConf = new SparkConf().setAppName("JavaDirectKafkaWordCount"); + JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, Durations.seconds(2)); + + HashSet topicsSet = new HashSet(Arrays.asList(topics.split(","))); + HashMap kafkaParams = new HashMap(); + kafkaParams.put("metadata.broker.list", brokers); + + // Create direct kafka stream with brokers and topics + JavaPairInputDStream messages = KafkaUtils.createDirectStream( + jssc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + kafkaParams, + topicsSet + ); + + // Get the lines, split them into words, count the words and print + JavaDStream lines = messages.map(new Function, String>() { + @Override + public String call(Tuple2 tuple2) { + return tuple2._2(); + } + }); + JavaDStream words = lines.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String x) { + return Lists.newArrayList(SPACE.split(x)); + } + }); + JavaPairDStream wordCounts = words.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(String s) { + return new Tuple2(s, 1); + } + }).reduceByKey( + new Function2() { + @Override + public Integer call(Integer i1, Integer i2) { + return i1 + i2; + } + }); + wordCounts.print(); + + // Start the computation + jssc.start(); + jssc.awaitTermination(); + } +} diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala new file mode 100644 index 000000000000..11a8cf09533c --- /dev/null +++ b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming + +import kafka.serializer.StringDecoder + +import org.apache.spark.streaming._ +import org.apache.spark.streaming.kafka._ +import org.apache.spark.SparkConf + +/** + * Consumes messages from one or more topics in Kafka and does wordcount. + * Usage: DirectKafkaWordCount + * is a list of one or more Kafka brokers + * is a list of one or more kafka topics to consume from + * + * Example: + * $ bin/run-example streaming.DirectKafkaWordCount broker1-host:port,broker2-host:port \ + * topic1,topic2 + */ +object DirectKafkaWordCount { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println(s""" + |Usage: DirectKafkaWordCount + | is a list of one or more Kafka brokers + | is a list of one or more kafka topics to consume from + | + """.stripMargin) + System.exit(1) + } + + StreamingExamples.setStreamingLogLevels() + + val Array(brokers, topics) = args + + // Create context with 2 second batch interval + val sparkConf = new SparkConf().setAppName("DirectKafkaWordCount") + val ssc = new StreamingContext(sparkConf, Seconds(2)) + + // Create direct kafka stream with brokers and topics + val topicsSet = topics.split(",").toSet + val kafkaParams = Map[String, String]("metadata.broker.list" -> brokers) + val messages = KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, topicsSet) + + // Get the lines, split them into words, count the words and print + val lines = messages.map(_._2) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1L)).reduceByKey(_ + _) + wordCounts.print() + + // Start the computation + ssc.start() + ssc.awaitTermination() + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java index 247d2a5e31a8..9bbc14ea4087 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java @@ -33,9 +33,9 @@ import org.apache.spark.ml.tuning.CrossValidator; import org.apache.spark.ml.tuning.CrossValidatorModel; import org.apache.spark.ml.tuning.ParamGridBuilder; -import org.apache.spark.sql.SchemaRDD; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; /** * A simple example demonstrating model selection using CrossValidator. @@ -71,7 +71,7 @@ public static void main(String[] args) { new LabeledDocument(9L, "a e c l", 0.0), new LabeledDocument(10L, "spark compile", 1.0), new LabeledDocument(11L, "hadoop software", 0.0)); - SchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -112,14 +112,15 @@ public static void main(String[] args) { new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); - SchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class); + DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. cvModel uses the best model found (lrModel). - cvModel.transform(test).registerAsTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); - for (Row r: predictions.collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) + DataFrame predictions = cvModel.transform(test); + for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + ", prediction=" + r.get(3)); } + + jsc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java new file mode 100644 index 000000000000..eaf00d09f550 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.util.List; + +import com.google.common.collect.Lists; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.classification.Classifier; +import org.apache.spark.ml.classification.ClassificationModel; +import org.apache.spark.ml.param.IntParam; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.param.Params; +import org.apache.spark.ml.param.Params$; +import org.apache.spark.mllib.linalg.BLAS; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; + + +/** + * A simple example demonstrating how to write your own learning algorithm using Estimator, + * Transformer, and other abstractions. + * This mimics {@link org.apache.spark.ml.classification.LogisticRegression}. + * + * Run with + *
    + * bin/run-example ml.JavaDeveloperApiExample
    + * 
    + */ +public class JavaDeveloperApiExample { + + public static void main(String[] args) throws Exception { + SparkConf conf = new SparkConf().setAppName("JavaDeveloperApiExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // Prepare training data. + List localTraining = Lists.newArrayList( + new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), + new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), + new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), + new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); + DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class); + + // Create a LogisticRegression instance. This instance is an Estimator. + MyJavaLogisticRegression lr = new MyJavaLogisticRegression(); + // Print out the parameters, documentation, and any default values. + System.out.println("MyJavaLogisticRegression parameters:\n" + lr.explainParams() + "\n"); + + // We may set parameters using setter methods. + lr.setMaxIter(10); + + // Learn a LogisticRegression model. This uses the parameters stored in lr. + MyJavaLogisticRegressionModel model = lr.fit(training); + + // Prepare test data. + List localTest = Lists.newArrayList( + new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), + new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), + new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); + DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); + + // Make predictions on test documents. cvModel uses the best model found (lrModel). + DataFrame results = model.transform(test); + double sumPredictions = 0; + for (Row r : results.select("features", "label", "prediction").collect()) { + sumPredictions += r.getDouble(2); + } + if (sumPredictions != 0.0) { + throw new Exception("MyJavaLogisticRegression predicted something other than 0," + + " even though all weights are 0!"); + } + + jsc.stop(); + } +} + +/** + * Example of defining a type of {@link Classifier}. + * + * NOTE: This is private since it is an example. In practice, you may not want it to be private. + */ +class MyJavaLogisticRegression + extends Classifier + implements Params { + + /** + * Param for max number of iterations + *

    + * NOTE: The usual way to add a parameter to a model or algorithm is to include: + * - val myParamName: ParamType + * - def getMyParamName + * - def setMyParamName + */ + IntParam maxIter = new IntParam(this, "maxIter", "max number of iterations"); + + int getMaxIter() { return (Integer) getOrDefault(maxIter); } + + public MyJavaLogisticRegression() { + setMaxIter(100); + } + + // The parameter setter is in this class since it should return type MyJavaLogisticRegression. + MyJavaLogisticRegression setMaxIter(int value) { + return (MyJavaLogisticRegression) set(maxIter, value); + } + + // This method is used by fit(). + // In Java, we have to make it public since Java does not understand Scala's protected modifier. + public MyJavaLogisticRegressionModel train(DataFrame dataset, ParamMap paramMap) { + // Extract columns from data using helper method. + JavaRDD oldDataset = extractLabeledPoints(dataset, paramMap).toJavaRDD(); + + // Do learning to estimate the weight vector. + int numFeatures = oldDataset.take(1).get(0).features().size(); + Vector weights = Vectors.zeros(numFeatures); // Learning would happen here. + + // Create a model, and return it. + return new MyJavaLogisticRegressionModel(this, paramMap, weights); + } +} + +/** + * Example of defining a type of {@link ClassificationModel}. + * + * NOTE: This is private since it is an example. In practice, you may not want it to be private. + */ +class MyJavaLogisticRegressionModel + extends ClassificationModel implements Params { + + private MyJavaLogisticRegression parent_; + public MyJavaLogisticRegression parent() { return parent_; } + + private ParamMap fittingParamMap_; + public ParamMap fittingParamMap() { return fittingParamMap_; } + + private Vector weights_; + public Vector weights() { return weights_; } + + public MyJavaLogisticRegressionModel( + MyJavaLogisticRegression parent_, + ParamMap fittingParamMap_, + Vector weights_) { + this.parent_ = parent_; + this.fittingParamMap_ = fittingParamMap_; + this.weights_ = weights_; + } + + // This uses the default implementation of transform(), which reads column "features" and outputs + // columns "prediction" and "rawPrediction." + + // This uses the default implementation of predict(), which chooses the label corresponding to + // the maximum value returned by [[predictRaw()]]. + + /** + * Raw prediction for each possible label. + * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives + * a measure of confidence in each possible label (where larger = more confident). + * This internal method is used to implement [[transform()]] and output [[rawPredictionCol]]. + * + * @return vector where element i is the raw prediction for label i. + * This raw prediction may be any real number, where a larger value indicates greater + * confidence for that label. + * + * In Java, we have to make this method public since Java does not understand Scala's protected + * modifier. + */ + public Vector predictRaw(Vector features) { + double margin = BLAS.dot(features, weights_); + // There are 2 classes (binary classification), so we return a length-2 vector, + // where index i corresponds to class i (i = 0, 1). + return Vectors.dense(-margin, margin); + } + + /** + * Number of classes the label can take. 2 indicates binary classification. + */ + public int numClasses() { return 2; } + + /** + * Create a copy of the model. + * The copy is shallow, except for the embedded paramMap, which gets a deep copy. + *

    + * This is used for the defaul implementation of [[transform()]]. + * + * In Java, we have to make this method public since Java does not understand Scala's protected + * modifier. + */ + public MyJavaLogisticRegressionModel copy() { + MyJavaLogisticRegressionModel m = + new MyJavaLogisticRegressionModel(parent_, fittingParamMap_, weights_); + Params$.MODULE$.inheritValues(this.extractParamMap(), this, m); + return m; + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index 5b92655e2e83..4e02acce696e 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -28,9 +28,9 @@ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.SchemaRDD; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; /** * A simple example demonstrating ways to specify parameters for Estimators and Transformers. @@ -48,13 +48,13 @@ public static void main(String[] args) { // Prepare training data. // We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans - // into SchemaRDDs, where it uses the bean metadata to infer the schema. + // into DataFrames, where it uses the bean metadata to infer the schema. List localTraining = Lists.newArrayList( new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); - SchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class); + DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class); // Create a LogisticRegression instance. This instance is an Estimator. LogisticRegression lr = new LogisticRegression(); @@ -81,7 +81,7 @@ public static void main(String[] args) { // One can also combine ParamMaps. ParamMap paramMap2 = new ParamMap(); - paramMap2.put(lr.scoreCol().w("probability")); // Change output column name + paramMap2.put(lr.probabilityCol().w("myProbability")); // Change output column name ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); // Now learn a new model using the paramMapCombined parameters. @@ -94,18 +94,18 @@ public static void main(String[] args) { new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); - SchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class); + DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. - // Note that model2.transform() outputs a 'probability' column instead of the usual 'score' - // column since we renamed the lr.scoreCol parameter previously. - model2.transform(test).registerAsTable("results"); - SchemaRDD results = - jsql.sql("SELECT features, label, probability, prediction FROM results"); - for (Row r: results.collect()) { + // Note that model2.transform() outputs a 'myProbability' column instead of the usual + // 'probability' column since we renamed the lr.probabilityCol parameter previously. + DataFrame results = model2.transform(test); + for (Row r: results.select("features", "label", "myProbability", "prediction").collect()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) + ", prediction=" + r.get(3)); } + + jsc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java index 74db449fada7..ef1ec103a879 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java @@ -29,9 +29,9 @@ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.HashingTF; import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.sql.SchemaRDD; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; /** * A simple text classification pipeline that recognizes "spark" from input text. It uses the Java @@ -54,7 +54,7 @@ public static void main(String[] args) { new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), new LabeledDocument(3L, "hadoop mapreduce", 0.0)); - SchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -79,14 +79,15 @@ public static void main(String[] args) { new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); - SchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class); + DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. - model.transform(test).registerAsTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); - for (Row r: predictions.collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) + DataFrame predictions = model.transform(test); + for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + ", prediction=" + r.get(3)); } + + jsc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java new file mode 100644 index 000000000000..36baf5868736 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import java.util.ArrayList; + +import com.google.common.base.Joiner; +import com.google.common.collect.Lists; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.fpm.FPGrowth; +import org.apache.spark.mllib.fpm.FPGrowthModel; + +/** + * Java example for mining frequent itemsets using FP-growth. + * Example usage: ./bin/run-example mllib.JavaFPGrowthExample ./data/mllib/sample_fpgrowth.txt + */ +public class JavaFPGrowthExample { + + public static void main(String[] args) { + String inputFile; + double minSupport = 0.3; + int numPartition = -1; + if (args.length < 1) { + System.err.println( + "Usage: JavaFPGrowth [minSupport] [numPartition]"); + System.exit(1); + } + inputFile = args[0]; + if (args.length >= 2) { + minSupport = Double.parseDouble(args[1]); + } + if (args.length >= 3) { + numPartition = Integer.parseInt(args[2]); + } + + SparkConf sparkConf = new SparkConf().setAppName("JavaFPGrowthExample"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + + JavaRDD> transactions = sc.textFile(inputFile).map( + new Function>() { + @Override + public ArrayList call(String s) { + return Lists.newArrayList(s.split(" ")); + } + } + ); + + FPGrowthModel model = new FPGrowth() + .setMinSupport(minSupport) + .setNumPartitions(numPartition) + .run(transactions); + + for (FPGrowth.FreqItemset s: model.freqItemsets().toJavaRDD().collect()) { + System.out.println("[" + Joiner.on(",").join(s.javaItems()) + "], " + s.freq()); + } + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java new file mode 100644 index 000000000000..36207ae38d9a --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.clustering.DistributedLDAModel; +import org.apache.spark.mllib.clustering.LDA; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.SparkConf; + +public class JavaLDAExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("LDA Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // Load and parse the data + String path = "data/mllib/sample_lda_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public Vector call(String s) { + String[] sarray = s.trim().split(" "); + double[] values = new double[sarray.length]; + for (int i = 0; i < sarray.length; i++) + values[i] = Double.parseDouble(sarray[i]); + return Vectors.dense(values); + } + } + ); + // Index documents with unique IDs + JavaPairRDD corpus = JavaPairRDD.fromJavaRDD(parsedData.zipWithIndex().map( + new Function, Tuple2>() { + public Tuple2 call(Tuple2 doc_id) { + return doc_id.swap(); + } + } + )); + corpus.cache(); + + // Cluster the documents into three topics using LDA + DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus); + + // Output topics. Each is a distribution over words (matching word count vectors) + System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize() + + " words):"); + Matrix topics = ldaModel.topicsMatrix(); + for (int topic = 0; topic < 3; topic++) { + System.out.print("Topic " + topic + ":"); + for (int word = 0; word < ldaModel.vocabSize(); word++) { + System.out.print(" " + topics.apply(word, topic)); + } + System.out.println(); + } + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java new file mode 100644 index 000000000000..6c6f9768f015 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import scala.Tuple3; + +import com.google.common.collect.Lists; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.clustering.PowerIterationClustering; +import org.apache.spark.mllib.clustering.PowerIterationClusteringModel; + +/** + * Java example for graph clustering using power iteration clustering (PIC). + */ +public class JavaPowerIterationClusteringExample { + public static void main(String[] args) { + SparkConf sparkConf = new SparkConf().setAppName("JavaPowerIterationClusteringExample"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + + @SuppressWarnings("unchecked") + JavaRDD> similarities = sc.parallelize(Lists.newArrayList( + new Tuple3(0L, 1L, 0.9), + new Tuple3(1L, 2L, 0.9), + new Tuple3(2L, 3L, 0.9), + new Tuple3(3L, 4L, 0.1), + new Tuple3(4L, 5L, 0.9))); + + PowerIterationClustering pic = new PowerIterationClustering() + .setK(2) + .setMaxIterations(10); + PowerIterationClusteringModel model = pic.run(similarities); + + for (PowerIterationClustering.Assignment a: model.assignments().toJavaRDD().collect()) { + System.out.println(a.id() + " -> " + a.cluster()); + } + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index b70804635d5c..8159ffbe2d26 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -26,9 +26,9 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; public class JavaSparkSQL { public static class Person implements Serializable { @@ -55,7 +55,7 @@ public void setAge(int age) { public static void main(String[] args) throws Exception { SparkConf sparkConf = new SparkConf().setAppName("JavaSparkSQL"); JavaSparkContext ctx = new JavaSparkContext(sparkConf); - SQLContext sqlCtx = new SQLContext(ctx); + SQLContext sqlContext = new SQLContext(ctx); System.out.println("=== Data source: RDD ==="); // Load a text file and convert each line to a Java Bean. @@ -74,13 +74,13 @@ public Person call(String line) { }); // Apply a schema to an RDD of Java Beans and register it as a table. - SchemaRDD schemaPeople = sqlCtx.applySchema(people, Person.class); + DataFrame schemaPeople = sqlContext.createDataFrame(people, Person.class); schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. - SchemaRDD teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); - // The results of SQL queries are SchemaRDDs and support all the normal RDD operations. + // The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. List teenagerNames = teenagers.toJavaRDD().map(new Function() { @Override @@ -93,18 +93,18 @@ public String call(Row row) { } System.out.println("=== Data source: Parquet File ==="); - // JavaSchemaRDDs can be saved as parquet files, maintaining the schema information. + // DataFrames can be saved as parquet files, maintaining the schema information. schemaPeople.saveAsParquetFile("people.parquet"); // Read in the parquet file created above. // Parquet files are self-describing so the schema is preserved. - // The result of loading a parquet file is also a JavaSchemaRDD. - SchemaRDD parquetFile = sqlCtx.parquetFile("people.parquet"); + // The result of loading a parquet file is also a DataFrame. + DataFrame parquetFile = sqlContext.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); - SchemaRDD teenagers2 = - sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); + DataFrame teenagers2 = + sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); teenagerNames = teenagers2.toJavaRDD().map(new Function() { @Override public String call(Row row) { @@ -119,8 +119,8 @@ public String call(Row row) { // A JSON dataset is pointed by path. // The path can be either a single text file or a directory storing text files. String path = "examples/src/main/resources/people.json"; - // Create a JavaSchemaRDD from the file(s) pointed by path - SchemaRDD peopleFromJsonFile = sqlCtx.jsonFile(path); + // Create a DataFrame from the file(s) pointed by path + DataFrame peopleFromJsonFile = sqlContext.jsonFile(path); // Because the schema of a JSON dataset is automatically inferred, to write queries, // it is better to take a look at what is the schema. @@ -130,13 +130,13 @@ public String call(Row row) { // |-- age: IntegerType // |-- name: StringType - // Register this JavaSchemaRDD as a table. + // Register this DataFrame as a table. peopleFromJsonFile.registerTempTable("people"); - // SQL statements can be run by using the sql methods provided by sqlCtx. - SchemaRDD teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + // SQL statements can be run by using the sql methods provided by sqlContext. + DataFrame teenagers3 = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); - // The results of SQL queries are JavaSchemaRDDs and support all the normal RDD operations. + // The results of SQL queries are DataFrame and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. teenagerNames = teenagers3.toJavaRDD().map(new Function() { @Override @@ -146,14 +146,14 @@ public String call(Row row) { System.out.println(name); } - // Alternatively, a JavaSchemaRDD can be created for a JSON dataset represented by + // Alternatively, a DataFrame can be created for a JSON dataset represented by // a RDD[String] storing one JSON object per string. List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = ctx.parallelize(jsonData); - SchemaRDD peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD.rdd()); + DataFrame peopleFromJsonRDD = sqlContext.jsonRDD(anotherPeopleRDD.rdd()); - // Take a look at the schema of this new JavaSchemaRDD. + // Take a look at the schema of this new DataFrame. peopleFromJsonRDD.printSchema(); // The schema of anotherPeople is ... // root @@ -164,7 +164,7 @@ public String call(Row row) { peopleFromJsonRDD.registerTempTable("people2"); - SchemaRDD peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); + DataFrame peopleWithCity = sqlContext.sql("SELECT name, address.city FROM people2"); List nameAndCity = peopleWithCity.toJavaRDD().map(new Function() { @Override public String call(Row row) { diff --git a/core/src/main/scala/org/apache/spark/TaskContextHelper.scala b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecord.java similarity index 71% rename from core/src/main/scala/org/apache/spark/TaskContextHelper.scala rename to examples/src/main/java/org/apache/spark/examples/streaming/JavaRecord.java index 4636c4600a01..e63697a79f23 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextHelper.scala +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecord.java @@ -15,15 +15,17 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.examples.streaming; -/** - * This class exists to restrict the visibility of TaskContext setters. - */ -private [spark] object TaskContextHelper { +/** Java Bean class to be used with the example JavaSqlNetworkWordCount. */ +public class JavaRecord implements java.io.Serializable { + private String word; - def setTaskContext(tc: TaskContext): Unit = TaskContext.setTaskContext(tc) + public String getWord() { + return word; + } - def unset(): Unit = TaskContext.unset() - + public void setWord(String word) { + this.word = word; + } } diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java new file mode 100644 index 000000000000..46562ddbbcb5 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming; + +import java.util.regex.Pattern; + +import com.google.common.collect.Lists; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.api.java.StorageLevels; +import org.apache.spark.streaming.Durations; +import org.apache.spark.streaming.Time; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +/** + * Use DataFrames and SQL to count words in UTF8 encoded, '\n' delimited text received from the + * network every second. + * + * Usage: JavaSqlNetworkWordCount + * and describe the TCP server that Spark Streaming would connect to receive data. + * + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ bin/run-example org.apache.spark.examples.streaming.JavaSqlNetworkWordCount localhost 9999` + */ + +public final class JavaSqlNetworkWordCount { + private static final Pattern SPACE = Pattern.compile(" "); + + public static void main(String[] args) { + if (args.length < 2) { + System.err.println("Usage: JavaNetworkWordCount "); + System.exit(1); + } + + StreamingExamples.setStreamingLogLevels(); + + // Create the context with a 1 second batch size + SparkConf sparkConf = new SparkConf().setAppName("JavaSqlNetworkWordCount"); + JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1)); + + // Create a JavaReceiverInputDStream on target ip:port and count the + // words in input stream of \n delimited text (eg. generated by 'nc') + // Note that no duplication in storage level only for running locally. + // Replication necessary in distributed scenario for fault tolerance. + JavaReceiverInputDStream lines = ssc.socketTextStream( + args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER); + JavaDStream words = lines.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String x) { + return Lists.newArrayList(SPACE.split(x)); + } + }); + + // Convert RDDs of the words DStream to DataFrame and run SQL query + words.foreachRDD(new Function2, Time, Void>() { + @Override + public Void call(JavaRDD rdd, Time time) { + SQLContext sqlContext = JavaSQLContextSingleton.getInstance(rdd.context()); + + // Convert JavaRDD[String] to JavaRDD[bean class] to DataFrame + JavaRDD rowRDD = rdd.map(new Function() { + public JavaRecord call(String word) { + JavaRecord record = new JavaRecord(); + record.setWord(word); + return record; + } + }); + DataFrame wordsDataFrame = sqlContext.createDataFrame(rowRDD, JavaRecord.class); + + // Register as table + wordsDataFrame.registerTempTable("words"); + + // Do word count on table using SQL and print it + DataFrame wordCountsDataFrame = + sqlContext.sql("select word, count(*) as total from words group by word"); + System.out.println("========= " + time + "========="); + wordCountsDataFrame.show(); + return null; + } + }); + + ssc.start(); + ssc.awaitTermination(); + } +} + +/** Lazily instantiated singleton instance of SQLContext */ +class JavaSQLContextSingleton { + static private transient SQLContext instance = null; + static public SQLContext getInstance(SparkContext sparkContext) { + if (instance == null) { + instance = new SQLContext(sparkContext); + } + return instance; + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java new file mode 100644 index 000000000000..dbf2ef02d7b7 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming; + +import java.util.Arrays; +import java.util.List; +import java.util.regex.Pattern; + +import scala.Tuple2; + +import com.google.common.base.Optional; +import com.google.common.collect.Lists; + +import org.apache.spark.HashPartitioner; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.StorageLevels; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.streaming.Durations; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +/** + * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every + * second starting with initial value of word count. + * Usage: JavaStatefulNetworkWordCount + * and describe the TCP server that Spark Streaming would connect to receive + * data. + *

    + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ bin/run-example + * org.apache.spark.examples.streaming.JavaStatefulNetworkWordCount localhost 9999` + */ +public class JavaStatefulNetworkWordCount { + private static final Pattern SPACE = Pattern.compile(" "); + + public static void main(String[] args) { + if (args.length < 2) { + System.err.println("Usage: JavaStatefulNetworkWordCount "); + System.exit(1); + } + + StreamingExamples.setStreamingLogLevels(); + + // Update the cumulative count function + final Function2, Optional, Optional> updateFunction = + new Function2, Optional, Optional>() { + @Override + public Optional call(List values, Optional state) { + Integer newSum = state.or(0); + for (Integer value : values) { + newSum += value; + } + return Optional.of(newSum); + } + }; + + // Create the context with a 1 second batch size + SparkConf sparkConf = new SparkConf().setAppName("JavaStatefulNetworkWordCount"); + JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1)); + ssc.checkpoint("."); + + // Initial RDD input to updateStateByKey + @SuppressWarnings("unchecked") + List> tuples = Arrays.asList(new Tuple2("hello", 1), + new Tuple2("world", 1)); + JavaPairRDD initialRDD = ssc.sc().parallelizePairs(tuples); + + JavaReceiverInputDStream lines = ssc.socketTextStream( + args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER_2); + + JavaDStream words = lines.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String x) { + return Lists.newArrayList(SPACE.split(x)); + } + }); + + JavaPairDStream wordsDstream = words.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(String s) { + return new Tuple2(s, 1); + } + }); + + // This will give a Dstream made of state (which is the cumulative count of the words) + JavaPairDStream stateDstream = wordsDstream.updateStateByKey(updateFunction, + new HashPartitioner(ssc.sc().defaultParallelism()), initialRDD); + + stateDstream.print(); + ssc.start(); + ssc.awaitTermination(); + } +} diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py index 70b6146e39a8..1c3a787bd0e9 100755 --- a/examples/src/main/python/als.py +++ b/examples/src/main/python/als.py @@ -21,7 +21,8 @@ This example requires numpy (http://www.numpy.org/) """ -from os.path import realpath +from __future__ import print_function + import sys import numpy as np @@ -57,9 +58,9 @@ def update(i, vec, mat, ratings): Usage: als [M] [U] [F] [iterations] [partitions]" """ - print >> sys.stderr, """WARN: This is a naive implementation of ALS and is given as an + print("""WARN: This is a naive implementation of ALS and is given as an example. Please use the ALS method found in pyspark.mllib.recommendation for more - conventional use.""" + conventional use.""", file=sys.stderr) sc = SparkContext(appName="PythonALS") M = int(sys.argv[1]) if len(sys.argv) > 1 else 100 @@ -68,8 +69,8 @@ def update(i, vec, mat, ratings): ITERATIONS = int(sys.argv[4]) if len(sys.argv) > 4 else 5 partitions = int(sys.argv[5]) if len(sys.argv) > 5 else 2 - print "Running ALS with M=%d, U=%d, F=%d, iters=%d, partitions=%d\n" % \ - (M, U, F, ITERATIONS, partitions) + print("Running ALS with M=%d, U=%d, F=%d, iters=%d, partitions=%d\n" % + (M, U, F, ITERATIONS, partitions)) R = matrix(rand(M, F)) * matrix(rand(U, F).T) ms = matrix(rand(M, F)) @@ -95,7 +96,7 @@ def update(i, vec, mat, ratings): usb = sc.broadcast(us) error = rmse(R, ms, us) - print "Iteration %d:" % i - print "\nRMSE: %5.4f\n" % error + print("Iteration %d:" % i) + print("\nRMSE: %5.4f\n" % error) sc.stop() diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py index 4626bbb7e3b0..da368ac628a4 100644 --- a/examples/src/main/python/avro_inputformat.py +++ b/examples/src/main/python/avro_inputformat.py @@ -15,9 +15,12 @@ # limitations under the License. # +from __future__ import print_function + import sys from pyspark import SparkContext +from functools import reduce """ Read data file users.avro in local Spark distro: @@ -49,7 +52,7 @@ """ if __name__ == "__main__": if len(sys.argv) != 2 and len(sys.argv) != 3: - print >> sys.stderr, """ + print(""" Usage: avro_inputformat [reader_schema_file] Run with example jar: @@ -57,7 +60,7 @@ /path/to/examples/avro_inputformat.py [reader_schema_file] Assumes you have Avro data stored in . Reader schema can be optionally specified in [reader_schema_file]. - """ + """, file=sys.stderr) exit(-1) path = sys.argv[1] @@ -77,6 +80,6 @@ conf=conf) output = avro_rdd.map(lambda x: x[0]).collect() for k in output: - print k + print(k) sc.stop() diff --git a/examples/src/main/python/cassandra_inputformat.py b/examples/src/main/python/cassandra_inputformat.py index 05f34b74df45..93ca0cfcc930 100644 --- a/examples/src/main/python/cassandra_inputformat.py +++ b/examples/src/main/python/cassandra_inputformat.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from pyspark import SparkContext @@ -47,14 +49,14 @@ """ if __name__ == "__main__": if len(sys.argv) != 4: - print >> sys.stderr, """ + print(""" Usage: cassandra_inputformat Run with example jar: ./bin/spark-submit --driver-class-path /path/to/example/jar \ /path/to/examples/cassandra_inputformat.py Assumes you have some data in Cassandra already, running on , in and - """ + """, file=sys.stderr) exit(-1) host = sys.argv[1] @@ -77,6 +79,6 @@ conf=conf) output = cass_rdd.collect() for (k, v) in output: - print (k, v) + print((k, v)) sc.stop() diff --git a/examples/src/main/python/cassandra_outputformat.py b/examples/src/main/python/cassandra_outputformat.py index d144539e58b8..5d643eac92f9 100644 --- a/examples/src/main/python/cassandra_outputformat.py +++ b/examples/src/main/python/cassandra_outputformat.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from pyspark import SparkContext @@ -46,7 +48,7 @@ """ if __name__ == "__main__": if len(sys.argv) != 7: - print >> sys.stderr, """ + print(""" Usage: cassandra_outputformat Run with example jar: @@ -60,7 +62,7 @@ ... fname text, ... lname text ... ); - """ + """, file=sys.stderr) exit(-1) host = sys.argv[1] diff --git a/examples/src/main/python/hbase_inputformat.py b/examples/src/main/python/hbase_inputformat.py index 3b16010f1cb9..e17819d5feb7 100644 --- a/examples/src/main/python/hbase_inputformat.py +++ b/examples/src/main/python/hbase_inputformat.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from pyspark import SparkContext @@ -47,14 +49,14 @@ """ if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, """ + print(""" Usage: hbase_inputformat Run with example jar: ./bin/spark-submit --driver-class-path /path/to/example/jar \ /path/to/examples/hbase_inputformat.py
    Assumes you have some data in HBase already, running on , in
    - """ + """, file=sys.stderr) exit(-1) host = sys.argv[1] @@ -74,6 +76,6 @@ conf=conf) output = hbase_rdd.collect() for (k, v) in output: - print (k, v) + print((k, v)) sc.stop() diff --git a/examples/src/main/python/hbase_outputformat.py b/examples/src/main/python/hbase_outputformat.py index abb425b1f886..9e5641789a97 100644 --- a/examples/src/main/python/hbase_outputformat.py +++ b/examples/src/main/python/hbase_outputformat.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from pyspark import SparkContext @@ -40,7 +42,7 @@ """ if __name__ == "__main__": if len(sys.argv) != 7: - print >> sys.stderr, """ + print(""" Usage: hbase_outputformat
    Run with example jar: @@ -48,7 +50,7 @@ /path/to/examples/hbase_outputformat.py Assumes you have created
    with column family in HBase running on already - """ + """, file=sys.stderr) exit(-1) host = sys.argv[1] diff --git a/examples/src/main/python/kmeans.py b/examples/src/main/python/kmeans.py index 86ef6f32c84e..19391506463f 100755 --- a/examples/src/main/python/kmeans.py +++ b/examples/src/main/python/kmeans.py @@ -22,6 +22,7 @@ This example requires NumPy (http://www.numpy.org/). """ +from __future__ import print_function import sys @@ -47,12 +48,12 @@ def closestPoint(p, centers): if __name__ == "__main__": if len(sys.argv) != 4: - print >> sys.stderr, "Usage: kmeans " + print("Usage: kmeans ", file=sys.stderr) exit(-1) - print >> sys.stderr, """WARN: This is a naive implementation of KMeans Clustering and is given + print("""WARN: This is a naive implementation of KMeans Clustering and is given as an example! Please refer to examples/src/main/python/mllib/kmeans.py for an example on - how to use MLlib's KMeans implementation.""" + how to use MLlib's KMeans implementation.""", file=sys.stderr) sc = SparkContext(appName="PythonKMeans") lines = sc.textFile(sys.argv[1]) @@ -69,13 +70,13 @@ def closestPoint(p, centers): pointStats = closest.reduceByKey( lambda (x1, y1), (x2, y2): (x1 + x2, y1 + y2)) newPoints = pointStats.map( - lambda (x, (y, z)): (x, y / z)).collect() + lambda xy: (xy[0], xy[1][0] / xy[1][1])).collect() tempDist = sum(np.sum((kPoints[x] - y) ** 2) for (x, y) in newPoints) for (x, y) in newPoints: kPoints[x] = y - print "Final centers: " + str(kPoints) + print("Final centers: " + str(kPoints)) sc.stop() diff --git a/examples/src/main/python/logistic_regression.py b/examples/src/main/python/logistic_regression.py index 3aa56b052816..b318b7d87bfd 100755 --- a/examples/src/main/python/logistic_regression.py +++ b/examples/src/main/python/logistic_regression.py @@ -22,10 +22,8 @@ In practice, one may prefer to use the LogisticRegression algorithm in MLlib, as shown in examples/src/main/python/mllib/logistic_regression.py. """ +from __future__ import print_function -from collections import namedtuple -from math import exp -from os.path import realpath import sys import numpy as np @@ -42,19 +40,19 @@ def readPointBatch(iterator): strs = list(iterator) matrix = np.zeros((len(strs), D + 1)) - for i in xrange(len(strs)): - matrix[i] = np.fromstring(strs[i].replace(',', ' '), dtype=np.float32, sep=' ') + for i, s in enumerate(strs): + matrix[i] = np.fromstring(s.replace(',', ' '), dtype=np.float32, sep=' ') return [matrix] if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: logistic_regression " + print("Usage: logistic_regression ", file=sys.stderr) exit(-1) - print >> sys.stderr, """WARN: This is a naive implementation of Logistic Regression and is + print("""WARN: This is a naive implementation of Logistic Regression and is given as an example! Please refer to examples/src/main/python/mllib/logistic_regression.py - to see how MLlib's implementation is used.""" + to see how MLlib's implementation is used.""", file=sys.stderr) sc = SparkContext(appName="PythonLR") points = sc.textFile(sys.argv[1]).mapPartitions(readPointBatch).cache() @@ -62,7 +60,7 @@ def readPointBatch(iterator): # Initialize w to a random value w = 2 * np.random.ranf(size=D) - 1 - print "Initial w: " + str(w) + print("Initial w: " + str(w)) # Compute logistic regression gradient for a matrix of data points def gradient(matrix, w): @@ -76,9 +74,9 @@ def add(x, y): return x for i in range(iterations): - print "On iteration %i" % (i + 1) + print("On iteration %i" % (i + 1)) w -= points.map(lambda m: gradient(m, w)).reduce(add) - print "Final w: " + str(w) + print("Final w: " + str(w)) sc.stop() diff --git a/examples/src/main/python/ml/simple_text_classification_pipeline.py b/examples/src/main/python/ml/simple_text_classification_pipeline.py new file mode 100644 index 000000000000..fab21f003b23 --- /dev/null +++ b/examples/src/main/python/ml/simple_text_classification_pipeline.py @@ -0,0 +1,71 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.ml import Pipeline +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.feature import HashingTF, Tokenizer +from pyspark.sql import Row, SQLContext + + +""" +A simple text classification pipeline that recognizes "spark" from +input text. This is to show how to create and configure a Spark ML +pipeline in Python. Run with: + + bin/spark-submit examples/src/main/python/ml/simple_text_classification_pipeline.py +""" + + +if __name__ == "__main__": + sc = SparkContext(appName="SimpleTextClassificationPipeline") + sqlContext = SQLContext(sc) + + # Prepare training documents, which are labeled. + LabeledDocument = Row("id", "text", "label") + training = sc.parallelize([(0, "a b c d e spark", 1.0), + (1, "b d", 0.0), + (2, "spark f g h", 1.0), + (3, "hadoop mapreduce", 0.0)]) \ + .map(lambda x: LabeledDocument(*x)).toDF() + + # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. + tokenizer = Tokenizer(inputCol="text", outputCol="words") + hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") + lr = LogisticRegression(maxIter=10, regParam=0.01) + pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) + + # Fit the pipeline to training documents. + model = pipeline.fit(training) + + # Prepare test documents, which are unlabeled. + Document = Row("id", "text") + test = sc.parallelize([(4, "spark i j k"), + (5, "l m n"), + (6, "mapreduce spark"), + (7, "apache hadoop")]) \ + .map(lambda x: Document(*x)).toDF() + + # Make predictions on test documents and print columns of interest. + prediction = model.transform(test) + selected = prediction.select("id", "text", "prediction") + for row in selected.collect(): + print(row) + + sc.stop() diff --git a/examples/src/main/python/mllib/correlations.py b/examples/src/main/python/mllib/correlations.py index 4218eca822a9..0e13546b88e6 100755 --- a/examples/src/main/python/mllib/correlations.py +++ b/examples/src/main/python/mllib/correlations.py @@ -18,6 +18,7 @@ """ Correlations using MLlib. """ +from __future__ import print_function import sys @@ -29,7 +30,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: - print >> sys.stderr, "Usage: correlations ()" + print("Usage: correlations ()", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonCorrelations") if len(sys.argv) == 2: @@ -41,20 +42,20 @@ points = MLUtils.loadLibSVMFile(sc, filepath)\ .map(lambda lp: LabeledPoint(lp.label, lp.features.toArray())) - print - print 'Summary of data file: ' + filepath - print '%d data points' % points.count() + print() + print('Summary of data file: ' + filepath) + print('%d data points' % points.count()) # Statistics (correlations) - print - print 'Correlation (%s) between label and each feature' % corrType - print 'Feature\tCorrelation' + print() + print('Correlation (%s) between label and each feature' % corrType) + print('Feature\tCorrelation') numFeatures = points.take(1)[0].features.size labelRDD = points.map(lambda lp: lp.label) for i in range(numFeatures): featureRDD = points.map(lambda lp: lp.features[i]) corr = Statistics.corr(labelRDD, featureRDD, corrType) - print '%d\t%g' % (i, corr) - print + print('%d\t%g' % (i, corr)) + print() sc.stop() diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py index 540dae785f6e..e23ecc0c5d30 100644 --- a/examples/src/main/python/mllib/dataset_example.py +++ b/examples/src/main/python/mllib/dataset_example.py @@ -16,9 +16,10 @@ # """ -An example of how to use SchemaRDD as a dataset for ML. Run with:: +An example of how to use DataFrame as a dataset for ML. Run with:: bin/spark-submit examples/src/main/python/mllib/dataset_example.py """ +from __future__ import print_function import os import sys @@ -32,31 +33,31 @@ def summarize(dataset): - print "schema: %s" % dataset.schema().json() + print("schema: %s" % dataset.schema().json()) labels = dataset.map(lambda r: r.label) - print "label average: %f" % labels.mean() + print("label average: %f" % labels.mean()) features = dataset.map(lambda r: r.features) summary = Statistics.colStats(features) - print "features average: %r" % summary.mean() + print("features average: %r" % summary.mean()) if __name__ == "__main__": if len(sys.argv) > 2: - print >> sys.stderr, "Usage: dataset_example.py " + print("Usage: dataset_example.py ", file=sys.stderr) exit(-1) sc = SparkContext(appName="DatasetExample") - sqlCtx = SQLContext(sc) + sqlContext = SQLContext(sc) if len(sys.argv) == 2: input = sys.argv[1] else: input = "data/mllib/sample_libsvm_data.txt" points = MLUtils.loadLibSVMFile(sc, input) - dataset0 = sqlCtx.inferSchema(points).setName("dataset0").cache() + dataset0 = sqlContext.inferSchema(points).setName("dataset0").cache() summarize(dataset0) tempdir = tempfile.NamedTemporaryFile(delete=False).name os.unlink(tempdir) - print "Save dataset as a Parquet file to %s." % tempdir + print("Save dataset as a Parquet file to %s." % tempdir) dataset0.saveAsParquetFile(tempdir) - print "Load it back and summarize it again." - dataset1 = sqlCtx.parquetFile(tempdir).setName("dataset1").cache() + print("Load it back and summarize it again.") + dataset1 = sqlContext.parquetFile(tempdir).setName("dataset1").cache() summarize(dataset1) shutil.rmtree(tempdir) diff --git a/examples/src/main/python/mllib/decision_tree_runner.py b/examples/src/main/python/mllib/decision_tree_runner.py index fccabd841b13..513ed8fd5145 100755 --- a/examples/src/main/python/mllib/decision_tree_runner.py +++ b/examples/src/main/python/mllib/decision_tree_runner.py @@ -20,6 +20,7 @@ This example requires NumPy (http://www.numpy.org/). """ +from __future__ import print_function import numpy import os @@ -83,18 +84,17 @@ def reindexClassLabels(data): numClasses = len(classCounts) # origToNewLabels: class --> index in 0,...,numClasses-1 if (numClasses < 2): - print >> sys.stderr, \ - "Dataset for classification should have at least 2 classes." + \ - " The given dataset had only %d classes." % numClasses + print("Dataset for classification should have at least 2 classes." + " The given dataset had only %d classes." % numClasses, file=sys.stderr) exit(1) origToNewLabels = dict([(sortedClasses[i], i) for i in range(0, numClasses)]) - print "numClasses = %d" % numClasses - print "Per-class example fractions, counts:" - print "Class\tFrac\tCount" + print("numClasses = %d" % numClasses) + print("Per-class example fractions, counts:") + print("Class\tFrac\tCount") for c in sortedClasses: frac = classCounts[c] / (numExamples + 0.0) - print "%g\t%g\t%d" % (c, frac, classCounts[c]) + print("%g\t%g\t%d" % (c, frac, classCounts[c])) if (sortedClasses[0] == 0 and sortedClasses[-1] == numClasses - 1): return (data, origToNewLabels) @@ -105,8 +105,7 @@ def reindexClassLabels(data): def usage(): - print >> sys.stderr, \ - "Usage: decision_tree_runner [libsvm format data filepath]" + print("Usage: decision_tree_runner [libsvm format data filepath]", file=sys.stderr) exit(1) @@ -133,13 +132,13 @@ def usage(): model = DecisionTree.trainClassifier(reindexedData, numClasses=numClasses, categoricalFeaturesInfo=categoricalFeaturesInfo) # Print learned tree and stats. - print "Trained DecisionTree for classification:" - print " Model numNodes: %d" % model.numNodes() - print " Model depth: %d" % model.depth() - print " Training accuracy: %g" % getAccuracy(model, reindexedData) + print("Trained DecisionTree for classification:") + print(" Model numNodes: %d" % model.numNodes()) + print(" Model depth: %d" % model.depth()) + print(" Training accuracy: %g" % getAccuracy(model, reindexedData)) if model.numNodes() < 20: - print model.toDebugString() + print(model.toDebugString()) else: - print model + print(model) sc.stop() diff --git a/examples/src/main/python/mllib/gaussian_mixture_model.py b/examples/src/main/python/mllib/gaussian_mixture_model.py new file mode 100644 index 000000000000..2cb8010cdc07 --- /dev/null +++ b/examples/src/main/python/mllib/gaussian_mixture_model.py @@ -0,0 +1,66 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A Gaussian Mixture Model clustering program using MLlib. +""" +from __future__ import print_function + +import random +import argparse +import numpy as np + +from pyspark import SparkConf, SparkContext +from pyspark.mllib.clustering import GaussianMixture + + +def parseVector(line): + return np.array([float(x) for x in line.split(' ')]) + + +if __name__ == "__main__": + """ + Parameters + ---------- + :param inputFile: Input file path which contains data points + :param k: Number of mixture components + :param convergenceTol: Convergence threshold. Default to 1e-3 + :param maxIterations: Number of EM iterations to perform. Default to 100 + :param seed: Random seed + """ + + parser = argparse.ArgumentParser() + parser.add_argument('inputFile', help='Input File') + parser.add_argument('k', type=int, help='Number of clusters') + parser.add_argument('--convergenceTol', default=1e-3, type=float, help='convergence threshold') + parser.add_argument('--maxIterations', default=100, type=int, help='Number of iterations') + parser.add_argument('--seed', default=random.getrandbits(19), + type=long, help='Random seed') + args = parser.parse_args() + + conf = SparkConf().setAppName("GMM") + sc = SparkContext(conf=conf) + + lines = sc.textFile(args.inputFile) + data = lines.map(parseVector) + model = GaussianMixture.train(data, args.k, args.convergenceTol, + args.maxIterations, args.seed) + for i in range(args.k): + print(("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu, + "sigma = ", model.gaussians[i].sigma.toArray())) + print(("Cluster labels (first 100): ", model.predict(data).take(100))) + sc.stop() diff --git a/examples/src/main/python/mllib/gradient_boosted_trees.py b/examples/src/main/python/mllib/gradient_boosted_trees.py new file mode 100644 index 000000000000..781bd61c9d2b --- /dev/null +++ b/examples/src/main/python/mllib/gradient_boosted_trees.py @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Gradient boosted Trees classification and regression using MLlib. +""" +from __future__ import print_function + +import sys + +from pyspark.context import SparkContext +from pyspark.mllib.tree import GradientBoostedTrees +from pyspark.mllib.util import MLUtils + + +def testClassification(trainingData, testData): + # Train a GradientBoostedTrees model. + # Empty categoricalFeaturesInfo indicates all features are continuous. + model = GradientBoostedTrees.trainClassifier(trainingData, categoricalFeaturesInfo={}, + numIterations=30, maxDepth=4) + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testErr = labelsAndPredictions.filter(lambda v_p: v_p[0] != v_p[1]).count() \ + / float(testData.count()) + print('Test Error = ' + str(testErr)) + print('Learned classification ensemble model:') + print(model.toDebugString()) + + +def testRegression(trainingData, testData): + # Train a GradientBoostedTrees model. + # Empty categoricalFeaturesInfo indicates all features are continuous. + model = GradientBoostedTrees.trainRegressor(trainingData, categoricalFeaturesInfo={}, + numIterations=30, maxDepth=4) + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testMSE = labelsAndPredictions.map(lambda vp: (vp[0] - vp[1]) * (vp[0] - vp[1])).sum() \ + / float(testData.count()) + print('Test Mean Squared Error = ' + str(testMSE)) + print('Learned regression ensemble model:') + print(model.toDebugString()) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + print("Usage: gradient_boosted_trees", file=sys.stderr) + exit(1) + sc = SparkContext(appName="PythonGradientBoostedTrees") + + # Load and parse the data file into an RDD of LabeledPoint. + data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + print('\nRunning example of classification using GradientBoostedTrees\n') + testClassification(trainingData, testData) + + print('\nRunning example of regression using GradientBoostedTrees\n') + testRegression(trainingData, testData) + + sc.stop() diff --git a/examples/src/main/python/mllib/kmeans.py b/examples/src/main/python/mllib/kmeans.py index 2eeb1abeeb12..f901a87fa63a 100755 --- a/examples/src/main/python/mllib/kmeans.py +++ b/examples/src/main/python/mllib/kmeans.py @@ -20,6 +20,7 @@ This example requires NumPy (http://www.numpy.org/). """ +from __future__ import print_function import sys @@ -34,12 +35,12 @@ def parseVector(line): if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: kmeans " + print("Usage: kmeans ", file=sys.stderr) exit(-1) sc = SparkContext(appName="KMeans") lines = sc.textFile(sys.argv[1]) data = lines.map(parseVector) k = int(sys.argv[2]) model = KMeans.train(data, k) - print "Final centers: " + str(model.clusterCenters) + print("Final centers: " + str(model.clusterCenters)) sc.stop() diff --git a/examples/src/main/python/mllib/logistic_regression.py b/examples/src/main/python/mllib/logistic_regression.py index 8cae27fc4a52..d4f1d34e2d8c 100755 --- a/examples/src/main/python/mllib/logistic_regression.py +++ b/examples/src/main/python/mllib/logistic_regression.py @@ -20,11 +20,10 @@ This example requires NumPy (http://www.numpy.org/). """ +from __future__ import print_function -from math import exp import sys -import numpy as np from pyspark import SparkContext from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.classification import LogisticRegressionWithSGD @@ -42,12 +41,12 @@ def parsePoint(line): if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: logistic_regression " + print("Usage: logistic_regression ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonLR") points = sc.textFile(sys.argv[1]).map(parsePoint) iterations = int(sys.argv[2]) model = LogisticRegressionWithSGD.train(points, iterations) - print "Final weights: " + str(model.weights) - print "Final intercept: " + str(model.intercept) + print("Final weights: " + str(model.weights)) + print("Final intercept: " + str(model.intercept)) sc.stop() diff --git a/examples/src/main/python/mllib/random_forest_example.py b/examples/src/main/python/mllib/random_forest_example.py index d3c24f766432..4cfdad868c66 100755 --- a/examples/src/main/python/mllib/random_forest_example.py +++ b/examples/src/main/python/mllib/random_forest_example.py @@ -22,6 +22,7 @@ For information on multiclass classification, please refer to the decision_tree_runner.py example. """ +from __future__ import print_function import sys @@ -43,7 +44,7 @@ def testClassification(trainingData, testData): # Evaluate model on test instances and compute test error predictions = model.predict(testData.map(lambda x: x.features)) labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count()\ + testErr = labelsAndPredictions.filter(lambda v_p: v_p[0] != v_p[1]).count()\ / float(testData.count()) print('Test Error = ' + str(testErr)) print('Learned classification forest model:') @@ -62,8 +63,8 @@ def testRegression(trainingData, testData): # Evaluate model on test instances and compute test error predictions = model.predict(testData.map(lambda x: x.features)) labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum()\ - / float(testData.count()) + testMSE = labelsAndPredictions.map(lambda v_p1: (v_p1[0] - v_p1[1]) * (v_p1[0] - v_p1[1]))\ + .sum() / float(testData.count()) print('Test Mean Squared Error = ' + str(testMSE)) print('Learned regression forest model:') print(model.toDebugString()) @@ -71,7 +72,7 @@ def testRegression(trainingData, testData): if __name__ == "__main__": if len(sys.argv) > 1: - print >> sys.stderr, "Usage: random_forest_example" + print("Usage: random_forest_example", file=sys.stderr) exit(1) sc = SparkContext(appName="PythonRandomForestExample") diff --git a/examples/src/main/python/mllib/random_rdd_generation.py b/examples/src/main/python/mllib/random_rdd_generation.py index 1e8892741e71..729bae30b152 100755 --- a/examples/src/main/python/mllib/random_rdd_generation.py +++ b/examples/src/main/python/mllib/random_rdd_generation.py @@ -18,6 +18,7 @@ """ Randomly generated RDDs. """ +from __future__ import print_function import sys @@ -27,7 +28,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: - print >> sys.stderr, "Usage: random_rdd_generation" + print("Usage: random_rdd_generation", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonRandomRDDGeneration") @@ -37,19 +38,19 @@ # Example: RandomRDDs.normalRDD normalRDD = RandomRDDs.normalRDD(sc, numExamples) - print 'Generated RDD of %d examples sampled from the standard normal distribution'\ - % normalRDD.count() - print ' First 5 samples:' + print('Generated RDD of %d examples sampled from the standard normal distribution' + % normalRDD.count()) + print(' First 5 samples:') for sample in normalRDD.take(5): - print ' ' + str(sample) - print + print(' ' + str(sample)) + print() # Example: RandomRDDs.normalVectorRDD normalVectorRDD = RandomRDDs.normalVectorRDD(sc, numRows=numExamples, numCols=2) - print 'Generated RDD of %d examples of length-2 vectors.' % normalVectorRDD.count() - print ' First 5 samples:' + print('Generated RDD of %d examples of length-2 vectors.' % normalVectorRDD.count()) + print(' First 5 samples:') for sample in normalVectorRDD.take(5): - print ' ' + str(sample) - print + print(' ' + str(sample)) + print() sc.stop() diff --git a/examples/src/main/python/mllib/sampled_rdds.py b/examples/src/main/python/mllib/sampled_rdds.py index 92af3af5ebd1..b7033ab7daeb 100755 --- a/examples/src/main/python/mllib/sampled_rdds.py +++ b/examples/src/main/python/mllib/sampled_rdds.py @@ -18,6 +18,7 @@ """ Randomly sampled RDDs. """ +from __future__ import print_function import sys @@ -27,7 +28,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: - print >> sys.stderr, "Usage: sampled_rdds " + print("Usage: sampled_rdds ", file=sys.stderr) exit(-1) if len(sys.argv) == 2: datapath = sys.argv[1] @@ -41,24 +42,24 @@ examples = MLUtils.loadLibSVMFile(sc, datapath) numExamples = examples.count() if numExamples == 0: - print >> sys.stderr, "Error: Data file had no samples to load." + print("Error: Data file had no samples to load.", file=sys.stderr) exit(1) - print 'Loaded data with %d examples from file: %s' % (numExamples, datapath) + print('Loaded data with %d examples from file: %s' % (numExamples, datapath)) # Example: RDD.sample() and RDD.takeSample() expectedSampleSize = int(numExamples * fraction) - print 'Sampling RDD using fraction %g. Expected sample size = %d.' \ - % (fraction, expectedSampleSize) + print('Sampling RDD using fraction %g. Expected sample size = %d.' + % (fraction, expectedSampleSize)) sampledRDD = examples.sample(withReplacement=True, fraction=fraction) - print ' RDD.sample(): sample has %d examples' % sampledRDD.count() + print(' RDD.sample(): sample has %d examples' % sampledRDD.count()) sampledArray = examples.takeSample(withReplacement=True, num=expectedSampleSize) - print ' RDD.takeSample(): sample has %d examples' % len(sampledArray) + print(' RDD.takeSample(): sample has %d examples' % len(sampledArray)) - print + print() # Example: RDD.sampleByKey() keyedRDD = examples.map(lambda lp: (int(lp.label), lp.features)) - print ' Keyed data using label (Int) as key ==> Orig' + print(' Keyed data using label (Int) as key ==> Orig') # Count examples per label in original data. keyCountsA = keyedRDD.countByKey() @@ -69,18 +70,18 @@ sampledByKeyRDD = keyedRDD.sampleByKey(withReplacement=True, fractions=fractions) keyCountsB = sampledByKeyRDD.countByKey() sizeB = sum(keyCountsB.values()) - print ' Sampled %d examples using approximate stratified sampling (by label). ==> Sample' \ - % sizeB + print(' Sampled %d examples using approximate stratified sampling (by label). ==> Sample' + % sizeB) # Compare samples - print ' \tFractions of examples with key' - print 'Key\tOrig\tSample' + print(' \tFractions of examples with key') + print('Key\tOrig\tSample') for k in sorted(keyCountsA.keys()): fracA = keyCountsA[k] / float(numExamples) if sizeB != 0: fracB = keyCountsB.get(k, 0) / float(sizeB) else: fracB = 0 - print '%d\t%g\t%g' % (k, fracA, fracB) + print('%d\t%g\t%g' % (k, fracA, fracB)) sc.stop() diff --git a/examples/src/main/python/mllib/word2vec.py b/examples/src/main/python/mllib/word2vec.py index 99fef4276a36..40d1b887927e 100644 --- a/examples/src/main/python/mllib/word2vec.py +++ b/examples/src/main/python/mllib/word2vec.py @@ -23,6 +23,7 @@ # grep -o -E '\w+(\W+\w+){0,15}' text8 > text8_lines # This was done so that the example can be run in local mode +from __future__ import print_function import sys @@ -34,7 +35,7 @@ if __name__ == "__main__": if len(sys.argv) < 2: - print USAGE + print(USAGE) sys.exit("Argument for file not provided") file_path = sys.argv[1] sc = SparkContext(appName='Word2Vec') @@ -46,5 +47,5 @@ synonyms = model.findSynonyms('china', 40) for word, cosine_distance in synonyms: - print "{}: {}".format(word, cosine_distance) + print("{}: {}".format(word, cosine_distance)) sc.stop() diff --git a/examples/src/main/python/pagerank.py b/examples/src/main/python/pagerank.py index a5f25d78c114..2fdc9773d4eb 100755 --- a/examples/src/main/python/pagerank.py +++ b/examples/src/main/python/pagerank.py @@ -19,6 +19,7 @@ This is an example implementation of PageRank. For more conventional use, Please refer to PageRank implementation provided by graphx """ +from __future__ import print_function import re import sys @@ -42,11 +43,12 @@ def parseNeighbors(urls): if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: pagerank " + print("Usage: pagerank ", file=sys.stderr) exit(-1) - print >> sys.stderr, """WARN: This is a naive implementation of PageRank and is - given as an example! Please refer to PageRank implementation provided by graphx""" + print("""WARN: This is a naive implementation of PageRank and is + given as an example! Please refer to PageRank implementation provided by graphx""", + file=sys.stderr) # Initialize the spark context. sc = SparkContext(appName="PythonPageRank") @@ -62,19 +64,19 @@ def parseNeighbors(urls): links = lines.map(lambda urls: parseNeighbors(urls)).distinct().groupByKey().cache() # Loads all URLs with other URL(s) link to from input file and initialize ranks of them to one. - ranks = links.map(lambda (url, neighbors): (url, 1.0)) + ranks = links.map(lambda url_neighbors: (url_neighbors[0], 1.0)) # Calculates and updates URL ranks continuously using PageRank algorithm. - for iteration in xrange(int(sys.argv[2])): + for iteration in range(int(sys.argv[2])): # Calculates URL contributions to the rank of other URLs. contribs = links.join(ranks).flatMap( - lambda (url, (urls, rank)): computeContribs(urls, rank)) + lambda url_urls_rank: computeContribs(url_urls_rank[1][0], url_urls_rank[1][1])) # Re-calculates URL ranks based on neighbor contributions. ranks = contribs.reduceByKey(add).mapValues(lambda rank: rank * 0.85 + 0.15) # Collects all URL ranks and dump them to console. for (link, rank) in ranks.collect(): - print "%s has rank: %s." % (link, rank) + print("%s has rank: %s." % (link, rank)) sc.stop() diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py index fa4c20ab2028..96ddac761d69 100644 --- a/examples/src/main/python/parquet_inputformat.py +++ b/examples/src/main/python/parquet_inputformat.py @@ -1,3 +1,4 @@ +from __future__ import print_function # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -35,14 +36,14 @@ """ if __name__ == "__main__": if len(sys.argv) != 2: - print >> sys.stderr, """ + print(""" Usage: parquet_inputformat.py Run with example jar: ./bin/spark-submit --driver-class-path /path/to/example/jar \\ /path/to/examples/parquet_inputformat.py Assumes you have Parquet data stored in . - """ + """, file=sys.stderr) exit(-1) path = sys.argv[1] @@ -56,6 +57,6 @@ valueConverter='org.apache.spark.examples.pythonconverters.IndexedRecordToJavaConverter') output = parquet_rdd.map(lambda x: x[1]).collect() for k in output: - print k + print(k) sc.stop() diff --git a/examples/src/main/python/pi.py b/examples/src/main/python/pi.py index a7c74e969cdb..92e5cf45abc8 100755 --- a/examples/src/main/python/pi.py +++ b/examples/src/main/python/pi.py @@ -1,3 +1,4 @@ +from __future__ import print_function # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -35,7 +36,7 @@ def f(_): y = random() * 2 - 1 return 1 if x ** 2 + y ** 2 < 1 else 0 - count = sc.parallelize(xrange(1, n + 1), partitions).map(f).reduce(add) - print "Pi is roughly %f" % (4.0 * count / n) + count = sc.parallelize(range(1, n + 1), partitions).map(f).reduce(add) + print("Pi is roughly %f" % (4.0 * count / n)) sc.stop() diff --git a/examples/src/main/python/sort.py b/examples/src/main/python/sort.py index bb686f17518a..f6b0ecb02c10 100755 --- a/examples/src/main/python/sort.py +++ b/examples/src/main/python/sort.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from pyspark import SparkContext @@ -22,7 +24,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: - print >> sys.stderr, "Usage: sort " + print("Usage: sort ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonSort") lines = sc.textFile(sys.argv[1], 1) @@ -33,6 +35,6 @@ # In reality, we wouldn't want to collect all the data to the driver node. output = sortedCount.collect() for (num, unitcount) in output: - print num + print(num) sc.stop() diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py index d2c5ca48c6cb..2c188759328f 100644 --- a/examples/src/main/python/sql.py +++ b/examples/src/main/python/sql.py @@ -15,11 +15,14 @@ # limitations under the License. # +from __future__ import print_function + import os +import sys from pyspark import SparkContext from pyspark.sql import SQLContext -from pyspark.sql import Row, StructField, StructType, StringType, IntegerType +from pyspark.sql.types import Row, StructField, StructType, StringType, IntegerType if __name__ == "__main__": @@ -30,26 +33,30 @@ some_rdd = sc.parallelize([Row(name="John", age=19), Row(name="Smith", age=23), Row(name="Sarah", age=18)]) - # Infer schema from the first row, create a SchemaRDD and print the schema - some_schemardd = sqlContext.inferSchema(some_rdd) - some_schemardd.printSchema() + # Infer schema from the first row, create a DataFrame and print the schema + some_df = sqlContext.createDataFrame(some_rdd) + some_df.printSchema() # Another RDD is created from a list of tuples another_rdd = sc.parallelize([("John", 19), ("Smith", 23), ("Sarah", 18)]) # Schema with two fields - person_name and person_age schema = StructType([StructField("person_name", StringType(), False), StructField("person_age", IntegerType(), False)]) - # Create a SchemaRDD by applying the schema to the RDD and print the schema - another_schemardd = sqlContext.applySchema(another_rdd, schema) - another_schemardd.printSchema() + # Create a DataFrame by applying the schema to the RDD and print the schema + another_df = sqlContext.createDataFrame(another_rdd, schema) + another_df.printSchema() # root # |-- age: integer (nullable = true) # |-- name: string (nullable = true) # A JSON dataset is pointed to by path. # The path can be either a single text file or a directory storing text files. - path = os.path.join(os.environ['SPARK_HOME'], "examples/src/main/resources/people.json") - # Create a SchemaRDD from the file(s) pointed to by path + if len(sys.argv) < 2: + path = "file://" + \ + os.path.join(os.environ['SPARK_HOME'], "examples/src/main/resources/people.json") + else: + path = sys.argv[1] + # Create a DataFrame from the file(s) pointed to by path people = sqlContext.jsonFile(path) # root # |-- person_name: string (nullable = false) @@ -61,13 +68,13 @@ # |-- age: IntegerType # |-- name: StringType - # Register this SchemaRDD as a table. + # Register this DataFrame as a table. people.registerAsTable("people") # SQL statements can be run by using the sql methods provided by sqlContext teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") for each in teenagers.collect(): - print each[0] + print(each[0]) sc.stop() diff --git a/examples/src/main/python/status_api_demo.py b/examples/src/main/python/status_api_demo.py new file mode 100644 index 000000000000..49b7902185aa --- /dev/null +++ b/examples/src/main/python/status_api_demo.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import time +import threading +import Queue + +from pyspark import SparkConf, SparkContext + + +def delayed(seconds): + def f(x): + time.sleep(seconds) + return x + return f + + +def call_in_background(f, *args): + result = Queue.Queue(1) + t = threading.Thread(target=lambda: result.put(f(*args))) + t.daemon = True + t.start() + return result + + +def main(): + conf = SparkConf().set("spark.ui.showConsoleProgress", "false") + sc = SparkContext(appName="PythonStatusAPIDemo", conf=conf) + + def run(): + rdd = sc.parallelize(range(10), 10).map(delayed(2)) + reduced = rdd.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y) + return reduced.map(delayed(2)).collect() + + result = call_in_background(run) + status = sc.statusTracker() + while result.empty(): + ids = status.getJobIdsForGroup() + for id in ids: + job = status.getJobInfo(id) + print("Job", id, "status: ", job.status) + for sid in job.stageIds: + info = status.getStageInfo(sid) + if info: + print("Stage %d: %d tasks total (%d active, %d complete)" % + (sid, info.numTasks, info.numActiveTasks, info.numCompletedTasks)) + time.sleep(1) + + print("Job results are:", result.get()) + sc.stop() + +if __name__ == "__main__": + main() diff --git a/examples/src/main/python/streaming/hdfs_wordcount.py b/examples/src/main/python/streaming/hdfs_wordcount.py index f7ffb5379681..f815dd26823d 100644 --- a/examples/src/main/python/streaming/hdfs_wordcount.py +++ b/examples/src/main/python/streaming/hdfs_wordcount.py @@ -25,6 +25,7 @@ Then create a text file in `localdir` and the words in the file will get counted. """ +from __future__ import print_function import sys @@ -33,7 +34,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: - print >> sys.stderr, "Usage: hdfs_wordcount.py " + print("Usage: hdfs_wordcount.py ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonStreamingHDFSWordCount") diff --git a/examples/src/main/python/streaming/kafka_wordcount.py b/examples/src/main/python/streaming/kafka_wordcount.py new file mode 100644 index 000000000000..b178e7899b5e --- /dev/null +++ b/examples/src/main/python/streaming/kafka_wordcount.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + Usage: kafka_wordcount.py + + To run this on your local machine, you need to setup Kafka and create a producer first, see + http://kafka.apache.org/documentation.html#quickstart + + and then run the example + `$ bin/spark-submit --jars external/kafka-assembly/target/scala-*/\ + spark-streaming-kafka-assembly-*.jar examples/src/main/python/streaming/kafka_wordcount.py \ + localhost:2181 test` +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.kafka import KafkaUtils + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: kafka_wordcount.py ", file=sys.stderr) + exit(-1) + + sc = SparkContext(appName="PythonStreamingKafkaWordCount") + ssc = StreamingContext(sc, 1) + + zkQuorum, topic = sys.argv[1:] + kvs = KafkaUtils.createStream(ssc, zkQuorum, "spark-streaming-consumer", {topic: 1}) + lines = kvs.map(lambda x: x[1]) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/network_wordcount.py b/examples/src/main/python/streaming/network_wordcount.py index cfa9c1ff5bfb..2b48bcfd55db 100644 --- a/examples/src/main/python/streaming/network_wordcount.py +++ b/examples/src/main/python/streaming/network_wordcount.py @@ -25,6 +25,7 @@ and then run the example `$ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localhost 9999` """ +from __future__ import print_function import sys @@ -33,7 +34,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: network_wordcount.py " + print("Usage: network_wordcount.py ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonStreamingNetworkWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/recoverable_network_wordcount.py b/examples/src/main/python/streaming/recoverable_network_wordcount.py index fc6827c82bf9..ac91f0a06b17 100644 --- a/examples/src/main/python/streaming/recoverable_network_wordcount.py +++ b/examples/src/main/python/streaming/recoverable_network_wordcount.py @@ -35,6 +35,7 @@ checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from the checkpoint data. """ +from __future__ import print_function import os import sys @@ -46,7 +47,7 @@ def createContext(host, port, outputPath): # If you do not see this printed, that means the StreamingContext has been loaded # from the new checkpoint - print "Creating new context" + print("Creating new context") if os.path.exists(outputPath): os.remove(outputPath) sc = SparkContext(appName="PythonStreamingRecoverableNetworkWordCount") @@ -60,8 +61,8 @@ def createContext(host, port, outputPath): def echo(time, rdd): counts = "Counts at time %s %s" % (time, rdd.collect()) - print counts - print "Appending to " + os.path.abspath(outputPath) + print(counts) + print("Appending to " + os.path.abspath(outputPath)) with open(outputPath, 'a') as f: f.write(counts + "\n") @@ -70,8 +71,8 @@ def echo(time, rdd): if __name__ == "__main__": if len(sys.argv) != 5: - print >> sys.stderr, "Usage: recoverable_network_wordcount.py "\ - " " + print("Usage: recoverable_network_wordcount.py " + " ", file=sys.stderr) exit(-1) host, port, checkpoint, output = sys.argv[1:] ssc = StreamingContext.getOrCreate(checkpoint, diff --git a/examples/src/main/python/streaming/sql_network_wordcount.py b/examples/src/main/python/streaming/sql_network_wordcount.py new file mode 100644 index 000000000000..da90c07dbd82 --- /dev/null +++ b/examples/src/main/python/streaming/sql_network_wordcount.py @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Use DataFrames and SQL to count words in UTF8 encoded, '\n' delimited text received from the + network every second. + + Usage: sql_network_wordcount.py + and describe the TCP server that Spark Streaming would connect to receive data. + + To run this on your local machine, you need to first run a Netcat server + `$ nc -lk 9999` + and then run the example + `$ bin/spark-submit examples/src/main/python/streaming/sql_network_wordcount.py localhost 9999` +""" +from __future__ import print_function + +import os +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.sql import SQLContext, Row + + +def getSqlContextInstance(sparkContext): + if ('sqlContextSingletonInstance' not in globals()): + globals()['sqlContextSingletonInstance'] = SQLContext(sparkContext) + return globals()['sqlContextSingletonInstance'] + + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: sql_network_wordcount.py ", file=sys.stderr) + exit(-1) + host, port = sys.argv[1:] + sc = SparkContext(appName="PythonSqlNetworkWordCount") + ssc = StreamingContext(sc, 1) + + # Create a socket stream on target ip:port and count the + # words in input stream of \n delimited text (eg. generated by 'nc') + lines = ssc.socketTextStream(host, int(port)) + words = lines.flatMap(lambda line: line.split(" ")) + + # Convert RDDs of the words DStream to DataFrame and run SQL query + def process(time, rdd): + print("========= %s =========" % str(time)) + + try: + # Get the singleton instance of SQLContext + sqlContext = getSqlContextInstance(rdd.context) + + # Convert RDD[String] to RDD[Row] to DataFrame + rowRdd = rdd.map(lambda w: Row(word=w)) + wordsDataFrame = sqlContext.createDataFrame(rowRdd) + + # Register as table + wordsDataFrame.registerTempTable("words") + + # Do word count on table using SQL and print it + wordCountsDataFrame = \ + sqlContext.sql("select word, count(*) as total from words group by word") + wordCountsDataFrame.show() + except: + pass + + words.foreachRDD(process) + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py index 18a9a5a452ff..16ef646b7c42 100644 --- a/examples/src/main/python/streaming/stateful_network_wordcount.py +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -29,6 +29,7 @@ `$ bin/spark-submit examples/src/main/python/streaming/stateful_network_wordcount.py \ localhost 9999` """ +from __future__ import print_function import sys @@ -37,7 +38,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: stateful_network_wordcount.py " + print("Usage: stateful_network_wordcount.py ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonStreamingStatefulNetworkWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/transitive_closure.py b/examples/src/main/python/transitive_closure.py index 00a281bfb650..7bf5fb6ddfe2 100755 --- a/examples/src/main/python/transitive_closure.py +++ b/examples/src/main/python/transitive_closure.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from random import Random @@ -49,20 +51,20 @@ def generateGraph(): # the graph to obtain the path (x, z). # Because join() joins on keys, the edges are stored in reversed order. - edges = tc.map(lambda (x, y): (y, x)) + edges = tc.map(lambda x_y: (x_y[1], x_y[0])) - oldCount = 0L + oldCount = 0 nextCount = tc.count() while True: oldCount = nextCount # Perform the join, obtaining an RDD of (y, (z, x)) pairs, # then project the result to obtain the new (x, z) paths. - new_edges = tc.join(edges).map(lambda (_, (a, b)): (b, a)) + new_edges = tc.join(edges).map(lambda __a_b: (__a_b[1][1], __a_b[1][0])) tc = tc.union(new_edges).distinct().cache() nextCount = tc.count() if nextCount == oldCount: break - print "TC has %i edges" % tc.count() + print("TC has %i edges" % tc.count()) sc.stop() diff --git a/examples/src/main/python/wordcount.py b/examples/src/main/python/wordcount.py index ae6cd13b83d9..7c0143607b61 100755 --- a/examples/src/main/python/wordcount.py +++ b/examples/src/main/python/wordcount.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from operator import add @@ -23,7 +25,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: - print >> sys.stderr, "Usage: wordcount " + print("Usage: wordcount ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonWordCount") lines = sc.textFile(sys.argv[1], 1) @@ -32,6 +34,6 @@ .reduceByKey(add) output = counts.collect() for (word, count) in output: - print "%s: %i" % (word, count) + print("%s: %i" % (word, count)) sc.stop() diff --git a/examples/src/main/r/kmeans.R b/examples/src/main/r/kmeans.R new file mode 100644 index 000000000000..6e6b5cb93789 --- /dev/null +++ b/examples/src/main/r/kmeans.R @@ -0,0 +1,93 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(SparkR) + +# Logistic regression in Spark. +# Note: unlike the example in Scala, a point here is represented as a vector of +# doubles. + +parseVectors <- function(lines) { + lines <- strsplit(as.character(lines) , " ", fixed = TRUE) + list(matrix(as.numeric(unlist(lines)), ncol = length(lines[[1]]))) +} + +dist.fun <- function(P, C) { + apply( + C, + 1, + function(x) { + colSums((t(P) - x)^2) + } + ) +} + +closestPoint <- function(P, C) { + max.col(-dist.fun(P, C)) +} +# Main program + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 3) { + print("Usage: kmeans ") + q("no") +} + +sc <- sparkR.init(appName = "RKMeans") +K <- as.integer(args[[2]]) +convergeDist <- as.double(args[[3]]) + +lines <- textFile(sc, args[[1]]) +points <- cache(lapplyPartition(lines, parseVectors)) +# kPoints <- take(points, K) +kPoints <- do.call(rbind, takeSample(points, FALSE, K, 16189L)) +tempDist <- 1.0 + +while (tempDist > convergeDist) { + closest <- lapplyPartition( + lapply(points, + function(p) { + cp <- closestPoint(p, kPoints); + mapply(list, unique(cp), split.data.frame(cbind(1, p), cp), SIMPLIFY=FALSE) + }), + function(x) {do.call(c, x) + }) + + pointStats <- reduceByKey(closest, + function(p1, p2) { + t(colSums(rbind(p1, p2))) + }, + 2L) + + newPoints <- do.call( + rbind, + collect(lapply(pointStats, + function(tup) { + point.sum <- tup[[2]][, -1] + point.count <- tup[[2]][, 1] + point.sum/point.count + }))) + + D <- dist.fun(kPoints, newPoints) + tempDist <- sum(D[cbind(1:3, max.col(-D))]) + kPoints <- newPoints + cat("Finished iteration (delta = ", tempDist, ")\n") +} + +cat("Final centers:\n") +writeLines(unlist(lapply(kPoints, paste, collapse = " "))) diff --git a/examples/src/main/r/linear_solver_mnist.R b/examples/src/main/r/linear_solver_mnist.R new file mode 100644 index 000000000000..c864a4232d01 --- /dev/null +++ b/examples/src/main/r/linear_solver_mnist.R @@ -0,0 +1,107 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Instructions: https://github.com/amplab-extras/SparkR-pkg/wiki/SparkR-Example:-Digit-Recognition-on-EC2 + +library(SparkR) +library(Matrix) + +args <- commandArgs(trailing = TRUE) + +# number of random features; default to 1100 +D <- ifelse(length(args) > 0, as.integer(args[[1]]), 1100) +# number of partitions for training dataset +trainParts <- 12 +# dimension of digits +d <- 784 +# number of test examples +NTrain <- 60000 +# number of training examples +NTest <- 10000 +# scale of features +gamma <- 4e-4 + +sc <- sparkR.init(appName = "SparkR-LinearSolver") + +# You can also use HDFS path to speed things up: +# hdfs:///train-mnist-dense-with-labels.data +file <- textFile(sc, "/data/train-mnist-dense-with-labels.data", trainParts) + +W <- gamma * matrix(nrow=D, ncol=d, data=rnorm(D*d)) +b <- 2 * pi * matrix(nrow=D, ncol=1, data=runif(D)) +broadcastW <- broadcast(sc, W) +broadcastB <- broadcast(sc, b) + +includePackage(sc, Matrix) +numericLines <- lapplyPartitionsWithIndex(file, + function(split, part) { + matList <- sapply(part, function(line) { + as.numeric(strsplit(line, ",", fixed=TRUE)[[1]]) + }, simplify=FALSE) + mat <- Matrix(ncol=d+1, data=unlist(matList, F, F), + sparse=T, byrow=T) + mat + }) + +featureLabels <- cache(lapplyPartition( + numericLines, + function(part) { + label <- part[,1] + mat <- part[,-1] + ones <- rep(1, nrow(mat)) + features <- cos( + mat %*% t(value(broadcastW)) + (matrix(ncol=1, data=ones) %*% t(value(broadcastB)))) + onesMat <- Matrix(ones) + featuresPlus <- cBind(features, onesMat) + labels <- matrix(nrow=nrow(mat), ncol=10, data=-1) + for (i in 1:nrow(mat)) { + labels[i, label[i]] <- 1 + } + list(label=labels, features=featuresPlus) + })) + +FTF <- Reduce("+", collect(lapplyPartition(featureLabels, + function(part) { + t(part$features) %*% part$features + }), flatten=F)) + +FTY <- Reduce("+", collect(lapplyPartition(featureLabels, + function(part) { + t(part$features) %*% part$label + }), flatten=F)) + +# solve for the coefficient matrix +C <- solve(FTF, FTY) + +test <- Matrix(as.matrix(read.csv("/data/test-mnist-dense-with-labels.data", + header=F), sparse=T)) +testData <- test[,-1] +testLabels <- matrix(ncol=1, test[,1]) + +err <- 0 + +# contstruct the feature maps for all examples from this digit +featuresTest <- cos(testData %*% t(value(broadcastW)) + + (matrix(ncol=1, data=rep(1, NTest)) %*% t(value(broadcastB)))) +featuresTest <- cBind(featuresTest, Matrix(rep(1, NTest))) + +# extract the one vs. all assignment +results <- featuresTest %*% C +labelsGot <- apply(results, 1, which.max) +err <- sum(testLabels != labelsGot) / nrow(testLabels) + +cat("\nFinished running. The error rate is: ", err, ".\n") diff --git a/examples/src/main/r/logistic_regression.R b/examples/src/main/r/logistic_regression.R new file mode 100644 index 000000000000..2a86aa98160d --- /dev/null +++ b/examples/src/main/r/logistic_regression.R @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(SparkR) + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 3) { + print("Usage: logistic_regression ") + q("no") +} + +# Initialize Spark context +sc <- sparkR.init(appName = "LogisticRegressionR") +iterations <- as.integer(args[[2]]) +D <- as.integer(args[[3]]) + +readPartition <- function(part){ + part = strsplit(part, " ", fixed = T) + list(matrix(as.numeric(unlist(part)), ncol = length(part[[1]]))) +} + +# Read data points and convert each partition to a matrix +points <- cache(lapplyPartition(textFile(sc, args[[1]]), readPartition)) + +# Initialize w to a random value +w <- runif(n=D, min = -1, max = 1) +cat("Initial w: ", w, "\n") + +# Compute logistic regression gradient for a matrix of data points +gradient <- function(partition) { + partition = partition[[1]] + Y <- partition[, 1] # point labels (first column of input file) + X <- partition[, -1] # point coordinates + + # For each point (x, y), compute gradient function + dot <- X %*% w + logit <- 1 / (1 + exp(-Y * dot)) + grad <- t(X) %*% ((logit - 1) * Y) + list(grad) +} + +for (i in 1:iterations) { + cat("On iteration ", i, "\n") + w <- w - reduce(lapplyPartition(points, gradient), "+") +} + +cat("Final w: ", w, "\n") diff --git a/examples/src/main/r/pi.R b/examples/src/main/r/pi.R new file mode 100644 index 000000000000..aa7a833e147a --- /dev/null +++ b/examples/src/main/r/pi.R @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(SparkR) + +args <- commandArgs(trailing = TRUE) + +sc <- sparkR.init(appName = "PiR") + +slices <- ifelse(length(args) > 1, as.integer(args[[2]]), 2) + +n <- 100000 * slices + +piFunc <- function(elem) { + rands <- runif(n = 2, min = -1, max = 1) + val <- ifelse((rands[1]^2 + rands[2]^2) < 1, 1.0, 0.0) + val +} + + +piFuncVec <- function(elems) { + message(length(elems)) + rands1 <- runif(n = length(elems), min = -1, max = 1) + rands2 <- runif(n = length(elems), min = -1, max = 1) + val <- ifelse((rands1^2 + rands2^2) < 1, 1.0, 0.0) + sum(val) +} + +rdd <- parallelize(sc, 1:n, slices) +count <- reduce(lapplyPartition(rdd, piFuncVec), sum) +cat("Pi is roughly", 4.0 * count / n, "\n") +cat("Num elements in RDD ", count(rdd), "\n") diff --git a/examples/src/main/r/wordcount.R b/examples/src/main/r/wordcount.R new file mode 100644 index 000000000000..b734cb0ecf55 --- /dev/null +++ b/examples/src/main/r/wordcount.R @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(SparkR) + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 1) { + print("Usage: wordcount ") + q("no") +} + +# Initialize Spark context +sc <- sparkR.init(appName = "RwordCount") +lines <- textFile(sc, args[[1]]) + +words <- flatMap(lines, + function(line) { + strsplit(line, " ")[[1]] + }) +wordCount <- lapply(words, function(word) { list(word, 1L) }) + +counts <- reduceByKey(wordCount, "+", 2L) +output <- collect(counts) + +for (wordcount in output) { + cat(wordcount[[1]], ": ", wordcount[[2]], "\n") +} diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index 1b53f3edbe92..4c129dbe2d12 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -29,7 +29,7 @@ object BroadcastTest { val blockSize = if (args.length > 3) args(3) else "4096" val sparkConf = new SparkConf().setAppName("Broadcast Test") - .set("spark.broadcast.factory", s"org.apache.spark.broadcast.${bcName}BroaddcastFactory") + .set("spark.broadcast.factory", s"org.apache.spark.broadcast.${bcName}BroadcastFactory") .set("spark.broadcast.blockSize", blockSize) val sc = new SparkContext(sparkConf) diff --git a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala index 65251e93190f..e757283823fc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala @@ -19,6 +19,8 @@ package org.apache.spark.examples import scala.collection.JavaConversions._ +import org.apache.spark.util.Utils + /** Prints out environmental information, sleeps, and then exits. Made to * test driver submission in the standalone scheduler. */ object DriverSubmissionTest { @@ -30,7 +32,7 @@ object DriverSubmissionTest { val numSecondsToSleep = args(0).toInt val env = System.getenv() - val properties = System.getProperties() + val properties = Utils.getSystemProperties println("Environment variables containing SPARK_TEST:") env.filter{case (k, v) => k.contains("SPARK_TEST")}.foreach(println) diff --git a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala index 822673347bdc..f4684b42b5d4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala @@ -18,7 +18,7 @@ package org.apache.spark.examples import org.apache.hadoop.hbase.client.HBaseAdmin -import org.apache.hadoop.hbase.{HBaseConfiguration, HTableDescriptor} +import org.apache.hadoop.hbase.{HBaseConfiguration, HTableDescriptor, TableName} import org.apache.hadoop.hbase.mapreduce.TableInputFormat import org.apache.spark._ @@ -36,7 +36,7 @@ object HBaseTest { // Initialize hBase table if necessary val admin = new HBaseAdmin(conf) if (!admin.isTableAvailable(args(0))) { - val tableDesc = new HTableDescriptor(args(0)) + val tableDesc = new HTableDescriptor(TableName.valueOf(args(0))) admin.createTable(tableDesc) } diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala index 17624c20cff3..f73eac1e2b90 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala @@ -40,8 +40,8 @@ object LocalKMeans { val convergeDist = 0.001 val rand = new Random(42) - def generateData = { - def generatePoint(i: Int) = { + def generateData: Array[DenseVector[Double]] = { + def generatePoint(i: Int): DenseVector[Double] = { DenseVector.fill(D){rand.nextDouble * R} } Array.tabulate(N)(generatePoint) diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index 92a683ad57ea..a55e0dc8d36c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -37,8 +37,8 @@ object LocalLR { case class DataPoint(x: Vector[Double], y: Double) - def generateData = { - def generatePoint(i: Int) = { + def generateData: Array[DataPoint] = { + def generatePoint(i: Int): DataPoint = { val y = if(i % 2 == 0) -1 else 1 val x = DenseVector.fill(D){rand.nextGaussian + y * R} DataPoint(x, y) diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala index 74620ad007d8..32e02eab8b03 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala @@ -54,8 +54,8 @@ object LogQuery { // scalastyle:on /** Tracks the total query count and number of aggregate bytes for a particular group. */ class Stats(val count: Int, val numBytes: Int) extends Serializable { - def merge(other: Stats) = new Stats(count + other.count, numBytes + other.numBytes) - override def toString = "bytes=%s\tn=%s".format(numBytes, count) + def merge(other: Stats): Stats = new Stats(count + other.count, numBytes + other.numBytes) + override def toString: String = "bytes=%s\tn=%s".format(numBytes, count) } def extractKey(line: String): (String, String, String) = { diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index 257a7d29f922..8c01a6084462 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -42,8 +42,8 @@ object SparkLR { case class DataPoint(x: Vector[Double], y: Double) - def generateData = { - def generatePoint(i: Int) = { + def generateData: Array[DataPoint] = { + def generatePoint(i: Int): DataPoint = { val y = if(i % 2 == 0) -1 else 1 val x = DenseVector.fill(D){rand.nextGaussian + y * R} DataPoint(x, y) diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala index f7f83086df3d..772cd897f514 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -31,7 +31,7 @@ object SparkTC { val numVertices = 100 val rand = new Random(42) - def generateGraph = { + def generateGraph: Seq[(Int, Int)] = { val edges: mutable.Set[(Int, Int)] = mutable.Set.empty while (edges.size < numEdges) { val from = rand.nextInt(numVertices) diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala index e322d4ce5a74..ab6e63deb3c9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala @@ -90,7 +90,7 @@ class PRMessage() extends Message[String] with Serializable { } class CustomPartitioner(partitions: Int) extends Partitioner { - def numPartitions = partitions + def numPartitions: Int = partitions def getPartition(key: Any): Int = { val hash = key match { diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala index e809a65b7997..f6f8d9f90c27 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala @@ -17,11 +17,6 @@ package org.apache.spark.examples.graphx -import org.apache.spark.SparkContext._ -import org.apache.spark._ -import org.apache.spark.graphx._ - - /** * Uses GraphX to run PageRank on a LiveJournal social network graph. Download the dataset from * http://snap.stanford.edu/data/soc-LiveJournal1.html. @@ -31,13 +26,13 @@ object LiveJournalPageRank { if (args.length < 1) { System.err.println( "Usage: LiveJournalPageRank \n" + + " --numEPart=\n" + + " The number of partitions for the graph's edge RDD.\n" + " [--tol=]\n" + " The tolerance allowed at convergence (smaller => more accurate). Default is " + "0.001.\n" + " [--output=]\n" + " If specified, the file to write the ranks to.\n" + - " [--numEPart=]\n" + - " The number of partitions for the graph's edge RDD. Default is 4.\n" + " [--partStrategy=RandomVertexCut | EdgePartition1D | EdgePartition2D | " + "CanonicalRandomVertexCut]\n" + " The way edges are assigned to edge partitions. Default is RandomVertexCut.") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala index d8c7ef38ee46..6c0af20461d3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala @@ -18,12 +18,12 @@ package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.ml.feature.{HashingTF, Tokenizer} import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{Row, SQLContext} /** @@ -44,10 +44,10 @@ object CrossValidatorExample { val conf = new SparkConf().setAppName("CrossValidatorExample") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext._ + import sqlContext.implicits._ // Prepare training documents, which are labeled. - val training = sparkContext.parallelize(Seq( + val training = sc.parallelize(Seq( LabeledDocument(0L, "a b c d e spark", 1.0), LabeledDocument(1L, "b d", 0.0), LabeledDocument(2L, "spark f g h", 1.0), @@ -90,21 +90,21 @@ object CrossValidatorExample { crossval.setNumFolds(2) // Use 3+ in practice // Run cross-validation, and choose the best set of parameters. - val cvModel = crossval.fit(training) + val cvModel = crossval.fit(training.toDF()) // Prepare test documents, which are unlabeled. - val test = sparkContext.parallelize(Seq( + val test = sc.parallelize(Seq( Document(4L, "spark i j k"), Document(5L, "l m n"), Document(6L, "mapreduce spark"), Document(7L, "apache hadoop"))) // Make predictions on test documents. cvModel uses the best model found (lrModel). - cvModel.transform(test) - .select('id, 'text, 'score, 'prediction) + cvModel.transform(test.toDF()) + .select("id", "text", "probability", "prediction") .collect() - .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => - println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println(s"($id, $text) --> prob=$prob, prediction=$prediction") } sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala new file mode 100644 index 000000000000..2cd515c89d3d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -0,0 +1,336 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.ml.tree.DecisionTreeModel +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier} +import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer} +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.{SQLContext, DataFrame} + + +/** + * An example runner for decision trees. Run with + * {{{ + * ./bin/run-example ml.DecisionTreeExample [options] + * }}} + * Note that Decision Trees can take a large amount of memory. If the run-example command above + * fails, try running via spark-submit and specifying the amount of memory as at least 1g. + * For local mode, run + * {{{ + * ./bin/spark-submit --class org.apache.spark.examples.ml.DecisionTreeExample --driver-memory 1g + * [examples JAR path] [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object DecisionTreeExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + algo: String = "Classification", + maxDepth: Int = 5, + maxBins: Int = 32, + minInstancesPerNode: Int = 1, + minInfoGain: Double = 0.0, + numTrees: Int = 1, + featureSubsetStrategy: String = "auto", + fracTest: Double = 0.2, + cacheNodeIds: Boolean = false, + checkpointDir: Option[String] = None, + checkpointInterval: Int = 10) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("DecisionTreeExample") { + head("DecisionTreeExample: an example decision tree app.") + opt[String]("algo") + .text(s"algorithm (classification, regression), default: ${defaultParams.algo}") + .action((x, c) => c.copy(algo = x)) + opt[Int]("maxDepth") + .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") + .action((x, c) => c.copy(maxDepth = x)) + opt[Int]("maxBins") + .text(s"max number of bins, default: ${defaultParams.maxBins}") + .action((x, c) => c.copy(maxBins = x)) + opt[Int]("minInstancesPerNode") + .text(s"min number of instances required at child nodes to create the parent split," + + s" default: ${defaultParams.minInstancesPerNode}") + .action((x, c) => c.copy(minInstancesPerNode = x)) + opt[Double]("minInfoGain") + .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}") + .action((x, c) => c.copy(minInfoGain = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[Boolean]("cacheNodeIds") + .text(s"whether to use node Id cache during training, " + + s"default: ${defaultParams.cacheNodeIds}") + .action((x, c) => c.copy(cacheNodeIds = x)) + opt[String]("checkpointDir") + .text(s"checkpoint directory where intermediate node Id caches will be stored, " + + s"default: ${defaultParams.checkpointDir match { + case Some(strVal) => strVal + case None => "None" + }}") + .action((x, c) => c.copy(checkpointDir = Some(x))) + opt[Int]("checkpointInterval") + .text(s"how often to checkpoint the node Id cache, " + + s"default: ${defaultParams.checkpointInterval}") + .action((x, c) => c.copy(checkpointInterval = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest > 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + /** Load a dataset from the given path, using the given format */ + private[ml] def loadData( + sc: SparkContext, + path: String, + format: String, + expectedNumFeatures: Option[Int] = None): RDD[LabeledPoint] = { + format match { + case "dense" => MLUtils.loadLabeledPoints(sc, path) + case "libsvm" => expectedNumFeatures match { + case Some(numFeatures) => MLUtils.loadLibSVMFile(sc, path, numFeatures) + case None => MLUtils.loadLibSVMFile(sc, path) + } + case _ => throw new IllegalArgumentException(s"Bad data format: $format") + } + } + + /** + * Load training and test data from files. + * @param input Path to input dataset. + * @param dataFormat "libsvm" or "dense" + * @param testInput Path to test dataset. + * @param algo Classification or Regression + * @param fracTest Fraction of input data to hold out for testing. Ignored if testInput given. + * @return (training dataset, test dataset) + */ + private[ml] def loadDatasets( + sc: SparkContext, + input: String, + dataFormat: String, + testInput: String, + algo: String, + fracTest: Double): (DataFrame, DataFrame) = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Load training data + val origExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat) + + // Load or create test set + val splits: Array[RDD[LabeledPoint]] = if (testInput != "") { + // Load testInput. + val numFeatures = origExamples.take(1)(0).features.size + val origTestExamples: RDD[LabeledPoint] = + loadData(sc, testInput, dataFormat, Some(numFeatures)) + Array(origExamples, origTestExamples) + } else { + // Split input into training, test. + origExamples.randomSplit(Array(1.0 - fracTest, fracTest), seed = 12345) + } + + // For classification, convert labels to Strings since we will index them later with + // StringIndexer. + def labelsToStrings(data: DataFrame): DataFrame = { + algo.toLowerCase match { + case "classification" => + data.withColumn("labelString", data("label").cast(StringType)) + case "regression" => + data + case _ => + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + } + val dataframes = splits.map(_.toDF()).map(labelsToStrings).map(_.cache()) + + (dataframes(0), dataframes(1)) + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"DecisionTreeExample with $params") + val sc = new SparkContext(conf) + params.checkpointDir.foreach(sc.setCheckpointDir) + val algo = params.algo.toLowerCase + + println(s"DecisionTreeExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = + loadDatasets(sc, params.input, params.dataFormat, params.testInput, algo, params.fracTest) + + val numTraining = training.count() + val numTest = test.count() + val numFeatures = training.select("features").first().getAs[Vector](0).size + println("Loaded data:") + println(s" numTraining = $numTraining, numTest = $numTest") + println(s" numFeatures = $numFeatures") + + // Set up Pipeline + val stages = new mutable.ArrayBuffer[PipelineStage]() + // (1) For classification, re-index classes. + val labelColName = if (algo == "classification") "indexedLabel" else "label" + if (algo == "classification") { + val labelIndexer = new StringIndexer() + .setInputCol("labelString") + .setOutputCol(labelColName) + stages += labelIndexer + } + // (2) Identify categorical features using VectorIndexer. + // Features with more than maxCategories values will be treated as continuous. + val featuresIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(10) + stages += featuresIndexer + // (3) Learn DecisionTree + val dt = algo match { + case "classification" => + new DecisionTreeClassifier() + .setFeaturesCol("indexedFeatures") + .setLabelCol(labelColName) + .setMaxDepth(params.maxDepth) + .setMaxBins(params.maxBins) + .setMinInstancesPerNode(params.minInstancesPerNode) + .setMinInfoGain(params.minInfoGain) + .setCacheNodeIds(params.cacheNodeIds) + .setCheckpointInterval(params.checkpointInterval) + case "regression" => + new DecisionTreeRegressor() + .setFeaturesCol("indexedFeatures") + .setLabelCol(labelColName) + .setMaxDepth(params.maxDepth) + .setMaxBins(params.maxBins) + .setMinInstancesPerNode(params.minInstancesPerNode) + .setMinInfoGain(params.minInfoGain) + .setCacheNodeIds(params.cacheNodeIds) + .setCheckpointInterval(params.checkpointInterval) + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + stages += dt + val pipeline = new Pipeline().setStages(stages.toArray) + + // Fit the Pipeline + val startTime = System.nanoTime() + val pipelineModel = pipeline.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + // Get the trained Decision Tree from the fitted PipelineModel + val treeModel: DecisionTreeModel = algo match { + case "classification" => + pipelineModel.getModel[DecisionTreeClassificationModel]( + dt.asInstanceOf[DecisionTreeClassifier]) + case "regression" => + pipelineModel.getModel[DecisionTreeRegressionModel](dt.asInstanceOf[DecisionTreeRegressor]) + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + if (treeModel.numNodes < 20) { + println(treeModel.toDebugString) // Print full model. + } else { + println(treeModel) // Print model summary. + } + + // Predict on training + val trainingFullPredictions = pipelineModel.transform(training).cache() + val trainingPredictions = trainingFullPredictions.select("prediction") + .map(_.getDouble(0)) + val trainingLabels = trainingFullPredictions.select(labelColName).map(_.getDouble(0)) + // Predict on test data + val testFullPredictions = pipelineModel.transform(test).cache() + val testPredictions = testFullPredictions.select("prediction") + .map(_.getDouble(0)) + val testLabels = testFullPredictions.select(labelColName).map(_.getDouble(0)) + + // For classification, print number of classes for reference. + if (algo == "classification") { + val numClasses = + MetadataUtils.getNumClasses(trainingFullPredictions.schema(labelColName)) match { + case Some(n) => n + case None => throw new RuntimeException( + "DecisionTreeExample had unknown failure when indexing labels for classification.") + } + println(s"numClasses = $numClasses.") + } + + // Evaluate model on training, test data + algo match { + case "classification" => + val trainingAccuracy = + new MulticlassMetrics(trainingPredictions.zip(trainingLabels)).precision + println(s"Train accuracy = $trainingAccuracy") + val testAccuracy = + new MulticlassMetrics(testPredictions.zip(testLabels)).precision + println(s"Test accuracy = $testAccuracy") + case "regression" => + val trainingRMSE = + new RegressionMetrics(trainingPredictions.zip(trainingLabels)).rootMeanSquaredError + println(s"Training root mean squared error (RMSE) = $trainingRMSE") + val testRMSE = + new RegressionMetrics(testPredictions.zip(testLabels)).rootMeanSquaredError + println(s"Test root mean squared error (RMSE) = $testRMSE") + case _ => + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala new file mode 100644 index 000000000000..2245fa429fda --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.ml.classification.{Classifier, ClassifierParams, ClassificationModel} +import org.apache.spark.ml.param.{Params, IntParam, ParamMap} +import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + + +/** + * A simple example demonstrating how to write your own learning algorithm using Estimator, + * Transformer, and other abstractions. + * This mimics [[org.apache.spark.ml.classification.LogisticRegression]]. + * Run with + * {{{ + * bin/run-example ml.DeveloperApiExample + * }}} + */ +object DeveloperApiExample { + + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("DeveloperApiExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Prepare training data. + val training = sc.parallelize(Seq( + LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), + LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)))) + + // Create a LogisticRegression instance. This instance is an Estimator. + val lr = new MyLogisticRegression() + // Print out the parameters, documentation, and any default values. + println("MyLogisticRegression parameters:\n" + lr.explainParams() + "\n") + + // We may set parameters using setter methods. + lr.setMaxIter(10) + + // Learn a LogisticRegression model. This uses the parameters stored in lr. + val model = lr.fit(training.toDF()) + + // Prepare test data. + val test = sc.parallelize(Seq( + LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), + LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), + LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) + + // Make predictions on test data. + val sumPredictions: Double = model.transform(test.toDF()) + .select("features", "label", "prediction") + .collect() + .map { case Row(features: Vector, label: Double, prediction: Double) => + prediction + }.sum + assert(sumPredictions == 0.0, + "MyLogisticRegression predicted something other than 0, even though all weights are 0!") + + sc.stop() + } +} + +/** + * Example of defining a parameter trait for a user-defined type of [[Classifier]]. + * + * NOTE: This is private since it is an example. In practice, you may not want it to be private. + */ +private trait MyLogisticRegressionParams extends ClassifierParams { + + /** + * Param for max number of iterations + * + * NOTE: The usual way to add a parameter to a model or algorithm is to include: + * - val myParamName: ParamType + * - def getMyParamName + * - def setMyParamName + * Here, we have a trait to be mixed in with the Estimator and Model (MyLogisticRegression + * and MyLogisticRegressionModel). We place the setter (setMaxIter) method in the Estimator + * class since the maxIter parameter is only used during training (not in the Model). + */ + val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") + def getMaxIter: Int = getOrDefault(maxIter) +} + +/** + * Example of defining a type of [[Classifier]]. + * + * NOTE: This is private since it is an example. In practice, you may not want it to be private. + */ +private class MyLogisticRegression + extends Classifier[Vector, MyLogisticRegression, MyLogisticRegressionModel] + with MyLogisticRegressionParams { + + setMaxIter(100) // Initialize + + // The parameter setter is in this class since it should return type MyLogisticRegression. + def setMaxIter(value: Int): this.type = set(maxIter, value) + + // This method is used by fit() + override protected def train( + dataset: DataFrame, + paramMap: ParamMap): MyLogisticRegressionModel = { + // Extract columns from data using helper method. + val oldDataset = extractLabeledPoints(dataset, paramMap) + + // Do learning to estimate the weight vector. + val numFeatures = oldDataset.take(1)(0).features.size + val weights = Vectors.zeros(numFeatures) // Learning would happen here. + + // Create a model, and return it. + new MyLogisticRegressionModel(this, paramMap, weights) + } +} + +/** + * Example of defining a type of [[ClassificationModel]]. + * + * NOTE: This is private since it is an example. In practice, you may not want it to be private. + */ +private class MyLogisticRegressionModel( + override val parent: MyLogisticRegression, + override val fittingParamMap: ParamMap, + val weights: Vector) + extends ClassificationModel[Vector, MyLogisticRegressionModel] + with MyLogisticRegressionParams { + + // This uses the default implementation of transform(), which reads column "features" and outputs + // columns "prediction" and "rawPrediction." + + // This uses the default implementation of predict(), which chooses the label corresponding to + // the maximum value returned by [[predictRaw()]]. + + /** + * Raw prediction for each possible label. + * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives + * a measure of confidence in each possible label (where larger = more confident). + * This internal method is used to implement [[transform()]] and output [[rawPredictionCol]]. + * + * @return vector where element i is the raw prediction for label i. + * This raw prediction may be any real number, where a larger value indicates greater + * confidence for that label. + */ + override protected def predictRaw(features: Vector): Vector = { + val margin = BLAS.dot(features, weights) + // There are 2 classes (binary classification), so we return a length-2 vector, + // where index i corresponds to class i (i = 0, 1). + Vectors.dense(-margin, margin) + } + + /** Number of classes the label can take. 2 indicates binary classification. */ + override val numClasses: Int = 2 + + /** + * Create a copy of the model. + * The copy is shallow, except for the embedded paramMap, which gets a deep copy. + * + * This is used for the default implementation of [[transform()]]. + */ + override protected def copy(): MyLogisticRegressionModel = { + val m = new MyLogisticRegressionModel(parent, fittingParamMap, weights) + Params.inheritValues(extractParamMap(), this, m) + m + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala index cf62772b9265..25f21113bf62 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala @@ -93,8 +93,8 @@ object MovieLensALS { | bin/spark-submit --class org.apache.spark.examples.ml.MovieLensALS \ | examples/target/scala-*/spark-examples-*.jar \ | --rank 10 --maxIter 15 --regParam 0.1 \ - | --movies path/to/movielens/movies.dat \ - | --ratings path/to/movielens/ratings.dat + | --movies data/mllib/als/sample_movielens_movies.txt \ + | --ratings data/mllib/als/sample_movielens_ratings.txt """.stripMargin) } @@ -109,7 +109,7 @@ object MovieLensALS { val conf = new SparkConf().setAppName(s"MovieLensALS with $params") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext._ + import sqlContext.implicits._ val ratings = sc.textFile(params.ratings).map(Rating.parseRating).cache() @@ -137,13 +137,13 @@ object MovieLensALS { .setRegParam(params.regParam) .setNumBlocks(params.numBlocks) - val model = als.fit(training) + val model = als.fit(training.toDF()) - val predictions = model.transform(test).cache() + val predictions = model.transform(test.toDF()).cache() // Evaluate the model. // TODO: Create an evaluator to compute RMSE. - val mse = predictions.select('rating, 'prediction) + val mse = predictions.select("rating", "prediction").rdd .flatMap { case Row(rating: Float, prediction: Float) => val err = rating.toDouble - prediction val err2 = err * err @@ -157,17 +157,23 @@ object MovieLensALS { println(s"Test RMSE = $rmse.") // Inspect false positives. - predictions.registerTempTable("prediction") - sc.textFile(params.movies).map(Movie.parseMovie).registerTempTable("movie") - sqlContext.sql( - """ - |SELECT userId, prediction.movieId, title, rating, prediction - | FROM prediction JOIN movie ON prediction.movieId = movie.movieId - | WHERE rating <= 1 AND prediction >= 4 - | LIMIT 100 - """.stripMargin) - .collect() - .foreach(println) + // Note: We reference columns in 2 ways: + // (1) predictions("movieId") lets us specify the movieId column in the predictions + // DataFrame, rather than the movieId column in the movies DataFrame. + // (2) $"userId" specifies the userId column in the predictions DataFrame. + // We could also write predictions("userId") but do not have to since + // the movies DataFrame does not have a column "userId." + val movies = sc.textFile(params.movies).map(Movie.parseMovie).toDF() + val falsePositives = predictions.join(movies) + .where((predictions("movieId") === movies("movieId")) + && ($"rating" <= 1) && ($"prediction" >= 4)) + .select($"userId", predictions("movieId"), $"title", $"rating", $"prediction") + val numFalsePositives = falsePositives.count() + println(s"Found $numFalsePositives false positives") + if (numFalsePositives > 0) { + println(s"Example false positives:") + falsePositives.limit(100).collect().foreach(println) + } sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index e8a2adff929c..bf805149d0af 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -18,7 +18,6 @@ package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.param.ParamMap import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -38,12 +37,12 @@ object SimpleParamsExample { val conf = new SparkConf().setAppName("SimpleParamsExample") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext._ + import sqlContext.implicits._ // Prepare training data. - // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of Java Beans - // into SchemaRDDs, where it uses the bean metadata to infer the schema. - val training = sparkContext.parallelize(Seq( + // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes + // into DataFrames, where it uses the case class metadata to infer the schema. + val training = sc.parallelize(Seq( LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), @@ -59,7 +58,7 @@ object SimpleParamsExample { .setRegParam(0.01) // Learn a LogisticRegression model. This uses the parameters stored in lr. - val model1 = lr.fit(training) + val model1 = lr.fit(training.toDF()) // Since model1 is a Model (i.e., a Transformer produced by an Estimator), // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this @@ -73,29 +72,29 @@ object SimpleParamsExample { paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. // One can also combine ParamMaps. - val paramMap2 = ParamMap(lr.scoreCol -> "probability") // Change output column name + val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name val paramMapCombined = paramMap ++ paramMap2 // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. - val model2 = lr.fit(training, paramMapCombined) + val model2 = lr.fit(training.toDF(), paramMapCombined) println("Model 2 was fit using parameters: " + model2.fittingParamMap) - // Prepare test documents. - val test = sparkContext.parallelize(Seq( + // Prepare test data. + val test = sc.parallelize(Seq( LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) - // Make predictions on test documents using the Transformer.transform() method. + // Make predictions on test data using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. - // Note that model2.transform() outputs a 'probability' column instead of the usual 'score' - // column since we renamed the lr.scoreCol parameter previously. - model2.transform(test) - .select('features, 'label, 'probability, 'prediction) + // Note that model2.transform() outputs a 'myProbability' column instead of the usual + // 'probability' column since we renamed the lr.probabilityCol parameter previously. + model2.transform(test.toDF()) + .select("features", "label", "myProbability", "prediction") .collect() - .foreach { case Row(features: Vector, label: Double, prob: Double, prediction: Double) => - println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction) + .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) => + println(s"($features, $label) -> prob=$prob, prediction=$prediction") } sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala index b9a6ef0229de..6772efd2c581 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -20,10 +20,10 @@ package org.apache.spark.examples.ml import scala.beans.BeanInfo import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{Row, SQLContext} @BeanInfo @@ -45,10 +45,10 @@ object SimpleTextClassificationPipeline { val conf = new SparkConf().setAppName("SimpleTextClassificationPipeline") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext._ + import sqlContext.implicits._ // Prepare training documents, which are labeled. - val training = sparkContext.parallelize(Seq( + val training = sc.parallelize(Seq( LabeledDocument(0L, "a b c d e spark", 1.0), LabeledDocument(1L, "b d", 0.0), LabeledDocument(2L, "spark f g h", 1.0), @@ -69,21 +69,21 @@ object SimpleTextClassificationPipeline { .setStages(Array(tokenizer, hashingTF, lr)) // Fit the pipeline to training documents. - val model = pipeline.fit(training) + val model = pipeline.fit(training.toDF()) // Prepare test documents, which are unlabeled. - val test = sparkContext.parallelize(Seq( + val test = sc.parallelize(Seq( Document(4L, "spark i j k"), Document(5L, "l m n"), Document(6L, "mapreduce spark"), Document(7L, "apache hadoop"))) // Make predictions on test documents. - model.transform(test) - .select('id, 'text, 'score, 'prediction) + model.transform(test.toDF()) + .select("id", "text", "probability", "prediction") .collect() - .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => - println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println(s"($id, $text) --> prob=$prob, prediction=$prediction") } sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala index f8d83f4ec732..e943d6c889fa 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -28,10 +28,10 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext, SchemaRDD} +import org.apache.spark.sql.{Row, SQLContext, DataFrame} /** - * An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with + * An example of how to use [[org.apache.spark.sql.DataFrame]] as a Dataset for ML. Run with * {{{ * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options] * }}} @@ -47,7 +47,7 @@ object DatasetExample { val defaultParams = Params() val parser = new OptionParser[Params]("DatasetExample") { - head("Dataset: an example app using SchemaRDD as a Dataset for ML.") + head("Dataset: an example app using DataFrame as a Dataset for ML.") opt[String]("input") .text(s"input path to dataset") .action((x, c) => c.copy(input = x)) @@ -71,7 +71,7 @@ object DatasetExample { val conf = new SparkConf().setAppName(s"DatasetExample with $params") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext._ // for implicit conversions + import sqlContext.implicits._ // for implicit conversions // Load input data val origData: RDD[LabeledPoint] = params.dataFormat match { @@ -80,20 +80,20 @@ object DatasetExample { } println(s"Loaded ${origData.count()} instances from file: ${params.input}") - // Convert input data to SchemaRDD explicitly. - val schemaRDD: SchemaRDD = origData - println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}") - println(s"Converted to SchemaRDD with ${schemaRDD.count()} records") + // Convert input data to DataFrame explicitly. + val df: DataFrame = origData.toDF() + println(s"Inferred schema:\n${df.schema.prettyJson}") + println(s"Converted to DataFrame with ${df.count()} records") - // Select columns, using implicit conversion to SchemaRDD. - val labelsSchemaRDD: SchemaRDD = origData.select('label) - val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v } + // Select columns + val labelsDf: DataFrame = df.select("label") + val labels: RDD[Double] = labelsDf.map { case Row(v: Double) => v } val numLabels = labels.count() val meanLabel = labels.fold(0.0)(_ + _) / numLabels println(s"Selected label column with average value $meanLabel") - val featuresSchemaRDD: SchemaRDD = origData.select('features) - val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => v } + val featuresDf: DataFrame = df.select("features") + val features: RDD[Vector] = featuresDf.map { case Row(v: Vector) => v } val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), (sum1, sum2) => sum1.merge(sum2)) @@ -103,13 +103,13 @@ object DatasetExample { tmpDir.deleteOnExit() val outputDir = new File(tmpDir, "dataset").toString println(s"Saving to $outputDir as Parquet file.") - schemaRDD.saveAsParquetFile(outputDir) + df.saveAsParquetFile(outputDir) println(s"Loading Parquet file with UDT from $outputDir.") val newDataset = sqlContext.parquetFile(outputDir) println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") - val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v } + val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v } val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), (sum1, sum2) => sum1.merge(sum2)) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 205d80dd0268..262fd2c9611d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -272,6 +272,8 @@ object DecisionTreeRunner { case Variance => impurity.Variance } + params.checkpointDir.foreach(sc.setCheckpointDir) + val strategy = new Strategy( algo = params.algo, @@ -282,7 +284,6 @@ object DecisionTreeRunner { minInstancesPerNode = params.minInstancesPerNode, minInfoGain = params.minInfoGain, useNodeIdCache = params.useNodeIdCache, - checkpointDir = params.checkpointDir, checkpointInterval = params.checkpointInterval) if (params.numTrees == 1) { val startTime = System.nanoTime() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala similarity index 91% rename from examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala rename to examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala index de58be38c7bf..df76b45e5081 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala @@ -18,17 +18,17 @@ package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.mllib.clustering.GaussianMixtureEM +import org.apache.spark.mllib.clustering.GaussianMixture import org.apache.spark.mllib.linalg.Vectors /** * An example Gaussian Mixture Model EM app. Run with * {{{ - * ./bin/run-example org.apache.spark.examples.mllib.DenseGmmEM + * ./bin/run-example mllib.DenseGaussianMixture * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. */ -object DenseGmmEM { +object DenseGaussianMixture { def main(args: Array[String]): Unit = { if (args.length < 3) { println("usage: DenseGmmEM [maxIterations]") @@ -46,7 +46,7 @@ object DenseGmmEM { Vectors.dense(line.trim.split(' ').map(_.toDouble)) }.cache() - val clusters = new GaussianMixtureEM() + val clusters = new GaussianMixture() .setK(k) .setConvergenceTol(convergenceTol) .setMaxIterations(maxIterations) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala index 11e35598baf5..14cc5cbb679c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala @@ -56,7 +56,7 @@ object DenseKMeans { .text(s"number of clusters, required") .action((x, c) => c.copy(k = x)) opt[Int]("numIterations") - .text(s"number of iterations, default; ${defaultParams.numIterations}") + .text(s"number of iterations, default: ${defaultParams.numIterations}") .action((x, c) => c.copy(numIterations = x)) opt[String]("initMode") .text(s"initialization mode (${InitializationMode.values.mkString(",")}), " + diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala new file mode 100644 index 000000000000..13f24a1e5961 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scopt.OptionParser + +import org.apache.spark.mllib.fpm.FPGrowth +import org.apache.spark.{SparkConf, SparkContext} + +/** + * Example for mining frequent itemsets using FP-growth. + * Example usage: ./bin/run-example mllib.FPGrowthExample \ + * --minSupport 0.8 --numPartition 2 ./data/mllib/sample_fpgrowth.txt + */ +object FPGrowthExample { + + case class Params( + input: String = null, + minSupport: Double = 0.3, + numPartition: Int = -1) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("FPGrowthExample") { + head("FPGrowth: an example FP-growth app.") + opt[Double]("minSupport") + .text(s"minimal support level, default: ${defaultParams.minSupport}") + .action((x, c) => c.copy(minSupport = x)) + opt[Int]("numPartition") + .text(s"number of partition, default: ${defaultParams.numPartition}") + .action((x, c) => c.copy(numPartition = x)) + arg[String]("") + .text("input paths to input data set, whose file format is that each line " + + "contains a transaction with each item in String and separated by a space") + .required() + .action((x, c) => c.copy(input = x)) + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"FPGrowthExample with $params") + val sc = new SparkContext(conf) + val transactions = sc.textFile(params.input).map(_.split(" ")).cache() + + println(s"Number of transactions: ${transactions.count()}") + + val model = new FPGrowth() + .setMinSupport(params.minSupport) + .setNumPartitions(params.numPartition) + .run(transactions) + + println(s"Number of frequent itemsets: ${model.freqItemsets.count()}") + + model.freqItemsets.collect().foreach { itemset => + println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) + } + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala new file mode 100644 index 000000000000..08a93595a2e1 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import java.text.BreakIterator + +import scala.collection.mutable + +import scopt.OptionParser + +import org.apache.log4j.{Level, Logger} + +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.mllib.clustering.LDA +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.rdd.RDD + + +/** + * An example Latent Dirichlet Allocation (LDA) app. Run with + * {{{ + * ./bin/run-example mllib.LDAExample [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object LDAExample { + + private case class Params( + input: Seq[String] = Seq.empty, + k: Int = 20, + maxIterations: Int = 10, + docConcentration: Double = -1, + topicConcentration: Double = -1, + vocabSize: Int = 10000, + stopwordFile: String = "", + checkpointDir: Option[String] = None, + checkpointInterval: Int = 10) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("LDAExample") { + head("LDAExample: an example LDA app for plain text data.") + opt[Int]("k") + .text(s"number of topics. default: ${defaultParams.k}") + .action((x, c) => c.copy(k = x)) + opt[Int]("maxIterations") + .text(s"number of iterations of learning. default: ${defaultParams.maxIterations}") + .action((x, c) => c.copy(maxIterations = x)) + opt[Double]("docConcentration") + .text(s"amount of topic smoothing to use (> 1.0) (-1=auto)." + + s" default: ${defaultParams.docConcentration}") + .action((x, c) => c.copy(docConcentration = x)) + opt[Double]("topicConcentration") + .text(s"amount of term (word) smoothing to use (> 1.0) (-1=auto)." + + s" default: ${defaultParams.topicConcentration}") + .action((x, c) => c.copy(topicConcentration = x)) + opt[Int]("vocabSize") + .text(s"number of distinct word types to use, chosen by frequency. (-1=all)" + + s" default: ${defaultParams.vocabSize}") + .action((x, c) => c.copy(vocabSize = x)) + opt[String]("stopwordFile") + .text(s"filepath for a list of stopwords. Note: This must fit on a single machine." + + s" default: ${defaultParams.stopwordFile}") + .action((x, c) => c.copy(stopwordFile = x)) + opt[String]("checkpointDir") + .text(s"Directory for checkpointing intermediate results." + + s" Checkpointing helps with recovery and eliminates temporary shuffle files on disk." + + s" default: ${defaultParams.checkpointDir}") + .action((x, c) => c.copy(checkpointDir = Some(x))) + opt[Int]("checkpointInterval") + .text(s"Iterations between each checkpoint. Only used if checkpointDir is set." + + s" default: ${defaultParams.checkpointInterval}") + .action((x, c) => c.copy(checkpointInterval = x)) + arg[String]("...") + .text("input paths (directories) to plain text corpora." + + " Each text file line should hold 1 document.") + .unbounded() + .required() + .action((x, c) => c.copy(input = c.input :+ x)) + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + parser.showUsageAsError + sys.exit(1) + } + } + + private def run(params: Params) { + val conf = new SparkConf().setAppName(s"LDAExample with $params") + val sc = new SparkContext(conf) + + Logger.getRootLogger.setLevel(Level.WARN) + + // Load documents, and prepare them for LDA. + val preprocessStart = System.nanoTime() + val (corpus, vocabArray, actualNumTokens) = + preprocess(sc, params.input, params.vocabSize, params.stopwordFile) + corpus.cache() + val actualCorpusSize = corpus.count() + val actualVocabSize = vocabArray.size + val preprocessElapsed = (System.nanoTime() - preprocessStart) / 1e9 + + println() + println(s"Corpus summary:") + println(s"\t Training set size: $actualCorpusSize documents") + println(s"\t Vocabulary size: $actualVocabSize terms") + println(s"\t Training set size: $actualNumTokens tokens") + println(s"\t Preprocessing time: $preprocessElapsed sec") + println() + + // Run LDA. + val lda = new LDA() + lda.setK(params.k) + .setMaxIterations(params.maxIterations) + .setDocConcentration(params.docConcentration) + .setTopicConcentration(params.topicConcentration) + .setCheckpointInterval(params.checkpointInterval) + if (params.checkpointDir.nonEmpty) { + sc.setCheckpointDir(params.checkpointDir.get) + } + val startTime = System.nanoTime() + val ldaModel = lda.run(corpus) + val elapsed = (System.nanoTime() - startTime) / 1e9 + + println(s"Finished training LDA model. Summary:") + println(s"\t Training time: $elapsed sec") + val avgLogLikelihood = ldaModel.logLikelihood / actualCorpusSize.toDouble + println(s"\t Training data average log likelihood: $avgLogLikelihood") + println() + + // Print the topics, showing the top-weighted terms for each topic. + val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10) + val topics = topicIndices.map { case (terms, termWeights) => + terms.zip(termWeights).map { case (term, weight) => (vocabArray(term.toInt), weight) } + } + println(s"${params.k} topics:") + topics.zipWithIndex.foreach { case (topic, i) => + println(s"TOPIC $i") + topic.foreach { case (term, weight) => + println(s"$term\t$weight") + } + println() + } + sc.stop() + } + + /** + * Load documents, tokenize them, create vocabulary, and prepare documents as term count vectors. + * @return (corpus, vocabulary as array, total token count in corpus) + */ + private def preprocess( + sc: SparkContext, + paths: Seq[String], + vocabSize: Int, + stopwordFile: String): (RDD[(Long, Vector)], Array[String], Long) = { + + // Get dataset of document texts + // One document per line in each text file. If the input consists of many small files, + // this can result in a large number of small partitions, which can degrade performance. + // In this case, consider using coalesce() to create fewer, larger partitions. + val textRDD: RDD[String] = sc.textFile(paths.mkString(",")) + + // Split text into words + val tokenizer = new SimpleTokenizer(sc, stopwordFile) + val tokenized: RDD[(Long, IndexedSeq[String])] = textRDD.zipWithIndex().map { case (text, id) => + id -> tokenizer.getWords(text) + } + tokenized.cache() + + // Counts words: RDD[(word, wordCount)] + val wordCounts: RDD[(String, Long)] = tokenized + .flatMap { case (_, tokens) => tokens.map(_ -> 1L) } + .reduceByKey(_ + _) + wordCounts.cache() + val fullVocabSize = wordCounts.count() + // Select vocab + // (vocab: Map[word -> id], total tokens after selecting vocab) + val (vocab: Map[String, Int], selectedTokenCount: Long) = { + val tmpSortedWC: Array[(String, Long)] = if (vocabSize == -1 || fullVocabSize <= vocabSize) { + // Use all terms + wordCounts.collect().sortBy(-_._2) + } else { + // Sort terms to select vocab + wordCounts.sortBy(_._2, ascending = false).take(vocabSize) + } + (tmpSortedWC.map(_._1).zipWithIndex.toMap, tmpSortedWC.map(_._2).sum) + } + + val documents = tokenized.map { case (id, tokens) => + // Filter tokens by vocabulary, and create word count vector representation of document. + val wc = new mutable.HashMap[Int, Int]() + tokens.foreach { term => + if (vocab.contains(term)) { + val termIndex = vocab(term) + wc(termIndex) = wc.getOrElse(termIndex, 0) + 1 + } + } + val indices = wc.keys.toArray.sorted + val values = indices.map(i => wc(i).toDouble) + + val sb = Vectors.sparse(vocab.size, indices, values) + (id, sb) + } + + val vocabArray = new Array[String](vocab.size) + vocab.foreach { case (term, i) => vocabArray(i) = term } + + (documents, vocabArray, selectedTokenCount) + } +} + +/** + * Simple Tokenizer. + * + * TODO: Formalize the interface, and make this a public class in mllib.feature + */ +private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Serializable { + + private val stopwords: Set[String] = if (stopwordFile.isEmpty) { + Set.empty[String] + } else { + val stopwordText = sc.textFile(stopwordFile).collect() + stopwordText.flatMap(_.stripMargin.split("\\s+")).toSet + } + + // Matches sequences of Unicode letters + private val allWordRegex = "^(\\p{L}*)$".r + + // Ignore words shorter than this length. + private val minWordLength = 3 + + def getWords(text: String): IndexedSeq[String] = { + + val words = new mutable.ArrayBuffer[String]() + + // Use Java BreakIterator to tokenize text into words. + val wb = BreakIterator.getWordInstance + wb.setText(text) + + // current,end index start,end of each word + var current = wb.first() + var end = wb.next() + while (end != BreakIterator.DONE) { + // Convert to lowercase + val word: String = text.substring(current, end).toLowerCase + // Remove short words and strings that aren't only letters + word match { + case allWordRegex(w) if w.length >= minWordLength && !stopwords.contains(w) => + words += w + case _ => + } + + current = end + try { + end = wb.next() + } catch { + case e: Exception => + // Ignore remaining text in line. + // This is a known bug in BreakIterator (for some Java versions), + // which fails when it sees certain characters. + end = BreakIterator.DONE + } + } + words + } + +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 91a0a860d6c7..0bc36ea65e1a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -175,9 +175,12 @@ object MovieLensALS { } /** Compute RMSE (Root Mean Squared Error). */ - def computeRmse(model: MatrixFactorizationModel, data: RDD[Rating], implicitPrefs: Boolean) = { + def computeRmse(model: MatrixFactorizationModel, data: RDD[Rating], implicitPrefs: Boolean) + : Double = { - def mapPredictedRating(r: Double) = if (implicitPrefs) math.max(math.min(r, 1.0), 0.0) else r + def mapPredictedRating(r: Double): Double = { + if (implicitPrefs) math.max(math.min(r, 1.0), 0.0) else r + } val predictions: RDD[Rating] = model.predict(data.map(x => (x.user, x.product))) val predictionsAndRatings = predictions.map{ x => diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala new file mode 100644 index 000000000000..6d8b806569df --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.log4j.{Level, Logger} +import scopt.OptionParser + +import org.apache.spark.mllib.clustering.PowerIterationClustering +import org.apache.spark.rdd.RDD +import org.apache.spark.{SparkConf, SparkContext} + +/** + * An example Power Iteration Clustering http://www.icml2010.org/papers/387.pdf app. + * Takes an input of K concentric circles and the number of points in the innermost circle. + * The output should be K clusters - each cluster containing precisely the points associated + * with each of the input circles. + * + * Run with + * {{{ + * ./bin/run-example mllib.PowerIterationClusteringExample [options] + * + * Where options include: + * k: Number of circles/clusters + * n: Number of sampled points on innermost circle.. There are proportionally more points + * within the outer/larger circles + * maxIterations: Number of Power Iterations + * outerRadius: radius of the outermost of the concentric circles + * }}} + * + * Here is a sample run and output: + * + * ./bin/run-example mllib.PowerIterationClusteringExample -k 3 --n 30 --maxIterations 15 + * + * Cluster assignments: 1 -> [0,1,2,3,4],2 -> [5,6,7,8,9,10,11,12,13,14], + * 0 -> [15,16,17,18,19,20,21,22,23,24,25,26,27,28,29] + * + * + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object PowerIterationClusteringExample { + + case class Params( + input: String = null, + k: Int = 3, + numPoints: Int = 5, + maxIterations: Int = 10, + outerRadius: Double = 3.0 + ) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("PowerIterationClusteringExample") { + head("PowerIterationClusteringExample: an example PIC app using concentric circles.") + opt[Int]('k', "k") + .text(s"number of circles (/clusters), default: ${defaultParams.k}") + .action((x, c) => c.copy(k = x)) + opt[Int]('n', "n") + .text(s"number of points in smallest circle, default: ${defaultParams.numPoints}") + .action((x, c) => c.copy(numPoints = x)) + opt[Int]("maxIterations") + .text(s"number of iterations, default: ${defaultParams.maxIterations}") + .action((x, c) => c.copy(maxIterations = x)) + opt[Double]('r', "r") + .text(s"radius of outermost circle, default: ${defaultParams.outerRadius}") + .action((x, c) => c.copy(outerRadius = x)) + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf() + .setMaster("local") + .setAppName(s"PowerIterationClustering with $params") + val sc = new SparkContext(conf) + + Logger.getRootLogger.setLevel(Level.WARN) + + val circlesRdd = generateCirclesRdd(sc, params.k, params.numPoints, params.outerRadius) + val model = new PowerIterationClustering() + .setK(params.k) + .setMaxIterations(params.maxIterations) + .run(circlesRdd) + + val clusters = model.assignments.collect().groupBy(_.cluster).mapValues(_.map(_.id)) + val assignments = clusters.toList.sortBy { case (k, v) => v.length} + val assignmentsStr = assignments + .map { case (k, v) => + s"$k -> ${v.sorted.mkString("[", ",", "]")}" + }.mkString(",") + val sizesStr = assignments.map { + _._2.size + }.sorted.mkString("(", ",", ")") + println(s"Cluster assignments: $assignmentsStr\ncluster sizes: $sizesStr") + + sc.stop() + } + + def generateCircle(radius: Double, n: Int): Seq[(Double, Double)] = { + Seq.tabulate(n) { i => + val theta = 2.0 * math.Pi * i / n + (radius * math.cos(theta), radius * math.sin(theta)) + } + } + + def generateCirclesRdd(sc: SparkContext, + nCircles: Int = 3, + nPoints: Int = 30, + outerRadius: Double): RDD[(Long, Long, Double)] = { + + val radii = Array.tabulate(nCircles) { cx => outerRadius / (nCircles - cx)} + val groupSizes = Array.tabulate(nCircles) { cx => (cx + 1) * nPoints} + val points = (0 until nCircles).flatMap { cx => + generateCircle(radii(cx), groupSizes(cx)) + }.zipWithIndex + val rdd = sc.parallelize(points) + val distancesRdd = rdd.cartesian(rdd).flatMap { case (((x0, y0), i0), ((x1, y1), i1)) => + if (i0 < i1) { + Some((i0.toLong, i1.toLong, gaussianSimilarity((x0, y0), (x1, y1), 1.0))) + } else { + None + } + } + distancesRdd + } + + /** + * Gaussian Similarity: http://en.wikipedia.org/wiki/Radial_basis_function_kernel + */ + def gaussianSimilarity(p1: (Double, Double), p2: (Double, Double), sigma: Double): Double = { + val coeff = 1.0 / (math.sqrt(2.0 * math.Pi) * sigma) + val expCoeff = -1.0 / 2.0 * math.pow(sigma, 2.0) + val ssquares = (p1._1 - p2._1) * (p1._1 - p2._1) + (p1._2 - p2._2) * (p1._2 - p2._2) + coeff * math.exp(expCoeff * ssquares) + } +} + diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala index c5bd5b0b178d..1a95048bbfe2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala @@ -35,8 +35,7 @@ import org.apache.spark.streaming.{Seconds, StreamingContext} * * To run on your local machine using the two directories `trainingDir` and `testDir`, * with updates every 5 seconds, and 2 features per data point, call: - * $ bin/run-example \ - * org.apache.spark.examples.mllib.StreamingLinearRegression trainingDir testDir 5 2 + * $ bin/run-example mllib.StreamingLinearRegression trainingDir testDir 5 2 * * As you add text files to `trainingDir` the model will continuously update. * Anytime you add text files to `testDir`, you'll see predictions from the current model. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala new file mode 100644 index 000000000000..e1998099c2d7 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD +import org.apache.spark.SparkConf +import org.apache.spark.streaming.{Seconds, StreamingContext} + +/** + * Train a logistic regression model on one stream of data and make predictions + * on another stream, where the data streams arrive as text files + * into two different directories. + * + * The rows of the text files must be labeled data points in the form + * `(y,[x1,x2,x3,...,xn])` + * Where n is the number of features, y is a binary label, and + * n must be the same for train and test. + * + * Usage: StreamingLogisticRegression + * + * To run on your local machine using the two directories `trainingDir` and `testDir`, + * with updates every 5 seconds, and 2 features per data point, call: + * $ bin/run-example mllib.StreamingLogisticRegression trainingDir testDir 5 2 + * + * As you add text files to `trainingDir` the model will continuously update. + * Anytime you add text files to `testDir`, you'll see predictions from the current model. + * + */ +object StreamingLogisticRegression { + + def main(args: Array[String]) { + + if (args.length != 4) { + System.err.println( + "Usage: StreamingLogisticRegression ") + System.exit(1) + } + + val conf = new SparkConf().setMaster("local").setAppName("StreamingLogisticRegression") + val ssc = new StreamingContext(conf, Seconds(args(2).toLong)) + + val trainingData = ssc.textFileStream(args(0)).map(LabeledPoint.parse) + val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse) + + val model = new StreamingLogisticRegressionWithSGD() + .setInitialWeights(Vectors.zeros(args(3).toInt)) + + model.trainOn(trainingData) + model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() + + ssc.start() + ssc.awaitTermination() + + } + +} diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index 2e98b2dc30b8..6331d1c0060f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -19,6 +19,7 @@ package org.apache.spark.examples.sql import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.functions._ // One method for defining the schema of an RDD is to make a case class with the desired column // names and types. @@ -31,43 +32,43 @@ object RDDRelation { val sqlContext = new SQLContext(sc) // Importing the SQL context gives access to all the SQL functions and implicit conversions. - import sqlContext._ + import sqlContext.implicits._ - val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) + val df = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))).toDF() // Any RDD containing case classes can be registered as a table. The schema of the table is // automatically inferred using scala reflection. - rdd.registerTempTable("records") + df.registerTempTable("records") // Once tables have been registered, you can run SQL queries over them. println("Result of SELECT *:") - sql("SELECT * FROM records").collect().foreach(println) + sqlContext.sql("SELECT * FROM records").collect().foreach(println) // Aggregation queries are also supported. - val count = sql("SELECT COUNT(*) FROM records").collect().head.getLong(0) + val count = sqlContext.sql("SELECT COUNT(*) FROM records").collect().head.getLong(0) println(s"COUNT(*): $count") // The results of SQL queries are themselves RDDs and support all normal RDD functions. The // items in the RDD are of type Row, which allows you to access each column by ordinal. - val rddFromSql = sql("SELECT key, value FROM records WHERE key < 10") + val rddFromSql = sqlContext.sql("SELECT key, value FROM records WHERE key < 10") println("Result of RDD.map:") rddFromSql.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect().foreach(println) // Queries can also be written using a LINQ-like Scala DSL. - rdd.where('key === 1).orderBy('value.asc).select('key).collect().foreach(println) + df.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println) // Write out an RDD as a parquet file. - rdd.saveAsParquetFile("pair.parquet") + df.saveAsParquetFile("pair.parquet") // Read in parquet file. Parquet files are self-describing so the schmema is preserved. val parquetFile = sqlContext.parquetFile("pair.parquet") // Queries can be run using the DSL on parequet files just like the original RDD. - parquetFile.where('key === 1).select('value as 'a).collect().foreach(println) + parquetFile.where($"key" === 1).select($"value".as("a")).collect().foreach(println) // These files can also be registered as tables. parquetFile.registerTempTable("parquetFile") - sql("SELECT * FROM parquetFile").collect().foreach(println) + sqlContext.sql("SELECT * FROM parquetFile").collect().foreach(println) sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index 5725da184811..b7ba60ec2815 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -43,7 +43,8 @@ object HiveFromSpark { // HiveContext. When not configured by the hive-site.xml, the context automatically // creates metastore_db and warehouse in the current directory. val hiveContext = new HiveContext(sc) - import hiveContext._ + import hiveContext.implicits._ + import hiveContext.sql sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") sql(s"LOAD DATA LOCAL INPATH '${kv1File.getAbsolutePath}' INTO TABLE src") @@ -67,7 +68,7 @@ object HiveFromSpark { // You can also register RDDs as temporary tables within a HiveContext. val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) - rdd.registerTempTable("records") + rdd.toDF().registerTempTable("records") // Queries can then join RDD data with data stored in Hive. println("Result of SELECT *:") diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala index b433082dce1a..92867b44be13 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala @@ -85,13 +85,13 @@ extends Actor with ActorHelper { lazy private val remotePublisher = context.actorSelection(urlOfPublisher) - override def preStart = remotePublisher ! SubscribeReceiver(context.self) + override def preStart(): Unit = remotePublisher ! SubscribeReceiver(context.self) - def receive = { + def receive: PartialFunction[Any, Unit] = { case msg => store(msg.asInstanceOf[T]) } - override def postStop() = remotePublisher ! UnsubscribeReceiver(context.self) + override def postStop(): Unit = remotePublisher ! UnsubscribeReceiver(context.self) } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala index 6ff0c47793a2..85b9a54b40ba 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala @@ -17,8 +17,8 @@ package org.apache.spark.examples.streaming -import org.eclipse.paho.client.mqttv3.{MqttClient, MqttClientPersistence, MqttException, MqttMessage, MqttTopic} -import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence +import org.eclipse.paho.client.mqttv3._ +import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} @@ -31,8 +31,6 @@ import org.apache.spark.SparkConf */ object MQTTPublisher { - var client: MqttClient = _ - def main(args: Array[String]) { if (args.length < 2) { System.err.println("Usage: MQTTPublisher ") @@ -42,25 +40,36 @@ object MQTTPublisher { StreamingExamples.setStreamingLogLevels() val Seq(brokerUrl, topic) = args.toSeq + + var client: MqttClient = null try { - var peristance:MqttClientPersistence =new MqttDefaultFilePersistence("/tmp") - client = new MqttClient(brokerUrl, MqttClient.generateClientId(), peristance) + val persistence = new MemoryPersistence() + client = new MqttClient(brokerUrl, MqttClient.generateClientId(), persistence) + + client.connect() + + val msgtopic = client.getTopic(topic) + val msgContent = "hello mqtt demo for spark streaming" + val message = new MqttMessage(msgContent.getBytes("utf-8")) + + while (true) { + try { + msgtopic.publish(message) + println(s"Published data. topic: ${msgtopic.getName()}; Message: $message") + } catch { + case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => + Thread.sleep(10) + println("Queue is full, wait for to consume data from the message queue") + } + } } catch { case e: MqttException => println("Exception Caught: " + e) + } finally { + if (client != null) { + client.disconnect() + } } - - client.connect() - - val msgtopic: MqttTopic = client.getTopic(topic) - val msg: String = "hello mqtt demo for spark streaming" - - while (true) { - val message: MqttMessage = new MqttMessage(String.valueOf(msg).getBytes("utf-8")) - msgtopic.publish(message) - println("Published data. topic: " + msgtopic.getName() + " Message: " + message) - } - client.disconnect() } } @@ -96,9 +105,9 @@ object MQTTWordCount { val sparkConf = new SparkConf().setAppName("MQTTWordCount") val ssc = new StreamingContext(sparkConf, Seconds(2)) val lines = MQTTUtils.createStream(ssc, brokerUrl, topic, StorageLevel.MEMORY_ONLY_SER_2) - - val words = lines.flatMap(x => x.toString.split(" ")) + val words = lines.flatMap(x => x.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() ssc.start() ssc.awaitTermination() diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index c3a05c89d817..751b30ea1578 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -55,7 +55,8 @@ import org.apache.spark.util.IntParam */ object RecoverableNetworkWordCount { - def createContext(ip: String, port: Int, outputPath: String, checkpointDirectory: String) = { + def createContext(ip: String, port: Int, outputPath: String, checkpointDirectory: String) + : StreamingContext = { // If you do not see this printed, that means the StreamingContext has been loaded // from the new checkpoint diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala new file mode 100644 index 000000000000..5a6b9216a3fb --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming + +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Time, Seconds, StreamingContext} +import org.apache.spark.util.IntParam +import org.apache.spark.sql.SQLContext +import org.apache.spark.storage.StorageLevel + +/** + * Use DataFrames and SQL to count words in UTF8 encoded, '\n' delimited text received from the + * network every second. + * + * Usage: SqlNetworkWordCount + * and describe the TCP server that Spark Streaming would connect to receive data. + * + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ bin/run-example org.apache.spark.examples.streaming.SqlNetworkWordCount localhost 9999` + */ + +object SqlNetworkWordCount { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: NetworkWordCount ") + System.exit(1) + } + + StreamingExamples.setStreamingLogLevels() + + // Create the context with a 2 second batch size + val sparkConf = new SparkConf().setAppName("SqlNetworkWordCount") + val ssc = new StreamingContext(sparkConf, Seconds(2)) + + // Create a socket stream on target ip:port and count the + // words in input stream of \n delimited text (eg. generated by 'nc') + // Note that no duplication in storage level only for running locally. + // Replication necessary in distributed scenario for fault tolerance. + val lines = ssc.socketTextStream(args(0), args(1).toInt, StorageLevel.MEMORY_AND_DISK_SER) + val words = lines.flatMap(_.split(" ")) + + // Convert RDDs of the words DStream to DataFrame and run SQL query + words.foreachRDD((rdd: RDD[String], time: Time) => { + // Get the singleton instance of SQLContext + val sqlContext = SQLContextSingleton.getInstance(rdd.sparkContext) + import sqlContext.implicits._ + + // Convert RDD[String] to RDD[case class] to DataFrame + val wordsDataFrame = rdd.map(w => Record(w)).toDF() + + // Register as table + wordsDataFrame.registerTempTable("words") + + // Do word count on table using SQL and print it + val wordCountsDataFrame = + sqlContext.sql("select word, count(*) as total from words group by word") + println(s"========= $time =========") + wordCountsDataFrame.show() + }) + + ssc.start() + ssc.awaitTermination() + } +} + + +/** Case class for converting RDD to DataFrame */ +case class Record(word: String) + + +/** Lazily instantiated singleton instance of SQLContext */ +object SQLContextSingleton { + + @transient private var instance: SQLContext = _ + + def getInstance(sparkContext: SparkContext): SQLContext = { + if (instance == null) { + instance = new SQLContext(sparkContext) + } + instance + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala index 6510c70bd186..e99d1baa72b9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala @@ -35,7 +35,7 @@ import org.apache.spark.SparkConf */ object SimpleZeroMQPublisher { - def main(args: Array[String]) = { + def main(args: Array[String]): Unit = { if (args.length < 2) { System.err.println("Usage: SimpleZeroMQPublisher ") System.exit(1) @@ -45,7 +45,7 @@ object SimpleZeroMQPublisher { val acs: ActorSystem = ActorSystem() val pubSocket = ZeroMQExtension(acs).newSocket(SocketType.Pub, Bind(url)) - implicit def stringToByteString(x: String) = ByteString(x) + implicit def stringToByteString(x: String): ByteString = ByteString(x) val messages: List[ByteString] = List("words ", "may ", "count ") while (true) { Thread.sleep(1000) @@ -86,7 +86,7 @@ object ZeroMQWordCount { // Create the context and set the batch size val ssc = new StreamingContext(sparkConf, Seconds(2)) - def bytesToStringIterator(x: Seq[ByteString]) = (x.map(_.utf8String)).iterator + def bytesToStringIterator(x: Seq[ByteString]): Iterator[String] = x.map(_.utf8String).iterator // For this stream, a zeroMQ publisher should be running. val lines = ZeroMQUtils.createStream(ssc, url, Subscribe(topic), bytesToStringIterator _) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 8402491b6267..54d996b8ac99 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -94,7 +94,7 @@ object PageViewGenerator { while (true) { val socket = listener.accept() new Thread() { - override def run = { + override def run(): Unit = { println("Got client connected from: " + socket.getInetAddress) val out = new PrintWriter(socket.getOutputStream(), true) diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 0706f1ebf66e..67907bbfb6d1 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-sink/src/test/resources/log4j.properties b/external/flume-sink/src/test/resources/log4j.properties index 2a58e9981722..42df8792f147 100644 --- a/external/flume-sink/src/test/resources/log4j.properties +++ b/external/flume-sink/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 1f2681394c58..8df7edbdcad3 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 2de2a7926bfd..60e2994431b3 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -37,8 +37,7 @@ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.receiver.Receiver -import org.jboss.netty.channel.ChannelPipelineFactory -import org.jboss.netty.channel.Channels +import org.jboss.netty.channel.{ChannelPipeline, ChannelPipelineFactory, Channels} import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory import org.jboss.netty.handler.codec.compression._ @@ -187,8 +186,8 @@ class FlumeReceiver( logInfo("Flume receiver stopped") } - override def preferredLocation = Some(host) - + override def preferredLocation: Option[String] = Option(host) + /** A Netty Pipeline factory that will decompress incoming data from * and the Netty client and compress data going back to the client. * @@ -198,13 +197,12 @@ class FlumeReceiver( */ private[streaming] class CompressionChannelPipelineFactory extends ChannelPipelineFactory { - - def getPipeline() = { + def getPipeline(): ChannelPipeline = { val pipeline = Channels.pipeline() val encoder = new ZlibEncoder(6) pipeline.addFirst("deflater", encoder) pipeline.addFirst("inflater", new ZlibDecoder()) pipeline + } } } -} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index 4b732c1592ab..44dec45c227c 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -19,7 +19,6 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress -import org.apache.spark.annotation.Experimental import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} @@ -121,7 +120,6 @@ object FlumeUtils { * @param port Port of the host at which the Spark Sink is listening * @param storageLevel Storage level to use for storing the received objects */ - @Experimental def createPollingStream( ssc: StreamingContext, hostname: String, @@ -138,7 +136,6 @@ object FlumeUtils { * @param addresses List of InetSocketAddresses representing the hosts to connect to. * @param storageLevel Storage level to use for storing the received objects */ - @Experimental def createPollingStream( ssc: StreamingContext, addresses: Seq[InetSocketAddress], @@ -159,7 +156,6 @@ object FlumeUtils { * result in this stream using more threads * @param storageLevel Storage level to use for storing the received objects */ - @Experimental def createPollingStream( ssc: StreamingContext, addresses: Seq[InetSocketAddress], @@ -178,7 +174,6 @@ object FlumeUtils { * @param hostname Hostname of the host on which the Spark Sink is running * @param port Port of the host at which the Spark Sink is listening */ - @Experimental def createPollingStream( jssc: JavaStreamingContext, hostname: String, @@ -195,7 +190,6 @@ object FlumeUtils { * @param port Port of the host at which the Spark Sink is listening * @param storageLevel Storage level to use for storing the received objects */ - @Experimental def createPollingStream( jssc: JavaStreamingContext, hostname: String, @@ -212,7 +206,6 @@ object FlumeUtils { * @param addresses List of InetSocketAddresses on which the Spark Sink is running. * @param storageLevel Storage level to use for storing the received objects */ - @Experimental def createPollingStream( jssc: JavaStreamingContext, addresses: Array[InetSocketAddress], @@ -233,7 +226,6 @@ object FlumeUtils { * result in this stream using more threads * @param storageLevel Storage level to use for storing the received objects */ - @Experimental def createPollingStream( jssc: JavaStreamingContext, addresses: Array[InetSocketAddress], diff --git a/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 1e24da7f5f60..cfedb5a042a3 100644 --- a/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -31,7 +31,7 @@ public void setUp() { SparkConf conf = new SparkConf() .setMaster("local[2]") .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/external/flume/src/test/resources/log4j.properties b/external/flume/src/test/resources/log4j.properties index 9697237bfa1a..75e3b53a093f 100644 --- a/external/flume/src/test/resources/log4j.properties +++ b/external/flume/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index b57a1c71e35b..2edea9b5b69b 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -1,21 +1,20 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + package org.apache.spark.streaming.flume import java.net.InetSocketAddress @@ -34,10 +33,9 @@ import org.scalatest.{BeforeAndAfter, FunSuite} import org.apache.spark.{SparkConf, Logging} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.streaming.util.ManualClock import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext} import org.apache.spark.streaming.flume.sink._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{ManualClock, Utils} class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging { @@ -54,7 +52,7 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging def beforeFunction() { logInfo("Using manual clock") - conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + conf.set("spark.streaming.clock", "org.apache.spark.util.ManualClock") } before(beforeFunction()) @@ -214,7 +212,7 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging assert(counter === totalEventsPerChannel * channels.size) } - def assertChannelIsEmpty(channel: MemoryChannel) = { + def assertChannelIsEmpty(channel: MemoryChannel): Unit = { val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") queueRemaining.setAccessible(true) val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") @@ -236,7 +234,7 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging tx.commit() tx.close() Thread.sleep(500) // Allow some time for the events to reach - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) } null } diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index f333e3891b5f..39e6754c81db 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -19,15 +19,16 @@ package org.apache.spark.streaming.flume import java.net.{InetSocketAddress, ServerSocket} import java.nio.ByteBuffer -import java.nio.charset.Charset import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.duration._ import scala.language.postfixOps +import com.google.common.base.Charsets import org.apache.avro.ipc.NettyTransceiver import org.apache.avro.ipc.specific.SpecificRequestor +import org.apache.commons.lang3.RandomUtils import org.apache.flume.source.avro import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol} import org.jboss.netty.channel.ChannelPipeline @@ -40,7 +41,6 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} -import org.apache.spark.streaming.scheduler.{StreamingListener, StreamingListenerReceiverStarted} import org.apache.spark.util.Utils class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { @@ -76,7 +76,8 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L /** Find a free port */ private def findFreePort(): Int = { - Utils.startServiceOnPort(23456, (trialPort: Int) => { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { val socket = new ServerSocket(trialPort) socket.close() (null, trialPort) @@ -108,7 +109,7 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L val inputEvents = input.map { item => val event = new AvroFlumeEvent - event.setBody(ByteBuffer.wrap(item.getBytes("UTF-8"))) + event.setBody(ByteBuffer.wrap(item.getBytes(Charsets.UTF_8))) event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) event } @@ -138,20 +139,21 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L status should be (avro.Status.OK) } - val decoder = Charset.forName("UTF-8").newDecoder() eventually(timeout(10 seconds), interval(100 milliseconds)) { val outputEvents = outputBuffer.flatten.map { _.event } outputEvents.foreach { event => event.getHeaders.get("test") should be("header") } - val output = outputEvents.map(event => decoder.decode(event.getBody()).toString) + val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8)) output should be (input) } } /** Class to create socket channel with compression */ - private class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { + private class CompressionChannelFactory(compressionLevel: Int) + extends NioClientSocketChannelFactory { + override def newChannel(pipeline: ChannelPipeline): SocketChannel = { val encoder = new ZlibEncoder(compressionLevel) pipeline.addFirst("deflater", encoder) diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml new file mode 100644 index 000000000000..0b79f47647f6 --- /dev/null +++ b/external/kafka-assembly/pom.xml @@ -0,0 +1,102 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.4.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-kafka-assembly_2.10 + jar + Spark Project External Kafka Assembly + http://spark.apache.org/ + + + streaming-kafka-assembly + + + + + org.apache.spark + spark-streaming-kafka_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + + diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index b29b0509656b..f695cff410a1 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT ../../pom.xml @@ -44,7 +44,7 @@ org.apache.kafka kafka_${scala.binary.version} - 0.8.0 + 0.8.1.1 com.sun.jmx diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala new file mode 100644 index 000000000000..5a74febb4bd4 --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Represent the host and port info for a Kafka broker. + * Differs from the Kafka project's internal kafka.cluster.Broker, which contains a server ID + */ +@Experimental +final class Broker private( + /** Broker's hostname */ + val host: String, + /** Broker's port */ + val port: Int) extends Serializable { + override def equals(obj: Any): Boolean = obj match { + case that: Broker => + this.host == that.host && + this.port == that.port + case _ => false + } + + override def hashCode: Int = { + 41 * (41 + host.hashCode) + port + } + + override def toString(): String = { + s"Broker($host, $port)" + } +} + +/** + * :: Experimental :: + * Companion object that provides methods to create instances of [[Broker]]. + */ +@Experimental +object Broker { + def create(host: String, port: Int): Broker = + new Broker(host, port) + + def apply(host: String, port: Int): Broker = + new Broker(host, port) + + def unapply(broker: Broker): Option[(String, Int)] = { + if (broker == null) { + None + } else { + Some((broker.host, broker.port)) + } + } +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala new file mode 100644 index 000000000000..1b1fc8051d05 --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + + +import scala.annotation.tailrec +import scala.collection.mutable +import scala.reflect.{classTag, ClassTag} + +import kafka.common.TopicAndPartition +import kafka.message.MessageAndMetadata +import kafka.serializer.Decoder + +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset +import org.apache.spark.streaming.{StreamingContext, Time} +import org.apache.spark.streaming.dstream._ + +/** + * A stream of {@link org.apache.spark.streaming.kafka.KafkaRDD} where + * each given Kafka topic/partition corresponds to an RDD partition. + * The spark configuration spark.streaming.kafka.maxRatePerPartition gives the maximum number + * of messages + * per second that each '''partition''' will accept. + * Starting offsets are specified in advance, + * and this DStream is not responsible for committing offsets, + * so that you can control exactly-once semantics. + * For an easy interface to Kafka-managed offsets, + * see {@link org.apache.spark.streaming.kafka.KafkaCluster} + * @param kafkaParams Kafka + * configuration parameters. + * Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s), + * NOT zookeeper servers, specified in host1:port1,host2:port2 form. + * @param fromOffsets per-topic/partition Kafka offsets defining the (inclusive) + * starting point of the stream + * @param messageHandler function for translating each message into the desired type + */ +private[streaming] +class DirectKafkaInputDStream[ + K: ClassTag, + V: ClassTag, + U <: Decoder[K]: ClassTag, + T <: Decoder[V]: ClassTag, + R: ClassTag]( + @transient ssc_ : StreamingContext, + val kafkaParams: Map[String, String], + val fromOffsets: Map[TopicAndPartition, Long], + messageHandler: MessageAndMetadata[K, V] => R +) extends InputDStream[R](ssc_) with Logging { + val maxRetries = context.sparkContext.getConf.getInt( + "spark.streaming.kafka.maxRetries", 1) + + protected[streaming] override val checkpointData = + new DirectKafkaInputDStreamCheckpointData + + protected val kc = new KafkaCluster(kafkaParams) + + protected val maxMessagesPerPartition: Option[Long] = { + val ratePerSec = context.sparkContext.getConf.getInt( + "spark.streaming.kafka.maxRatePerPartition", 0) + if (ratePerSec > 0) { + val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000 + Some((secsPerBatch * ratePerSec).toLong) + } else { + None + } + } + + protected var currentOffsets = fromOffsets + + @tailrec + protected final def latestLeaderOffsets(retries: Int): Map[TopicAndPartition, LeaderOffset] = { + val o = kc.getLatestLeaderOffsets(currentOffsets.keySet) + // Either.fold would confuse @tailrec, do it manually + if (o.isLeft) { + val err = o.left.get.toString + if (retries <= 0) { + throw new SparkException(err) + } else { + log.error(err) + Thread.sleep(kc.config.refreshLeaderBackoffMs) + latestLeaderOffsets(retries - 1) + } + } else { + o.right.get + } + } + + // limits the maximum number of messages per partition + protected def clamp( + leaderOffsets: Map[TopicAndPartition, LeaderOffset]): Map[TopicAndPartition, LeaderOffset] = { + maxMessagesPerPartition.map { mmp => + leaderOffsets.map { case (tp, lo) => + tp -> lo.copy(offset = Math.min(currentOffsets(tp) + mmp, lo.offset)) + } + }.getOrElse(leaderOffsets) + } + + override def compute(validTime: Time): Option[KafkaRDD[K, V, U, T, R]] = { + val untilOffsets = clamp(latestLeaderOffsets(maxRetries)) + val rdd = KafkaRDD[K, V, U, T, R]( + context.sparkContext, kafkaParams, currentOffsets, untilOffsets, messageHandler) + + currentOffsets = untilOffsets.map(kv => kv._1 -> kv._2.offset) + Some(rdd) + } + + override def start(): Unit = { + } + + def stop(): Unit = { + } + + private[streaming] + class DirectKafkaInputDStreamCheckpointData extends DStreamCheckpointData(this) { + def batchForTime: mutable.HashMap[Time, Array[(String, Int, Long, Long)]] = { + data.asInstanceOf[mutable.HashMap[Time, Array[OffsetRange.OffsetRangeTuple]]] + } + + override def update(time: Time) { + batchForTime.clear() + generatedRDDs.foreach { kv => + val a = kv._2.asInstanceOf[KafkaRDD[K, V, U, T, R]].offsetRanges.map(_.toTuple).toArray + batchForTime += kv._1 -> a + } + } + + override def cleanup(time: Time) { } + + override def restore() { + // this is assuming that the topics don't change during execution, which is true currently + val topics = fromOffsets.keySet + val leaders = kc.findLeaders(topics).fold( + errs => throw new SparkException(errs.mkString("\n")), + ok => ok + ) + + batchForTime.toSeq.sortBy(_._1)(Time.ordering).foreach { case (t, b) => + logInfo(s"Restoring KafkaRDD for time $t ${b.mkString("[", ", ", "]")}") + generatedRDDs += t -> new KafkaRDD[K, V, U, T, R]( + context.sparkContext, kafkaParams, b.map(OffsetRange(_)), leaders, messageHandler) + } + } + } + +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala new file mode 100644 index 000000000000..bd767031c184 --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import scala.util.control.NonFatal +import scala.util.Random +import scala.collection.mutable.ArrayBuffer +import java.util.Properties +import kafka.api._ +import kafka.common.{ErrorMapping, OffsetMetadataAndError, TopicAndPartition} +import kafka.consumer.{ConsumerConfig, SimpleConsumer} +import org.apache.spark.SparkException + +/** + * Convenience methods for interacting with a Kafka cluster. + * @param kafkaParams Kafka + * configuration parameters. + * Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s), + * NOT zookeeper servers, specified in host1:port1,host2:port2 form + */ +private[spark] +class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { + import KafkaCluster.{Err, LeaderOffset, SimpleConsumerConfig} + + // ConsumerConfig isn't serializable + @transient private var _config: SimpleConsumerConfig = null + + def config: SimpleConsumerConfig = this.synchronized { + if (_config == null) { + _config = SimpleConsumerConfig(kafkaParams) + } + _config + } + + def connect(host: String, port: Int): SimpleConsumer = + new SimpleConsumer(host, port, config.socketTimeoutMs, + config.socketReceiveBufferBytes, config.clientId) + + def connectLeader(topic: String, partition: Int): Either[Err, SimpleConsumer] = + findLeader(topic, partition).right.map(hp => connect(hp._1, hp._2)) + + // Metadata api + // scalastyle:off + // https://cwiki.apache.org/confluence/display/KAFKA/A+Guide+To+The+Kafka+Protocol#AGuideToTheKafkaProtocol-MetadataAPI + // scalastyle:on + + def findLeader(topic: String, partition: Int): Either[Err, (String, Int)] = { + val req = TopicMetadataRequest(TopicMetadataRequest.CurrentVersion, + 0, config.clientId, Seq(topic)) + val errs = new Err + withBrokers(Random.shuffle(config.seedBrokers), errs) { consumer => + val resp: TopicMetadataResponse = consumer.send(req) + resp.topicsMetadata.find(_.topic == topic).flatMap { tm: TopicMetadata => + tm.partitionsMetadata.find(_.partitionId == partition) + }.foreach { pm: PartitionMetadata => + pm.leader.foreach { leader => + return Right((leader.host, leader.port)) + } + } + } + Left(errs) + } + + def findLeaders( + topicAndPartitions: Set[TopicAndPartition] + ): Either[Err, Map[TopicAndPartition, (String, Int)]] = { + val topics = topicAndPartitions.map(_.topic) + val response = getPartitionMetadata(topics).right + val answer = response.flatMap { tms: Set[TopicMetadata] => + val leaderMap = tms.flatMap { tm: TopicMetadata => + tm.partitionsMetadata.flatMap { pm: PartitionMetadata => + val tp = TopicAndPartition(tm.topic, pm.partitionId) + if (topicAndPartitions(tp)) { + pm.leader.map { l => + tp -> (l.host -> l.port) + } + } else { + None + } + } + }.toMap + + if (leaderMap.keys.size == topicAndPartitions.size) { + Right(leaderMap) + } else { + val missing = topicAndPartitions.diff(leaderMap.keySet) + val err = new Err + err.append(new SparkException(s"Couldn't find leaders for ${missing}")) + Left(err) + } + } + answer + } + + def getPartitions(topics: Set[String]): Either[Err, Set[TopicAndPartition]] = { + getPartitionMetadata(topics).right.map { r => + r.flatMap { tm: TopicMetadata => + tm.partitionsMetadata.map { pm: PartitionMetadata => + TopicAndPartition(tm.topic, pm.partitionId) + } + } + } + } + + def getPartitionMetadata(topics: Set[String]): Either[Err, Set[TopicMetadata]] = { + val req = TopicMetadataRequest( + TopicMetadataRequest.CurrentVersion, 0, config.clientId, topics.toSeq) + val errs = new Err + withBrokers(Random.shuffle(config.seedBrokers), errs) { consumer => + val resp: TopicMetadataResponse = consumer.send(req) + val respErrs = resp.topicsMetadata.filter(m => m.errorCode != ErrorMapping.NoError) + + if (respErrs.isEmpty) { + return Right(resp.topicsMetadata.toSet) + } else { + respErrs.foreach { m => + val cause = ErrorMapping.exceptionFor(m.errorCode) + val msg = s"Error getting partition metadata for '${m.topic}'. Does the topic exist?" + errs.append(new SparkException(msg, cause)) + } + } + } + Left(errs) + } + + // Leader offset api + // scalastyle:off + // https://cwiki.apache.org/confluence/display/KAFKA/A+Guide+To+The+Kafka+Protocol#AGuideToTheKafkaProtocol-OffsetAPI + // scalastyle:on + + def getLatestLeaderOffsets( + topicAndPartitions: Set[TopicAndPartition] + ): Either[Err, Map[TopicAndPartition, LeaderOffset]] = + getLeaderOffsets(topicAndPartitions, OffsetRequest.LatestTime) + + def getEarliestLeaderOffsets( + topicAndPartitions: Set[TopicAndPartition] + ): Either[Err, Map[TopicAndPartition, LeaderOffset]] = + getLeaderOffsets(topicAndPartitions, OffsetRequest.EarliestTime) + + def getLeaderOffsets( + topicAndPartitions: Set[TopicAndPartition], + before: Long + ): Either[Err, Map[TopicAndPartition, LeaderOffset]] = { + getLeaderOffsets(topicAndPartitions, before, 1).right.map { r => + r.map { kv => + // mapValues isnt serializable, see SI-7005 + kv._1 -> kv._2.head + } + } + } + + private def flip[K, V](m: Map[K, V]): Map[V, Seq[K]] = + m.groupBy(_._2).map { kv => + kv._1 -> kv._2.keys.toSeq + } + + def getLeaderOffsets( + topicAndPartitions: Set[TopicAndPartition], + before: Long, + maxNumOffsets: Int + ): Either[Err, Map[TopicAndPartition, Seq[LeaderOffset]]] = { + findLeaders(topicAndPartitions).right.flatMap { tpToLeader => + val leaderToTp: Map[(String, Int), Seq[TopicAndPartition]] = flip(tpToLeader) + val leaders = leaderToTp.keys + var result = Map[TopicAndPartition, Seq[LeaderOffset]]() + val errs = new Err + withBrokers(leaders, errs) { consumer => + val partitionsToGetOffsets: Seq[TopicAndPartition] = + leaderToTp((consumer.host, consumer.port)) + val reqMap = partitionsToGetOffsets.map { tp: TopicAndPartition => + tp -> PartitionOffsetRequestInfo(before, maxNumOffsets) + }.toMap + val req = OffsetRequest(reqMap) + val resp = consumer.getOffsetsBefore(req) + val respMap = resp.partitionErrorAndOffsets + partitionsToGetOffsets.foreach { tp: TopicAndPartition => + respMap.get(tp).foreach { por: PartitionOffsetsResponse => + if (por.error == ErrorMapping.NoError) { + if (por.offsets.nonEmpty) { + result += tp -> por.offsets.map { off => + LeaderOffset(consumer.host, consumer.port, off) + } + } else { + errs.append(new SparkException( + s"Empty offsets for ${tp}, is ${before} before log beginning?")) + } + } else { + errs.append(ErrorMapping.exceptionFor(por.error)) + } + } + } + if (result.keys.size == topicAndPartitions.size) { + return Right(result) + } + } + val missing = topicAndPartitions.diff(result.keySet) + errs.append(new SparkException(s"Couldn't find leader offsets for ${missing}")) + Left(errs) + } + } + + // Consumer offset api + // scalastyle:off + // https://cwiki.apache.org/confluence/display/KAFKA/A+Guide+To+The+Kafka+Protocol#AGuideToTheKafkaProtocol-OffsetCommit/FetchAPI + // scalastyle:on + + /** Requires Kafka >= 0.8.1.1 */ + def getConsumerOffsets( + groupId: String, + topicAndPartitions: Set[TopicAndPartition] + ): Either[Err, Map[TopicAndPartition, Long]] = { + getConsumerOffsetMetadata(groupId, topicAndPartitions).right.map { r => + r.map { kv => + kv._1 -> kv._2.offset + } + } + } + + /** Requires Kafka >= 0.8.1.1 */ + def getConsumerOffsetMetadata( + groupId: String, + topicAndPartitions: Set[TopicAndPartition] + ): Either[Err, Map[TopicAndPartition, OffsetMetadataAndError]] = { + var result = Map[TopicAndPartition, OffsetMetadataAndError]() + val req = OffsetFetchRequest(groupId, topicAndPartitions.toSeq) + val errs = new Err + withBrokers(Random.shuffle(config.seedBrokers), errs) { consumer => + val resp = consumer.fetchOffsets(req) + val respMap = resp.requestInfo + val needed = topicAndPartitions.diff(result.keySet) + needed.foreach { tp: TopicAndPartition => + respMap.get(tp).foreach { ome: OffsetMetadataAndError => + if (ome.error == ErrorMapping.NoError) { + result += tp -> ome + } else { + errs.append(ErrorMapping.exceptionFor(ome.error)) + } + } + } + if (result.keys.size == topicAndPartitions.size) { + return Right(result) + } + } + val missing = topicAndPartitions.diff(result.keySet) + errs.append(new SparkException(s"Couldn't find consumer offsets for ${missing}")) + Left(errs) + } + + /** Requires Kafka >= 0.8.1.1 */ + def setConsumerOffsets( + groupId: String, + offsets: Map[TopicAndPartition, Long] + ): Either[Err, Map[TopicAndPartition, Short]] = { + setConsumerOffsetMetadata(groupId, offsets.map { kv => + kv._1 -> OffsetMetadataAndError(kv._2) + }) + } + + /** Requires Kafka >= 0.8.1.1 */ + def setConsumerOffsetMetadata( + groupId: String, + metadata: Map[TopicAndPartition, OffsetMetadataAndError] + ): Either[Err, Map[TopicAndPartition, Short]] = { + var result = Map[TopicAndPartition, Short]() + val req = OffsetCommitRequest(groupId, metadata) + val errs = new Err + val topicAndPartitions = metadata.keySet + withBrokers(Random.shuffle(config.seedBrokers), errs) { consumer => + val resp = consumer.commitOffsets(req) + val respMap = resp.requestInfo + val needed = topicAndPartitions.diff(result.keySet) + needed.foreach { tp: TopicAndPartition => + respMap.get(tp).foreach { err: Short => + if (err == ErrorMapping.NoError) { + result += tp -> err + } else { + errs.append(ErrorMapping.exceptionFor(err)) + } + } + } + if (result.keys.size == topicAndPartitions.size) { + return Right(result) + } + } + val missing = topicAndPartitions.diff(result.keySet) + errs.append(new SparkException(s"Couldn't set offsets for ${missing}")) + Left(errs) + } + + // Try a call against potentially multiple brokers, accumulating errors + private def withBrokers(brokers: Iterable[(String, Int)], errs: Err) + (fn: SimpleConsumer => Any): Unit = { + brokers.foreach { hp => + var consumer: SimpleConsumer = null + try { + consumer = connect(hp._1, hp._2) + fn(consumer) + } catch { + case NonFatal(e) => + errs.append(e) + } finally { + if (consumer != null) { + consumer.close() + } + } + } + } +} + +private[spark] +object KafkaCluster { + type Err = ArrayBuffer[Throwable] + + private[spark] + case class LeaderOffset(host: String, port: Int, offset: Long) + + /** + * High-level kafka consumers connect to ZK. ConsumerConfig assumes this use case. + * Simple consumers connect directly to brokers, but need many of the same configs. + * This subclass won't warn about missing ZK params, or presence of broker params. + */ + private[spark] + class SimpleConsumerConfig private(brokers: String, originalProps: Properties) + extends ConsumerConfig(originalProps) { + val seedBrokers: Array[(String, Int)] = brokers.split(",").map { hp => + val hpa = hp.split(":") + if (hpa.size == 1) { + throw new SparkException(s"Broker not the in correct format of : [$brokers]") + } + (hpa(0), hpa(1).toInt) + } + } + + private[spark] + object SimpleConsumerConfig { + /** + * Make a consumer config without requiring group.id or zookeeper.connect, + * since communicating with brokers also needs common settings such as timeout + */ + def apply(kafkaParams: Map[String, String]): SimpleConsumerConfig = { + // These keys are from other pre-existing kafka configs for specifying brokers, accept either + val brokers = kafkaParams.get("metadata.broker.list") + .orElse(kafkaParams.get("bootstrap.servers")) + .getOrElse(throw new SparkException( + "Must specify metadata.broker.list or bootstrap.servers")) + + val props = new Properties() + kafkaParams.foreach { case (key, value) => + // prevent warnings on parameters ConsumerConfig doesn't know about + if (key != "metadata.broker.list" && key != "bootstrap.servers") { + props.put(key, value) + } + } + + Seq("zookeeper.connect", "group.id").foreach { s => + if (!props.contains(s)) { + props.setProperty(s, "") + } + } + + new SimpleConsumerConfig(brokers, props) + } + } +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index 4d26b640e8d7..cca0fac0234e 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -31,7 +31,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.Receiver -import org.apache.spark.util.Utils +import org.apache.spark.util.ThreadUtils /** * Input stream that pulls messages from a Kafka Broker. @@ -111,7 +111,8 @@ class KafkaReceiver[ val topicMessageStreams = consumerConnector.createMessageStreams( topics, keyDecoder, valueDecoder) - val executorPool = Utils.newDaemonFixedThreadPool(topics.values.sum, "KafkaMessageHandler") + val executorPool = + ThreadUtils.newDaemonFixedThreadPool(topics.values.sum, "KafkaMessageHandler") try { // Start the messages handler for each partition topicMessageStreams.values.foreach { streams => diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala new file mode 100644 index 000000000000..a1b4a12e5d6a --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import scala.reflect.{classTag, ClassTag} + +import org.apache.spark.{Logging, Partition, SparkContext, SparkException, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.util.NextIterator + +import kafka.api.{FetchRequestBuilder, FetchResponse} +import kafka.common.{ErrorMapping, TopicAndPartition} +import kafka.consumer.SimpleConsumer +import kafka.message.{MessageAndMetadata, MessageAndOffset} +import kafka.serializer.Decoder +import kafka.utils.VerifiableProperties + +/** + * A batch-oriented interface for consuming from Kafka. + * Starting and ending offsets are specified in advance, + * so that you can control exactly-once semantics. + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" to be set + * with Kafka broker(s) specified in host1:port1,host2:port2 form. + * @param offsetRanges offset ranges that define the Kafka data belonging to this RDD + * @param messageHandler function for translating each message into the desired type + */ +private[kafka] +class KafkaRDD[ + K: ClassTag, + V: ClassTag, + U <: Decoder[_]: ClassTag, + T <: Decoder[_]: ClassTag, + R: ClassTag] private[spark] ( + sc: SparkContext, + kafkaParams: Map[String, String], + val offsetRanges: Array[OffsetRange], + leaders: Map[TopicAndPartition, (String, Int)], + messageHandler: MessageAndMetadata[K, V] => R + ) extends RDD[R](sc, Nil) with Logging with HasOffsetRanges { + override def getPartitions: Array[Partition] = { + offsetRanges.zipWithIndex.map { case (o, i) => + val (host, port) = leaders(TopicAndPartition(o.topic, o.partition)) + new KafkaRDDPartition(i, o.topic, o.partition, o.fromOffset, o.untilOffset, host, port) + }.toArray + } + + override def getPreferredLocations(thePart: Partition): Seq[String] = { + val part = thePart.asInstanceOf[KafkaRDDPartition] + // TODO is additional hostname resolution necessary here + Seq(part.host) + } + + private def errBeginAfterEnd(part: KafkaRDDPartition): String = + s"Beginning offset ${part.fromOffset} is after the ending offset ${part.untilOffset} " + + s"for topic ${part.topic} partition ${part.partition}. " + + "You either provided an invalid fromOffset, or the Kafka topic has been damaged" + + private def errRanOutBeforeEnd(part: KafkaRDDPartition): String = + s"Ran out of messages before reaching ending offset ${part.untilOffset} " + + s"for topic ${part.topic} partition ${part.partition} start ${part.fromOffset}." + + " This should not happen, and indicates that messages may have been lost" + + private def errOvershotEnd(itemOffset: Long, part: KafkaRDDPartition): String = + s"Got ${itemOffset} > ending offset ${part.untilOffset} " + + s"for topic ${part.topic} partition ${part.partition} start ${part.fromOffset}." + + " This should not happen, and indicates a message may have been skipped" + + override def compute(thePart: Partition, context: TaskContext): Iterator[R] = { + val part = thePart.asInstanceOf[KafkaRDDPartition] + assert(part.fromOffset <= part.untilOffset, errBeginAfterEnd(part)) + if (part.fromOffset == part.untilOffset) { + log.info(s"Beginning offset ${part.fromOffset} is the same as ending offset " + + s"skipping ${part.topic} ${part.partition}") + Iterator.empty + } else { + new KafkaRDDIterator(part, context) + } + } + + private class KafkaRDDIterator( + part: KafkaRDDPartition, + context: TaskContext) extends NextIterator[R] { + + context.addTaskCompletionListener{ context => closeIfNeeded() } + + log.info(s"Computing topic ${part.topic}, partition ${part.partition} " + + s"offsets ${part.fromOffset} -> ${part.untilOffset}") + + val kc = new KafkaCluster(kafkaParams) + val keyDecoder = classTag[U].runtimeClass.getConstructor(classOf[VerifiableProperties]) + .newInstance(kc.config.props) + .asInstanceOf[Decoder[K]] + val valueDecoder = classTag[T].runtimeClass.getConstructor(classOf[VerifiableProperties]) + .newInstance(kc.config.props) + .asInstanceOf[Decoder[V]] + val consumer = connectLeader + var requestOffset = part.fromOffset + var iter: Iterator[MessageAndOffset] = null + + // The idea is to use the provided preferred host, except on task retry atttempts, + // to minimize number of kafka metadata requests + private def connectLeader: SimpleConsumer = { + if (context.attemptNumber > 0) { + kc.connectLeader(part.topic, part.partition).fold( + errs => throw new SparkException( + s"Couldn't connect to leader for topic ${part.topic} ${part.partition}: " + + errs.mkString("\n")), + consumer => consumer + ) + } else { + kc.connect(part.host, part.port) + } + } + + private def handleFetchErr(resp: FetchResponse) { + if (resp.hasError) { + val err = resp.errorCode(part.topic, part.partition) + if (err == ErrorMapping.LeaderNotAvailableCode || + err == ErrorMapping.NotLeaderForPartitionCode) { + log.error(s"Lost leader for topic ${part.topic} partition ${part.partition}, " + + s" sleeping for ${kc.config.refreshLeaderBackoffMs}ms") + Thread.sleep(kc.config.refreshLeaderBackoffMs) + } + // Let normal rdd retry sort out reconnect attempts + throw ErrorMapping.exceptionFor(err) + } + } + + private def fetchBatch: Iterator[MessageAndOffset] = { + val req = new FetchRequestBuilder() + .addFetch(part.topic, part.partition, requestOffset, kc.config.fetchMessageMaxBytes) + .build() + val resp = consumer.fetch(req) + handleFetchErr(resp) + // kafka may return a batch that starts before the requested offset + resp.messageSet(part.topic, part.partition) + .iterator + .dropWhile(_.offset < requestOffset) + } + + override def close(): Unit = consumer.close() + + override def getNext(): R = { + if (iter == null || !iter.hasNext) { + iter = fetchBatch + } + if (!iter.hasNext) { + assert(requestOffset == part.untilOffset, errRanOutBeforeEnd(part)) + finished = true + null.asInstanceOf[R] + } else { + val item = iter.next() + if (item.offset >= part.untilOffset) { + assert(item.offset == part.untilOffset, errOvershotEnd(item.offset, part)) + finished = true + null.asInstanceOf[R] + } else { + requestOffset = item.nextOffset + messageHandler(new MessageAndMetadata( + part.topic, part.partition, item.message, item.offset, keyDecoder, valueDecoder)) + } + } + } + } +} + +private[kafka] +object KafkaRDD { + import KafkaCluster.LeaderOffset + + /** + * @param kafkaParams Kafka + * configuration parameters. + * Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s), + * NOT zookeeper servers, specified in host1:port1,host2:port2 form. + * @param fromOffsets per-topic/partition Kafka offsets defining the (inclusive) + * starting point of the batch + * @param untilOffsets per-topic/partition Kafka offsets defining the (exclusive) + * ending point of the batch + * @param messageHandler function for translating each message into the desired type + */ + def apply[ + K: ClassTag, + V: ClassTag, + U <: Decoder[_]: ClassTag, + T <: Decoder[_]: ClassTag, + R: ClassTag]( + sc: SparkContext, + kafkaParams: Map[String, String], + fromOffsets: Map[TopicAndPartition, Long], + untilOffsets: Map[TopicAndPartition, LeaderOffset], + messageHandler: MessageAndMetadata[K, V] => R + ): KafkaRDD[K, V, U, T, R] = { + val leaders = untilOffsets.map { case (tp, lo) => + tp -> (lo.host, lo.port) + }.toMap + + val offsetRanges = fromOffsets.map { case (tp, fo) => + val uo = untilOffsets(tp) + OffsetRange(tp.topic, tp.partition, fo, uo.offset) + }.toArray + + new KafkaRDD[K, V, U, T, R](sc, kafkaParams, offsetRanges, leaders, messageHandler) + } +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala new file mode 100644 index 000000000000..a842a6f17766 --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import org.apache.spark.Partition + +/** @param topic kafka topic name + * @param partition kafka partition id + * @param fromOffset inclusive starting offset + * @param untilOffset exclusive ending offset + * @param host preferred kafka host, i.e. the leader at the time the rdd was created + * @param port preferred kafka host's port + */ +private[kafka] +class KafkaRDDPartition( + val index: Int, + val topic: String, + val partition: Int, + val fromOffset: Long, + val untilOffset: Long, + val host: String, + val port: Int +) extends Partition diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala new file mode 100644 index 000000000000..13e947506597 --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import java.io.File +import java.lang.{Integer => JInt} +import java.net.InetSocketAddress +import java.util.{Map => JMap} +import java.util.Properties +import java.util.concurrent.TimeoutException + +import scala.annotation.tailrec +import scala.language.postfixOps +import scala.util.control.NonFatal + +import kafka.admin.AdminUtils +import kafka.producer.{KeyedMessage, Producer, ProducerConfig} +import kafka.serializer.StringEncoder +import kafka.server.{KafkaConfig, KafkaServer} +import kafka.utils.ZKStringSerializer +import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} +import org.I0Itec.zkclient.ZkClient + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.streaming.Time +import org.apache.spark.util.Utils + +/** + * This is a helper class for Kafka test suites. This has the functionality to set up + * and tear down local Kafka servers, and to push data using Kafka producers. + * + * The reason to put Kafka test utility class in src is to test Python related Kafka APIs. + */ +private class KafkaTestUtils extends Logging { + + // Zookeeper related configurations + private val zkHost = "localhost" + private var zkPort: Int = 0 + private val zkConnectionTimeout = 6000 + private val zkSessionTimeout = 6000 + + private var zookeeper: EmbeddedZookeeper = _ + + private var zkClient: ZkClient = _ + + // Kafka broker related configurations + private val brokerHost = "localhost" + private var brokerPort = 9092 + private var brokerConf: KafkaConfig = _ + + // Kafka broker server + private var server: KafkaServer = _ + + // Kafka producer + private var producer: Producer[String, String] = _ + + // Flag to test whether the system is correctly started + private var zkReady = false + private var brokerReady = false + + def zkAddress: String = { + assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper address") + s"$zkHost:$zkPort" + } + + def brokerAddress: String = { + assert(brokerReady, "Kafka not setup yet or already torn down, cannot get broker address") + s"$brokerHost:$brokerPort" + } + + def zookeeperClient: ZkClient = { + assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper client") + Option(zkClient).getOrElse( + throw new IllegalStateException("Zookeeper client is not yet initialized")) + } + + // Set up the Embedded Zookeeper server and get the proper Zookeeper port + private def setupEmbeddedZookeeper(): Unit = { + // Zookeeper server startup + zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort") + // Get the actual zookeeper binding port + zkPort = zookeeper.actualPort + zkClient = new ZkClient(s"$zkHost:$zkPort", zkSessionTimeout, zkConnectionTimeout, + ZKStringSerializer) + zkReady = true + } + + // Set up the Embedded Kafka server + private def setupEmbeddedKafkaServer(): Unit = { + assert(zkReady, "Zookeeper should be set up beforehand") + + // Kafka broker startup + Utils.startServiceOnPort(brokerPort, port => { + brokerPort = port + brokerConf = new KafkaConfig(brokerConfiguration) + server = new KafkaServer(brokerConf) + server.startup() + (server, port) + }, new SparkConf(), "KafkaBroker") + + brokerReady = true + } + + /** setup the whole embedded servers, including Zookeeper and Kafka brokers */ + def setup(): Unit = { + setupEmbeddedZookeeper() + setupEmbeddedKafkaServer() + } + + /** Teardown the whole servers, including Kafka broker and Zookeeper */ + def teardown(): Unit = { + brokerReady = false + zkReady = false + + if (producer != null) { + producer.close() + producer = null + } + + if (server != null) { + server.shutdown() + server = null + } + + brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) } + + if (zkClient != null) { + zkClient.close() + zkClient = null + } + + if (zookeeper != null) { + zookeeper.shutdown() + zookeeper = null + } + } + + /** Create a Kafka topic and wait until it propagated to the whole cluster */ + def createTopic(topic: String): Unit = { + AdminUtils.createTopic(zkClient, topic, 1, 1) + // wait until metadata is propagated + waitUntilMetadataIsPropagated(topic, 0) + } + + /** Java-friendly function for sending messages to the Kafka broker */ + def sendMessages(topic: String, messageToFreq: JMap[String, JInt]): Unit = { + import scala.collection.JavaConversions._ + sendMessages(topic, Map(messageToFreq.mapValues(_.intValue()).toSeq: _*)) + } + + /** Send the messages to the Kafka broker */ + def sendMessages(topic: String, messageToFreq: Map[String, Int]): Unit = { + val messages = messageToFreq.flatMap { case (s, freq) => Seq.fill(freq)(s) }.toArray + sendMessages(topic, messages) + } + + /** Send the array of messages to the Kafka broker */ + def sendMessages(topic: String, messages: Array[String]): Unit = { + producer = new Producer[String, String](new ProducerConfig(producerConfiguration)) + producer.send(messages.map { new KeyedMessage[String, String](topic, _ ) }: _*) + producer.close() + producer = null + } + + private def brokerConfiguration: Properties = { + val props = new Properties() + props.put("broker.id", "0") + props.put("host.name", "localhost") + props.put("port", brokerPort.toString) + props.put("log.dir", Utils.createTempDir().getAbsolutePath) + props.put("zookeeper.connect", zkAddress) + props.put("log.flush.interval.messages", "1") + props.put("replica.socket.timeout.ms", "1500") + props + } + + private def producerConfiguration: Properties = { + val props = new Properties() + props.put("metadata.broker.list", brokerAddress) + props.put("serializer.class", classOf[StringEncoder].getName) + props + } + + // A simplified version of scalatest eventually, rewritten here to avoid adding extra test + // dependency + def eventually[T](timeout: Time, interval: Time)(func: => T): T = { + def makeAttempt(): Either[Throwable, T] = { + try { + Right(func) + } catch { + case e if NonFatal(e) => Left(e) + } + } + + val startTime = System.currentTimeMillis() + @tailrec + def tryAgain(attempt: Int): T = { + makeAttempt() match { + case Right(result) => result + case Left(e) => + val duration = System.currentTimeMillis() - startTime + if (duration < timeout.milliseconds) { + Thread.sleep(interval.milliseconds) + } else { + throw new TimeoutException(e.getMessage) + } + + tryAgain(attempt + 1) + } + } + + tryAgain(1) + } + + private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = { + eventually(Time(10000), Time(100)) { + assert( + server.apis.metadataCache.containsTopicAndPartition(topic, partition), + s"Partition [$topic, $partition] metadata not propagated after timeout" + ) + } + } + + private class EmbeddedZookeeper(val zkConnect: String) { + val snapshotDir = Utils.createTempDir() + val logDir = Utils.createTempDir() + + val zookeeper = new ZooKeeperServer(snapshotDir, logDir, 500) + val (ip, port) = { + val splits = zkConnect.split(":") + (splits(0), splits(1).toInt) + } + val factory = new NIOServerCnxnFactory() + factory.configure(new InetSocketAddress(ip, port), 16) + factory.startup(zookeeper) + + val actualPort = factory.getLocalPort + + def shutdown() { + factory.shutdown() + Utils.deleteRecursively(snapshotDir) + Utils.deleteRecursively(logDir) + } + } +} + diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index df725f0c65a6..5a9bd4214cf5 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -18,21 +18,30 @@ package org.apache.spark.streaming.kafka import java.lang.{Integer => JInt} +import java.lang.{Long => JLong} import java.util.{Map => JMap} +import java.util.{Set => JSet} import scala.reflect.ClassTag import scala.collection.JavaConversions._ -import kafka.serializer.{Decoder, StringDecoder} +import kafka.common.TopicAndPartition +import kafka.message.MessageAndMetadata +import kafka.serializer.{DefaultDecoder, Decoder, StringDecoder} +import org.apache.spark.api.java.function.{Function => JFunction} +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.annotation.Experimental +import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaPairReceiverInputDStream, JavaStreamingContext} -import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.api.java.{JavaPairInputDStream, JavaInputDStream, JavaPairReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream} +import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} object KafkaUtils { /** - * Create an input stream that pulls messages from a Kafka Broker. + * Create an input stream that pulls messages from Kafka Brokers. * @param ssc StreamingContext object * @param zkQuorum Zookeeper quorum (hostname:port,hostname:port,..) * @param groupId The group id for this consumer @@ -56,7 +65,7 @@ object KafkaUtils { } /** - * Create an input stream that pulls messages from a Kafka Broker. + * Create an input stream that pulls messages from Kafka Brokers. * @param ssc StreamingContext object * @param kafkaParams Map of kafka configuration parameters, * see http://kafka.apache.org/08/configuration.html @@ -75,7 +84,7 @@ object KafkaUtils { } /** - * Create an input stream that pulls messages from a Kafka Broker. + * Create an input stream that pulls messages from Kafka Brokers. * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. * @param jssc JavaStreamingContext object * @param zkQuorum Zookeeper quorum (hostname:port,hostname:port,..) @@ -93,7 +102,7 @@ object KafkaUtils { } /** - * Create an input stream that pulls messages from a Kafka Broker. + * Create an input stream that pulls messages from Kafka Brokers. * @param jssc JavaStreamingContext object * @param zkQuorum Zookeeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. @@ -113,10 +122,10 @@ object KafkaUtils { } /** - * Create an input stream that pulls messages from a Kafka Broker. + * Create an input stream that pulls messages from Kafka Brokers. * @param jssc JavaStreamingContext object - * @param keyTypeClass Key type of RDD - * @param valueTypeClass value type of RDD + * @param keyTypeClass Key type of DStream + * @param valueTypeClass value type of Dstream * @param keyDecoderClass Type of kafka key decoder * @param valueDecoderClass Type of kafka value decoder * @param kafkaParams Map of kafka configuration parameters, @@ -144,4 +153,409 @@ object KafkaUtils { createStream[K, V, U, T]( jssc.ssc, kafkaParams.toMap, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) } + + /** get leaders for the given offset ranges, or throw an exception */ + private def leadersForRanges( + kafkaParams: Map[String, String], + offsetRanges: Array[OffsetRange]): Map[TopicAndPartition, (String, Int)] = { + val kc = new KafkaCluster(kafkaParams) + val topics = offsetRanges.map(o => TopicAndPartition(o.topic, o.partition)).toSet + val leaders = kc.findLeaders(topics).fold( + errs => throw new SparkException(errs.mkString("\n")), + ok => ok + ) + leaders + } + + /** + * Create a RDD from Kafka using offset ranges for each topic and partition. + * + * @param sc SparkContext object + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers) specified in + * host1:port1,host2:port2 form. + * @param offsetRanges Each OffsetRange in the batch corresponds to a + * range of offsets for a given Kafka topic/partition + */ + @Experimental + def createRDD[ + K: ClassTag, + V: ClassTag, + KD <: Decoder[K]: ClassTag, + VD <: Decoder[V]: ClassTag]( + sc: SparkContext, + kafkaParams: Map[String, String], + offsetRanges: Array[OffsetRange] + ): RDD[(K, V)] = { + val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message) + val leaders = leadersForRanges(kafkaParams, offsetRanges) + new KafkaRDD[K, V, KD, VD, (K, V)](sc, kafkaParams, offsetRanges, leaders, messageHandler) + } + + /** + * :: Experimental :: + * Create a RDD from Kafka using offset ranges for each topic and partition. This allows you + * specify the Kafka leader to connect to (to optimize fetching) and access the message as well + * as the metadata. + * + * @param sc SparkContext object + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers) specified in + * host1:port1,host2:port2 form. + * @param offsetRanges Each OffsetRange in the batch corresponds to a + * range of offsets for a given Kafka topic/partition + * @param leaders Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty map, + * in which case leaders will be looked up on the driver. + * @param messageHandler Function for translating each message and metadata into the desired type + */ + @Experimental + def createRDD[ + K: ClassTag, + V: ClassTag, + KD <: Decoder[K]: ClassTag, + VD <: Decoder[V]: ClassTag, + R: ClassTag]( + sc: SparkContext, + kafkaParams: Map[String, String], + offsetRanges: Array[OffsetRange], + leaders: Map[TopicAndPartition, Broker], + messageHandler: MessageAndMetadata[K, V] => R + ): RDD[R] = { + val leaderMap = if (leaders.isEmpty) { + leadersForRanges(kafkaParams, offsetRanges) + } else { + // This could be avoided by refactoring KafkaRDD.leaders and KafkaCluster to use Broker + leaders.map { + case (tp: TopicAndPartition, Broker(host, port)) => (tp, (host, port)) + }.toMap + } + new KafkaRDD[K, V, KD, VD, R](sc, kafkaParams, offsetRanges, leaderMap, messageHandler) + } + + + /** + * Create a RDD from Kafka using offset ranges for each topic and partition. + * + * @param jsc JavaSparkContext object + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers) specified in + * host1:port1,host2:port2 form. + * @param offsetRanges Each OffsetRange in the batch corresponds to a + * range of offsets for a given Kafka topic/partition + */ + @Experimental + def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V]]( + jsc: JavaSparkContext, + keyClass: Class[K], + valueClass: Class[V], + keyDecoderClass: Class[KD], + valueDecoderClass: Class[VD], + kafkaParams: JMap[String, String], + offsetRanges: Array[OffsetRange] + ): JavaPairRDD[K, V] = { + implicit val keyCmt: ClassTag[K] = ClassTag(keyClass) + implicit val valueCmt: ClassTag[V] = ClassTag(valueClass) + implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) + implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) + new JavaPairRDD(createRDD[K, V, KD, VD]( + jsc.sc, Map(kafkaParams.toSeq: _*), offsetRanges)) + } + + /** + * :: Experimental :: + * Create a RDD from Kafka using offset ranges for each topic and partition. This allows you + * specify the Kafka leader to connect to (to optimize fetching) and access the message as well + * as the metadata. + * + * @param jsc JavaSparkContext object + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers) specified in + * host1:port1,host2:port2 form. + * @param offsetRanges Each OffsetRange in the batch corresponds to a + * range of offsets for a given Kafka topic/partition + * @param leaders Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty map, + * in which case leaders will be looked up on the driver. + * @param messageHandler Function for translating each message and metadata into the desired type + */ + @Experimental + def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V], R]( + jsc: JavaSparkContext, + keyClass: Class[K], + valueClass: Class[V], + keyDecoderClass: Class[KD], + valueDecoderClass: Class[VD], + recordClass: Class[R], + kafkaParams: JMap[String, String], + offsetRanges: Array[OffsetRange], + leaders: JMap[TopicAndPartition, Broker], + messageHandler: JFunction[MessageAndMetadata[K, V], R] + ): JavaRDD[R] = { + implicit val keyCmt: ClassTag[K] = ClassTag(keyClass) + implicit val valueCmt: ClassTag[V] = ClassTag(valueClass) + implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) + implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) + implicit val recordCmt: ClassTag[R] = ClassTag(recordClass) + val leaderMap = Map(leaders.toSeq: _*) + createRDD[K, V, KD, VD, R]( + jsc.sc, Map(kafkaParams.toSeq: _*), offsetRanges, leaderMap, messageHandler.call _) + } + + /** + * :: Experimental :: + * Create an input stream that directly pulls messages from Kafka Brokers + * without using any receiver. This stream can guarantee that each message + * from Kafka is included in transformations exactly once (see points below). + * + * Points to note: + * - No receivers: This stream does not use any receiver. It directly queries Kafka + * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked + * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. + * You can access the offsets used in each batch from the generated RDDs (see + * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). + * - Failure Recovery: To recover from driver failures, you have to enable checkpointing + * in the [[StreamingContext]]. The information on consumed offset can be + * recovered from the checkpoint. See the programming guide for details (constraints, etc.). + * - End-to-end semantics: This stream ensures that every records is effectively received and + * transformed exactly once, but gives no guarantees on whether the transformed data are + * outputted exactly once. For end-to-end exactly-once semantics, you have to either ensure + * that the output operation is idempotent, or use transactions to output records atomically. + * See the programming guide for more details. + * + * @param ssc StreamingContext object + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers) specified in + * host1:port1,host2:port2 form. + * @param fromOffsets Per-topic/partition Kafka offsets defining the (inclusive) + * starting point of the stream + * @param messageHandler Function for translating each message and metadata into the desired type + */ + @Experimental + def createDirectStream[ + K: ClassTag, + V: ClassTag, + KD <: Decoder[K]: ClassTag, + VD <: Decoder[V]: ClassTag, + R: ClassTag] ( + ssc: StreamingContext, + kafkaParams: Map[String, String], + fromOffsets: Map[TopicAndPartition, Long], + messageHandler: MessageAndMetadata[K, V] => R + ): InputDStream[R] = { + new DirectKafkaInputDStream[K, V, KD, VD, R]( + ssc, kafkaParams, fromOffsets, messageHandler) + } + + /** + * :: Experimental :: + * Create an input stream that directly pulls messages from Kafka Brokers + * without using any receiver. This stream can guarantee that each message + * from Kafka is included in transformations exactly once (see points below). + * + * Points to note: + * - No receivers: This stream does not use any receiver. It directly queries Kafka + * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked + * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. + * You can access the offsets used in each batch from the generated RDDs (see + * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). + * - Failure Recovery: To recover from driver failures, you have to enable checkpointing + * in the [[StreamingContext]]. The information on consumed offset can be + * recovered from the checkpoint. See the programming guide for details (constraints, etc.). + * - End-to-end semantics: This stream ensures that every records is effectively received and + * transformed exactly once, but gives no guarantees on whether the transformed data are + * outputted exactly once. For end-to-end exactly-once semantics, you have to either ensure + * that the output operation is idempotent, or use transactions to output records atomically. + * See the programming guide for more details. + * + * @param ssc StreamingContext object + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers), specified in + * host1:port1,host2:port2 form. + * If not starting from a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest" + * to determine where the stream starts (defaults to "largest") + * @param topics Names of the topics to consume + */ + @Experimental + def createDirectStream[ + K: ClassTag, + V: ClassTag, + KD <: Decoder[K]: ClassTag, + VD <: Decoder[V]: ClassTag] ( + ssc: StreamingContext, + kafkaParams: Map[String, String], + topics: Set[String] + ): InputDStream[(K, V)] = { + val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message) + val kc = new KafkaCluster(kafkaParams) + val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase) + + (for { + topicPartitions <- kc.getPartitions(topics).right + leaderOffsets <- (if (reset == Some("smallest")) { + kc.getEarliestLeaderOffsets(topicPartitions) + } else { + kc.getLatestLeaderOffsets(topicPartitions) + }).right + } yield { + val fromOffsets = leaderOffsets.map { case (tp, lo) => + (tp, lo.offset) + } + new DirectKafkaInputDStream[K, V, KD, VD, (K, V)]( + ssc, kafkaParams, fromOffsets, messageHandler) + }).fold( + errs => throw new SparkException(errs.mkString("\n")), + ok => ok + ) + } + + /** + * :: Experimental :: + * Create an input stream that directly pulls messages from Kafka Brokers + * without using any receiver. This stream can guarantee that each message + * from Kafka is included in transformations exactly once (see points below). + * + * Points to note: + * - No receivers: This stream does not use any receiver. It directly queries Kafka + * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked + * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. + * You can access the offsets used in each batch from the generated RDDs (see + * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). + * - Failure Recovery: To recover from driver failures, you have to enable checkpointing + * in the [[StreamingContext]]. The information on consumed offset can be + * recovered from the checkpoint. See the programming guide for details (constraints, etc.). + * - End-to-end semantics: This stream ensures that every records is effectively received and + * transformed exactly once, but gives no guarantees on whether the transformed data are + * outputted exactly once. For end-to-end exactly-once semantics, you have to either ensure + * that the output operation is idempotent, or use transactions to output records atomically. + * See the programming guide for more details. + * + * @param jssc JavaStreamingContext object + * @param keyClass Class of the keys in the Kafka records + * @param valueClass Class of the values in the Kafka records + * @param keyDecoderClass Class of the key decoder + * @param valueDecoderClass Class of the value decoder + * @param recordClass Class of the records in DStream + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers), specified in + * host1:port1,host2:port2 form. + * @param fromOffsets Per-topic/partition Kafka offsets defining the (inclusive) + * starting point of the stream + * @param messageHandler Function for translating each message and metadata into the desired type + */ + @Experimental + def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V], R]( + jssc: JavaStreamingContext, + keyClass: Class[K], + valueClass: Class[V], + keyDecoderClass: Class[KD], + valueDecoderClass: Class[VD], + recordClass: Class[R], + kafkaParams: JMap[String, String], + fromOffsets: JMap[TopicAndPartition, JLong], + messageHandler: JFunction[MessageAndMetadata[K, V], R] + ): JavaInputDStream[R] = { + implicit val keyCmt: ClassTag[K] = ClassTag(keyClass) + implicit val valueCmt: ClassTag[V] = ClassTag(valueClass) + implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) + implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) + implicit val recordCmt: ClassTag[R] = ClassTag(recordClass) + createDirectStream[K, V, KD, VD, R]( + jssc.ssc, + Map(kafkaParams.toSeq: _*), + Map(fromOffsets.mapValues { _.longValue() }.toSeq: _*), + messageHandler.call _ + ) + } + + /** + * :: Experimental :: + * Create an input stream that directly pulls messages from Kafka Brokers + * without using any receiver. This stream can guarantee that each message + * from Kafka is included in transformations exactly once (see points below). + * + * Points to note: + * - No receivers: This stream does not use any receiver. It directly queries Kafka + * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked + * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. + * You can access the offsets used in each batch from the generated RDDs (see + * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). + * - Failure Recovery: To recover from driver failures, you have to enable checkpointing + * in the [[StreamingContext]]. The information on consumed offset can be + * recovered from the checkpoint. See the programming guide for details (constraints, etc.). + * - End-to-end semantics: This stream ensures that every records is effectively received and + * transformed exactly once, but gives no guarantees on whether the transformed data are + * outputted exactly once. For end-to-end exactly-once semantics, you have to either ensure + * that the output operation is idempotent, or use transactions to output records atomically. + * See the programming guide for more details. + * + * @param jssc JavaStreamingContext object + * @param keyClass Class of the keys in the Kafka records + * @param valueClass Class of the values in the Kafka records + * @param keyDecoderClass Class of the key decoder + * @param valueDecoderClass Class type of the value decoder + * @param kafkaParams Kafka + * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" + * to be set with Kafka broker(s) (NOT zookeeper servers), specified in + * host1:port1,host2:port2 form. + * If not starting from a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest" + * to determine where the stream starts (defaults to "largest") + * @param topics Names of the topics to consume + */ + @Experimental + def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V]]( + jssc: JavaStreamingContext, + keyClass: Class[K], + valueClass: Class[V], + keyDecoderClass: Class[KD], + valueDecoderClass: Class[VD], + kafkaParams: JMap[String, String], + topics: JSet[String] + ): JavaPairInputDStream[K, V] = { + implicit val keyCmt: ClassTag[K] = ClassTag(keyClass) + implicit val valueCmt: ClassTag[V] = ClassTag(valueClass) + implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) + implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) + createDirectStream[K, V, KD, VD]( + jssc.ssc, + Map(kafkaParams.toSeq: _*), + Set(topics.toSeq: _*) + ) + } +} + +/** + * This is a helper class that wraps the KafkaUtils.createStream() into more + * Python-friendly class and function so that it can be easily + * instantiated and called from Python's KafkaUtils (see SPARK-6027). + * + * The zero-arg constructor helps instantiate this class from the Class object + * classOf[KafkaUtilsPythonHelper].newInstance(), and the createStream() + * takes care of known parameters instead of passing them from Python + */ +private class KafkaUtilsPythonHelper { + def createStream( + jssc: JavaStreamingContext, + kafkaParams: JMap[String, String], + topics: JMap[String, JInt], + storageLevel: StorageLevel): JavaPairReceiverInputDStream[Array[Byte], Array[Byte]] = { + KafkaUtils.createStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder]( + jssc, + classOf[Array[Byte]], + classOf[Array[Byte]], + classOf[DefaultDecoder], + classOf[DefaultDecoder], + kafkaParams, + topics, + storageLevel) + } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala new file mode 100644 index 000000000000..9c3dfeb8f592 --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import kafka.common.TopicAndPartition + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Represents any object that has a collection of [[OffsetRange]]s. This can be used access the + * offset ranges in RDDs generated by the direct Kafka DStream (see + * [[KafkaUtils.createDirectStream()]]). + * {{{ + * KafkaUtils.createDirectStream(...).foreachRDD { rdd => + * val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + * ... + * } + * }}} + */ +@Experimental +trait HasOffsetRanges { + def offsetRanges: Array[OffsetRange] +} + +/** + * :: Experimental :: + * Represents a range of offsets from a single Kafka TopicAndPartition. Instances of this class + * can be created with `OffsetRange.create()`. + */ +@Experimental +final class OffsetRange private( + /** Kafka topic name */ + val topic: String, + /** Kafka partition id */ + val partition: Int, + /** inclusive starting offset */ + val fromOffset: Long, + /** exclusive ending offset */ + val untilOffset: Long) extends Serializable { + import OffsetRange.OffsetRangeTuple + + override def equals(obj: Any): Boolean = obj match { + case that: OffsetRange => + this.topic == that.topic && + this.partition == that.partition && + this.fromOffset == that.fromOffset && + this.untilOffset == that.untilOffset + case _ => false + } + + override def hashCode(): Int = { + toTuple.hashCode() + } + + override def toString(): String = { + s"OffsetRange(topic: '$topic', partition: $partition, range: [$fromOffset -> $untilOffset]" + } + + /** this is to avoid ClassNotFoundException during checkpoint restore */ + private[streaming] + def toTuple: OffsetRangeTuple = (topic, partition, fromOffset, untilOffset) +} + +/** + * :: Experimental :: + * Companion object the provides methods to create instances of [[OffsetRange]]. + */ +@Experimental +object OffsetRange { + def create(topic: String, partition: Int, fromOffset: Long, untilOffset: Long): OffsetRange = + new OffsetRange(topic, partition, fromOffset, untilOffset) + + def create( + topicAndPartition: TopicAndPartition, + fromOffset: Long, + untilOffset: Long): OffsetRange = + new OffsetRange(topicAndPartition.topic, topicAndPartition.partition, fromOffset, untilOffset) + + def apply(topic: String, partition: Int, fromOffset: Long, untilOffset: Long): OffsetRange = + new OffsetRange(topic, partition, fromOffset, untilOffset) + + def apply( + topicAndPartition: TopicAndPartition, + fromOffset: Long, + untilOffset: Long): OffsetRange = + new OffsetRange(topicAndPartition.topic, topicAndPartition.partition, fromOffset, untilOffset) + + /** this is to avoid ClassNotFoundException during checkpoint restore */ + private[kafka] + type OffsetRangeTuple = (String, Int, Long, Long) + + private[kafka] + def apply(t: OffsetRangeTuple) = + new OffsetRange(t._1, t._2, t._3, t._4) +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala index be734b80272d..ea87e960379f 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala @@ -33,7 +33,7 @@ import org.I0Itec.zkclient.ZkClient import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver} -import org.apache.spark.util.Utils +import org.apache.spark.util.ThreadUtils /** * ReliableKafkaReceiver offers the ability to reliably store data into BlockManager without loss. @@ -121,7 +121,7 @@ class ReliableKafkaReceiver[ zkClient = new ZkClient(consumerConfig.zkConnect, consumerConfig.zkSessionTimeoutMs, consumerConfig.zkConnectionTimeoutMs, ZKStringSerializer) - messageHandlerThreadPool = Utils.newDaemonFixedThreadPool( + messageHandlerThreadPool = ThreadUtils.newDaemonFixedThreadPool( topics.values.sum, "KafkaMessageHandler") blockGenerator.start() @@ -201,12 +201,31 @@ class ReliableKafkaReceiver[ topicPartitionOffsetMap.clear() } - /** Store the ready-to-be-stored block and commit the related offsets to zookeeper. */ + /** + * Store the ready-to-be-stored block and commit the related offsets to zookeeper. This method + * will try a fixed number of times to push the block. If the push fails, the receiver is stopped. + */ private def storeBlockAndCommitOffset( blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { - store(arrayBuffer.asInstanceOf[mutable.ArrayBuffer[(K, V)]]) - Option(blockOffsetMap.get(blockId)).foreach(commitOffset) - blockOffsetMap.remove(blockId) + var count = 0 + var pushed = false + var exception: Exception = null + while (!pushed && count <= 3) { + try { + store(arrayBuffer.asInstanceOf[mutable.ArrayBuffer[(K, V)]]) + pushed = true + } catch { + case ex: Exception => + count += 1 + exception = ex + } + } + if (pushed) { + Option(blockOffsetMap.get(blockId)).foreach(commitOffset) + blockOffsetMap.remove(blockId) + } else { + stop("Error while storing block into Spark", exception) + } } /** diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java new file mode 100644 index 000000000000..4c1d6a03eb2b --- /dev/null +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Arrays; + +import scala.Tuple2; + +import kafka.common.TopicAndPartition; +import kafka.message.MessageAndMetadata; +import kafka.serializer.StringDecoder; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.streaming.Durations; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +public class JavaDirectKafkaStreamSuite implements Serializable { + private transient JavaStreamingContext ssc = null; + private transient KafkaTestUtils kafkaTestUtils = null; + + @Before + public void setUp() { + kafkaTestUtils = new KafkaTestUtils(); + kafkaTestUtils.setup(); + SparkConf sparkConf = new SparkConf() + .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); + ssc = new JavaStreamingContext(sparkConf, Durations.milliseconds(200)); + } + + @After + public void tearDown() { + if (ssc != null) { + ssc.stop(); + ssc = null; + } + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown(); + kafkaTestUtils = null; + } + } + + @Test + public void testKafkaStream() throws InterruptedException { + String topic1 = "topic1"; + String topic2 = "topic2"; + + String[] topic1data = createTopicAndSendData(topic1); + String[] topic2data = createTopicAndSendData(topic2); + + HashSet sent = new HashSet(); + sent.addAll(Arrays.asList(topic1data)); + sent.addAll(Arrays.asList(topic2data)); + + HashMap kafkaParams = new HashMap(); + kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); + kafkaParams.put("auto.offset.reset", "smallest"); + + JavaDStream stream1 = KafkaUtils.createDirectStream( + ssc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + kafkaParams, + topicToSet(topic1) + ).map( + new Function, String>() { + @Override + public String call(Tuple2 kv) throws Exception { + return kv._2(); + } + } + ); + + JavaDStream stream2 = KafkaUtils.createDirectStream( + ssc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + String.class, + kafkaParams, + topicOffsetToMap(topic2, (long) 0), + new Function, String>() { + @Override + public String call(MessageAndMetadata msgAndMd) throws Exception { + return msgAndMd.message(); + } + } + ); + JavaDStream unifiedStream = stream1.union(stream2); + + final HashSet result = new HashSet(); + unifiedStream.foreachRDD( + new Function, Void>() { + @Override + public Void call(JavaRDD rdd) throws Exception { + result.addAll(rdd.collect()); + return null; + } + } + ); + ssc.start(); + long startTime = System.currentTimeMillis(); + boolean matches = false; + while (!matches && System.currentTimeMillis() - startTime < 20000) { + matches = sent.size() == result.size(); + Thread.sleep(50); + } + Assert.assertEquals(sent, result); + ssc.stop(); + } + + private HashSet topicToSet(String topic) { + HashSet topicSet = new HashSet(); + topicSet.add(topic); + return topicSet; + } + + private HashMap topicOffsetToMap(String topic, Long offsetToStart) { + HashMap topicMap = new HashMap(); + topicMap.put(new TopicAndPartition(topic, 0), offsetToStart); + return topicMap; + } + + private String[] createTopicAndSendData(String topic) { + String[] data = { topic + "-1", topic + "-2", topic + "-3"}; + kafkaTestUtils.createTopic(topic); + kafkaTestUtils.sendMessages(topic, data); + return data; + } +} diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java new file mode 100644 index 000000000000..a9dc6e50613c --- /dev/null +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka; + +import java.io.Serializable; +import java.util.HashMap; + +import scala.Tuple2; + +import kafka.common.TopicAndPartition; +import kafka.message.MessageAndMetadata; +import kafka.serializer.StringDecoder; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; + +public class JavaKafkaRDDSuite implements Serializable { + private transient JavaSparkContext sc = null; + private transient KafkaTestUtils kafkaTestUtils = null; + + @Before + public void setUp() { + kafkaTestUtils = new KafkaTestUtils(); + kafkaTestUtils.setup(); + SparkConf sparkConf = new SparkConf() + .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); + sc = new JavaSparkContext(sparkConf); + } + + @After + public void tearDown() { + if (sc != null) { + sc.stop(); + sc = null; + } + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown(); + kafkaTestUtils = null; + } + } + + @Test + public void testKafkaRDD() throws InterruptedException { + String topic1 = "topic1"; + String topic2 = "topic2"; + + String[] topic1data = createTopicAndSendData(topic1); + String[] topic2data = createTopicAndSendData(topic2); + + HashMap kafkaParams = new HashMap(); + kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); + + OffsetRange[] offsetRanges = { + OffsetRange.create(topic1, 0, 0, 1), + OffsetRange.create(topic2, 0, 0, 1) + }; + + HashMap emptyLeaders = new HashMap(); + HashMap leaders = new HashMap(); + String[] hostAndPort = kafkaTestUtils.brokerAddress().split(":"); + Broker broker = Broker.create(hostAndPort[0], Integer.parseInt(hostAndPort[1])); + leaders.put(new TopicAndPartition(topic1, 0), broker); + leaders.put(new TopicAndPartition(topic2, 0), broker); + + JavaRDD rdd1 = KafkaUtils.createRDD( + sc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + kafkaParams, + offsetRanges + ).map( + new Function, String>() { + @Override + public String call(Tuple2 kv) throws Exception { + return kv._2(); + } + } + ); + + JavaRDD rdd2 = KafkaUtils.createRDD( + sc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + String.class, + kafkaParams, + offsetRanges, + emptyLeaders, + new Function, String>() { + @Override + public String call(MessageAndMetadata msgAndMd) throws Exception { + return msgAndMd.message(); + } + } + ); + + JavaRDD rdd3 = KafkaUtils.createRDD( + sc, + String.class, + String.class, + StringDecoder.class, + StringDecoder.class, + String.class, + kafkaParams, + offsetRanges, + leaders, + new Function, String>() { + @Override + public String call(MessageAndMetadata msgAndMd) throws Exception { + return msgAndMd.message(); + } + } + ); + + // just making sure the java user apis work; the scala tests handle logic corner cases + long count1 = rdd1.count(); + long count2 = rdd2.count(); + long count3 = rdd3.count(); + Assert.assertTrue(count1 > 0); + Assert.assertEquals(count1, count2); + Assert.assertEquals(count1, count3); + } + + private String[] createTopicAndSendData(String topic) { + String[] data = { topic + "-1", topic + "-2", topic + "-3"}; + kafkaTestUtils.createTopic(topic); + kafkaTestUtils.sendMessages(topic, data); + return data; + } +} diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index 6e1abf3f385e..540f4ceabab4 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -22,37 +22,32 @@ import java.util.List; import java.util.Random; -import org.apache.spark.SparkConf; -import org.apache.spark.streaming.Duration; -import scala.Predef; import scala.Tuple2; -import scala.collection.JavaConverters; - -import junit.framework.Assert; import kafka.serializer.StringDecoder; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.storage.StorageLevel; +import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; -import org.junit.Test; -import org.junit.After; -import org.junit.Before; - public class JavaKafkaStreamSuite implements Serializable { private transient JavaStreamingContext ssc = null; private transient Random random = new Random(); - private transient KafkaStreamSuiteBase suiteBase = null; + private transient KafkaTestUtils kafkaTestUtils = null; @Before public void setUp() { - suiteBase = new KafkaStreamSuiteBase() { }; - suiteBase.setupKafka(); - System.clearProperty("spark.driver.port"); + kafkaTestUtils = new KafkaTestUtils(); + kafkaTestUtils.setup(); SparkConf sparkConf = new SparkConf() .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); ssc = new JavaStreamingContext(sparkConf, new Duration(500)); @@ -60,10 +55,15 @@ public void setUp() { @After public void tearDown() { - ssc.stop(); - ssc = null; - System.clearProperty("spark.driver.port"); - suiteBase.tearDownKafka(); + if (ssc != null) { + ssc.stop(); + ssc = null; + } + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown(); + kafkaTestUtils = null; + } } @Test @@ -77,14 +77,11 @@ public void testKafkaStream() throws InterruptedException { sent.put("b", 3); sent.put("c", 10); - suiteBase.createTopic(topic); - HashMap tmp = new HashMap(sent); - suiteBase.produceAndSendMessage(topic, - JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap( - Predef.>conforms())); + kafkaTestUtils.createTopic(topic); + kafkaTestUtils.sendMessages(topic, sent); HashMap kafkaParams = new HashMap(); - kafkaParams.put("zookeeper.connect", suiteBase.zkAddress()); + kafkaParams.put("zookeeper.connect", kafkaTestUtils.zkAddress()); kafkaParams.put("group.id", "test-consumer-" + random.nextInt(10000)); kafkaParams.put("auto.offset.reset", "smallest"); @@ -127,6 +124,7 @@ public Void call(JavaPairRDD rdd) throws Exception { ); ssc.start(); + long startTime = System.currentTimeMillis(); boolean sizeMatches = false; while (!sizeMatches && System.currentTimeMillis() - startTime < 20000) { @@ -137,6 +135,5 @@ public Void call(JavaPairRDD rdd) throws Exception { for (String k : sent.keySet()) { Assert.assertEquals(sent.get(k).intValue(), result.get(k).intValue()); } - ssc.stop(); } } diff --git a/external/kafka/src/test/resources/log4j.properties b/external/kafka/src/test/resources/log4j.properties index 9697237bfa1a..75e3b53a093f 100644 --- a/external/kafka/src/test/resources/log4j.properties +++ b/external/kafka/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala new file mode 100644 index 000000000000..415730f5559c --- /dev/null +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import java.io.File + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ +import scala.language.postfixOps + +import kafka.common.TopicAndPartition +import kafka.message.MessageAndMetadata +import kafka.serializer.StringDecoder +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} +import org.scalatest.concurrent.Eventually + +import org.apache.spark.{Logging, SparkConf, SparkContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.Utils + +class DirectKafkaStreamSuite + extends FunSuite + with BeforeAndAfter + with BeforeAndAfterAll + with Eventually + with Logging { + val sparkConf = new SparkConf() + .setMaster("local[4]") + .setAppName(this.getClass.getSimpleName) + + private var sc: SparkContext = _ + private var ssc: StreamingContext = _ + private var testDir: File = _ + + private var kafkaTestUtils: KafkaTestUtils = _ + + override def beforeAll { + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() + } + + override def afterAll { + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } + + after { + if (ssc != null) { + ssc.stop() + sc = null + } + if (sc != null) { + sc.stop() + } + if (testDir != null) { + Utils.deleteRecursively(testDir) + } + } + + + test("basic stream receiving with multiple topics and smallest starting offset") { + val topics = Set("basic1", "basic2", "basic3") + val data = Map("a" -> 7, "b" -> 9) + topics.foreach { t => + kafkaTestUtils.createTopic(t) + kafkaTestUtils.sendMessages(t, data) + } + val totalSent = data.values.sum * topics.size + val kafkaParams = Map( + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "auto.offset.reset" -> "smallest" + ) + + ssc = new StreamingContext(sparkConf, Milliseconds(200)) + val stream = withClue("Error creating direct stream") { + KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, topics) + } + + val allReceived = new ArrayBuffer[(String, String)] + + stream.foreachRDD { rdd => + // Get the offset ranges in the RDD + val offsets = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + val collected = rdd.mapPartitionsWithIndex { (i, iter) => + // For each partition, get size of the range in the partition, + // and the number of items in the partition + val off = offsets(i) + val all = iter.toSeq + val partSize = all.size + val rangeSize = off.untilOffset - off.fromOffset + Iterator((partSize, rangeSize)) + }.collect + + // Verify whether number of elements in each partition + // matches with the corresponding offset range + collected.foreach { case (partSize, rangeSize) => + assert(partSize === rangeSize, "offset ranges are wrong") + } + } + stream.foreachRDD { rdd => allReceived ++= rdd.collect() } + ssc.start() + eventually(timeout(20000.milliseconds), interval(200.milliseconds)) { + assert(allReceived.size === totalSent, + "didn't get expected number of messages, messages:\n" + allReceived.mkString("\n")) + } + ssc.stop() + } + + test("receiving from largest starting offset") { + val topic = "largest" + val topicPartition = TopicAndPartition(topic, 0) + val data = Map("a" -> 10) + kafkaTestUtils.createTopic(topic) + val kafkaParams = Map( + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "auto.offset.reset" -> "largest" + ) + val kc = new KafkaCluster(kafkaParams) + def getLatestOffset(): Long = { + kc.getLatestLeaderOffsets(Set(topicPartition)).right.get(topicPartition).offset + } + + // Send some initial messages before starting context + kafkaTestUtils.sendMessages(topic, data) + eventually(timeout(10 seconds), interval(20 milliseconds)) { + assert(getLatestOffset() > 3) + } + val offsetBeforeStart = getLatestOffset() + + // Setup context and kafka stream with largest offset + ssc = new StreamingContext(sparkConf, Milliseconds(200)) + val stream = withClue("Error creating direct stream") { + KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, Set(topic)) + } + assert( + stream.asInstanceOf[DirectKafkaInputDStream[_, _, _, _, _]] + .fromOffsets(topicPartition) >= offsetBeforeStart, + "Start offset not from latest" + ) + + val collectedData = new mutable.ArrayBuffer[String]() + stream.map { _._2 }.foreachRDD { rdd => collectedData ++= rdd.collect() } + ssc.start() + val newData = Map("b" -> 10) + kafkaTestUtils.sendMessages(topic, newData) + eventually(timeout(10 seconds), interval(50 milliseconds)) { + collectedData.contains("b") + } + assert(!collectedData.contains("a")) + } + + + test("creating stream by offset") { + val topic = "offset" + val topicPartition = TopicAndPartition(topic, 0) + val data = Map("a" -> 10) + kafkaTestUtils.createTopic(topic) + val kafkaParams = Map( + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "auto.offset.reset" -> "largest" + ) + val kc = new KafkaCluster(kafkaParams) + def getLatestOffset(): Long = { + kc.getLatestLeaderOffsets(Set(topicPartition)).right.get(topicPartition).offset + } + + // Send some initial messages before starting context + kafkaTestUtils.sendMessages(topic, data) + eventually(timeout(10 seconds), interval(20 milliseconds)) { + assert(getLatestOffset() >= 10) + } + val offsetBeforeStart = getLatestOffset() + + // Setup context and kafka stream with largest offset + ssc = new StreamingContext(sparkConf, Milliseconds(200)) + val stream = withClue("Error creating direct stream") { + KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder, String]( + ssc, kafkaParams, Map(topicPartition -> 11L), + (m: MessageAndMetadata[String, String]) => m.message()) + } + assert( + stream.asInstanceOf[DirectKafkaInputDStream[_, _, _, _, _]] + .fromOffsets(topicPartition) >= offsetBeforeStart, + "Start offset not from latest" + ) + + val collectedData = new mutable.ArrayBuffer[String]() + stream.foreachRDD { rdd => collectedData ++= rdd.collect() } + ssc.start() + val newData = Map("b" -> 10) + kafkaTestUtils.sendMessages(topic, newData) + eventually(timeout(10 seconds), interval(50 milliseconds)) { + collectedData.contains("b") + } + assert(!collectedData.contains("a")) + } + + // Test to verify the offset ranges can be recovered from the checkpoints + test("offset recovery") { + val topic = "recovery" + kafkaTestUtils.createTopic(topic) + testDir = Utils.createTempDir() + + val kafkaParams = Map( + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "auto.offset.reset" -> "smallest" + ) + + // Send data to Kafka and wait for it to be received + def sendDataAndWaitForReceive(data: Seq[Int]) { + val strings = data.map { _.toString} + kafkaTestUtils.sendMessages(topic, strings.map { _ -> 1}.toMap) + eventually(timeout(10 seconds), interval(50 milliseconds)) { + assert(strings.forall { DirectKafkaStreamSuite.collectedData.contains }) + } + } + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(100)) + val kafkaStream = withClue("Error creating direct stream") { + KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, Set(topic)) + } + val keyedStream = kafkaStream.map { v => "key" -> v._2.toInt } + val stateStream = keyedStream.updateStateByKey { (values: Seq[Int], state: Option[Int]) => + Some(values.sum + state.getOrElse(0)) + } + ssc.checkpoint(testDir.getAbsolutePath) + + // This is to collect the raw data received from Kafka + kafkaStream.foreachRDD { (rdd: RDD[(String, String)], time: Time) => + val data = rdd.map { _._2 }.collect() + DirectKafkaStreamSuite.collectedData.appendAll(data) + } + + // This is ensure all the data is eventually receiving only once + stateStream.foreachRDD { (rdd: RDD[(String, Int)]) => + rdd.collect().headOption.foreach { x => DirectKafkaStreamSuite.total = x._2 } + } + ssc.start() + + // Send some data and wait for them to be received + for (i <- (1 to 10).grouped(4)) { + sendDataAndWaitForReceive(i) + } + + // Verify that offset ranges were generated + val offsetRangesBeforeStop = getOffsetRanges(kafkaStream) + assert(offsetRangesBeforeStop.size >= 1, "No offset ranges generated") + assert( + offsetRangesBeforeStop.head._2.forall { _.fromOffset === 0 }, + "starting offset not zero" + ) + ssc.stop() + logInfo("====== RESTARTING ========") + + // Recover context from checkpoints + ssc = new StreamingContext(testDir.getAbsolutePath) + val recoveredStream = ssc.graph.getInputStreams().head.asInstanceOf[DStream[(String, String)]] + + // Verify offset ranges have been recovered + val recoveredOffsetRanges = getOffsetRanges(recoveredStream) + assert(recoveredOffsetRanges.size > 0, "No offset ranges recovered") + val earlierOffsetRangesAsSets = offsetRangesBeforeStop.map { x => (x._1, x._2.toSet) } + assert( + recoveredOffsetRanges.forall { or => + earlierOffsetRangesAsSets.contains((or._1, or._2.toSet)) + }, + "Recovered ranges are not the same as the ones generated" + ) + + // Restart context, give more data and verify the total at the end + // If the total is write that means each records has been received only once + ssc.start() + sendDataAndWaitForReceive(11 to 20) + eventually(timeout(10 seconds), interval(50 milliseconds)) { + assert(DirectKafkaStreamSuite.total === (1 to 20).sum) + } + ssc.stop() + } + + /** Get the generated offset ranges from the DirectKafkaStream */ + private def getOffsetRanges[K, V]( + kafkaStream: DStream[(K, V)]): Seq[(Time, Array[OffsetRange])] = { + kafkaStream.generatedRDDs.mapValues { rdd => + rdd.asInstanceOf[KafkaRDD[K, V, _, _, (K, V)]].offsetRanges + }.toSeq.sortBy { _._1 } + } +} + +object DirectKafkaStreamSuite { + val collectedData = new mutable.ArrayBuffer[String]() + var total = -1L +} diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala new file mode 100644 index 000000000000..7fb841b79cb6 --- /dev/null +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import scala.util.Random + +import kafka.common.TopicAndPartition +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +class KafkaClusterSuite extends FunSuite with BeforeAndAfterAll { + private val topic = "kcsuitetopic" + Random.nextInt(10000) + private val topicAndPartition = TopicAndPartition(topic, 0) + private var kc: KafkaCluster = null + + private var kafkaTestUtils: KafkaTestUtils = _ + + override def beforeAll() { + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() + + kafkaTestUtils.createTopic(topic) + kafkaTestUtils.sendMessages(topic, Map("a" -> 1)) + kc = new KafkaCluster(Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress)) + } + + override def afterAll() { + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } + + test("metadata apis") { + val leader = kc.findLeaders(Set(topicAndPartition)).right.get(topicAndPartition) + val leaderAddress = s"${leader._1}:${leader._2}" + assert(leaderAddress === kafkaTestUtils.brokerAddress, "didn't get leader") + + val parts = kc.getPartitions(Set(topic)).right.get + assert(parts(topicAndPartition), "didn't get partitions") + + val err = kc.getPartitions(Set(topic + "BAD")) + assert(err.isLeft, "getPartitions for a nonexistant topic should be an error") + } + + test("leader offset apis") { + val earliest = kc.getEarliestLeaderOffsets(Set(topicAndPartition)).right.get + assert(earliest(topicAndPartition).offset === 0, "didn't get earliest") + + val latest = kc.getLatestLeaderOffsets(Set(topicAndPartition)).right.get + assert(latest(topicAndPartition).offset === 1, "didn't get latest") + } + + test("consumer offset apis") { + val group = "kcsuitegroup" + Random.nextInt(10000) + + val offset = Random.nextInt(10000) + + val set = kc.setConsumerOffsets(group, Map(topicAndPartition -> offset)) + assert(set.isRight, "didn't set consumer offsets") + + val get = kc.getConsumerOffsets(group, Set(topicAndPartition)).right.get + assert(get(topicAndPartition) === offset, "didn't get consumer offsets") + } +} diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala new file mode 100644 index 000000000000..7d26ce50875b --- /dev/null +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import scala.util.Random + +import kafka.serializer.StringDecoder +import kafka.common.TopicAndPartition +import kafka.message.MessageAndMetadata +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark._ + +class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { + + private var kafkaTestUtils: KafkaTestUtils = _ + + private val sparkConf = new SparkConf().setMaster("local[4]") + .setAppName(this.getClass.getSimpleName) + private var sc: SparkContext = _ + + override def beforeAll { + sc = new SparkContext(sparkConf) + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() + } + + override def afterAll { + if (sc != null) { + sc.stop + sc = null + } + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } + + test("basic usage") { + val topic = "topicbasic" + kafkaTestUtils.createTopic(topic) + val messages = Set("the", "quick", "brown", "fox") + kafkaTestUtils.sendMessages(topic, messages.toArray) + + + val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "group.id" -> s"test-consumer-${Random.nextInt(10000)}") + + val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size)) + + val rdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( + sc, kafkaParams, offsetRanges) + + val received = rdd.map(_._2).collect.toSet + assert(received === messages) + } + + test("iterator boundary conditions") { + // the idea is to find e.g. off-by-one errors between what kafka has available and the rdd + val topic = "topic1" + val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) + kafkaTestUtils.createTopic(topic) + + val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "group.id" -> s"test-consumer-${Random.nextInt(10000)}") + + val kc = new KafkaCluster(kafkaParams) + + // this is the "lots of messages" case + kafkaTestUtils.sendMessages(topic, sent) + // rdd defined from leaders after sending messages, should get the number sent + val rdd = getRdd(kc, Set(topic)) + + assert(rdd.isDefined) + assert(rdd.get.count === sent.values.sum, "didn't get all sent messages") + + val ranges = rdd.get.asInstanceOf[HasOffsetRanges] + .offsetRanges.map(o => TopicAndPartition(o.topic, o.partition) -> o.untilOffset).toMap + + kc.setConsumerOffsets(kafkaParams("group.id"), ranges) + + // this is the "0 messages" case + val rdd2 = getRdd(kc, Set(topic)) + // shouldn't get anything, since message is sent after rdd was defined + val sentOnlyOne = Map("d" -> 1) + + kafkaTestUtils.sendMessages(topic, sentOnlyOne) + assert(rdd2.isDefined) + assert(rdd2.get.count === 0, "got messages when there shouldn't be any") + + // this is the "exactly 1 message" case, namely the single message from sentOnlyOne above + val rdd3 = getRdd(kc, Set(topic)) + // send lots of messages after rdd was defined, they shouldn't show up + kafkaTestUtils.sendMessages(topic, Map("extra" -> 22)) + + assert(rdd3.isDefined) + assert(rdd3.get.count === sentOnlyOne.values.sum, "didn't get exactly one message") + + } + + // get an rdd from the committed consumer offsets until the latest leader offsets, + private def getRdd(kc: KafkaCluster, topics: Set[String]) = { + val groupId = kc.kafkaParams("group.id") + def consumerOffsets(topicPartitions: Set[TopicAndPartition]) = { + kc.getConsumerOffsets(groupId, topicPartitions).right.toOption.orElse( + kc.getEarliestLeaderOffsets(topicPartitions).right.toOption.map { offs => + offs.map(kv => kv._1 -> kv._2.offset) + } + ) + } + kc.getPartitions(topics).right.toOption.flatMap { topicPartitions => + consumerOffsets(topicPartitions).flatMap { from => + kc.getLatestLeaderOffsets(topicPartitions).right.toOption.map { until => + val offsetRanges = from.map { case (tp: TopicAndPartition, fromOffset: Long) => + OffsetRange(tp.topic, tp.partition, fromOffset, until(tp).offset) + }.toArray + + val leaders = until.map { case (tp: TopicAndPartition, lo: KafkaCluster.LeaderOffset) => + tp -> Broker(lo.host, lo.port) + }.toMap + + KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder, String]( + sc, kc.kafkaParams, offsetRanges, leaders, + (mmd: MessageAndMetadata[String, String]) => s"${mmd.offset} ${mmd.message}") + } + } + } + } +} diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index b19c053ebfc4..24699dfc33ad 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -17,199 +17,38 @@ package org.apache.spark.streaming.kafka -import java.io.File -import java.net.InetSocketAddress -import java.util.Properties - import scala.collection.mutable import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random -import kafka.admin.CreateTopicCommand -import kafka.common.{KafkaException, TopicAndPartition} -import kafka.producer.{KeyedMessage, Producer, ProducerConfig} -import kafka.serializer.{StringDecoder, StringEncoder} -import kafka.server.{KafkaConfig, KafkaServer} -import kafka.utils.ZKStringSerializer -import org.I0Itec.zkclient.ZkClient -import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} -import org.scalatest.{BeforeAndAfter, FunSuite} +import kafka.serializer.StringDecoder +import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.scalatest.concurrent.Eventually -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.SparkConf import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} -import org.apache.spark.util.Utils - -/** - * This is an abstract base class for Kafka testsuites. This has the functionality to set up - * and tear down local Kafka servers, and to push data using Kafka producers. - */ -abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Logging { - - var zkAddress: String = _ - var zkClient: ZkClient = _ - - private val zkHost = "localhost" - private val zkConnectionTimeout = 6000 - private val zkSessionTimeout = 6000 - private var zookeeper: EmbeddedZookeeper = _ - private var zkPort: Int = 0 - private var brokerPort = 9092 - private var brokerConf: KafkaConfig = _ - private var server: KafkaServer = _ - private var producer: Producer[String, String] = _ - - def setupKafka() { - // Zookeeper server startup - zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort") - // Get the actual zookeeper binding port - zkPort = zookeeper.actualPort - zkAddress = s"$zkHost:$zkPort" - logInfo("==================== 0 ====================") - - zkClient = new ZkClient(zkAddress, zkSessionTimeout, zkConnectionTimeout, - ZKStringSerializer) - logInfo("==================== 1 ====================") - - // Kafka broker startup - var bindSuccess: Boolean = false - while(!bindSuccess) { - try { - val brokerProps = getBrokerConfig() - brokerConf = new KafkaConfig(brokerProps) - server = new KafkaServer(brokerConf) - logInfo("==================== 2 ====================") - server.startup() - logInfo("==================== 3 ====================") - bindSuccess = true - } catch { - case e: KafkaException => - if (e.getMessage != null && e.getMessage.contains("Socket server failed to bind to")) { - brokerPort += 1 - } - case e: Exception => throw new Exception("Kafka server create failed", e) - } - } - - Thread.sleep(2000) - logInfo("==================== 4 ====================") - } - - def tearDownKafka() { - if (producer != null) { - producer.close() - producer = null - } - - if (server != null) { - server.shutdown() - server = null - } - - brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) } - - if (zkClient != null) { - zkClient.close() - zkClient = null - } - - if (zookeeper != null) { - zookeeper.shutdown() - zookeeper = null - } - } - - private def createTestMessage(topic: String, sent: Map[String, Int]) - : Seq[KeyedMessage[String, String]] = { - val messages = for ((s, freq) <- sent; i <- 0 until freq) yield { - new KeyedMessage[String, String](topic, s) - } - messages.toSeq - } - def createTopic(topic: String) { - CreateTopicCommand.createTopic(zkClient, topic, 1, 1, "0") - logInfo("==================== 5 ====================") - // wait until metadata is propagated - waitUntilMetadataIsPropagated(topic, 0) - } - - def produceAndSendMessage(topic: String, sent: Map[String, Int]) { - producer = new Producer[String, String](new ProducerConfig(getProducerConfig())) - producer.send(createTestMessage(topic, sent): _*) - producer.close() - logInfo("==================== 6 ====================") - } +class KafkaStreamSuite extends FunSuite with Eventually with BeforeAndAfterAll { + private var ssc: StreamingContext = _ + private var kafkaTestUtils: KafkaTestUtils = _ - private def getBrokerConfig(): Properties = { - val props = new Properties() - props.put("broker.id", "0") - props.put("host.name", "localhost") - props.put("port", brokerPort.toString) - props.put("log.dir", Utils.createTempDir().getAbsolutePath) - props.put("zookeeper.connect", zkAddress) - props.put("log.flush.interval.messages", "1") - props.put("replica.socket.timeout.ms", "1500") - props + override def beforeAll(): Unit = { + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() } - private def getProducerConfig(): Properties = { - val brokerAddr = brokerConf.hostName + ":" + brokerConf.port - val props = new Properties() - props.put("metadata.broker.list", brokerAddr) - props.put("serializer.class", classOf[StringEncoder].getName) - props - } - - private def waitUntilMetadataIsPropagated(topic: String, partition: Int) { - eventually(timeout(1000 milliseconds), interval(100 milliseconds)) { - assert( - server.apis.leaderCache.keySet.contains(TopicAndPartition(topic, partition)), - s"Partition [$topic, $partition] metadata not propagated after timeout" - ) - } - } - - class EmbeddedZookeeper(val zkConnect: String) { - val random = new Random() - val snapshotDir = Utils.createTempDir() - val logDir = Utils.createTempDir() - - val zookeeper = new ZooKeeperServer(snapshotDir, logDir, 500) - val (ip, port) = { - val splits = zkConnect.split(":") - (splits(0), splits(1).toInt) - } - val factory = new NIOServerCnxnFactory() - factory.configure(new InetSocketAddress(ip, port), 16) - factory.startup(zookeeper) - - val actualPort = factory.getLocalPort - - def shutdown() { - factory.shutdown() - Utils.deleteRecursively(snapshotDir) - Utils.deleteRecursively(logDir) - } - } -} - - -class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter { - var ssc: StreamingContext = _ - - before { - setupKafka() - } - - after { + override def afterAll(): Unit = { if (ssc != null) { ssc.stop() ssc = null } - tearDownKafka() + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } } test("Kafka input stream") { @@ -217,10 +56,10 @@ class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter { ssc = new StreamingContext(sparkConf, Milliseconds(500)) val topic = "topic1" val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) - createTopic(topic) - produceAndSendMessage(topic, sent) + kafkaTestUtils.createTopic(topic) + kafkaTestUtils.sendMessages(topic, sent) - val kafkaParams = Map("zookeeper.connect" -> zkAddress, + val kafkaParams = Map("zookeeper.connect" -> kafkaTestUtils.zkAddress, "group.id" -> s"test-consumer-${Random.nextInt(10000)}", "auto.offset.reset" -> "smallest") @@ -234,14 +73,14 @@ class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter { result.put(kv._1, count) } } + ssc.start() + eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { assert(sent.size === result.size) sent.keys.foreach { k => assert(sent(k) === result(k).toInt) } } - ssc.stop() } } - diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala index 64ccc92c81fa..38548dd73b82 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.streaming.kafka - import java.io.File import scala.collection.mutable @@ -25,61 +24,71 @@ import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random -import com.google.common.io.Files import kafka.serializer.StringDecoder import kafka.utils.{ZKGroupTopicDirs, ZkUtils} -import org.apache.commons.io.FileUtils -import org.scalatest.BeforeAndAfter +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.scalatest.concurrent.Eventually import org.apache.spark.SparkConf import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} +import org.apache.spark.util.Utils -class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter with Eventually { +class ReliableKafkaStreamSuite extends FunSuite + with BeforeAndAfterAll with BeforeAndAfter with Eventually { - val sparkConf = new SparkConf() + private val sparkConf = new SparkConf() .setMaster("local[4]") .setAppName(this.getClass.getSimpleName) .set("spark.streaming.receiver.writeAheadLog.enable", "true") - val data = Map("a" -> 10, "b" -> 10, "c" -> 10) + private val data = Map("a" -> 10, "b" -> 10, "c" -> 10) + private var kafkaTestUtils: KafkaTestUtils = _ - var groupId: String = _ - var kafkaParams: Map[String, String] = _ - var ssc: StreamingContext = _ - var tempDirectory: File = null + private var groupId: String = _ + private var kafkaParams: Map[String, String] = _ + private var ssc: StreamingContext = _ + private var tempDirectory: File = null + + override def beforeAll() : Unit = { + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() - before { - setupKafka() groupId = s"test-consumer-${Random.nextInt(10000)}" kafkaParams = Map( - "zookeeper.connect" -> zkAddress, + "zookeeper.connect" -> kafkaTestUtils.zkAddress, "group.id" -> groupId, "auto.offset.reset" -> "smallest" ) + tempDirectory = Utils.createTempDir() + } + + override def afterAll(): Unit = { + Utils.deleteRecursively(tempDirectory) + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } + + before { ssc = new StreamingContext(sparkConf, Milliseconds(500)) - tempDirectory = Files.createTempDir() ssc.checkpoint(tempDirectory.getAbsolutePath) } after { if (ssc != null) { ssc.stop() + ssc = null } - if (tempDirectory != null && tempDirectory.exists()) { - FileUtils.deleteDirectory(tempDirectory) - tempDirectory = null - } - tearDownKafka() } - test("Reliable Kafka input stream with single topic") { - var topic = "test-topic" - createTopic(topic) - produceAndSendMessage(topic, data) + val topic = "test-topic" + kafkaTestUtils.createTopic(topic) + kafkaTestUtils.sendMessages(topic, data) // Verify whether the offset of this group/topic/partition is 0 before starting. assert(getCommitOffset(groupId, topic, 0) === None) @@ -95,6 +104,7 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter } } ssc.start() + eventually(timeout(20000 milliseconds), interval(200 milliseconds)) { // A basic process verification for ReliableKafkaReceiver. // Verify whether received message number is equal to the sent message number. @@ -104,14 +114,13 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter // Verify the offset number whether it is equal to the total message number. assert(getCommitOffset(groupId, topic, 0) === Some(29L)) } - ssc.stop() } test("Reliable Kafka input stream with multiple topics") { val topics = Map("topic1" -> 1, "topic2" -> 1, "topic3" -> 1) topics.foreach { case (t, _) => - createTopic(t) - produceAndSendMessage(t, data) + kafkaTestUtils.createTopic(t) + kafkaTestUtils.sendMessages(t, data) } // Before started, verify all the group/topic/partition offsets are 0. @@ -122,19 +131,18 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter ssc, kafkaParams, topics, StorageLevel.MEMORY_ONLY) stream.foreachRDD(_ => Unit) ssc.start() + eventually(timeout(20000 milliseconds), interval(100 milliseconds)) { // Verify the offset for each group/topic to see whether they are equal to the expected one. topics.foreach { case (t, _) => assert(getCommitOffset(groupId, t, 0) === Some(29L)) } } - ssc.stop() } /** Getting partition offset from Zookeeper. */ private def getCommitOffset(groupId: String, topic: String, partition: Int): Option[Long] = { - assert(zkClient != null, "Zookeeper client is not initialized") val topicDirs = new ZKGroupTopicDirs(groupId, topic) val zkPath = s"${topicDirs.consumerOffsetDir}/$partition" - ZkUtils.readDataMaybeNull(zkClient, zkPath)._1.map(_.toLong) + ZkUtils.readDataMaybeNull(kafkaTestUtils.zookeeperClient, zkPath)._1.map(_.toLong) } } diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 560c8b9d1827..98f95a9a64fa 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala index 77661f71ada2..3c0ef94cb0fa 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala @@ -17,23 +17,23 @@ package org.apache.spark.streaming.mqtt +import java.io.IOException +import java.util.concurrent.Executors +import java.util.Properties + +import scala.collection.JavaConversions._ import scala.collection.Map import scala.collection.mutable.HashMap -import scala.collection.JavaConversions._ import scala.reflect.ClassTag -import java.util.Properties -import java.util.concurrent.Executors -import java.io.IOException - +import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken import org.eclipse.paho.client.mqttv3.MqttCallback import org.eclipse.paho.client.mqttv3.MqttClient import org.eclipse.paho.client.mqttv3.MqttClientPersistence -import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence -import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken import org.eclipse.paho.client.mqttv3.MqttException import org.eclipse.paho.client.mqttv3.MqttMessage import org.eclipse.paho.client.mqttv3.MqttTopic +import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel @@ -55,14 +55,14 @@ class MQTTInputDStream( brokerUrl: String, topic: String, storageLevel: StorageLevel - ) extends ReceiverInputDStream[String](ssc_) with Logging { - + ) extends ReceiverInputDStream[String](ssc_) { + def getReceiver(): Receiver[String] = { new MQTTReceiver(brokerUrl, topic, storageLevel) } } -private[streaming] +private[streaming] class MQTTReceiver( brokerUrl: String, topic: String, @@ -72,38 +72,40 @@ class MQTTReceiver( def onStop() { } - + def onStart() { - // Set up persistence for messages + // Set up persistence for messages val persistence = new MemoryPersistence() // Initializing Mqtt Client specifying brokerUrl, clientID and MqttClientPersistance val client = new MqttClient(brokerUrl, MqttClient.generateClientId(), persistence) - // Connect to MqttBroker - client.connect() - - // Subscribe to Mqtt topic - client.subscribe(topic) - // Callback automatically triggers as and when new message arrives on specified topic - val callback: MqttCallback = new MqttCallback() { + val callback = new MqttCallback() { // Handles Mqtt message - override def messageArrived(arg0: String, arg1: MqttMessage) { - store(new String(arg1.getPayload(),"utf-8")) + override def messageArrived(topic: String, message: MqttMessage) { + store(new String(message.getPayload(),"utf-8")) } - override def deliveryComplete(arg0: IMqttDeliveryToken) { + override def deliveryComplete(token: IMqttDeliveryToken) { } - override def connectionLost(arg0: Throwable) { - restart("Connection lost ", arg0) + override def connectionLost(cause: Throwable) { + restart("Connection lost ", cause) } } - // Set up callback for MqttClient + // Set up callback for MqttClient. This needs to happen before + // connecting or subscribing, otherwise messages may be lost client.setCallback(callback) + + // Connect to MqttBroker + client.connect() + + // Subscribe to Mqtt topic + client.subscribe(topic) + } } diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala index c5ffe51f9986..1142d0f56ba3 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala @@ -17,10 +17,11 @@ package org.apache.spark.streaming.mqtt +import scala.reflect.ClassTag + import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext, JavaDStream} -import scala.reflect.ClassTag import org.apache.spark.streaming.dstream.{ReceiverInputDStream, DStream} object MQTTUtils { diff --git a/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 1e24da7f5f60..cfedb5a042a3 100644 --- a/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -31,7 +31,7 @@ public void setUp() { SparkConf conf = new SparkConf() .setMaster("local[2]") .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/external/mqtt/src/test/resources/log4j.properties b/external/mqtt/src/test/resources/log4j.properties index 9697237bfa1a..75e3b53a093f 100644 --- a/external/mqtt/src/test/resources/log4j.properties +++ b/external/mqtt/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index fe53a29cba0c..a19a72c58a70 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -18,11 +18,14 @@ package org.apache.spark.streaming.mqtt import java.net.{URI, ServerSocket} +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit import scala.concurrent.duration._ import scala.language.postfixOps import org.apache.activemq.broker.{TransportConnector, BrokerService} +import org.apache.commons.lang3.RandomUtils import org.eclipse.paho.client.mqttv3._ import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence @@ -32,14 +35,16 @@ import org.scalatest.concurrent.Eventually import org.apache.spark.streaming.{Milliseconds, StreamingContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.scheduler.StreamingListener +import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted import org.apache.spark.SparkConf import org.apache.spark.util.Utils class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { private val batchDuration = Milliseconds(500) - private val master: String = "local[2]" - private val framework: String = this.getClass.getSimpleName + private val master = "local[2]" + private val framework = this.getClass.getSimpleName private val freePort = findFreePort() private val brokerUri = "//localhost:" + freePort private val topic = "def" @@ -65,9 +70,9 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { test("mqtt input stream") { val sendMessage = "MQTT demo for spark streaming" - val receiveStream: ReceiverInputDStream[String] = + val receiveStream = MQTTUtils.createStream(ssc, "tcp:" + brokerUri, topic, StorageLevel.MEMORY_ONLY) - var receiveMessage: List[String] = List() + @volatile var receiveMessage: List[String] = List() receiveStream.foreachRDD { rdd => if (rdd.collect.length > 0) { receiveMessage = receiveMessage ::: List(rdd.first) @@ -75,6 +80,11 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { } } ssc.start() + + // wait for the receiver to start before publishing data, or we risk failing + // the test nondeterministically. See SPARK-4631 + waitForReceiverToStart() + publishData(sendMessage) eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { assert(sendMessage.equals(receiveMessage(0))) @@ -84,6 +94,7 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { private def setupMQTT() { broker = new BrokerService() + broker.setDataDirectoryFile(Utils.createTempDir()) connector = new TransportConnector() connector.setName("mqtt") connector.setUri(new URI("mqtt:" + brokerUri)) @@ -103,7 +114,8 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { } private def findFreePort(): Int = { - Utils.startServiceOnPort(23456, (trialPort: Int) => { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { val socket = new ServerSocket(trialPort) socket.close() (null, trialPort) @@ -113,16 +125,23 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { def publishData(data: String): Unit = { var client: MqttClient = null try { - val persistence: MqttClientPersistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) + val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) client = new MqttClient("tcp:" + brokerUri, MqttClient.generateClientId(), persistence) client.connect() if (client.isConnected) { - val msgTopic: MqttTopic = client.getTopic(topic) - val message: MqttMessage = new MqttMessage(data.getBytes("utf-8")) + val msgTopic = client.getTopic(topic) + val message = new MqttMessage(data.getBytes("utf-8")) message.setQos(1) message.setRetained(true) - for (i <- 0 to 100) { - msgTopic.publish(message) + + for (i <- 0 to 10) { + try { + msgTopic.publish(message) + } catch { + case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => + // wait for Spark streaming to consume something from the message queue + Thread.sleep(50) + } } } } finally { @@ -131,4 +150,18 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { client = null } } + + /** + * Block until at least one receiver has started or timeout occurs. + */ + private def waitForReceiverToStart() = { + val latch = new CountDownLatch(1) + ssc.addStreamingListener(new StreamingListener { + override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { + latch.countDown() + } + }) + + assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.") + } } diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index da6ffe7662f6..8b6a8959ac4c 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala index 4eacc47da569..7cf02d85d73d 100644 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala +++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala @@ -70,7 +70,7 @@ class TwitterReceiver( try { val newTwitterStream = new TwitterStreamFactory().getInstance(twitterAuth) newTwitterStream.addListener(new StatusListener { - def onStatus(status: Status) = { + def onStatus(status: Status): Unit = { store(status) } // Unimplemented diff --git a/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 1e24da7f5f60..cfedb5a042a3 100644 --- a/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -31,7 +31,7 @@ public void setUp() { SparkConf conf = new SparkConf() .setMaster("local[2]") .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/external/twitter/src/test/resources/log4j.properties b/external/twitter/src/test/resources/log4j.properties index 64bfc5745088..9a3569789d2e 100644 --- a/external/twitter/src/test/resources/log4j.properties +++ b/external/twitter/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index e919c2c9b19e..a50d378b3433 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala index 554705878ee7..588e6bac7b14 100644 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala +++ b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala @@ -29,13 +29,16 @@ import org.apache.spark.streaming.receiver.ActorHelper /** * A receiver to subscribe to ZeroMQ stream. */ -private[streaming] class ZeroMQReceiver[T: ClassTag](publisherUrl: String, - subscribe: Subscribe, - bytesToObjects: Seq[ByteString] => Iterator[T]) +private[streaming] class ZeroMQReceiver[T: ClassTag]( + publisherUrl: String, + subscribe: Subscribe, + bytesToObjects: Seq[ByteString] => Iterator[T]) extends Actor with ActorHelper with Logging { - override def preStart() = ZeroMQExtension(context.system) - .newSocket(SocketType.Sub, Listener(self), Connect(publisherUrl), subscribe) + override def preStart(): Unit = { + ZeroMQExtension(context.system) + .newSocket(SocketType.Sub, Listener(self), Connect(publisherUrl), subscribe) + } def receive: Receive = { diff --git a/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 1e24da7f5f60..cfedb5a042a3 100644 --- a/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -31,7 +31,7 @@ public void setUp() { SparkConf conf = new SparkConf() .setMaster("local[2]") .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/external/zeromq/src/test/resources/log4j.properties b/external/zeromq/src/test/resources/log4j.properties index 9697237bfa1a..75e3b53a093f 100644 --- a/external/zeromq/src/test/resources/log4j.properties +++ b/external/zeromq/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 0fb431808bac..4351a8a12fe2 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -19,8 +19,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/extras/java8-tests/src/test/resources/log4j.properties b/extras/java8-tests/src/test/resources/log4j.properties index 287c8e356350..eb3b1999eb99 100644 --- a/extras/java8-tests/src/test/resources/log4j.properties +++ b/extras/java8-tests/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index c815eda52bda..25847a1b33d9 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -19,8 +19,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT ../../pom.xml @@ -67,11 +67,6 @@ scalacheck_${scala.binary.version} test - - org.easymock - easymockclassextension - test - com.novocode junit-interface diff --git a/extras/kinesis-asl/src/main/resources/log4j.properties b/extras/kinesis-asl/src/main/resources/log4j.properties index 97348fb5b612..6cdc9286c5d7 100644 --- a/extras/kinesis-asl/src/main/resources/log4j.properties +++ b/extras/kinesis-asl/src/main/resources/log4j.properties @@ -31,7 +31,7 @@ log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n # Settings to quiet third party logs that are too verbose -log4j.logger.org.eclipse.jetty=WARN -log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO \ No newline at end of file diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala index 0b80b611cdce..588e86a1887e 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala @@ -18,9 +18,7 @@ package org.apache.spark.streaming.kinesis import org.apache.spark.Logging import org.apache.spark.streaming.Duration -import org.apache.spark.streaming.util.Clock -import org.apache.spark.streaming.util.ManualClock -import org.apache.spark.streaming.util.SystemClock +import org.apache.spark.util.{Clock, ManualClock, SystemClock} /** * This is a helper class for managing checkpoint clocks. @@ -35,7 +33,7 @@ private[kinesis] class KinesisCheckpointState( /* Initialize the checkpoint clock using the given currentClock + checkpointInterval millis */ val checkpointClock = new ManualClock() - checkpointClock.setTime(currentClock.currentTime() + checkpointInterval.milliseconds) + checkpointClock.setTime(currentClock.getTimeMillis() + checkpointInterval.milliseconds) /** * Check if it's time to checkpoint based on the current time and the derived time @@ -44,13 +42,13 @@ private[kinesis] class KinesisCheckpointState( * @return true if it's time to checkpoint */ def shouldCheckpoint(): Boolean = { - new SystemClock().currentTime() > checkpointClock.currentTime() + new SystemClock().getTimeMillis() > checkpointClock.getTimeMillis() } /** * Advance the checkpoint clock by the checkpoint interval. */ def advanceCheckpoint() = { - checkpointClock.addToTime(checkpointInterval.milliseconds) + checkpointClock.advance(checkpointInterval.milliseconds) } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 1bd1f324298e..a7fe4476cacb 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -23,6 +23,7 @@ import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.Duration import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.util.Utils import com.amazonaws.auth.AWSCredentialsProvider import com.amazonaws.auth.DefaultAWSCredentialsProviderChain @@ -118,7 +119,7 @@ private[kinesis] class KinesisReceiver( * method. */ override def onStart() { - workerId = InetAddress.getLocalHost.getHostAddress() + ":" + UUID.randomUUID() + workerId = Utils.localHostName() + ":" + UUID.randomUUID() credentialsProvider = new DefaultAWSCredentialsProviderChain() kinesisClientLibConfiguration = new KinesisClientLibConfiguration(appName, streamName, credentialsProvider, workerId).withKinesisEndpoint(endpointUrl) diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index 8ecc2d90160b..af8cd875b454 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -104,7 +104,7 @@ private[kinesis] class KinesisRecordProcessor( logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint of ${batch.size}" + s" records for shardId $shardId") logDebug(s"Checkpoint: Next checkpoint is at " + - s" ${checkpointState.checkpointClock.currentTime()} for shardId $shardId") + s" ${checkpointState.checkpointClock.getTimeMillis()} for shardId $shardId") } } catch { case e: Throwable => { diff --git a/extras/kinesis-asl/src/test/resources/log4j.properties b/extras/kinesis-asl/src/test/resources/log4j.properties index 853ef0ed2986..edbecdae9209 100644 --- a/extras/kinesis-asl/src/test/resources/log4j.properties +++ b/extras/kinesis-asl/src/test/resources/log4j.properties @@ -24,4 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index 41dbd64c2b1f..255fe6581960 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -20,17 +20,17 @@ import java.nio.ByteBuffer import scala.collection.JavaConversions.seqAsJavaList -import org.apache.spark.annotation.Experimental import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.Milliseconds import org.apache.spark.streaming.Seconds import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.TestSuiteBase -import org.apache.spark.streaming.util.Clock -import org.apache.spark.streaming.util.ManualClock +import org.apache.spark.util.{ManualClock, Clock} + +import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter import org.scalatest.Matchers -import org.scalatest.mock.EasyMockSugar +import org.scalatest.mock.MockitoSugar import com.amazonaws.services.kinesis.clientlibrary.exceptions.InvalidStateException import com.amazonaws.services.kinesis.clientlibrary.exceptions.KinesisClientLibDependencyException @@ -42,10 +42,10 @@ import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason import com.amazonaws.services.kinesis.model.Record /** - * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor + * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor */ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAfter - with EasyMockSugar { + with MockitoSugar { val app = "TestKinesisReceiver" val stream = "mySparkStream" @@ -73,6 +73,14 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft currentClockMock = mock[Clock] } + override def afterFunction(): Unit = { + super.afterFunction() + // Since this suite was originally written using EasyMock, add this to preserve the old + // mocking semantics (see SPARK-5735 for more details) + verifyNoMoreInteractions(receiverMock, checkpointerMock, checkpointClockMock, + checkpointStateMock, currentClockMock) + } + test("kinesis utils api") { val ssc = new StreamingContext(master, framework, batchDuration) // Tests the API, does not actually test data receiving @@ -83,193 +91,175 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft } test("process records including store and checkpoint") { - val expectedCheckpointIntervalMillis = 10 - expecting { - receiverMock.isStopped().andReturn(false).once() - receiverMock.store(record1.getData().array()).once() - receiverMock.store(record2.getData().array()).once() - checkpointStateMock.shouldCheckpoint().andReturn(true).once() - checkpointerMock.checkpoint().once() - checkpointStateMock.advanceCheckpoint().once() - } - whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, - checkpointStateMock) - recordProcessor.processRecords(batch, checkpointerMock) - } + when(receiverMock.isStopped()).thenReturn(false) + when(checkpointStateMock.shouldCheckpoint()).thenReturn(true) + + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + recordProcessor.processRecords(batch, checkpointerMock) + + verify(receiverMock, times(1)).isStopped() + verify(receiverMock, times(1)).store(record1.getData().array()) + verify(receiverMock, times(1)).store(record2.getData().array()) + verify(checkpointStateMock, times(1)).shouldCheckpoint() + verify(checkpointerMock, times(1)).checkpoint() + verify(checkpointStateMock, times(1)).advanceCheckpoint() } test("shouldn't store and checkpoint when receiver is stopped") { - expecting { - receiverMock.isStopped().andReturn(true).once() - } - whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, - checkpointStateMock) - recordProcessor.processRecords(batch, checkpointerMock) - } + when(receiverMock.isStopped()).thenReturn(true) + + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + recordProcessor.processRecords(batch, checkpointerMock) + + verify(receiverMock, times(1)).isStopped() } test("shouldn't checkpoint when exception occurs during store") { - expecting { - receiverMock.isStopped().andReturn(false).once() - receiverMock.store(record1.getData().array()).andThrow(new RuntimeException()).once() - } - whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { - intercept[RuntimeException] { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, - checkpointStateMock) - recordProcessor.processRecords(batch, checkpointerMock) - } + when(receiverMock.isStopped()).thenReturn(false) + when(receiverMock.store(record1.getData().array())).thenThrow(new RuntimeException()) + + intercept[RuntimeException] { + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + recordProcessor.processRecords(batch, checkpointerMock) } + + verify(receiverMock, times(1)).isStopped() + verify(receiverMock, times(1)).store(record1.getData().array()) } test("should set checkpoint time to currentTime + checkpoint interval upon instantiation") { - expecting { - currentClockMock.currentTime().andReturn(0).once() - } - whenExecuting(currentClockMock) { + when(currentClockMock.getTimeMillis()).thenReturn(0) + val checkpointIntervalMillis = 10 - val checkpointState = new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) - assert(checkpointState.checkpointClock.currentTime() == checkpointIntervalMillis) - } + val checkpointState = + new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) + assert(checkpointState.checkpointClock.getTimeMillis() == checkpointIntervalMillis) + + verify(currentClockMock, times(1)).getTimeMillis() } test("should checkpoint if we have exceeded the checkpoint interval") { - expecting { - currentClockMock.currentTime().andReturn(0).once() - } - whenExecuting(currentClockMock) { - val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MinValue), currentClockMock) - assert(checkpointState.shouldCheckpoint()) - } + when(currentClockMock.getTimeMillis()).thenReturn(0) + + val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MinValue), currentClockMock) + assert(checkpointState.shouldCheckpoint()) + + verify(currentClockMock, times(1)).getTimeMillis() } test("shouldn't checkpoint if we have not exceeded the checkpoint interval") { - expecting { - currentClockMock.currentTime().andReturn(0).once() - } - whenExecuting(currentClockMock) { - val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MaxValue), currentClockMock) - assert(!checkpointState.shouldCheckpoint()) - } + when(currentClockMock.getTimeMillis()).thenReturn(0) + + val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MaxValue), currentClockMock) + assert(!checkpointState.shouldCheckpoint()) + + verify(currentClockMock, times(1)).getTimeMillis() } test("should add to time when advancing checkpoint") { - expecting { - currentClockMock.currentTime().andReturn(0).once() - } - whenExecuting(currentClockMock) { - val checkpointIntervalMillis = 10 - val checkpointState = new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) - assert(checkpointState.checkpointClock.currentTime() == checkpointIntervalMillis) - checkpointState.advanceCheckpoint() - assert(checkpointState.checkpointClock.currentTime() == (2 * checkpointIntervalMillis)) - } + when(currentClockMock.getTimeMillis()).thenReturn(0) + + val checkpointIntervalMillis = 10 + val checkpointState = + new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) + assert(checkpointState.checkpointClock.getTimeMillis() == checkpointIntervalMillis) + checkpointState.advanceCheckpoint() + assert(checkpointState.checkpointClock.getTimeMillis() == (2 * checkpointIntervalMillis)) + + verify(currentClockMock, times(1)).getTimeMillis() } test("shutdown should checkpoint if the reason is TERMINATE") { - expecting { - checkpointerMock.checkpoint().once() - } - whenExecuting(checkpointerMock, checkpointStateMock) { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, - checkpointStateMock) - val reason = ShutdownReason.TERMINATE - recordProcessor.shutdown(checkpointerMock, reason) - } + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + val reason = ShutdownReason.TERMINATE + recordProcessor.shutdown(checkpointerMock, reason) + + verify(checkpointerMock, times(1)).checkpoint() } test("shutdown should not checkpoint if the reason is something other than TERMINATE") { - expecting { - } - whenExecuting(checkpointerMock, checkpointStateMock) { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, - checkpointStateMock) - recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE) - recordProcessor.shutdown(checkpointerMock, null) - } + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE) + recordProcessor.shutdown(checkpointerMock, null) + + verify(checkpointerMock, never()).checkpoint() } test("retry success on first attempt") { val expectedIsStopped = false - expecting { - receiverMock.isStopped().andReturn(expectedIsStopped).once() - } - whenExecuting(receiverMock) { - val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) - assert(actualVal == expectedIsStopped) - } + when(receiverMock.isStopped()).thenReturn(expectedIsStopped) + + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + + verify(receiverMock, times(1)).isStopped() } test("retry success on second attempt after a Kinesis throttling exception") { val expectedIsStopped = false - expecting { - receiverMock.isStopped().andThrow(new ThrottlingException("error message")) - .andReturn(expectedIsStopped).once() - } - whenExecuting(receiverMock) { - val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) - assert(actualVal == expectedIsStopped) - } + when(receiverMock.isStopped()) + .thenThrow(new ThrottlingException("error message")) + .thenReturn(expectedIsStopped) + + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + + verify(receiverMock, times(2)).isStopped() } test("retry success on second attempt after a Kinesis dependency exception") { val expectedIsStopped = false - expecting { - receiverMock.isStopped().andThrow(new KinesisClientLibDependencyException("error message")) - .andReturn(expectedIsStopped).once() - } - whenExecuting(receiverMock) { - val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) - assert(actualVal == expectedIsStopped) - } + when(receiverMock.isStopped()) + .thenThrow(new KinesisClientLibDependencyException("error message")) + .thenReturn(expectedIsStopped) + + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + + verify(receiverMock, times(2)).isStopped() } test("retry failed after a shutdown exception") { - expecting { - checkpointerMock.checkpoint().andThrow(new ShutdownException("error message")).once() - } - whenExecuting(checkpointerMock) { - intercept[ShutdownException] { - KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) - } + when(checkpointerMock.checkpoint()).thenThrow(new ShutdownException("error message")) + + intercept[ShutdownException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) } + + verify(checkpointerMock, times(1)).checkpoint() } test("retry failed after an invalid state exception") { - expecting { - checkpointerMock.checkpoint().andThrow(new InvalidStateException("error message")).once() - } - whenExecuting(checkpointerMock) { - intercept[InvalidStateException] { - KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) - } + when(checkpointerMock.checkpoint()).thenThrow(new InvalidStateException("error message")) + + intercept[InvalidStateException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) } + + verify(checkpointerMock, times(1)).checkpoint() } test("retry failed after unexpected exception") { - expecting { - checkpointerMock.checkpoint().andThrow(new RuntimeException("error message")).once() - } - whenExecuting(checkpointerMock) { - intercept[RuntimeException] { - KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) - } + when(checkpointerMock.checkpoint()).thenThrow(new RuntimeException("error message")) + + intercept[RuntimeException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) } + + verify(checkpointerMock, times(1)).checkpoint() } test("retry failed after exhausing all retries") { val expectedErrorMessage = "final try error message" - expecting { - checkpointerMock.checkpoint().andThrow(new ThrottlingException("error message")) - .andThrow(new ThrottlingException(expectedErrorMessage)).once() - } - whenExecuting(checkpointerMock) { - val exception = intercept[RuntimeException] { - KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) - } - exception.getMessage().shouldBe(expectedErrorMessage) + when(checkpointerMock.checkpoint()) + .thenThrow(new ThrottlingException("error message")) + .thenThrow(new ThrottlingException(expectedErrorMessage)) + + val exception = intercept[RuntimeException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) } + exception.getMessage().shouldBe(expectedErrorMessage) + + verify(checkpointerMock, times(2)).checkpoint() } } diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index d1427f6a0c6e..e14bbae4a9b6 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -19,8 +19,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT ../../pom.xml @@ -42,7 +42,7 @@ - com.codahale.metrics + io.dropwizard.metrics metrics-ganglia diff --git a/graphx/pom.xml b/graphx/pom.xml index 72374aae6da9..d38a3aa8256b 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT ../pom.xml @@ -41,9 +41,18 @@ ${project.version} - org.jblas - jblas - ${jblas.version} + com.google.guava + guava + + + com.github.fommil.netlib + core + ${netlib.java.version} + + + net.sourceforge.f2j + arpack_combined_all + 0.1 org.scalacheck diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala index f70715fca6ee..23430179f12e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala @@ -49,3 +49,19 @@ abstract class EdgeContext[VD, ED, A] { et } } + +object EdgeContext { + + /** + * Extractor mainly used for Graph#aggregateMessages*. + * Example: + * {{{ + * val messages = graph.aggregateMessages( + * case ctx @ EdgeContext(_, _, _, _, attr) => + * ctx.sendToDst(attr) + * , _ + _) + * }}} + */ + def unapply[VD, ED, A](edge: EdgeContext[VD, ED, A]): Some[(VertexId, VertexId, VD, VD, ED)] = + Some(edge.srcId, edge.dstId, edge.srcAttr, edge.dstAttr, edge.attr) +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala index 6f03eb143977..058c8c8aa1b2 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala @@ -34,12 +34,12 @@ class EdgeDirection private (private val name: String) extends Serializable { override def toString: String = "EdgeDirection." + name - override def equals(o: Any) = o match { + override def equals(o: Any): Boolean = o match { case other: EdgeDirection => other.name == name case _ => false } - override def hashCode = name.hashCode + override def hashCode: Int = name.hashCode } @@ -48,14 +48,14 @@ class EdgeDirection private (private val name: String) extends Serializable { */ object EdgeDirection { /** Edges arriving at a vertex. */ - final val In = new EdgeDirection("In") + final val In: EdgeDirection = new EdgeDirection("In") /** Edges originating from a vertex. */ - final val Out = new EdgeDirection("Out") + final val Out: EdgeDirection = new EdgeDirection("Out") /** Edges originating from *or* arriving at a vertex of interest. */ - final val Either = new EdgeDirection("Either") + final val Either: EdgeDirection = new EdgeDirection("Either") /** Edges originating from *and* arriving at a vertex of interest. */ - final val Both = new EdgeDirection("Both") + final val Both: EdgeDirection = new EdgeDirection("Both") } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala index 9d473d5ebda4..c8790cac3d8a 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala @@ -62,7 +62,7 @@ class EdgeTriplet[VD, ED] extends Edge[ED] { def vertexAttr(vid: VertexId): VD = if (srcId == vid) srcAttr else { assert(dstId == vid); dstAttr } - override def toString = ((srcId, srcAttr), (dstId, dstAttr), attr).toString() + override def toString: String = ((srcId, srcAttr), (dstId, dstAttr), attr).toString() def toTuple: ((VertexId, VD), (VertexId, VD), ED) = ((srcId, srcAttr), (dstId, dstAttr), attr) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index 84b72b390ca3..36dc7b0f86c8 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -55,7 +55,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * @return an RDD containing the edges in this graph * * @see [[Edge]] for the edge type. - * @see [[triplets]] to get an RDD which contains all the edges + * @see [[Graph#triplets]] to get an RDD which contains all the edges * along with their vertex data. * */ @@ -104,6 +104,18 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab */ def checkpoint(): Unit + /** + * Return whether this Graph has been checkpointed or not. + * This returns true iff both the vertices RDD and edges RDD have been checkpointed. + */ + def isCheckpointed: Boolean + + /** + * Gets the name of the files to which this Graph was checkpointed. + * (The vertices RDD and edges RDD are checkpointed separately.) + */ + def getCheckpointFiles: Seq[String] + /** * Uncaches both vertices and edges of this graph. This is useful in iterative algorithms that * build a new graph in each iteration. @@ -397,7 +409,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * {{{ * val rawGraph: Graph[_, _] = Graph.textFile("twittergraph") * val inDeg: RDD[(VertexId, Int)] = - * aggregateMessages[Int](ctx => ctx.sendToDst(1), _ + _) + * rawGraph.aggregateMessages[Int](ctx => ctx.sendToDst(1), _ + _) * }}} * * @note By expressing computation at the edge level we achieve diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala index 4933aecba128..21187be7678a 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala @@ -77,7 +77,7 @@ object GraphLoader extends Logging { if (!line.isEmpty && line(0) != '#') { val lineArray = line.split("\\s+") if (lineArray.length < 2) { - logWarning("Invalid line: " + line) + throw new IllegalArgumentException("Invalid line: " + line) } val srcId = lineArray(0).toLong val dstId = lineArray(1).toLong diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index dc8b4789c4b6..86f611d55aa8 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -113,7 +113,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali * Collect the neighbor vertex attributes for each vertex. * * @note This function could be highly inefficient on power-law - * graphs where high degree vertices may force a large ammount of + * graphs where high degree vertices may force a large amount of * information to be collected to a single location. * * @param edgeDirection the direction along which to collect @@ -187,7 +187,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali /** * Join the vertices with an RDD and then apply a function from the - * the vertex and RDD entry to a new vertex value. The input table + * vertex and RDD entry to a new vertex value. The input table * should contain at most one entry for each vertex. If no entry is * provided the map function is skipped and the old value is used. * diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 5e55620147df..01b013ff716f 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -78,8 +78,8 @@ object Pregel extends Logging { * * @param graph the input graph. * - * @param initialMsg the message each vertex will receive at the on - * the first iteration + * @param initialMsg the message each vertex will receive at the first + * iteration * * @param maxIterations the maximum number of iterations to run for * diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index 09ae3f9f6c09..a9f04b559c3d 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -122,8 +122,36 @@ abstract class VertexRDD[VD]( def mapValues[VD2: ClassTag](f: (VertexId, VD) => VD2): VertexRDD[VD2] /** - * Hides vertices that are the same between `this` and `other`; for vertices that are different, - * keeps the values from `other`. + * For each VertexId present in both `this` and `other`, minus will act as a set difference + * operation returning only those unique VertexId's present in `this`. + * + * @param other an RDD to run the set operation against + */ + def minus(other: RDD[(VertexId, VD)]): VertexRDD[VD] + + /** + * For each VertexId present in both `this` and `other`, minus will act as a set difference + * operation returning only those unique VertexId's present in `this`. + * + * @param other a VertexRDD to run the set operation against + */ + def minus(other: VertexRDD[VD]): VertexRDD[VD] + + /** + * For each vertex present in both `this` and `other`, `diff` returns only those vertices with + * differing values; for values that are different, keeps the values from `other`. This is + * only guaranteed to work if the VertexRDDs share a common ancestor. + * + * @param other the other RDD[(VertexId, VD)] with which to diff against. + */ + def diff(other: RDD[(VertexId, VD)]): VertexRDD[VD] + + /** + * For each vertex present in both `this` and `other`, `diff` returns only those vertices with + * differing values; for values that are different, keeps the values from `other`. This is + * only guaranteed to work if the VertexRDDs share a common ancestor. + * + * @param other the other VertexRDD with which to diff against. */ def diff(other: VertexRDD[VD]): VertexRDD[VD] diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index 373af7544837..c56157080925 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -324,7 +324,7 @@ class EdgePartition[ * * @return an iterator over edges in the partition */ - def iterator = new Iterator[Edge[ED]] { + def iterator: Iterator[Edge[ED]] = new Iterator[Edge[ED]] { private[this] val edge = new Edge[ED] private[this] var pos = 0 @@ -351,7 +351,7 @@ class EdgePartition[ override def hasNext: Boolean = pos < EdgePartition.this.size - override def next() = { + override def next(): EdgeTriplet[VD, ED] = { val triplet = new EdgeTriplet[VD, ED] val localSrcId = localSrcIds(pos) val localDstId = localDstIds(pos) @@ -518,11 +518,11 @@ private class AggregatingEdgeContext[VD, ED, A]( _attr = attr } - override def srcId = _srcId - override def dstId = _dstId - override def srcAttr = _srcAttr - override def dstAttr = _dstAttr - override def attr = _attr + override def srcId: VertexId = _srcId + override def dstId: VertexId = _dstId + override def srcAttr: VD = _srcAttr + override def dstAttr: VD = _dstAttr + override def attr: ED = _attr override def sendToSrc(msg: A) { send(_localSrcId, msg) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala index 897c7ee12a43..c88b2f65a86c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala @@ -19,7 +19,7 @@ package org.apache.spark.graphx.impl import scala.reflect.{classTag, ClassTag} -import org.apache.spark.{OneToOneDependency, Partition, Partitioner, TaskContext} +import org.apache.spark.{OneToOneDependency, HashPartitioner} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -46,7 +46,7 @@ class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] ( * partitioner that allows co-partitioning with `partitionsRDD`. */ override val partitioner = - partitionsRDD.partitioner.orElse(Some(Partitioner.defaultPartitioner(partitionsRDD))) + partitionsRDD.partitioner.orElse(Some(new HashPartitioner(partitions.size))) override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect() @@ -70,10 +70,20 @@ class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] ( this } - override def checkpoint() = { + override def getStorageLevel: StorageLevel = partitionsRDD.getStorageLevel + + override def checkpoint(): Unit = { partitionsRDD.checkpoint() } - + + override def isCheckpointed: Boolean = { + firstParent[(PartitionID, EdgePartition[ED, VD])].isCheckpointed + } + + override def getCheckpointFile: Option[String] = { + partitionsRDD.getCheckpointFile + } + /** The number of edges in the RDD. */ override def count(): Long = { partitionsRDD.map(_._2.size.toLong).reduce(_ + _) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 3f4a900d5b60..90a74d23a26c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -70,6 +70,17 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( replicatedVertexView.edges.checkpoint() } + override def isCheckpointed: Boolean = { + vertices.isCheckpointed && replicatedVertexView.edges.isCheckpointed + } + + override def getCheckpointFiles: Seq[String] = { + Seq(vertices.getCheckpointFile, replicatedVertexView.edges.getCheckpointFile).flatMap { + case Some(path) => Seq(path) + case None => Seq() + } + } + override def unpersist(blocking: Boolean = true): Graph[VD, ED] = { unpersistVertices(blocking) replicatedVertexView.edges.unpersist(blocking) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala index 8ab255bd4038..1df86449fa0c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala @@ -50,7 +50,7 @@ class ReplicatedVertexView[VD: ClassTag, ED: ClassTag]( * Return a new `ReplicatedVertexView` where edges are reversed and shipping levels are swapped to * match. */ - def reverse() = { + def reverse(): ReplicatedVertexView[VD, ED] = { val newEdges = edges.mapEdgePartitions((pid, part) => part.reverse) new ReplicatedVertexView(newEdges, hasDstId, hasSrcId) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala index 4fd2548b7faf..b90f9fa32705 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala @@ -88,6 +88,21 @@ private[graphx] abstract class VertexPartitionBaseOps this.withMask(newMask) } + /** Hides the VertexId's that are the same between `this` and `other`. */ + def minus(other: Self[VD]): Self[VD] = { + if (self.index != other.index) { + logWarning("Minus operations on two VertexPartitions with different indexes is slow.") + minus(createUsingIndex(other.iterator)) + } else { + self.withMask(self.mask.andNot(other.mask)) + } + } + + /** Hides the VertexId's that are the same between `this` and `other`. */ + def minus(other: Iterator[(VertexId, VD)]): Self[VD] = { + minus(createUsingIndex(other)) + } + /** * Hides vertices that are the same between this and other. For vertices that are different, keeps * the values from `other`. The indices of `this` and `other` must be the same. diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala index 9732c5b00c6d..33ac7b0ed609 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala @@ -71,10 +71,20 @@ class VertexRDDImpl[VD] private[graphx] ( this } - override def checkpoint() = { + override def getStorageLevel: StorageLevel = partitionsRDD.getStorageLevel + + override def checkpoint(): Unit = { partitionsRDD.checkpoint() } - + + override def isCheckpointed: Boolean = { + firstParent[ShippableVertexPartition[VD]].isCheckpointed + } + + override def getCheckpointFile: Option[String] = { + partitionsRDD.getCheckpointFile + } + /** The number of vertices in the RDD. */ override def count(): Long = { partitionsRDD.map(_.size).reduce(_ + _) @@ -93,9 +103,44 @@ class VertexRDDImpl[VD] private[graphx] ( override def mapValues[VD2: ClassTag](f: (VertexId, VD) => VD2): VertexRDD[VD2] = this.mapVertexPartitions(_.map(f)) + override def minus(other: RDD[(VertexId, VD)]): VertexRDD[VD] = { + minus(this.aggregateUsingIndex(other, (a: VD, b: VD) => a)) + } + + override def minus (other: VertexRDD[VD]): VertexRDD[VD] = { + other match { + case other: VertexRDD[_] if this.partitioner == other.partitioner => + this.withPartitionsRDD[VD]( + partitionsRDD.zipPartitions( + other.partitionsRDD, preservesPartitioning = true) { + (thisIter, otherIter) => + val thisPart = thisIter.next() + val otherPart = otherIter.next() + Iterator(thisPart.minus(otherPart)) + }) + case _ => + this.withPartitionsRDD[VD]( + partitionsRDD.zipPartitions( + other.partitionBy(this.partitioner.get), preservesPartitioning = true) { + (partIter, msgs) => partIter.map(_.minus(msgs)) + } + ) + } + } + + override def diff(other: RDD[(VertexId, VD)]): VertexRDD[VD] = { + diff(this.aggregateUsingIndex(other, (a: VD, b: VD) => a)) + } + override def diff(other: VertexRDD[VD]): VertexRDD[VD] = { + val otherPartition = other match { + case other: VertexRDD[_] if this.partitioner == other.partitioner => + other.partitionsRDD + case _ => + VertexRDD(other.partitionBy(this.partitioner.get)).partitionsRDD + } val newPartitionsRDD = partitionsRDD.zipPartitions( - other.partitionsRDD, preservesPartitioning = true + otherPartition, preservesPartitioning = true ) { (thisIter, otherIter) => val thisPart = thisIter.next() val otherPart = otherIter.next() @@ -123,7 +168,7 @@ class VertexRDDImpl[VD] private[graphx] ( // Test if the other vertex is a VertexRDD to choose the optimal join strategy. // If the other set is a VertexRDD then we use the much more efficient leftZipJoin other match { - case other: VertexRDD[_] => + case other: VertexRDD[_] if this.partitioner == other.partitioner => leftZipJoin(other)(f) case _ => this.withPartitionsRDD[VD3]( @@ -152,7 +197,7 @@ class VertexRDDImpl[VD] private[graphx] ( // Test if the other vertex is a VertexRDD to choose the optimal join strategy. // If the other set is a VertexRDD then we use the much more efficient innerZipJoin other match { - case other: VertexRDD[_] => + case other: VertexRDD[_] if this.partitioner == other.partitioner => innerZipJoin(other)(f) case _ => this.withPartitionsRDD( diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala index e2f6cc138958..859f89603904 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala @@ -37,7 +37,7 @@ object ConnectedComponents { */ def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[VertexId, ED] = { val ccGraph = graph.mapVertices { case (vid, _) => vid } - def sendMessage(edge: EdgeTriplet[VertexId, ED]) = { + def sendMessage(edge: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, VertexId)] = { if (edge.srcAttr < edge.dstAttr) { Iterator((edge.dstId, edge.srcAttr)) } else if (edge.srcAttr > edge.dstAttr) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala index 82e9e0651517..2bcf8684b8b8 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala @@ -43,7 +43,7 @@ object LabelPropagation { */ def run[VD, ED: ClassTag](graph: Graph[VD, ED], maxSteps: Int): Graph[VertexId, ED] = { val lpaGraph = graph.mapVertices { case (vid, _) => vid } - def sendMessage(e: EdgeTriplet[VertexId, ED]) = { + def sendMessage(e: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, Map[VertexId, VertexId])] = { Iterator((e.srcId, Map(e.dstAttr -> 1L)), (e.dstId, Map(e.srcAttr -> 1L))) } def mergeMessage(count1: Map[VertexId, Long], count2: Map[VertexId, Long]) @@ -54,7 +54,7 @@ object LabelPropagation { i -> (count1Val + count2Val) }.toMap } - def vertexProgram(vid: VertexId, attr: Long, message: Map[VertexId, Long]) = { + def vertexProgram(vid: VertexId, attr: Long, message: Map[VertexId, Long]): VertexId = { if (message.isEmpty) attr else message.maxBy(_._2)._1 } val initialMessage = Map[VertexId, Long]() diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index e139959c3f5c..042e366a29f5 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -25,8 +25,8 @@ import org.apache.spark.graphx._ /** * PageRank algorithm implementation. There are two implementations of PageRank implemented. * - * The first implementation uses the [[Pregel]] interface and runs PageRank for a fixed number - * of iterations: + * The first implementation uses the standalone [[Graph]] interface and runs PageRank + * for a fixed number of iterations: * {{{ * var PR = Array.fill(n)( 1.0 ) * val oldPR = Array.fill(n)( 1.0 ) @@ -38,7 +38,7 @@ import org.apache.spark.graphx._ * } * }}} * - * The second implementation uses the standalone [[Graph]] interface and runs PageRank until + * The second implementation uses the [[Pregel]] interface and runs PageRank until * convergence: * * {{{ @@ -156,7 +156,7 @@ object PageRank extends Logging { (newPR, newPR - oldPR) } - def sendMessage(edge: EdgeTriplet[(Double, Double), Double]) = { + def sendMessage(edge: EdgeTriplet[(Double, Double), Double]): Iterator[(VertexId, Double)] = { if (edge.srcAttr._2 > tol) { Iterator((edge.dstId, edge.srcAttr._2 * edge.attr)) } else { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index f58587e10a82..3b0e1628d86b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -18,7 +18,9 @@ package org.apache.spark.graphx.lib import scala.util.Random -import org.jblas.DoubleMatrix + +import com.github.fommil.netlib.BLAS.{getInstance => blas} + import org.apache.spark.rdd._ import org.apache.spark.graphx._ @@ -37,12 +39,23 @@ object SVDPlusPlus { var gamma7: Double) extends Serializable + /** + * This method is now replaced by the updated version of `run()` and returns exactly + * the same result. + */ + @deprecated("Call run()", "1.4.0") + def runSVDPlusPlus(edges: RDD[Edge[Double]], conf: Conf) + : (Graph[(Array[Double], Array[Double], Double, Double), Double], Double) = + { + run(edges, conf) + } + /** * Implement SVD++ based on "Factorization Meets the Neighborhood: * a Multifaceted Collaborative Filtering Model", * available at [[http://public.research.att.com/~volinsky/netflix/kdd08koren.pdf]]. * - * The prediction rule is rui = u + bu + bi + qi*(pu + |N(u)|^(-0.5)*sum(y)), + * The prediction rule is rui = u + bu + bi + qi*(pu + |N(u)|^^-0.5^^*sum(y)), * see the details on page 6. * * @param edges edges for constructing the graph @@ -52,16 +65,13 @@ object SVDPlusPlus { * @return a graph with vertex attributes containing the trained model */ def run(edges: RDD[Edge[Double]], conf: Conf) - : (Graph[(DoubleMatrix, DoubleMatrix, Double, Double), Double], Double) = + : (Graph[(Array[Double], Array[Double], Double, Double), Double], Double) = { // Generate default vertex attribute - def defaultF(rank: Int): (DoubleMatrix, DoubleMatrix, Double, Double) = { - val v1 = new DoubleMatrix(rank) - val v2 = new DoubleMatrix(rank) - for (i <- 0 until rank) { - v1.put(i, Random.nextDouble()) - v2.put(i, Random.nextDouble()) - } + def defaultF(rank: Int): (Array[Double], Array[Double], Double, Double) = { + // TODO: use a fixed random seed + val v1 = Array.fill(rank)(Random.nextDouble()) + val v2 = Array.fill(rank)(Random.nextDouble()) (v1, v2, 0.0, 0.0) } @@ -72,38 +82,47 @@ object SVDPlusPlus { // construct graph var g = Graph.fromEdges(edges, defaultF(conf.rank)).cache() + materialize(g) + edges.unpersist() // Calculate initial bias and norm val t0 = g.aggregateMessages[(Long, Double)]( ctx => { ctx.sendToSrc((1L, ctx.attr)); ctx.sendToDst((1L, ctx.attr)) }, (g1, g2) => (g1._1 + g2._1, g1._2 + g2._2)) - g = g.outerJoinVertices(t0) { - (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), + val gJoinT0 = g.outerJoinVertices(t0) { + (vid: VertexId, vd: (Array[Double], Array[Double], Double, Double), msg: Option[(Long, Double)]) => - (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1)) - } + (vd._1, vd._2, msg.get._2 / msg.get._1 - u, 1.0 / scala.math.sqrt(msg.get._1)) + }.cache() + materialize(gJoinT0) + g.unpersist() + g = gJoinT0 def sendMsgTrainF(conf: Conf, u: Double) (ctx: EdgeContext[ - (DoubleMatrix, DoubleMatrix, Double, Double), + (Array[Double], Array[Double], Double, Double), Double, - (DoubleMatrix, DoubleMatrix, Double)]) { + (Array[Double], Array[Double], Double)]) { val (usr, itm) = (ctx.srcAttr, ctx.dstAttr) val (p, q) = (usr._1, itm._1) - var pred = u + usr._3 + itm._3 + q.dot(usr._2) + val rank = p.length + var pred = u + usr._3 + itm._3 + blas.ddot(rank, q, 1, usr._2, 1) pred = math.max(pred, conf.minVal) pred = math.min(pred, conf.maxVal) val err = ctx.attr - pred - val updateP = q.mul(err) - .subColumnVector(p.mul(conf.gamma7)) - .mul(conf.gamma2) - val updateQ = usr._2.mul(err) - .subColumnVector(q.mul(conf.gamma7)) - .mul(conf.gamma2) - val updateY = q.mul(err * usr._4) - .subColumnVector(itm._2.mul(conf.gamma7)) - .mul(conf.gamma2) + // updateP = (err * q - conf.gamma7 * p) * conf.gamma2 + val updateP = q.clone() + blas.dscal(rank, err * conf.gamma2, updateP, 1) + blas.daxpy(rank, -conf.gamma7 * conf.gamma2, p, 1, updateP, 1) + // updateQ = (err * usr._2 - conf.gamma7 * q) * conf.gamma2 + val updateQ = usr._2.clone() + blas.dscal(rank, err * conf.gamma2, updateQ, 1) + blas.daxpy(rank, -conf.gamma7 * conf.gamma2, q, 1, updateQ, 1) + // updateY = (err * usr._4 * q - conf.gamma7 * itm._2) * conf.gamma2 + val updateY = q.clone() + blas.dscal(rank, err * usr._4 * conf.gamma2, updateY, 1) + blas.daxpy(rank, -conf.gamma7 * conf.gamma2, itm._2, 1, updateY, 1) ctx.sendToSrc((updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1)) ctx.sendToDst((updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1)) } @@ -111,49 +130,89 @@ object SVDPlusPlus { for (i <- 0 until conf.maxIters) { // Phase 1, calculate pu + |N(u)|^(-0.5)*sum(y) for user nodes g.cache() - val t1 = g.aggregateMessages[DoubleMatrix]( + val t1 = g.aggregateMessages[Array[Double]]( ctx => ctx.sendToSrc(ctx.dstAttr._2), - (g1, g2) => g1.addColumnVector(g2)) - g = g.outerJoinVertices(t1) { - (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), - msg: Option[DoubleMatrix]) => - if (msg.isDefined) (vd._1, vd._1 - .addColumnVector(msg.get.mul(vd._4)), vd._3, vd._4) else vd - } + (g1, g2) => { + val out = g1.clone() + blas.daxpy(out.length, 1.0, g2, 1, out, 1) + out + }) + val gJoinT1 = g.outerJoinVertices(t1) { + (vid: VertexId, vd: (Array[Double], Array[Double], Double, Double), + msg: Option[Array[Double]]) => + if (msg.isDefined) { + val out = vd._1.clone() + blas.daxpy(out.length, vd._4, msg.get, 1, out, 1) + (vd._1, out, vd._3, vd._4) + } else { + vd + } + }.cache() + materialize(gJoinT1) + g.unpersist() + g = gJoinT1 // Phase 2, update p for user nodes and q, y for item nodes g.cache() val t2 = g.aggregateMessages( sendMsgTrainF(conf, u), - (g1: (DoubleMatrix, DoubleMatrix, Double), g2: (DoubleMatrix, DoubleMatrix, Double)) => - (g1._1.addColumnVector(g2._1), g1._2.addColumnVector(g2._2), g1._3 + g2._3)) - g = g.outerJoinVertices(t2) { + (g1: (Array[Double], Array[Double], Double), g2: (Array[Double], Array[Double], Double)) => + { + val out1 = g1._1.clone() + blas.daxpy(out1.length, 1.0, g2._1, 1, out1, 1) + val out2 = g2._2.clone() + blas.daxpy(out2.length, 1.0, g2._2, 1, out2, 1) + (out1, out2, g1._3 + g2._3) + }) + val gJoinT2 = g.outerJoinVertices(t2) { (vid: VertexId, - vd: (DoubleMatrix, DoubleMatrix, Double, Double), - msg: Option[(DoubleMatrix, DoubleMatrix, Double)]) => - (vd._1.addColumnVector(msg.get._1), vd._2.addColumnVector(msg.get._2), - vd._3 + msg.get._3, vd._4) - } + vd: (Array[Double], Array[Double], Double, Double), + msg: Option[(Array[Double], Array[Double], Double)]) => { + val out1 = vd._1.clone() + blas.daxpy(out1.length, 1.0, msg.get._1, 1, out1, 1) + val out2 = vd._2.clone() + blas.daxpy(out2.length, 1.0, msg.get._2, 1, out2, 1) + (out1, out2, vd._3 + msg.get._3, vd._4) + } + }.cache() + materialize(gJoinT2) + g.unpersist() + g = gJoinT2 } // calculate error on training set def sendMsgTestF(conf: Conf, u: Double) - (ctx: EdgeContext[(DoubleMatrix, DoubleMatrix, Double, Double), Double, Double]) { + (ctx: EdgeContext[(Array[Double], Array[Double], Double, Double), Double, Double]) { val (usr, itm) = (ctx.srcAttr, ctx.dstAttr) val (p, q) = (usr._1, itm._1) - var pred = u + usr._3 + itm._3 + q.dot(usr._2) + var pred = u + usr._3 + itm._3 + blas.ddot(q.length, q, 1, usr._2, 1) pred = math.max(pred, conf.minVal) pred = math.min(pred, conf.maxVal) val err = (ctx.attr - pred) * (ctx.attr - pred) ctx.sendToDst(err) } + g.cache() val t3 = g.aggregateMessages[Double](sendMsgTestF(conf, u), _ + _) - g = g.outerJoinVertices(t3) { - (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[Double]) => + val gJoinT3 = g.outerJoinVertices(t3) { + (vid: VertexId, vd: (Array[Double], Array[Double], Double, Double), msg: Option[Double]) => if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd - } + }.cache() + materialize(gJoinT3) + g.unpersist() + g = gJoinT3 - (g, u) + // Convert DoubleMatrix to Array[Double]: + val newVertices = g.vertices.mapValues(v => (v._1.toArray, v._2.toArray, v._3, v._4)) + (Graph(newVertices, g.edges), u) } + + /** + * Forces materialization of a Graph by count()ing its RDDs. + */ + private def materialize(g: Graph[_,_]): Unit = { + g.vertices.count() + g.edges.count() + } + } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala index 590f0474957d..179f2843818e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala @@ -61,8 +61,8 @@ object ShortestPaths { } def sendMessage(edge: EdgeTriplet[SPMap, _]): Iterator[(VertexId, SPMap)] = { - val newAttr = incrementMap(edge.srcAttr) - if (edge.dstAttr != addMaps(newAttr, edge.dstAttr)) Iterator((edge.dstId, newAttr)) + val newAttr = incrementMap(edge.dstAttr) + if (edge.srcAttr != addMaps(newAttr, edge.srcAttr)) Iterator((edge.srcId, newAttr)) else Iterator.empty } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala index 57b01b6f2e1f..e2754ea699da 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala @@ -56,7 +56,7 @@ class GraphXPrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, private var _oldValues: Array[V] = null - override def size = keySet.size + override def size: Int = keySet.size /** Get the value for a given key */ def apply(k: K): V = { @@ -112,7 +112,7 @@ class GraphXPrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, } } - override def iterator = new Iterator[(K, V)] { + override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] { var pos = 0 var nextPair: (K, V) = computeNextPair() @@ -128,9 +128,9 @@ class GraphXPrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, } } - def hasNext = nextPair != null + def hasNext: Boolean = nextPair != null - def next() = { + def next(): (K, V) = { val pair = nextPair nextPair = computeNextPair() pair diff --git a/graphx/src/test/resources/log4j.properties b/graphx/src/test/resources/log4j.properties index 287c8e356350..eb3b1999eb99 100644 --- a/graphx/src/test/resources/log4j.properties +++ b/graphx/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala new file mode 100644 index 000000000000..eb1dbe52c2fd --- /dev/null +++ b/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx + +import org.scalatest.FunSuite + +import org.apache.spark.storage.StorageLevel + +class EdgeRDDSuite extends FunSuite with LocalSparkContext { + + test("cache, getStorageLevel") { + // test to see if getStorageLevel returns correct value after caching + withSpark { sc => + val verts = sc.parallelize(List((0L, 0), (1L, 1), (1L, 2), (2L, 3), (2L, 3), (2L, 3))) + val edges = EdgeRDD.fromEdges(sc.parallelize(List.empty[Edge[Int]])) + assert(edges.getStorageLevel == StorageLevel.NONE) + edges.cache() + assert(edges.getStorageLevel == StorageLevel.MEMORY_ONLY) + } + } + +} diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 9da0064104fb..a570e4ed75fc 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.graphx import org.scalatest.FunSuite -import com.google.common.io.Files - import org.apache.spark.SparkContext import org.apache.spark.graphx.Graph._ import org.apache.spark.graphx.PartitionStrategy._ import org.apache.spark.rdd._ +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils class GraphSuite extends FunSuite with LocalSparkContext { @@ -38,12 +38,12 @@ class GraphSuite extends FunSuite with LocalSparkContext { val doubleRing = ring ++ ring val graph = Graph.fromEdgeTuples(sc.parallelize(doubleRing), 1) assert(graph.edges.count() === doubleRing.size) - assert(graph.edges.collect.forall(e => e.attr == 1)) + assert(graph.edges.collect().forall(e => e.attr == 1)) // uniqueEdges option should uniquify edges and store duplicate count in edge attributes val uniqueGraph = Graph.fromEdgeTuples(sc.parallelize(doubleRing), 1, Some(RandomVertexCut)) assert(uniqueGraph.edges.count() === ring.size) - assert(uniqueGraph.edges.collect.forall(e => e.attr == 2)) + assert(uniqueGraph.edges.collect().forall(e => e.attr == 2)) } } @@ -64,7 +64,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { assert( graph.edges.count() === rawEdges.size ) // Vertices not explicitly provided but referenced by edges should be created automatically assert( graph.vertices.count() === 100) - graph.triplets.collect.map { et => + graph.triplets.collect().map { et => assert((et.srcId < 10 && et.srcAttr) || (et.srcId >= 10 && !et.srcAttr)) assert((et.dstId < 10 && et.dstAttr) || (et.dstId >= 10 && !et.dstAttr)) } @@ -75,15 +75,17 @@ class GraphSuite extends FunSuite with LocalSparkContext { withSpark { sc => val n = 5 val star = starGraph(sc, n) - assert(star.triplets.map(et => (et.srcId, et.dstId, et.srcAttr, et.dstAttr)).collect.toSet === - (1 to n).map(x => (0: VertexId, x: VertexId, "v", "v")).toSet) + assert(star.triplets.map(et => (et.srcId, et.dstId, et.srcAttr, et.dstAttr)).collect().toSet + === (1 to n).map(x => (0: VertexId, x: VertexId, "v", "v")).toSet) } } test("partitionBy") { withSpark { sc => - def mkGraph(edges: List[(Long, Long)]) = Graph.fromEdgeTuples(sc.parallelize(edges, 2), 0) - def nonemptyParts(graph: Graph[Int, Int]) = { + def mkGraph(edges: List[(Long, Long)]): Graph[Int, Int] = { + Graph.fromEdgeTuples(sc.parallelize(edges, 2), 0) + } + def nonemptyParts(graph: Graph[Int, Int]): RDD[List[Edge[Int]]] = { graph.edges.partitionsRDD.mapPartitions { iter => Iterator(iter.next()._2.iterator.toList) }.filter(_.nonEmpty) @@ -102,7 +104,8 @@ class GraphSuite extends FunSuite with LocalSparkContext { assert(nonemptyParts(mkGraph(sameSrcEdges).partitionBy(EdgePartition1D)).count === 1) // partitionBy(CanonicalRandomVertexCut) puts edges that are identical modulo direction into // the same partition - assert(nonemptyParts(mkGraph(canonicalEdges).partitionBy(CanonicalRandomVertexCut)).count === 1) + assert( + nonemptyParts(mkGraph(canonicalEdges).partitionBy(CanonicalRandomVertexCut)).count === 1) // partitionBy(EdgePartition2D) puts identical edges in the same partition assert(nonemptyParts(mkGraph(identicalEdges).partitionBy(EdgePartition2D)).count === 1) @@ -140,10 +143,10 @@ class GraphSuite extends FunSuite with LocalSparkContext { val g = Graph( sc.parallelize(List((0L, "a"), (1L, "b"), (2L, "c"))), sc.parallelize(List(Edge(0L, 1L, 1), Edge(0L, 2L, 1)), 2)) - assert(g.triplets.collect.map(_.toTuple).toSet === + assert(g.triplets.collect().map(_.toTuple).toSet === Set(((0L, "a"), (1L, "b"), 1), ((0L, "a"), (2L, "c"), 1))) val gPart = g.partitionBy(EdgePartition2D) - assert(gPart.triplets.collect.map(_.toTuple).toSet === + assert(gPart.triplets.collect().map(_.toTuple).toSet === Set(((0L, "a"), (1L, "b"), 1), ((0L, "a"), (2L, "c"), 1))) } } @@ -154,10 +157,10 @@ class GraphSuite extends FunSuite with LocalSparkContext { val star = starGraph(sc, n) // mapVertices preserving type val mappedVAttrs = star.mapVertices((vid, attr) => attr + "2") - assert(mappedVAttrs.vertices.collect.toSet === (0 to n).map(x => (x: VertexId, "v2")).toSet) + assert(mappedVAttrs.vertices.collect().toSet === (0 to n).map(x => (x: VertexId, "v2")).toSet) // mapVertices changing type val mappedVAttrs2 = star.mapVertices((vid, attr) => attr.length) - assert(mappedVAttrs2.vertices.collect.toSet === (0 to n).map(x => (x: VertexId, 1)).toSet) + assert(mappedVAttrs2.vertices.collect().toSet === (0 to n).map(x => (x: VertexId, 1)).toSet) } } @@ -177,12 +180,12 @@ class GraphSuite extends FunSuite with LocalSparkContext { // Trigger initial vertex replication graph0.triplets.foreach(x => {}) // Change type of replicated vertices, but preserve erased type - val graph1 = graph0.mapVertices { - case (vid, integerOpt) => integerOpt.map((x: java.lang.Integer) => (x.toDouble): java.lang.Double) + val graph1 = graph0.mapVertices { case (vid, integerOpt) => + integerOpt.map((x: java.lang.Integer) => x.toDouble: java.lang.Double) } // Access replicated vertices, exposing the erased type val graph2 = graph1.mapTriplets(t => t.srcAttr.get) - assert(graph2.edges.map(_.attr).collect.toSet === Set[java.lang.Double](1.0, 2.0, 3.0)) + assert(graph2.edges.map(_.attr).collect().toSet === Set[java.lang.Double](1.0, 2.0, 3.0)) } } @@ -202,7 +205,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { withSpark { sc => val n = 5 val star = starGraph(sc, n) - assert(star.mapTriplets(et => et.srcAttr + et.dstAttr).edges.collect.toSet === + assert(star.mapTriplets(et => et.srcAttr + et.dstAttr).edges.collect().toSet === (1L to n).map(x => Edge(0, x, "vv")).toSet) } } @@ -211,7 +214,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { withSpark { sc => val n = 5 val star = starGraph(sc, n) - assert(star.reverse.outDegrees.collect.toSet === (1 to n).map(x => (x: VertexId, 1)).toSet) + assert(star.reverse.outDegrees.collect().toSet === (1 to n).map(x => (x: VertexId, 1)).toSet) } } @@ -221,7 +224,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { val edges: RDD[Edge[Int]] = sc.parallelize(Array(Edge(1L, 2L, 0))) val graph = Graph(vertices, edges).reverse val result = graph.mapReduceTriplets[Int](et => Iterator((et.dstId, et.srcAttr)), _ + _) - assert(result.collect.toSet === Set((1L, 2))) + assert(result.collect().toSet === Set((1L, 2))) } } @@ -237,7 +240,8 @@ class GraphSuite extends FunSuite with LocalSparkContext { assert(subgraph.vertices.collect().toSet === (0 to n by 2).map(x => (x, "v")).toSet) // And 4 edges. - assert(subgraph.edges.map(_.copy()).collect().toSet === (2 to n by 2).map(x => Edge(0, x, 1)).toSet) + assert(subgraph.edges.map(_.copy()).collect().toSet === + (2 to n by 2).map(x => Edge(0, x, 1)).toSet) } } @@ -273,9 +277,9 @@ class GraphSuite extends FunSuite with LocalSparkContext { sc.parallelize((1 to n).flatMap(x => List((0: VertexId, x: VertexId), (0: VertexId, x: VertexId))), 1), "v") val star2 = doubleStar.groupEdges { (a, b) => a} - assert(star2.edges.collect.toArray.sorted(Edge.lexicographicOrdering[Int]) === - star.edges.collect.toArray.sorted(Edge.lexicographicOrdering[Int])) - assert(star2.vertices.collect.toSet === star.vertices.collect.toSet) + assert(star2.edges.collect().toArray.sorted(Edge.lexicographicOrdering[Int]) === + star.edges.collect().toArray.sorted(Edge.lexicographicOrdering[Int])) + assert(star2.vertices.collect().toSet === star.vertices.collect().toSet) } } @@ -300,21 +304,23 @@ class GraphSuite extends FunSuite with LocalSparkContext { throw new Exception("map ran on edge with dst vid %d, which is odd".format(et.dstId)) } Iterator((et.srcId, 1)) - }, (a: Int, b: Int) => a + b, Some((active, EdgeDirection.In))).collect.toSet + }, (a: Int, b: Int) => a + b, Some((active, EdgeDirection.In))).collect().toSet assert(numEvenNeighbors === (1 to n).map(x => (x: VertexId, n / 2)).toSet) // outerJoinVertices followed by mapReduceTriplets(activeSetOpt) - val ringEdges = sc.parallelize((0 until n).map(x => (x: VertexId, (x+1) % n: VertexId)), 3) + val ringEdges = sc.parallelize((0 until n).map(x => (x: VertexId, (x + 1) % n: VertexId)), 3) val ring = Graph.fromEdgeTuples(ringEdges, 0) .mapVertices((vid, attr) => vid).cache() val changed = ring.vertices.filter { case (vid, attr) => attr % 2 == 1 }.mapValues(-_).cache() - val changedGraph = ring.outerJoinVertices(changed) { (vid, old, newOpt) => newOpt.getOrElse(old) } + val changedGraph = ring.outerJoinVertices(changed) { (vid, old, newOpt) => + newOpt.getOrElse(old) + } val numOddNeighbors = changedGraph.mapReduceTriplets(et => { // Map function should only run on edges with source in the active set if (et.srcId % 2 != 1) { throw new Exception("map ran on edge with src vid %d, which is even".format(et.dstId)) } Iterator((et.dstId, 1)) - }, (a: Int, b: Int) => a + b, Some(changed, EdgeDirection.Out)).collect.toSet + }, (a: Int, b: Int) => a + b, Some(changed, EdgeDirection.Out)).collect().toSet assert(numOddNeighbors === (2 to n by 2).map(x => (x: VertexId, 1)).toSet) } @@ -340,17 +346,18 @@ class GraphSuite extends FunSuite with LocalSparkContext { val n = 5 val reverseStar = starGraph(sc, n).reverse.cache() // outerJoinVertices changing type - val reverseStarDegrees = - reverseStar.outerJoinVertices(reverseStar.outDegrees) { (vid, a, bOpt) => bOpt.getOrElse(0) } + val reverseStarDegrees = reverseStar.outerJoinVertices(reverseStar.outDegrees) { + (vid, a, bOpt) => bOpt.getOrElse(0) + } val neighborDegreeSums = reverseStarDegrees.mapReduceTriplets( et => Iterator((et.srcId, et.dstAttr), (et.dstId, et.srcAttr)), - (a: Int, b: Int) => a + b).collect.toSet + (a: Int, b: Int) => a + b).collect().toSet assert(neighborDegreeSums === Set((0: VertexId, n)) ++ (1 to n).map(x => (x: VertexId, 0))) // outerJoinVertices preserving type val messages = reverseStar.vertices.mapValues { (vid, attr) => vid.toString } val newReverseStar = reverseStar.outerJoinVertices(messages) { (vid, a, bOpt) => a + bOpt.getOrElse("") } - assert(newReverseStar.vertices.map(_._2).collect.toSet === + assert(newReverseStar.vertices.map(_._2).collect().toSet === (0 to n).map(x => "v%d".format(x)).toSet) } } @@ -361,20 +368,21 @@ class GraphSuite extends FunSuite with LocalSparkContext { val edges = sc.parallelize(List(Edge(1, 2, 0), Edge(2, 1, 0)), 2) val graph = Graph(verts, edges) val triplets = graph.triplets.map(et => (et.srcId, et.dstId, et.srcAttr, et.dstAttr)) - .collect.toSet + .collect().toSet assert(triplets === Set((1: VertexId, 2: VertexId, "a", "b"), (2: VertexId, 1: VertexId, "b", "a"))) } } test("checkpoint") { - val checkpointDir = Files.createTempDir() - checkpointDir.deleteOnExit() + val checkpointDir = Utils.createTempDir() withSpark { sc => sc.setCheckpointDir(checkpointDir.getAbsolutePath) val ring = (0L to 100L).zip((1L to 99L) :+ 0L).map { case (a, b) => Edge(a, b, 1)} val rdd = sc.parallelize(ring) val graph = Graph.fromEdges(rdd, 1.0F) + assert(!graph.isCheckpointed) + assert(graph.getCheckpointFiles.size === 0) graph.checkpoint() graph.edges.map(_.attr).count() graph.vertices.map(_._2).count() @@ -383,6 +391,42 @@ class GraphSuite extends FunSuite with LocalSparkContext { val verticesDependencies = graph.vertices.partitionsRDD.dependencies assert(edgesDependencies.forall(_.rdd.isInstanceOf[CheckpointRDD[_]])) assert(verticesDependencies.forall(_.rdd.isInstanceOf[CheckpointRDD[_]])) + assert(graph.isCheckpointed) + assert(graph.getCheckpointFiles.size === 2) + } + } + + test("cache, getStorageLevel") { + // test to see if getStorageLevel returns correct value + withSpark { sc => + val verts = sc.parallelize(List((1: VertexId, "a"), (2: VertexId, "b")), 1) + val edges = sc.parallelize(List(Edge(1, 2, 0), Edge(2, 1, 0)), 2) + val graph = Graph(verts, edges, "", StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY) + // Note: Before caching, graph.vertices is cached, but graph.edges is not (but graph.edges' + // parent RDD is cached). + graph.cache() + assert(graph.vertices.getStorageLevel == StorageLevel.MEMORY_ONLY) + assert(graph.edges.getStorageLevel == StorageLevel.MEMORY_ONLY) + } + } + + test("non-default number of edge partitions") { + val n = 10 + val defaultParallelism = 3 + val numEdgePartitions = 4 + assert(defaultParallelism != numEdgePartitions) + val conf = new org.apache.spark.SparkConf() + .set("spark.default.parallelism", defaultParallelism.toString) + val sc = new SparkContext("local", "test", conf) + try { + val edges = sc.parallelize((1 to n).map(x => (x: VertexId, 0: VertexId)), + numEdgePartitions) + val graph = Graph.fromEdgeTuples(edges, 1) + val neighborAttrSums = graph.mapReduceTriplets[Int]( + et => Iterator((et.dstId, et.srcAttr)), _ + _) + assert(neighborAttrSums.collect().toSet === Set((0: VertexId, n))) + } finally { + sc.stop() } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala index a3e28efc75a9..d2ad9be55577 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala @@ -26,7 +26,7 @@ import org.apache.spark.SparkContext */ trait LocalSparkContext { /** Runs `f` on a new SparkContext and ensures that it is stopped afterwards. */ - def withSpark[T](f: SparkContext => T) = { + def withSpark[T](f: SparkContext => T): T = { val conf = new SparkConf() GraphXUtils.registerKryoClasses(conf) val sc = new SparkContext("local", "test", conf) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala index 42d3f21dbae9..d0a7198d691d 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala @@ -17,15 +17,15 @@ package org.apache.spark.graphx -import org.apache.spark.SparkContext -import org.apache.spark.graphx.Graph._ -import org.apache.spark.graphx.impl.EdgePartition -import org.apache.spark.rdd._ import org.scalatest.FunSuite +import org.apache.spark.{HashPartitioner, SparkContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + class VertexRDDSuite extends FunSuite with LocalSparkContext { - def vertices(sc: SparkContext, n: Int) = { + private def vertices(sc: SparkContext, n: Int) = { VertexRDD(sc.parallelize((0 to n).map(x => (x.toLong, x)), 5)) } @@ -47,6 +47,35 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { } } + test("minus") { + withSpark { sc => + val vertexA = VertexRDD(sc.parallelize(0 until 75, 2).map(i => (i.toLong, 0))).cache() + val vertexB = VertexRDD(sc.parallelize(25 until 100, 2).map(i => (i.toLong, 1))).cache() + val vertexC = vertexA.minus(vertexB) + assert(vertexC.map(_._1).collect().toSet === (0 until 25).toSet) + } + } + + test("minus with RDD[(VertexId, VD)]") { + withSpark { sc => + val vertexA = VertexRDD(sc.parallelize(0 until 75, 2).map(i => (i.toLong, 0))).cache() + val vertexB: RDD[(VertexId, Int)] = + sc.parallelize(25 until 100, 2).map(i => (i.toLong, 1)).cache() + val vertexC = vertexA.minus(vertexB) + assert(vertexC.map(_._1).collect().toSet === (0 until 25).toSet) + } + } + + test("minus with non-equal number of partitions") { + withSpark { sc => + val vertexA = VertexRDD(sc.parallelize(0 until 75, 5).map(i => (i.toLong, 0))) + val vertexB = VertexRDD(sc.parallelize(50 until 100, 2).map(i => (i.toLong, 1))) + assert(vertexA.partitions.size != vertexB.partitions.size) + val vertexC = vertexA.minus(vertexB) + assert(vertexC.map(_._1).collect().toSet === (0 until 50).toSet) + } + } + test("diff") { withSpark { sc => val n = 100 @@ -59,42 +88,90 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { } } + test("diff with RDD[(VertexId, VD)]") { + withSpark { sc => + val n = 100 + val verts = vertices(sc, n).cache() + val flipEvens: RDD[(VertexId, Int)] = + sc.parallelize(0L to 100L) + .map(id => if (id % 2 == 0) (id, -id.toInt) else (id, id.toInt)).cache() + // diff should keep only the changed vertices + assert(verts.diff(flipEvens).map(_._2).collect().toSet === (2 to n by 2).map(-_).toSet) + } + } + + test("diff vertices with non-equal number of partitions") { + withSpark { sc => + val vertexA = VertexRDD(sc.parallelize(0 until 24, 3).map(i => (i.toLong, 0))) + val vertexB = VertexRDD(sc.parallelize(8 until 16, 2).map(i => (i.toLong, 1))) + assert(vertexA.partitions.size != vertexB.partitions.size) + val vertexC = vertexA.diff(vertexB) + assert(vertexC.map(_._1).collect().toSet === (8 until 16).toSet) + } + } + test("leftJoin") { withSpark { sc => val n = 100 val verts = vertices(sc, n).cache() val evens = verts.filter(q => ((q._2 % 2) == 0)).cache() // leftJoin with another VertexRDD - assert(verts.leftJoin(evens) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect.toSet === + assert(verts.leftJoin(evens) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect().toSet === (0 to n by 2).map(x => (x.toLong, 0)).toSet ++ (1 to n by 2).map(x => (x.toLong, x)).toSet) // leftJoin with an RDD val evensRDD = evens.map(identity) - assert(verts.leftJoin(evensRDD) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect.toSet === + assert(verts.leftJoin(evensRDD) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect().toSet === (0 to n by 2).map(x => (x.toLong, 0)).toSet ++ (1 to n by 2).map(x => (x.toLong, x)).toSet) } } + test("leftJoin vertices with non-equal number of partitions") { + withSpark { sc => + val vertexA = VertexRDD(sc.parallelize(0 until 100, 2).map(i => (i.toLong, 1))) + val vertexB = VertexRDD( + vertexA.filter(v => v._1 % 2 == 0).partitionBy(new HashPartitioner(3))) + assert(vertexA.partitions.size != vertexB.partitions.size) + val vertexC = vertexA.leftJoin(vertexB) { (vid, old, newOpt) => + old - newOpt.getOrElse(0) + } + assert(vertexC.filter(v => v._2 != 0).map(_._1).collect().toSet == (1 to 99 by 2).toSet) + } + } + test("innerJoin") { withSpark { sc => val n = 100 val verts = vertices(sc, n).cache() val evens = verts.filter(q => ((q._2 % 2) == 0)).cache() // innerJoin with another VertexRDD - assert(verts.innerJoin(evens) { (id, a, b) => a - b }.collect.toSet === + assert(verts.innerJoin(evens) { (id, a, b) => a - b }.collect().toSet === (0 to n by 2).map(x => (x.toLong, 0)).toSet) // innerJoin with an RDD val evensRDD = evens.map(identity) - assert(verts.innerJoin(evensRDD) { (id, a, b) => a - b }.collect.toSet === + assert(verts.innerJoin(evensRDD) { (id, a, b) => a - b }.collect().toSet === (0 to n by 2).map(x => (x.toLong, 0)).toSet) } } + test("innerJoin vertices with the non-equal number of partitions") { + withSpark { sc => + val vertexA = VertexRDD(sc.parallelize(0 until 100, 2).map(i => (i.toLong, 1))) + val vertexB = VertexRDD( + vertexA.filter(v => v._1 % 2 == 0).partitionBy(new HashPartitioner(3))) + assert(vertexA.partitions.size != vertexB.partitions.size) + val vertexC = vertexA.innerJoin(vertexB) { (vid, old, newVal) => + old - newVal + } + assert(vertexC.filter(v => v._2 == 0).map(_._1).collect().toSet == (0 to 98 by 2).toSet) + } + } + test("aggregateUsingIndex") { withSpark { sc => val n = 100 val verts = vertices(sc, n) val messageTargets = (0 to n) ++ (0 to n by 2) val messages = sc.parallelize(messageTargets.map(x => (x.toLong, 1))) - assert(verts.aggregateUsingIndex[Int](messages, _ + _).collect.toSet === + assert(verts.aggregateUsingIndex[Int](messages, _ + _).collect().toSet === (0 to n).map(x => (x.toLong, if (x % 2 == 0) 2 else 1)).toSet) } } @@ -106,7 +183,19 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val edges = EdgeRDD.fromEdges(sc.parallelize(List.empty[Edge[Int]])) val rdd = VertexRDD(verts, edges, 0, (a: Int, b: Int) => a + b) // test merge function - assert(rdd.collect.toSet == Set((0L, 0), (1L, 3), (2L, 9))) + assert(rdd.collect().toSet == Set((0L, 0), (1L, 3), (2L, 9))) + } + } + + test("cache, getStorageLevel") { + // test to see if getStorageLevel returns correct value after caching + withSpark { sc => + val verts = sc.parallelize(List((0L, 0), (1L, 1), (1L, 2), (2L, 3), (2L, 3), (2L, 3))) + val edges = EdgeRDD.fromEdges(sc.parallelize(List.empty[Edge[Int]])) + val rdd = VertexRDD(verts, edges, 0, (a: Int, b: Int) => a + b) + assert(rdd.getStorageLevel == StorageLevel.NONE) + rdd.cache() + assert(rdd.getStorageLevel == StorageLevel.MEMORY_ONLY) } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala index 3915be15b343..4cc30a96408f 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala @@ -32,7 +32,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { withSpark { sc => val gridGraph = GraphGenerators.gridGraph(sc, 10, 10) val ccGraph = gridGraph.connectedComponents() - val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum + val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum() assert(maxCCid === 0) } } // end of Grid connected components @@ -42,7 +42,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { withSpark { sc => val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).reverse val ccGraph = gridGraph.connectedComponents() - val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum + val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum() assert(maxCCid === 0) } } // end of Grid connected components @@ -50,8 +50,8 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { test("Chain Connected Components") { withSpark { sc => - val chain1 = (0 until 9).map(x => (x, x+1) ) - val chain2 = (10 until 20).map(x => (x, x+1) ) + val chain1 = (0 until 9).map(x => (x, x + 1)) + val chain2 = (10 until 20).map(x => (x, x + 1)) val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } val twoChains = Graph.fromEdgeTuples(rawEdges, 1.0) val ccGraph = twoChains.connectedComponents() @@ -73,12 +73,12 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { test("Reverse Chain Connected Components") { withSpark { sc => - val chain1 = (0 until 9).map(x => (x, x+1) ) - val chain2 = (10 until 20).map(x => (x, x+1) ) + val chain1 = (0 until 9).map(x => (x, x + 1)) + val chain2 = (10 until 20).map(x => (x, x + 1)) val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } val twoChains = Graph.fromEdgeTuples(rawEdges, true).reverse val ccGraph = twoChains.connectedComponents() - val vertices = ccGraph.vertices.collect + val vertices = ccGraph.vertices.collect() for ( (id, cc) <- vertices ) { if (id < 10) { assert(cc === 0) @@ -120,9 +120,9 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { // Build the initial Graph val graph = Graph(users, relationships, defaultUser) val ccGraph = graph.connectedComponents() - val vertices = ccGraph.vertices.collect + val vertices = ccGraph.vertices.collect() for ( (id, cc) <- vertices ) { - assert(cc == 0) + assert(cc === 0) } } } // end of toy connected components diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala index fc491ae327c2..95804b07b1db 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala @@ -19,15 +19,12 @@ package org.apache.spark.graphx.lib import org.scalatest.FunSuite -import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ -import org.apache.spark.graphx.lib._ import org.apache.spark.graphx.util.GraphGenerators -import org.apache.spark.rdd._ + object GridPageRank { - def apply(nRows: Int, nCols: Int, nIter: Int, resetProb: Double) = { + def apply(nRows: Int, nCols: Int, nIter: Int, resetProb: Double): Seq[(VertexId, Double)] = { val inNbrs = Array.fill(nRows * nCols)(collection.mutable.MutableList.empty[Int]) val outDegree = Array.fill(nRows * nCols)(0) // Convert row column address into vertex ids (row major order) @@ -35,13 +32,13 @@ object GridPageRank { // Make the grid graph for (r <- 0 until nRows; c <- 0 until nCols) { val ind = sub2ind(r,c) - if (r+1 < nRows) { + if (r + 1 < nRows) { outDegree(ind) += 1 - inNbrs(sub2ind(r+1,c)) += ind + inNbrs(sub2ind(r + 1,c)) += ind } - if (c+1 < nCols) { + if (c + 1 < nCols) { outDegree(ind) += 1 - inNbrs(sub2ind(r,c+1)) += ind + inNbrs(sub2ind(r,c + 1)) += ind } } // compute the pagerank @@ -64,7 +61,7 @@ class PageRankSuite extends FunSuite with LocalSparkContext { def compareRanks(a: VertexRDD[Double], b: VertexRDD[Double]): Double = { a.leftJoin(b) { case (id, a, bOpt) => (a - bOpt.getOrElse(0.0)) * (a - bOpt.getOrElse(0.0)) } - .map { case (id, error) => error }.sum + .map { case (id, error) => error }.sum() } test("Star PageRank") { @@ -80,12 +77,12 @@ class PageRankSuite extends FunSuite with LocalSparkContext { // Static PageRank should only take 2 iterations to converge val notMatching = staticRanks1.innerZipJoin(staticRanks2) { (vid, pr1, pr2) => if (pr1 != pr2) 1 else 0 - }.map { case (vid, test) => test }.sum + }.map { case (vid, test) => test }.sum() assert(notMatching === 0) val staticErrors = staticRanks2.map { case (vid, pr) => - val correct = (vid > 0 && pr == resetProb) || - (vid == 0 && math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * (nVertices - 1)) )) < 1.0E-5) + val p = math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * (nVertices - 1)) )) + val correct = (vid > 0 && pr == resetProb) || (vid == 0L && p < 1.0E-5) if (!correct) 1 else 0 } assert(staticErrors.sum === 0) @@ -95,8 +92,6 @@ class PageRankSuite extends FunSuite with LocalSparkContext { } } // end of test Star PageRank - - test("Grid PageRank") { withSpark { sc => val rows = 10 @@ -109,18 +104,18 @@ class PageRankSuite extends FunSuite with LocalSparkContext { val staticRanks = gridGraph.staticPageRank(numIter, resetProb).vertices.cache() val dynamicRanks = gridGraph.pageRank(tol, resetProb).vertices.cache() - val referenceRanks = VertexRDD(sc.parallelize(GridPageRank(rows, cols, numIter, resetProb))).cache() + val referenceRanks = VertexRDD( + sc.parallelize(GridPageRank(rows, cols, numIter, resetProb))).cache() assert(compareRanks(staticRanks, referenceRanks) < errorTol) assert(compareRanks(dynamicRanks, referenceRanks) < errorTol) } } // end of Grid PageRank - test("Chain PageRank") { withSpark { sc => - val chain1 = (0 until 9).map(x => (x, x+1) ) - val rawEdges = sc.parallelize(chain1, 1).map { case (s,d) => (s.toLong, d.toLong) } + val chain1 = (0 until 9).map(x => (x, x + 1)) + val rawEdges = sc.parallelize(chain1, 1).map { case (s, d) => (s.toLong, d.toLong) } val chain = Graph.fromEdgeTuples(rawEdges, 1.0).cache() val resetProb = 0.15 val tol = 0.0001 diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala index e01df56e94de..7bd6b7f3c4ab 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala @@ -32,11 +32,11 @@ class SVDPlusPlusSuite extends FunSuite with LocalSparkContext { Edge(fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble) } val conf = new SVDPlusPlus.Conf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations - var (graph, u) = SVDPlusPlus.run(edges, conf) + val (graph, _) = SVDPlusPlus.run(edges, conf) graph.cache() - val err = graph.vertices.collect().map{ case (vid, vd) => + val err = graph.vertices.map { case (vid, vd) => if (vid % 2 == 1) vd._4 else 0.0 - }.reduce(_ + _) / graph.triplets.collect().size + }.reduce(_ + _) / graph.numEdges assert(err <= svdppErr) } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala index 265827b3341c..f2c38e79c452 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala @@ -40,7 +40,7 @@ class ShortestPathsSuite extends FunSuite with LocalSparkContext { val graph = Graph.fromEdgeTuples(edges, 1) val landmarks = Seq(1, 4).map(_.toLong) val results = ShortestPaths.run(graph, landmarks).vertices.collect.map { - case (v, spMap) => (v, spMap.mapValues(_.get)) + case (v, spMap) => (v, spMap.mapValues(i => i)) } assert(results.toSet === shortestPaths) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala index df54aa37cad6..1f658c371ffc 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala @@ -34,8 +34,8 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { val edges = sc.parallelize(Seq.empty[Edge[Int]]) val graph = Graph(vertices, edges) val sccGraph = graph.stronglyConnectedComponents(5) - for ((id, scc) <- sccGraph.vertices.collect) { - assert(id == scc) + for ((id, scc) <- sccGraph.vertices.collect()) { + assert(id === scc) } } } @@ -45,8 +45,8 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { val rawEdges = sc.parallelize((0L to 6L).map(x => (x, (x + 1) % 7))) val graph = Graph.fromEdgeTuples(rawEdges, -1) val sccGraph = graph.stronglyConnectedComponents(20) - for ((id, scc) <- sccGraph.vertices.collect) { - assert(0L == scc) + for ((id, scc) <- sccGraph.vertices.collect()) { + assert(0L === scc) } } } @@ -60,13 +60,14 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { val rawEdges = sc.parallelize(edges) val graph = Graph.fromEdgeTuples(rawEdges, -1) val sccGraph = graph.stronglyConnectedComponents(20) - for ((id, scc) <- sccGraph.vertices.collect) { - if (id < 3) - assert(0L == scc) - else if (id < 6) - assert(3L == scc) - else - assert(id == scc) + for ((id, scc) <- sccGraph.vertices.collect()) { + if (id < 3) { + assert(0L === scc) + } else if (id < 6) { + assert(3L === scc) + } else { + assert(id === scc) + } } } } diff --git a/launcher/pom.xml b/launcher/pom.xml new file mode 100644 index 000000000000..ebfa7685eaa1 --- /dev/null +++ b/launcher/pom.xml @@ -0,0 +1,84 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.4.0-SNAPSHOT + ../pom.xml + + + org.apache.spark + spark-launcher_2.10 + jar + Spark Launcher Project + http://spark.apache.org/ + + launcher + + + + + + log4j + log4j + test + + + junit + junit + test + + + org.mockito + mockito-all + test + + + org.slf4j + slf4j-api + test + + + org.slf4j + slf4j-log4j12 + test + + + + + org.apache.hadoop + hadoop-client + test + + + org.codehaus.jackson + jackson-mapper-asl + + + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java new file mode 100644 index 000000000000..b8f02b961113 --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -0,0 +1,343 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileFilter; +import java.io.FileInputStream; +import java.io.InputStreamReader; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.jar.JarFile; +import java.util.regex.Pattern; + +import static org.apache.spark.launcher.CommandBuilderUtils.*; + +/** + * Abstract Spark command builder that defines common functionality. + */ +abstract class AbstractCommandBuilder { + + boolean verbose; + String appName; + String appResource; + String deployMode; + String javaHome; + String mainClass; + String master; + String propertiesFile; + final List appArgs; + final List jars; + final List files; + final List pyFiles; + final Map childEnv; + final Map conf; + + public AbstractCommandBuilder() { + this.appArgs = new ArrayList(); + this.childEnv = new HashMap(); + this.conf = new HashMap(); + this.files = new ArrayList(); + this.jars = new ArrayList(); + this.pyFiles = new ArrayList(); + } + + /** + * Builds the command to execute. + * + * @param env A map containing environment variables for the child process. It may already contain + * entries defined by the user (such as SPARK_HOME, or those defined by the + * SparkLauncher constructor that takes an environment), and may be modified to + * include other variables needed by the process to be executed. + */ + abstract List buildCommand(Map env) throws IOException; + + /** + * Builds a list of arguments to run java. + * + * This method finds the java executable to use and appends JVM-specific options for running a + * class with Spark in the classpath. It also loads options from the "java-opts" file in the + * configuration directory being used. + * + * Callers should still add at least the class to run, as well as any arguments to pass to the + * class. + */ + List buildJavaCommand(String extraClassPath) throws IOException { + List cmd = new ArrayList(); + String envJavaHome; + + if (javaHome != null) { + cmd.add(join(File.separator, javaHome, "bin", "java")); + } else if ((envJavaHome = System.getenv("JAVA_HOME")) != null) { + cmd.add(join(File.separator, envJavaHome, "bin", "java")); + } else { + cmd.add(join(File.separator, System.getProperty("java.home"), "bin", "java")); + } + + // Load extra JAVA_OPTS from conf/java-opts, if it exists. + File javaOpts = new File(join(File.separator, getConfDir(), "java-opts")); + if (javaOpts.isFile()) { + BufferedReader br = new BufferedReader(new InputStreamReader( + new FileInputStream(javaOpts), "UTF-8")); + try { + String line; + while ((line = br.readLine()) != null) { + addOptionString(cmd, line); + } + } finally { + br.close(); + } + } + + cmd.add("-cp"); + cmd.add(join(File.pathSeparator, buildClassPath(extraClassPath))); + return cmd; + } + + /** + * Adds the default perm gen size option for Spark if the VM requires it and the user hasn't + * set it. + */ + void addPermGenSizeOpt(List cmd) { + // Don't set MaxPermSize for Java 8 and later. + String[] version = System.getProperty("java.version").split("\\."); + if (Integer.parseInt(version[0]) > 1 || Integer.parseInt(version[1]) > 7) { + return; + } + + for (String arg : cmd) { + if (arg.startsWith("-XX:MaxPermSize=")) { + return; + } + } + + cmd.add("-XX:MaxPermSize=128m"); + } + + void addOptionString(List cmd, String options) { + if (!isEmpty(options)) { + for (String opt : parseOptionString(options)) { + cmd.add(opt); + } + } + } + + /** + * Builds the classpath for the application. Returns a list with one classpath entry per element; + * each entry is formatted in the way expected by java.net.URLClassLoader (more + * specifically, with trailing slashes for directories). + */ + List buildClassPath(String appClassPath) throws IOException { + String sparkHome = getSparkHome(); + + List cp = new ArrayList(); + addToClassPath(cp, getenv("SPARK_CLASSPATH")); + addToClassPath(cp, appClassPath); + + addToClassPath(cp, getConfDir()); + + boolean prependClasses = !isEmpty(getenv("SPARK_PREPEND_CLASSES")); + boolean isTesting = "1".equals(getenv("SPARK_TESTING")); + if (prependClasses || isTesting) { + String scala = getScalaVersion(); + List projects = Arrays.asList("core", "repl", "mllib", "bagel", "graphx", + "streaming", "tools", "sql/catalyst", "sql/core", "sql/hive", "sql/hive-thriftserver", + "yarn", "launcher"); + if (prependClasses) { + System.err.println( + "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark classes ahead of " + + "assembly."); + for (String project : projects) { + addToClassPath(cp, String.format("%s/%s/target/scala-%s/classes", sparkHome, project, + scala)); + } + } + if (isTesting) { + for (String project : projects) { + addToClassPath(cp, String.format("%s/%s/target/scala-%s/test-classes", sparkHome, + project, scala)); + } + } + + // Add this path to include jars that are shaded in the final deliverable created during + // the maven build. These jars are copied to this directory during the build. + addToClassPath(cp, String.format("%s/core/target/jars/*", sparkHome)); + } + + // We can't rely on the ENV_SPARK_ASSEMBLY variable to be set. Certain situations, such as + // when running unit tests, or user code that embeds Spark and creates a SparkContext + // with a local or local-cluster master, will cause this code to be called from an + // environment where that env variable is not guaranteed to exist. + // + // For the testing case, we rely on the test code to set and propagate the test classpath + // appropriately. + // + // For the user code case, we fall back to looking for the Spark assembly under SPARK_HOME. + // That duplicates some of the code in the shell scripts that look for the assembly, though. + String assembly = getenv(ENV_SPARK_ASSEMBLY); + if (assembly == null && isEmpty(getenv("SPARK_TESTING"))) { + assembly = findAssembly(); + } + addToClassPath(cp, assembly); + + // Datanucleus jars must be included on the classpath. Datanucleus jars do not work if only + // included in the uber jar as plugin.xml metadata is lost. Both sbt and maven will populate + // "lib_managed/jars/" with the datanucleus jars when Spark is built with Hive + File libdir; + if (new File(sparkHome, "RELEASE").isFile()) { + libdir = new File(sparkHome, "lib"); + } else { + libdir = new File(sparkHome, "lib_managed/jars"); + } + + checkState(libdir.isDirectory(), "Library directory '%s' does not exist.", + libdir.getAbsolutePath()); + for (File jar : libdir.listFiles()) { + if (jar.getName().startsWith("datanucleus-")) { + addToClassPath(cp, jar.getAbsolutePath()); + } + } + + addToClassPath(cp, getenv("HADOOP_CONF_DIR")); + addToClassPath(cp, getenv("YARN_CONF_DIR")); + addToClassPath(cp, getenv("SPARK_DIST_CLASSPATH")); + return cp; + } + + /** + * Adds entries to the classpath. + * + * @param cp List to which the new entries are appended. + * @param entries New classpath entries (separated by File.pathSeparator). + */ + private void addToClassPath(List cp, String entries) { + if (isEmpty(entries)) { + return; + } + String[] split = entries.split(Pattern.quote(File.pathSeparator)); + for (String entry : split) { + if (!isEmpty(entry)) { + if (new File(entry).isDirectory() && !entry.endsWith(File.separator)) { + entry += File.separator; + } + cp.add(entry); + } + } + } + + String getScalaVersion() { + String scala = getenv("SPARK_SCALA_VERSION"); + if (scala != null) { + return scala; + } + String sparkHome = getSparkHome(); + File scala210 = new File(sparkHome, "assembly/target/scala-2.10"); + File scala211 = new File(sparkHome, "assembly/target/scala-2.11"); + checkState(!scala210.isDirectory() || !scala211.isDirectory(), + "Presence of build for both scala versions (2.10 and 2.11) detected.\n" + + "Either clean one of them or set SPARK_SCALA_VERSION in your environment."); + if (scala210.isDirectory()) { + return "2.10"; + } else { + checkState(scala211.isDirectory(), "Cannot find any assembly build directories."); + return "2.11"; + } + } + + String getSparkHome() { + String path = getenv(ENV_SPARK_HOME); + checkState(path != null, + "Spark home not found; set it explicitly or use the SPARK_HOME environment variable."); + return path; + } + + /** + * Loads the configuration file for the application, if it exists. This is either the + * user-specified properties file, or the spark-defaults.conf file under the Spark configuration + * directory. + */ + Properties loadPropertiesFile() throws IOException { + Properties props = new Properties(); + File propsFile; + if (propertiesFile != null) { + propsFile = new File(propertiesFile); + checkArgument(propsFile.isFile(), "Invalid properties file '%s'.", propertiesFile); + } else { + propsFile = new File(getConfDir(), DEFAULT_PROPERTIES_FILE); + } + + if (propsFile.isFile()) { + FileInputStream fd = null; + try { + fd = new FileInputStream(propsFile); + props.load(new InputStreamReader(fd, "UTF-8")); + } finally { + if (fd != null) { + try { + fd.close(); + } catch (IOException e) { + // Ignore. + } + } + } + } + + return props; + } + + String getenv(String key) { + return firstNonEmpty(childEnv.get(key), System.getenv(key)); + } + + private String findAssembly() { + String sparkHome = getSparkHome(); + File libdir; + if (new File(sparkHome, "RELEASE").isFile()) { + libdir = new File(sparkHome, "lib"); + checkState(libdir.isDirectory(), "Library directory '%s' does not exist.", + libdir.getAbsolutePath()); + } else { + libdir = new File(sparkHome, String.format("assembly/target/scala-%s", getScalaVersion())); + } + + final Pattern re = Pattern.compile("spark-assembly.*hadoop.*\\.jar"); + FileFilter filter = new FileFilter() { + @Override + public boolean accept(File file) { + return file.isFile() && re.matcher(file.getName()).matches(); + } + }; + File[] assemblies = libdir.listFiles(filter); + checkState(assemblies != null && assemblies.length > 0, "No assemblies found in '%s'.", libdir); + checkState(assemblies.length == 1, "Multiple assemblies found in '%s'.", libdir); + return assemblies[0].getAbsolutePath(); + } + + private String getConfDir() { + String confDir = getenv("SPARK_CONF_DIR"); + return confDir != null ? confDir : join(File.separator, getSparkHome(), "conf"); + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java new file mode 100644 index 000000000000..8028e42ffb48 --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.File; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * Helper methods for command builders. + */ +class CommandBuilderUtils { + + static final String DEFAULT_MEM = "512m"; + static final String DEFAULT_PROPERTIES_FILE = "spark-defaults.conf"; + static final String ENV_SPARK_HOME = "SPARK_HOME"; + static final String ENV_SPARK_ASSEMBLY = "_SPARK_ASSEMBLY"; + + /** Returns whether the given string is null or empty. */ + static boolean isEmpty(String s) { + return s == null || s.isEmpty(); + } + + /** Joins a list of strings using the given separator. */ + static String join(String sep, String... elements) { + StringBuilder sb = new StringBuilder(); + for (String e : elements) { + if (e != null) { + if (sb.length() > 0) { + sb.append(sep); + } + sb.append(e); + } + } + return sb.toString(); + } + + /** Joins a list of strings using the given separator. */ + static String join(String sep, Iterable elements) { + StringBuilder sb = new StringBuilder(); + for (String e : elements) { + if (e != null) { + if (sb.length() > 0) { + sb.append(sep); + } + sb.append(e); + } + } + return sb.toString(); + } + + /** + * Returns the first non-empty value mapped to the given key in the given maps, or null otherwise. + */ + static String firstNonEmptyValue(String key, Map... maps) { + for (Map map : maps) { + String value = (String) map.get(key); + if (!isEmpty(value)) { + return value; + } + } + return null; + } + + /** Returns the first non-empty, non-null string in the given list, or null otherwise. */ + static String firstNonEmpty(String... candidates) { + for (String s : candidates) { + if (!isEmpty(s)) { + return s; + } + } + return null; + } + + /** Returns the name of the env variable that holds the native library path. */ + static String getLibPathEnvName() { + if (isWindows()) { + return "PATH"; + } + + String os = System.getProperty("os.name"); + if (os.startsWith("Mac OS X")) { + return "DYLD_LIBRARY_PATH"; + } else { + return "LD_LIBRARY_PATH"; + } + } + + /** Returns whether the OS is Windows. */ + static boolean isWindows() { + String os = System.getProperty("os.name"); + return os.startsWith("Windows"); + } + + /** + * Updates the user environment, appending the given pathList to the existing value of the given + * environment variable (or setting it if it hasn't yet been set). + */ + static void mergeEnvPathList(Map userEnv, String envKey, String pathList) { + if (!isEmpty(pathList)) { + String current = firstNonEmpty(userEnv.get(envKey), System.getenv(envKey)); + userEnv.put(envKey, join(File.pathSeparator, current, pathList)); + } + } + + /** + * Parse a string as if it were a list of arguments, following bash semantics. + * For example: + * + * Input: "\"ab cd\" efgh 'i \" j'" + * Output: [ "ab cd", "efgh", "i \" j" ] + */ + static List parseOptionString(String s) { + List opts = new ArrayList(); + StringBuilder opt = new StringBuilder(); + boolean inOpt = false; + boolean inSingleQuote = false; + boolean inDoubleQuote = false; + boolean escapeNext = false; + + // This is needed to detect when a quoted empty string is used as an argument ("" or ''). + boolean hasData = false; + + for (int i = 0; i < s.length(); i++) { + int c = s.codePointAt(i); + if (escapeNext) { + opt.appendCodePoint(c); + escapeNext = false; + } else if (inOpt) { + switch (c) { + case '\\': + if (inSingleQuote) { + opt.appendCodePoint(c); + } else { + escapeNext = true; + } + break; + case '\'': + if (inDoubleQuote) { + opt.appendCodePoint(c); + } else { + inSingleQuote = !inSingleQuote; + } + break; + case '"': + if (inSingleQuote) { + opt.appendCodePoint(c); + } else { + inDoubleQuote = !inDoubleQuote; + } + break; + default: + if (!Character.isWhitespace(c) || inSingleQuote || inDoubleQuote) { + opt.appendCodePoint(c); + } else { + opts.add(opt.toString()); + opt.setLength(0); + inOpt = false; + hasData = false; + } + } + } else { + switch (c) { + case '\'': + inSingleQuote = true; + inOpt = true; + hasData = true; + break; + case '"': + inDoubleQuote = true; + inOpt = true; + hasData = true; + break; + case '\\': + escapeNext = true; + inOpt = true; + hasData = true; + break; + default: + if (!Character.isWhitespace(c)) { + inOpt = true; + hasData = true; + opt.appendCodePoint(c); + } + } + } + } + + checkArgument(!inSingleQuote && !inDoubleQuote && !escapeNext, "Invalid option string: %s", s); + if (hasData) { + opts.add(opt.toString()); + } + return opts; + } + + /** Throws IllegalArgumentException if the given object is null. */ + static void checkNotNull(Object o, String arg) { + if (o == null) { + throw new IllegalArgumentException(String.format("'%s' must not be null.", arg)); + } + } + + /** Throws IllegalArgumentException with the given message if the check is false. */ + static void checkArgument(boolean check, String msg, Object... args) { + if (!check) { + throw new IllegalArgumentException(String.format(msg, args)); + } + } + + /** Throws IllegalStateException with the given message if the check is false. */ + static void checkState(boolean check, String msg, Object... args) { + if (!check) { + throw new IllegalStateException(String.format(msg, args)); + } + } + + /** + * Quote a command argument for a command to be run by a Windows batch script, if the argument + * needs quoting. Arguments only seem to need quotes in batch scripts if they have certain + * special characters, some of which need extra (and different) escaping. + * + * For example: + * original single argument: ab="cde fgh" + * quoted: "ab^=""cde fgh""" + */ + static String quoteForBatchScript(String arg) { + + boolean needsQuotes = false; + for (int i = 0; i < arg.length(); i++) { + int c = arg.codePointAt(i); + if (Character.isWhitespace(c) || c == '"' || c == '=') { + needsQuotes = true; + break; + } + } + if (!needsQuotes) { + return arg; + } + StringBuilder quoted = new StringBuilder(); + quoted.append("\""); + for (int i = 0; i < arg.length(); i++) { + int cp = arg.codePointAt(i); + switch (cp) { + case '"': + quoted.append('"'); + break; + + case '=': + quoted.append('^'); + break; + + default: + break; + } + quoted.appendCodePoint(cp); + } + quoted.append("\""); + return quoted.toString(); + } + + /** + * Quotes a string so that it can be used in a command string. + * Basically, just add simple escapes. E.g.: + * original single argument : ab "cd" ef + * after: "ab \"cd\" ef" + * + * This can be parsed back into a single argument by python's "shlex.split()" function. + */ + static String quoteForCommandString(String s) { + StringBuilder quoted = new StringBuilder().append('"'); + for (int i = 0; i < s.length(); i++) { + int cp = s.codePointAt(i); + if (cp == '"' || cp == '\\') { + quoted.appendCodePoint('\\'); + } + quoted.appendCodePoint(cp); + } + return quoted.append('"').toString(); + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java new file mode 100644 index 000000000000..206acfb514d8 --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.spark.launcher.CommandBuilderUtils.*; + +/** + * Command line interface for the Spark launcher. Used internally by Spark scripts. + */ +class Main { + + /** + * Usage: Main [class] [class args] + *

    + * This CLI works in two different modes: + *

      + *
    • "spark-submit": if class is "org.apache.spark.deploy.SparkSubmit", the + * {@link SparkLauncher} class is used to launch a Spark application.
    • + *
    • "spark-class": if another class is provided, an internal Spark class is run.
    • + *
    + * + * This class works in tandem with the "bin/spark-class" script on Unix-like systems, and + * "bin/spark-class2.cmd" batch script on Windows to execute the final command. + *

    + * On Unix-like systems, the output is a list of command arguments, separated by the NULL + * character. On Windows, the output is a command line suitable for direct execution from the + * script. + */ + public static void main(String[] argsArray) throws Exception { + checkArgument(argsArray.length > 0, "Not enough arguments: missing class name."); + + List args = new ArrayList(Arrays.asList(argsArray)); + String className = args.remove(0); + + boolean printLaunchCommand; + boolean printUsage; + AbstractCommandBuilder builder; + try { + if (className.equals("org.apache.spark.deploy.SparkSubmit")) { + builder = new SparkSubmitCommandBuilder(args); + } else { + builder = new SparkClassCommandBuilder(className, args); + } + printLaunchCommand = !isEmpty(System.getenv("SPARK_PRINT_LAUNCH_COMMAND")); + printUsage = false; + } catch (IllegalArgumentException e) { + builder = new UsageCommandBuilder(e.getMessage()); + printLaunchCommand = false; + printUsage = true; + } + + Map env = new HashMap(); + List cmd = builder.buildCommand(env); + if (printLaunchCommand) { + System.err.println("Spark Command: " + join(" ", cmd)); + System.err.println("========================================"); + } + + if (isWindows()) { + // When printing the usage message, we can't use "cmd /v" since that prevents the env + // variable from being seen in the caller script. So do not call prepareWindowsCommand(). + if (printUsage) { + System.out.println(join(" ", cmd)); + } else { + System.out.println(prepareWindowsCommand(cmd, env)); + } + } else { + // In bash, use NULL as the arg separator since it cannot be used in an argument. + List bashCmd = prepareBashCommand(cmd, env); + for (String c : bashCmd) { + System.out.print(c); + System.out.print('\0'); + } + } + } + + /** + * Prepare a command line for execution from a Windows batch script. + * + * The method quotes all arguments so that spaces are handled as expected. Quotes within arguments + * are "double quoted" (which is batch for escaping a quote). This page has more details about + * quoting and other batch script fun stuff: http://ss64.com/nt/syntax-esc.html + * + * The command is executed using "cmd /c" and formatted in single line, since that's the + * easiest way to consume this from a batch script (see spark-class2.cmd). + */ + private static String prepareWindowsCommand(List cmd, Map childEnv) { + StringBuilder cmdline = new StringBuilder("cmd /c \""); + for (Map.Entry e : childEnv.entrySet()) { + cmdline.append(String.format("set %s=%s", e.getKey(), e.getValue())); + cmdline.append(" && "); + } + for (String arg : cmd) { + cmdline.append(quoteForBatchScript(arg)); + cmdline.append(" "); + } + cmdline.append("\""); + return cmdline.toString(); + } + + /** + * Prepare the command for execution from a bash script. The final command will have commands to + * set up any needed environment variables needed by the child process. + */ + private static List prepareBashCommand(List cmd, Map childEnv) { + if (childEnv.isEmpty()) { + return cmd; + } + + List newCmd = new ArrayList(); + newCmd.add("env"); + + for (Map.Entry e : childEnv.entrySet()) { + newCmd.add(String.format("%s=%s", e.getKey(), e.getValue())); + } + newCmd.addAll(cmd); + return newCmd; + } + + /** + * Internal builder used when command line parsing fails. This will behave differently depending + * on the platform: + * + * - On Unix-like systems, it will print a call to the "usage" function with two arguments: the + * the error string, and the exit code to use. The function is expected to print the command's + * usage and exit with the provided exit code. The script should use "export -f usage" after + * declaring a function called "usage", so that the function is available to downstream scripts. + * + * - On Windows it will set the variable "SPARK_LAUNCHER_USAGE_ERROR" to the usage error message. + * The batch script should check for this variable and print its usage, since batch scripts + * don't really support the "export -f" functionality used in bash. + */ + private static class UsageCommandBuilder extends AbstractCommandBuilder { + + private final String message; + + UsageCommandBuilder(String message) { + this.message = message; + } + + @Override + public List buildCommand(Map env) { + if (isWindows()) { + return Arrays.asList("set", "SPARK_LAUNCHER_USAGE_ERROR=" + message); + } else { + return Arrays.asList("usage", message, "1"); + } + } + + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java new file mode 100644 index 000000000000..e601a0a19f36 --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; + +import static org.apache.spark.launcher.CommandBuilderUtils.*; + +/** + * Command builder for internal Spark classes. + *

    + * This class handles building the command to launch all internal Spark classes except for + * SparkSubmit (which is handled by {@link SparkSubmitCommandBuilder} class. + */ +class SparkClassCommandBuilder extends AbstractCommandBuilder { + + private final String className; + private final List classArgs; + + SparkClassCommandBuilder(String className, List classArgs) { + this.className = className; + this.classArgs = classArgs; + } + + @Override + public List buildCommand(Map env) throws IOException { + List javaOptsKeys = new ArrayList(); + String memKey = null; + String extraClassPath = null; + + // Master, Worker, and HistoryServer use SPARK_DAEMON_JAVA_OPTS (and specific opts) + + // SPARK_DAEMON_MEMORY. + if (className.equals("org.apache.spark.deploy.master.Master")) { + javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); + javaOptsKeys.add("SPARK_MASTER_OPTS"); + memKey = "SPARK_DAEMON_MEMORY"; + } else if (className.equals("org.apache.spark.deploy.worker.Worker")) { + javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); + javaOptsKeys.add("SPARK_WORKER_OPTS"); + memKey = "SPARK_DAEMON_MEMORY"; + } else if (className.equals("org.apache.spark.deploy.history.HistoryServer")) { + javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); + javaOptsKeys.add("SPARK_HISTORY_OPTS"); + memKey = "SPARK_DAEMON_MEMORY"; + } else if (className.equals("org.apache.spark.executor.CoarseGrainedExecutorBackend")) { + javaOptsKeys.add("SPARK_JAVA_OPTS"); + javaOptsKeys.add("SPARK_EXECUTOR_OPTS"); + memKey = "SPARK_EXECUTOR_MEMORY"; + } else if (className.equals("org.apache.spark.executor.MesosExecutorBackend")) { + javaOptsKeys.add("SPARK_EXECUTOR_OPTS"); + memKey = "SPARK_EXECUTOR_MEMORY"; + } else if (className.startsWith("org.apache.spark.tools.")) { + String sparkHome = getSparkHome(); + File toolsDir = new File(join(File.separator, sparkHome, "tools", "target", + "scala-" + getScalaVersion())); + checkState(toolsDir.isDirectory(), "Cannot find tools build directory."); + + Pattern re = Pattern.compile("spark-tools_.*\\.jar"); + for (File f : toolsDir.listFiles()) { + if (re.matcher(f.getName()).matches()) { + extraClassPath = f.getAbsolutePath(); + break; + } + } + + checkState(extraClassPath != null, + "Failed to find Spark Tools Jar in %s.\n" + + "You need to run \"build/sbt tools/package\" before running %s.", + toolsDir.getAbsolutePath(), className); + + javaOptsKeys.add("SPARK_JAVA_OPTS"); + } + + List cmd = buildJavaCommand(extraClassPath); + for (String key : javaOptsKeys) { + addOptionString(cmd, System.getenv(key)); + } + + String mem = firstNonEmpty(memKey != null ? System.getenv(memKey) : null, DEFAULT_MEM); + cmd.add("-Xms" + mem); + cmd.add("-Xmx" + mem); + addPermGenSizeOpt(cmd); + cmd.add(className); + cmd.addAll(classArgs); + return cmd; + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java new file mode 100644 index 000000000000..d4cfeacb6ef1 --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static org.apache.spark.launcher.CommandBuilderUtils.*; + +/** + * Launcher for Spark applications. + *

    + * Use this class to start Spark applications programmatically. The class uses a builder pattern + * to allow clients to configure the Spark application and launch it as a child process. + */ +public class SparkLauncher { + + /** The Spark master. */ + public static final String SPARK_MASTER = "spark.master"; + + /** Configuration key for the driver memory. */ + public static final String DRIVER_MEMORY = "spark.driver.memory"; + /** Configuration key for the driver class path. */ + public static final String DRIVER_EXTRA_CLASSPATH = "spark.driver.extraClassPath"; + /** Configuration key for the driver VM options. */ + public static final String DRIVER_EXTRA_JAVA_OPTIONS = "spark.driver.extraJavaOptions"; + /** Configuration key for the driver native library path. */ + public static final String DRIVER_EXTRA_LIBRARY_PATH = "spark.driver.extraLibraryPath"; + + /** Configuration key for the executor memory. */ + public static final String EXECUTOR_MEMORY = "spark.executor.memory"; + /** Configuration key for the executor class path. */ + public static final String EXECUTOR_EXTRA_CLASSPATH = "spark.executor.extraClassPath"; + /** Configuration key for the executor VM options. */ + public static final String EXECUTOR_EXTRA_JAVA_OPTIONS = "spark.executor.extraJavaOptions"; + /** Configuration key for the executor native library path. */ + public static final String EXECUTOR_EXTRA_LIBRARY_PATH = "spark.executor.extraLibraryPath"; + /** Configuration key for the number of executor CPU cores. */ + public static final String EXECUTOR_CORES = "spark.executor.cores"; + + private final SparkSubmitCommandBuilder builder; + + public SparkLauncher() { + this(null); + } + + /** + * Creates a launcher that will set the given environment variables in the child. + * + * @param env Environment variables to set. + */ + public SparkLauncher(Map env) { + this.builder = new SparkSubmitCommandBuilder(); + if (env != null) { + this.builder.childEnv.putAll(env); + } + } + + /** + * Set a custom JAVA_HOME for launching the Spark application. + * + * @param javaHome Path to the JAVA_HOME to use. + * @return This launcher. + */ + public SparkLauncher setJavaHome(String javaHome) { + checkNotNull(javaHome, "javaHome"); + builder.javaHome = javaHome; + return this; + } + + /** + * Set a custom Spark installation location for the application. + * + * @param sparkHome Path to the Spark installation to use. + * @return This launcher. + */ + public SparkLauncher setSparkHome(String sparkHome) { + checkNotNull(sparkHome, "sparkHome"); + builder.childEnv.put(ENV_SPARK_HOME, sparkHome); + return this; + } + + /** + * Set a custom properties file with Spark configuration for the application. + * + * @param path Path to custom properties file to use. + * @return This launcher. + */ + public SparkLauncher setPropertiesFile(String path) { + checkNotNull(path, "path"); + builder.propertiesFile = path; + return this; + } + + /** + * Set a single configuration value for the application. + * + * @param key Configuration key. + * @param value The value to use. + * @return This launcher. + */ + public SparkLauncher setConf(String key, String value) { + checkNotNull(key, "key"); + checkNotNull(value, "value"); + checkArgument(key.startsWith("spark."), "'key' must start with 'spark.'"); + builder.conf.put(key, value); + return this; + } + + /** + * Set the application name. + * + * @param appName Application name. + * @return This launcher. + */ + public SparkLauncher setAppName(String appName) { + checkNotNull(appName, "appName"); + builder.appName = appName; + return this; + } + + /** + * Set the Spark master for the application. + * + * @param master Spark master. + * @return This launcher. + */ + public SparkLauncher setMaster(String master) { + checkNotNull(master, "master"); + builder.master = master; + return this; + } + + /** + * Set the deploy mode for the application. + * + * @param mode Deploy mode. + * @return This launcher. + */ + public SparkLauncher setDeployMode(String mode) { + checkNotNull(mode, "mode"); + builder.deployMode = mode; + return this; + } + + /** + * Set the main application resource. This should be the location of a jar file for Scala/Java + * applications, or a python script for PySpark applications. + * + * @param resource Path to the main application resource. + * @return This launcher. + */ + public SparkLauncher setAppResource(String resource) { + checkNotNull(resource, "resource"); + builder.appResource = resource; + return this; + } + + /** + * Sets the application class name for Java/Scala applications. + * + * @param mainClass Application's main class. + * @return This launcher. + */ + public SparkLauncher setMainClass(String mainClass) { + checkNotNull(mainClass, "mainClass"); + builder.mainClass = mainClass; + return this; + } + + /** + * Adds command line arguments for the application. + * + * @param args Arguments to pass to the application's main class. + * @return This launcher. + */ + public SparkLauncher addAppArgs(String... args) { + for (String arg : args) { + checkNotNull(arg, "arg"); + builder.appArgs.add(arg); + } + return this; + } + + /** + * Adds a jar file to be submitted with the application. + * + * @param jar Path to the jar file. + * @return This launcher. + */ + public SparkLauncher addJar(String jar) { + checkNotNull(jar, "jar"); + builder.jars.add(jar); + return this; + } + + /** + * Adds a file to be submitted with the application. + * + * @param file Path to the file. + * @return This launcher. + */ + public SparkLauncher addFile(String file) { + checkNotNull(file, "file"); + builder.files.add(file); + return this; + } + + /** + * Adds a python file / zip / egg to be submitted with the application. + * + * @param file Path to the file. + * @return This launcher. + */ + public SparkLauncher addPyFile(String file) { + checkNotNull(file, "file"); + builder.pyFiles.add(file); + return this; + } + + /** + * Enables verbose reporting for SparkSubmit. + * + * @param verbose Whether to enable verbose output. + * @return This launcher. + */ + public SparkLauncher setVerbose(boolean verbose) { + builder.verbose = verbose; + return this; + } + + /** + * Launches a sub-process that will start the configured Spark application. + * + * @return A process handle for the Spark app. + */ + public Process launch() throws IOException { + List cmd = new ArrayList(); + String script = isWindows() ? "spark-submit.cmd" : "spark-submit"; + cmd.add(join(File.separator, builder.getSparkHome(), "bin", script)); + cmd.addAll(builder.buildSparkSubmitArgs()); + + // Since the child process is a batch script, let's quote things so that special characters are + // preserved, otherwise the batch interpreter will mess up the arguments. Batch scripts are + // weird. + if (isWindows()) { + List winCmd = new ArrayList(); + for (String arg : cmd) { + winCmd.add(quoteForBatchScript(arg)); + } + cmd = winCmd; + } + + ProcessBuilder pb = new ProcessBuilder(cmd.toArray(new String[cmd.size()])); + for (Map.Entry e : builder.childEnv.entrySet()) { + pb.environment().put(e.getKey(), e.getValue()); + } + return pb.start(); + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java new file mode 100644 index 000000000000..a73c9c87e312 --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -0,0 +1,360 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.File; +import java.io.IOException; +import java.util.*; + +import static org.apache.spark.launcher.CommandBuilderUtils.*; + +/** + * Special command builder for handling a CLI invocation of SparkSubmit. + *

    + * This builder adds command line parsing compatible with SparkSubmit. It handles setting + * driver-side options and special parsing behavior needed for the special-casing certain internal + * Spark applications. + *

    + * This class has also some special features to aid launching pyspark. + */ +class SparkSubmitCommandBuilder extends AbstractCommandBuilder { + + /** + * Name of the app resource used to identify the PySpark shell. The command line parser expects + * the resource name to be the very first argument to spark-submit in this case. + * + * NOTE: this cannot be "pyspark-shell" since that identifies the PySpark shell to SparkSubmit + * (see java_gateway.py), and can cause this code to enter into an infinite loop. + */ + static final String PYSPARK_SHELL = "pyspark-shell-main"; + + /** + * This is the actual resource name that identifies the PySpark shell to SparkSubmit. + */ + static final String PYSPARK_SHELL_RESOURCE = "pyspark-shell"; + + /** + * Name of the app resource used to identify the SparkR shell. The command line parser expects + * the resource name to be the very first argument to spark-submit in this case. + * + * NOTE: this cannot be "sparkr-shell" since that identifies the SparkR shell to SparkSubmit + * (see sparkR.R), and can cause this code to enter into an infinite loop. + */ + static final String SPARKR_SHELL = "sparkr-shell-main"; + + /** + * This is the actual resource name that identifies the SparkR shell to SparkSubmit. + */ + static final String SPARKR_SHELL_RESOURCE = "sparkr-shell"; + + /** + * This map must match the class names for available special classes, since this modifies the way + * command line parsing works. This maps the class name to the resource to use when calling + * spark-submit. + */ + private static final Map specialClasses = new HashMap(); + static { + specialClasses.put("org.apache.spark.repl.Main", "spark-shell"); + specialClasses.put("org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver", + "spark-internal"); + specialClasses.put("org.apache.spark.sql.hive.thriftserver.HiveThriftServer2", + "spark-internal"); + } + + private final List sparkArgs; + + /** + * Controls whether mixing spark-submit arguments with app arguments is allowed. This is needed + * to parse the command lines for things like bin/spark-shell, which allows users to mix and + * match arguments (e.g. "bin/spark-shell SparkShellArg --master foo"). + */ + private boolean allowsMixedArguments; + + SparkSubmitCommandBuilder() { + this.sparkArgs = new ArrayList(); + } + + SparkSubmitCommandBuilder(List args) { + this(); + List submitArgs = args; + if (args.size() > 0 && args.get(0).equals(PYSPARK_SHELL)) { + this.allowsMixedArguments = true; + appResource = PYSPARK_SHELL_RESOURCE; + submitArgs = args.subList(1, args.size()); + } else if (args.size() > 0 && args.get(0).equals(SPARKR_SHELL)) { + this.allowsMixedArguments = true; + appResource = SPARKR_SHELL_RESOURCE; + submitArgs = args.subList(1, args.size()); + } else { + this.allowsMixedArguments = false; + } + + new OptionParser().parse(submitArgs); + } + + @Override + public List buildCommand(Map env) throws IOException { + if (PYSPARK_SHELL_RESOURCE.equals(appResource)) { + return buildPySparkShellCommand(env); + } else if (SPARKR_SHELL_RESOURCE.equals(appResource)) { + return buildSparkRCommand(env); + } else { + return buildSparkSubmitCommand(env); + } + } + + List buildSparkSubmitArgs() { + List args = new ArrayList(); + SparkSubmitOptionParser parser = new SparkSubmitOptionParser(); + + if (verbose) { + args.add(parser.VERBOSE); + } + + if (master != null) { + args.add(parser.MASTER); + args.add(master); + } + + if (deployMode != null) { + args.add(parser.DEPLOY_MODE); + args.add(deployMode); + } + + if (appName != null) { + args.add(parser.NAME); + args.add(appName); + } + + for (Map.Entry e : conf.entrySet()) { + args.add(parser.CONF); + args.add(String.format("%s=%s", e.getKey(), e.getValue())); + } + + if (propertiesFile != null) { + args.add(parser.PROPERTIES_FILE); + args.add(propertiesFile); + } + + if (!jars.isEmpty()) { + args.add(parser.JARS); + args.add(join(",", jars)); + } + + if (!files.isEmpty()) { + args.add(parser.FILES); + args.add(join(",", files)); + } + + if (!pyFiles.isEmpty()) { + args.add(parser.PY_FILES); + args.add(join(",", pyFiles)); + } + + if (mainClass != null) { + args.add(parser.CLASS); + args.add(mainClass); + } + + args.addAll(sparkArgs); + if (appResource != null) { + args.add(appResource); + } + args.addAll(appArgs); + + return args; + } + + private List buildSparkSubmitCommand(Map env) throws IOException { + // Load the properties file and check whether spark-submit will be running the app's driver + // or just launching a cluster app. When running the driver, the JVM's argument will be + // modified to cover the driver's configuration. + Properties props = loadPropertiesFile(); + boolean isClientMode = isClientMode(props); + String extraClassPath = isClientMode ? + firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_CLASSPATH, conf, props) : null; + + List cmd = buildJavaCommand(extraClassPath); + addOptionString(cmd, System.getenv("SPARK_SUBMIT_OPTS")); + addOptionString(cmd, System.getenv("SPARK_JAVA_OPTS")); + + if (isClientMode) { + // Figuring out where the memory value come from is a little tricky due to precedence. + // Precedence is observed in the following order: + // - explicit configuration (setConf()), which also covers --driver-memory cli argument. + // - properties file. + // - SPARK_DRIVER_MEMORY env variable + // - SPARK_MEM env variable + // - default value (512m) + String memory = firstNonEmpty(firstNonEmptyValue(SparkLauncher.DRIVER_MEMORY, conf, props), + System.getenv("SPARK_DRIVER_MEMORY"), System.getenv("SPARK_MEM"), DEFAULT_MEM); + cmd.add("-Xms" + memory); + cmd.add("-Xmx" + memory); + addOptionString(cmd, firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, conf, props)); + mergeEnvPathList(env, getLibPathEnvName(), + firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, conf, props)); + } + + addPermGenSizeOpt(cmd); + cmd.add("org.apache.spark.deploy.SparkSubmit"); + cmd.addAll(buildSparkSubmitArgs()); + return cmd; + } + + private List buildPySparkShellCommand(Map env) throws IOException { + // For backwards compatibility, if a script is specified in + // the pyspark command line, then run it using spark-submit. + if (!appArgs.isEmpty() && appArgs.get(0).endsWith(".py")) { + System.err.println( + "WARNING: Running python applications through 'pyspark' is deprecated as of Spark 1.0.\n" + + "Use ./bin/spark-submit "); + appResource = appArgs.get(0); + appArgs.remove(0); + return buildCommand(env); + } + + checkArgument(appArgs.isEmpty(), "pyspark does not support any application options."); + + // When launching the pyspark shell, the spark-submit arguments should be stored in the + // PYSPARK_SUBMIT_ARGS env variable. + constructEnvVarArgs(env, "PYSPARK_SUBMIT_ARGS"); + + // The executable is the PYSPARK_DRIVER_PYTHON env variable set by the pyspark script, + // followed by PYSPARK_DRIVER_PYTHON_OPTS. + List pyargs = new ArrayList(); + pyargs.add(firstNonEmpty(System.getenv("PYSPARK_DRIVER_PYTHON"), "python")); + String pyOpts = System.getenv("PYSPARK_DRIVER_PYTHON_OPTS"); + if (!isEmpty(pyOpts)) { + pyargs.addAll(parseOptionString(pyOpts)); + } + + return pyargs; + } + + private List buildSparkRCommand(Map env) throws IOException { + if (!appArgs.isEmpty() && appArgs.get(0).endsWith(".R")) { + appResource = appArgs.get(0); + appArgs.remove(0); + return buildCommand(env); + } + // When launching the SparkR shell, store the spark-submit arguments in the SPARKR_SUBMIT_ARGS + // env variable. + constructEnvVarArgs(env, "SPARKR_SUBMIT_ARGS"); + + // Set shell.R as R_PROFILE_USER to load the SparkR package when the shell comes up. + String sparkHome = System.getenv("SPARK_HOME"); + env.put("R_PROFILE_USER", + join(File.separator, sparkHome, "R", "lib", "SparkR", "profile", "shell.R")); + + List args = new ArrayList(); + args.add(firstNonEmpty(System.getenv("SPARKR_DRIVER_R"), "R")); + return args; + } + + private void constructEnvVarArgs( + Map env, + String submitArgsEnvVariable) throws IOException { + Properties props = loadPropertiesFile(); + mergeEnvPathList(env, getLibPathEnvName(), + firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, conf, props)); + + StringBuilder submitArgs = new StringBuilder(); + for (String arg : buildSparkSubmitArgs()) { + if (submitArgs.length() > 0) { + submitArgs.append(" "); + } + submitArgs.append(quoteForCommandString(arg)); + } + env.put(submitArgsEnvVariable, submitArgs.toString()); + } + + + private boolean isClientMode(Properties userProps) { + String userMaster = firstNonEmpty(master, (String) userProps.get(SparkLauncher.SPARK_MASTER)); + // Default master is "local[*]", so assume client mode in that case. + return userMaster == null || + "client".equals(deployMode) || + (!userMaster.equals("yarn-cluster") && deployMode == null); + } + + private class OptionParser extends SparkSubmitOptionParser { + + @Override + protected boolean handle(String opt, String value) { + if (opt.equals(MASTER)) { + master = value; + } else if (opt.equals(DEPLOY_MODE)) { + deployMode = value; + } else if (opt.equals(PROPERTIES_FILE)) { + propertiesFile = value; + } else if (opt.equals(DRIVER_MEMORY)) { + conf.put(SparkLauncher.DRIVER_MEMORY, value); + } else if (opt.equals(DRIVER_JAVA_OPTIONS)) { + conf.put(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, value); + } else if (opt.equals(DRIVER_LIBRARY_PATH)) { + conf.put(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, value); + } else if (opt.equals(DRIVER_CLASS_PATH)) { + conf.put(SparkLauncher.DRIVER_EXTRA_CLASSPATH, value); + } else if (opt.equals(CONF)) { + String[] setConf = value.split("=", 2); + checkArgument(setConf.length == 2, "Invalid argument to %s: %s", CONF, value); + conf.put(setConf[0], setConf[1]); + } else if (opt.equals(CLASS)) { + // The special classes require some special command line handling, since they allow + // mixing spark-submit arguments with arguments that should be propagated to the shell + // itself. Note that for this to work, the "--class" argument must come before any + // non-spark-submit arguments. + mainClass = value; + if (specialClasses.containsKey(value)) { + allowsMixedArguments = true; + appResource = specialClasses.get(value); + } + } else { + sparkArgs.add(opt); + if (value != null) { + sparkArgs.add(value); + } + } + return true; + } + + @Override + protected boolean handleUnknown(String opt) { + // When mixing arguments, add unrecognized parameters directly to the user arguments list. In + // normal mode, any unrecognized parameter triggers the end of command line parsing, and the + // parameter itself will be interpreted by SparkSubmit as the application resource. The + // remaining params will be appended to the list of SparkSubmit arguments. + if (allowsMixedArguments) { + appArgs.add(opt); + return true; + } else { + sparkArgs.add(opt); + return false; + } + } + + @Override + protected void handleExtraArgs(List extra) { + for (String arg : extra) { + sparkArgs.add(arg); + } + } + + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java new file mode 100644 index 000000000000..8526d2e7cfa3 --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Parser for spark-submit command line options. + *

    + * This class encapsulates the parsing code for spark-submit command line options, so that there + * is a single list of options that needs to be maintained (well, sort of, but it makes it harder + * to break things). + */ +class SparkSubmitOptionParser { + + // The following constants define the "main" name for the available options. They're defined + // to avoid copy & paste of the raw strings where they're needed. + // + // The fields are not static so that they're exposed to Scala code that uses this class. See + // SparkSubmitArguments.scala. That is also why this class is not abstract - to allow code to + // easily use these constants without having to create dummy implementations of this class. + protected final String CLASS = "--class"; + protected final String CONF = "--conf"; + protected final String DEPLOY_MODE = "--deploy-mode"; + protected final String DRIVER_CLASS_PATH = "--driver-class-path"; + protected final String DRIVER_CORES = "--driver-cores"; + protected final String DRIVER_JAVA_OPTIONS = "--driver-java-options"; + protected final String DRIVER_LIBRARY_PATH = "--driver-library-path"; + protected final String DRIVER_MEMORY = "--driver-memory"; + protected final String EXECUTOR_MEMORY = "--executor-memory"; + protected final String FILES = "--files"; + protected final String JARS = "--jars"; + protected final String KILL_SUBMISSION = "--kill"; + protected final String MASTER = "--master"; + protected final String NAME = "--name"; + protected final String PACKAGES = "--packages"; + protected final String PROPERTIES_FILE = "--properties-file"; + protected final String PROXY_USER = "--proxy-user"; + protected final String PY_FILES = "--py-files"; + protected final String REPOSITORIES = "--repositories"; + protected final String STATUS = "--status"; + protected final String TOTAL_EXECUTOR_CORES = "--total-executor-cores"; + + // Options that do not take arguments. + protected final String HELP = "--help"; + protected final String SUPERVISE = "--supervise"; + protected final String VERBOSE = "--verbose"; + protected final String VERSION = "--version"; + + // Standalone-only options. + + // YARN-only options. + protected final String ARCHIVES = "--archives"; + protected final String EXECUTOR_CORES = "--executor-cores"; + protected final String QUEUE = "--queue"; + protected final String NUM_EXECUTORS = "--num-executors"; + + /** + * This is the canonical list of spark-submit options. Each entry in the array contains the + * different aliases for the same option; the first element of each entry is the "official" + * name of the option, passed to {@link #handle(String, String)}. + *

    + * Options not listed here nor in the "switch" list below will result in a call to + * {@link $#handleUnknown(String)}. + *

    + * These two arrays are visible for tests. + */ + final String[][] opts = { + { ARCHIVES }, + { CLASS }, + { CONF, "-c" }, + { DEPLOY_MODE }, + { DRIVER_CLASS_PATH }, + { DRIVER_CORES }, + { DRIVER_JAVA_OPTIONS }, + { DRIVER_LIBRARY_PATH }, + { DRIVER_MEMORY }, + { EXECUTOR_CORES }, + { EXECUTOR_MEMORY }, + { FILES }, + { JARS }, + { KILL_SUBMISSION }, + { MASTER }, + { NAME }, + { NUM_EXECUTORS }, + { PACKAGES }, + { PROPERTIES_FILE }, + { PROXY_USER }, + { PY_FILES }, + { QUEUE }, + { REPOSITORIES }, + { STATUS }, + { TOTAL_EXECUTOR_CORES }, + }; + + /** + * List of switches (command line options that do not take parameters) recognized by spark-submit. + */ + final String[][] switches = { + { HELP, "-h" }, + { SUPERVISE }, + { VERBOSE, "-v" }, + { VERSION }, + }; + + /** + * Parse a list of spark-submit command line options. + *

    + * See SparkSubmitArguments.scala for a more formal description of available options. + * + * @throws IllegalArgumentException If an error is found during parsing. + */ + protected final void parse(List args) { + Pattern eqSeparatedOpt = Pattern.compile("(--[^=]+)=(.+)"); + + int idx = 0; + for (idx = 0; idx < args.size(); idx++) { + String arg = args.get(idx); + String value = null; + + Matcher m = eqSeparatedOpt.matcher(arg); + if (m.matches()) { + arg = m.group(1); + value = m.group(2); + } + + // Look for options with a value. + String name = findCliOption(arg, opts); + if (name != null) { + if (value == null) { + if (idx == args.size() - 1) { + throw new IllegalArgumentException( + String.format("Missing argument for option '%s'.", arg)); + } + idx++; + value = args.get(idx); + } + if (!handle(name, value)) { + break; + } + continue; + } + + // Look for a switch. + name = findCliOption(arg, switches); + if (name != null) { + if (!handle(name, null)) { + break; + } + continue; + } + + if (!handleUnknown(arg)) { + break; + } + } + + if (idx < args.size()) { + idx++; + } + handleExtraArgs(args.subList(idx, args.size())); + } + + /** + * Callback for when an option with an argument is parsed. + * + * @param opt The long name of the cli option (might differ from actual command line). + * @param value The value. This will be null if the option does not take a value. + * @return Whether to continue parsing the argument list. + */ + protected boolean handle(String opt, String value) { + throw new UnsupportedOperationException(); + } + + /** + * Callback for when an unrecognized option is parsed. + * + * @param opt Unrecognized option from the command line. + * @return Whether to continue parsing the argument list. + */ + protected boolean handleUnknown(String opt) { + throw new UnsupportedOperationException(); + } + + /** + * Callback for remaining command line arguments after either {@link #handle(String, String)} or + * {@link #handleUnknown(String)} return "false". This will be called at the end of parsing even + * when there are no remaining arguments. + * + * @param extra List of remaining arguments. + */ + protected void handleExtraArgs(List extra) { + throw new UnsupportedOperationException(); + } + + private String findCliOption(String name, String[][] available) { + for (String[] candidates : available) { + for (String candidate : candidates) { + if (candidate.equals(name)) { + return candidates[0]; + } + } + } + return null; + } + +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/package-info.java b/launcher/src/main/java/org/apache/spark/launcher/package-info.java new file mode 100644 index 000000000000..7ed756f4b859 --- /dev/null +++ b/launcher/src/main/java/org/apache/spark/launcher/package-info.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Library for launching Spark applications. + *

    + * This library allows applications to launch Spark programmatically. There's only one entry + * point to the library - the {@link org.apache.spark.launcher.SparkLauncher} class. + *

    + * To launch a Spark application, just instantiate a {@link org.apache.spark.launcher.SparkLauncher} + * and configure the application to run. For example: + * + *

    + * {@code
    + *   import org.apache.spark.launcher.SparkLauncher;
    + *
    + *   public class MyLauncher {
    + *     public static void main(String[] args) throws Exception {
    + *       Process spark = new SparkLauncher()
    + *         .setAppResource("/my/app.jar")
    + *         .setMainClass("my.spark.app.Main")
    + *         .setMaster("local")
    + *         .setConf(SparkLauncher.DRIVER_MEMORY, "2g")
    + *         .launch();
    + *       spark.waitFor();
    + *     }
    + *   }
    + * }
    + * 
    + */ +package org.apache.spark.launcher; diff --git a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java new file mode 100644 index 000000000000..1ae42eed8a3a --- /dev/null +++ b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.junit.Test; +import static org.junit.Assert.*; + +import static org.apache.spark.launcher.CommandBuilderUtils.*; + +public class CommandBuilderUtilsSuite { + + @Test + public void testValidOptionStrings() { + testOpt("a b c d e", Arrays.asList("a", "b", "c", "d", "e")); + testOpt("a 'b c' \"d\" e", Arrays.asList("a", "b c", "d", "e")); + testOpt("a 'b\\\"c' \"'d'\" e", Arrays.asList("a", "b\\\"c", "'d'", "e")); + testOpt("a 'b\"c' \"\\\"d\\\"\" e", Arrays.asList("a", "b\"c", "\"d\"", "e")); + testOpt(" a b c \\\\ ", Arrays.asList("a", "b", "c", "\\")); + + // Following tests ported from UtilsSuite.scala. + testOpt("", new ArrayList()); + testOpt("a", Arrays.asList("a")); + testOpt("aaa", Arrays.asList("aaa")); + testOpt("a b c", Arrays.asList("a", "b", "c")); + testOpt(" a b\t c ", Arrays.asList("a", "b", "c")); + testOpt("a 'b c'", Arrays.asList("a", "b c")); + testOpt("a 'b c' d", Arrays.asList("a", "b c", "d")); + testOpt("'b c'", Arrays.asList("b c")); + testOpt("a \"b c\"", Arrays.asList("a", "b c")); + testOpt("a \"b c\" d", Arrays.asList("a", "b c", "d")); + testOpt("\"b c\"", Arrays.asList("b c")); + testOpt("a 'b\" c' \"d' e\"", Arrays.asList("a", "b\" c", "d' e")); + testOpt("a\t'b\nc'\nd", Arrays.asList("a", "b\nc", "d")); + testOpt("a \"b\\\\c\"", Arrays.asList("a", "b\\c")); + testOpt("a \"b\\\"c\"", Arrays.asList("a", "b\"c")); + testOpt("a 'b\\\"c'", Arrays.asList("a", "b\\\"c")); + testOpt("'a'b", Arrays.asList("ab")); + testOpt("'a''b'", Arrays.asList("ab")); + testOpt("\"a\"b", Arrays.asList("ab")); + testOpt("\"a\"\"b\"", Arrays.asList("ab")); + testOpt("''", Arrays.asList("")); + testOpt("\"\"", Arrays.asList("")); + } + + @Test + public void testInvalidOptionStrings() { + testInvalidOpt("\\"); + testInvalidOpt("\"abcde"); + testInvalidOpt("'abcde"); + } + + @Test + public void testWindowsBatchQuoting() { + assertEquals("abc", quoteForBatchScript("abc")); + assertEquals("\"a b c\"", quoteForBatchScript("a b c")); + assertEquals("\"a \"\"b\"\" c\"", quoteForBatchScript("a \"b\" c")); + assertEquals("\"a\"\"b\"\"c\"", quoteForBatchScript("a\"b\"c")); + assertEquals("\"ab^=\"\"cd\"\"\"", quoteForBatchScript("ab=\"cd\"")); + } + + @Test + public void testPythonArgQuoting() { + assertEquals("\"abc\"", quoteForCommandString("abc")); + assertEquals("\"a b c\"", quoteForCommandString("a b c")); + assertEquals("\"a \\\"b\\\" c\"", quoteForCommandString("a \"b\" c")); + } + + private void testOpt(String opts, List expected) { + assertEquals(String.format("test string failed to parse: [[ %s ]]", opts), + expected, parseOptionString(opts)); + } + + private void testInvalidOpt(String opts) { + try { + parseOptionString(opts); + fail("Expected exception for invalid option string."); + } catch (IllegalArgumentException e) { + // pass. + } + } + +} diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java new file mode 100644 index 000000000000..252d5abae1ca --- /dev/null +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.BufferedReader; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import static org.junit.Assert.*; + +/** + * These tests require the Spark assembly to be built before they can be run. + */ +public class SparkLauncherSuite { + + private static final Logger LOG = LoggerFactory.getLogger(SparkLauncherSuite.class); + + @Test + public void testChildProcLauncher() throws Exception { + Map env = new HashMap(); + env.put("SPARK_PRINT_LAUNCH_COMMAND", "1"); + + SparkLauncher launcher = new SparkLauncher(env) + .setSparkHome(System.getProperty("spark.test.home")) + .setMaster("local") + .setAppResource("spark-internal") + .setConf(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, + "-Dfoo=bar -Dtest.name=-testChildProcLauncher") + .setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path")) + .setMainClass(SparkLauncherTestApp.class.getName()) + .addAppArgs("proc"); + final Process app = launcher.launch(); + new Redirector("stdout", app.getInputStream()).start(); + new Redirector("stderr", app.getErrorStream()).start(); + assertEquals(0, app.waitFor()); + } + + public static class SparkLauncherTestApp { + + public static void main(String[] args) throws Exception { + assertEquals(1, args.length); + assertEquals("proc", args[0]); + assertEquals("bar", System.getProperty("foo")); + assertEquals("local", System.getProperty(SparkLauncher.SPARK_MASTER)); + } + + } + + private static class Redirector extends Thread { + + private final InputStream in; + + Redirector(String name, InputStream in) { + this.in = in; + setName(name); + setDaemon(true); + } + + @Override + public void run() { + try { + BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8")); + String line; + while ((line = reader.readLine()) != null) { + LOG.warn(line); + } + } catch (Exception e) { + LOG.error("Error reading process output.", e); + } + } + + } + +} diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java new file mode 100644 index 000000000000..97043a76cc61 --- /dev/null +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.io.File; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import static org.junit.Assert.*; + +public class SparkSubmitCommandBuilderSuite { + + private static File dummyPropsFile; + private static SparkSubmitOptionParser parser; + + @BeforeClass + public static void setUp() throws Exception { + dummyPropsFile = File.createTempFile("spark", "properties"); + parser = new SparkSubmitOptionParser(); + } + + @AfterClass + public static void cleanUp() throws Exception { + dummyPropsFile.delete(); + } + + @Test + public void testDriverCmdBuilder() throws Exception { + testCmdBuilder(true); + } + + @Test + public void testClusterCmdBuilder() throws Exception { + testCmdBuilder(false); + } + + @Test + public void testCliParser() throws Exception { + List sparkSubmitArgs = Arrays.asList( + parser.MASTER, + "local", + parser.DRIVER_MEMORY, + "42g", + parser.DRIVER_CLASS_PATH, + "/driverCp", + parser.DRIVER_JAVA_OPTIONS, + "extraJavaOpt", + parser.CONF, + "spark.randomOption=foo", + parser.CONF, + SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH + "=/driverLibPath"); + Map env = new HashMap(); + List cmd = buildCommand(sparkSubmitArgs, env); + + assertTrue(findInStringList(env.get(CommandBuilderUtils.getLibPathEnvName()), + File.pathSeparator, "/driverLibPath")); + assertTrue(findInStringList(findArgValue(cmd, "-cp"), File.pathSeparator, "/driverCp")); + assertTrue("Driver -Xms should be configured.", cmd.contains("-Xms42g")); + assertTrue("Driver -Xmx should be configured.", cmd.contains("-Xmx42g")); + assertTrue("Command should contain user-defined conf.", + Collections.indexOfSubList(cmd, Arrays.asList(parser.CONF, "spark.randomOption=foo")) > 0); + } + + @Test + public void testShellCliParser() throws Exception { + List sparkSubmitArgs = Arrays.asList( + parser.CLASS, + "org.apache.spark.repl.Main", + parser.MASTER, + "foo", + "--app-arg", + "bar", + "--app-switch", + parser.FILES, + "baz", + parser.NAME, + "appName"); + + List args = newCommandBuilder(sparkSubmitArgs).buildSparkSubmitArgs(); + List expected = Arrays.asList("spark-shell", "--app-arg", "bar", "--app-switch"); + assertEquals(expected, args.subList(args.size() - expected.size(), args.size())); + } + + @Test + public void testAlternateSyntaxParsing() throws Exception { + List sparkSubmitArgs = Arrays.asList( + parser.CLASS + "=org.my.Class", + parser.MASTER + "=foo", + parser.DEPLOY_MODE + "=bar"); + + List cmd = newCommandBuilder(sparkSubmitArgs).buildSparkSubmitArgs(); + assertEquals("org.my.Class", findArgValue(cmd, parser.CLASS)); + assertEquals("foo", findArgValue(cmd, parser.MASTER)); + assertEquals("bar", findArgValue(cmd, parser.DEPLOY_MODE)); + } + + @Test + public void testPySparkLauncher() throws Exception { + List sparkSubmitArgs = Arrays.asList( + SparkSubmitCommandBuilder.PYSPARK_SHELL, + "--master=foo", + "--deploy-mode=bar"); + + Map env = new HashMap(); + List cmd = buildCommand(sparkSubmitArgs, env); + assertEquals("python", cmd.get(cmd.size() - 1)); + assertEquals( + String.format("\"%s\" \"foo\" \"%s\" \"bar\" \"%s\"", + parser.MASTER, parser.DEPLOY_MODE, SparkSubmitCommandBuilder.PYSPARK_SHELL_RESOURCE), + env.get("PYSPARK_SUBMIT_ARGS")); + } + + @Test + public void testPySparkFallback() throws Exception { + List sparkSubmitArgs = Arrays.asList( + "--master=foo", + "--deploy-mode=bar", + "script.py", + "arg1"); + + Map env = new HashMap(); + List cmd = buildCommand(sparkSubmitArgs, env); + + assertEquals("foo", findArgValue(cmd, "--master")); + assertEquals("bar", findArgValue(cmd, "--deploy-mode")); + assertEquals("script.py", cmd.get(cmd.size() - 2)); + assertEquals("arg1", cmd.get(cmd.size() - 1)); + } + + private void testCmdBuilder(boolean isDriver) throws Exception { + String deployMode = isDriver ? "client" : "cluster"; + + SparkSubmitCommandBuilder launcher = + newCommandBuilder(Collections.emptyList()); + launcher.childEnv.put(CommandBuilderUtils.ENV_SPARK_HOME, + System.getProperty("spark.test.home")); + launcher.master = "yarn"; + launcher.deployMode = deployMode; + launcher.appResource = "/foo"; + launcher.appName = "MyApp"; + launcher.mainClass = "my.Class"; + launcher.propertiesFile = dummyPropsFile.getAbsolutePath(); + launcher.appArgs.add("foo"); + launcher.appArgs.add("bar"); + launcher.conf.put(SparkLauncher.DRIVER_MEMORY, "1g"); + launcher.conf.put(SparkLauncher.DRIVER_EXTRA_CLASSPATH, "/driver"); + launcher.conf.put(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, "-Ddriver -XX:MaxPermSize=256m"); + launcher.conf.put(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, "/native"); + launcher.conf.put("spark.foo", "foo"); + + Map env = new HashMap(); + List cmd = launcher.buildCommand(env); + + // Checks below are different for driver and non-driver mode. + + if (isDriver) { + assertTrue("Driver -Xms should be configured.", cmd.contains("-Xms1g")); + assertTrue("Driver -Xmx should be configured.", cmd.contains("-Xmx1g")); + } else { + boolean found = false; + for (String arg : cmd) { + if (arg.startsWith("-Xms") || arg.startsWith("-Xmx")) { + found = true; + break; + } + } + assertFalse("Memory arguments should not be set.", found); + } + + for (String arg : cmd) { + if (arg.startsWith("-XX:MaxPermSize=")) { + if (isDriver) { + assertEquals("-XX:MaxPermSize=256m", arg); + } else { + assertEquals("-XX:MaxPermSize=128m", arg); + } + } + } + + String[] cp = findArgValue(cmd, "-cp").split(Pattern.quote(File.pathSeparator)); + if (isDriver) { + assertTrue("Driver classpath should contain provided entry.", contains("/driver", cp)); + } else { + assertFalse("Driver classpath should not be in command.", contains("/driver", cp)); + } + + String libPath = env.get(CommandBuilderUtils.getLibPathEnvName()); + if (isDriver) { + assertNotNull("Native library path should be set.", libPath); + assertTrue("Native library path should contain provided entry.", + contains("/native", libPath.split(Pattern.quote(File.pathSeparator)))); + } else { + assertNull("Native library should not be set.", libPath); + } + + // Checks below are the same for both driver and non-driver mode. + assertEquals(dummyPropsFile.getAbsolutePath(), findArgValue(cmd, parser.PROPERTIES_FILE)); + assertEquals("yarn", findArgValue(cmd, parser.MASTER)); + assertEquals(deployMode, findArgValue(cmd, parser.DEPLOY_MODE)); + assertEquals("my.Class", findArgValue(cmd, parser.CLASS)); + assertEquals("MyApp", findArgValue(cmd, parser.NAME)); + + boolean appArgsOk = false; + for (int i = 0; i < cmd.size(); i++) { + if (cmd.get(i).equals("/foo")) { + assertEquals("foo", cmd.get(i + 1)); + assertEquals("bar", cmd.get(i + 2)); + assertEquals(cmd.size(), i + 3); + appArgsOk = true; + break; + } + } + assertTrue("App resource and args should be added to command.", appArgsOk); + + Map conf = parseConf(cmd, parser); + assertEquals("foo", conf.get("spark.foo")); + } + + private boolean contains(String needle, String[] haystack) { + for (String entry : haystack) { + if (entry.equals(needle)) { + return true; + } + } + return false; + } + + private Map parseConf(List cmd, SparkSubmitOptionParser parser) { + Map conf = new HashMap(); + for (int i = 0; i < cmd.size(); i++) { + if (cmd.get(i).equals(parser.CONF)) { + String[] val = cmd.get(i + 1).split("=", 2); + conf.put(val[0], val[1]); + i += 1; + } + } + return conf; + } + + private String findArgValue(List cmd, String name) { + for (int i = 0; i < cmd.size(); i++) { + if (cmd.get(i).equals(name)) { + return cmd.get(i + 1); + } + } + fail(String.format("arg '%s' not found", name)); + return null; + } + + private boolean findInStringList(String list, String sep, String needle) { + return contains(needle, list.split(sep)); + } + + private SparkSubmitCommandBuilder newCommandBuilder(List args) { + SparkSubmitCommandBuilder builder = new SparkSubmitCommandBuilder(args); + builder.childEnv.put(CommandBuilderUtils.ENV_SPARK_HOME, System.getProperty("spark.test.home")); + builder.childEnv.put(CommandBuilderUtils.ENV_SPARK_ASSEMBLY, "dummy"); + return builder; + } + + private List buildCommand(List args, Map env) throws Exception { + return newCommandBuilder(args).buildCommand(env); + } + +} diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java new file mode 100644 index 000000000000..f3d210991705 --- /dev/null +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitOptionParserSuite.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import org.junit.Before; +import org.junit.Test; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import static org.apache.spark.launcher.SparkSubmitOptionParser.*; + +public class SparkSubmitOptionParserSuite { + + private SparkSubmitOptionParser parser; + + @Before + public void setUp() { + parser = spy(new DummyParser()); + } + + @Test + public void testAllOptions() { + int count = 0; + for (String[] optNames : parser.opts) { + for (String optName : optNames) { + String value = optName + "-value"; + parser.parse(Arrays.asList(optName, value)); + count++; + verify(parser).handle(eq(optNames[0]), eq(value)); + verify(parser, times(count)).handle(anyString(), anyString()); + verify(parser, times(count)).handleExtraArgs(eq(Collections.emptyList())); + } + } + + for (String[] switchNames : parser.switches) { + int switchCount = 0; + for (String name : switchNames) { + parser.parse(Arrays.asList(name)); + count++; + switchCount++; + verify(parser, times(switchCount)).handle(eq(switchNames[0]), same((String) null)); + verify(parser, times(count)).handle(anyString(), any(String.class)); + verify(parser, times(count)).handleExtraArgs(eq(Collections.emptyList())); + } + } + } + + @Test + public void testExtraOptions() { + List args = Arrays.asList(parser.MASTER, parser.MASTER, "foo", "bar"); + parser.parse(args); + verify(parser).handle(eq(parser.MASTER), eq(parser.MASTER)); + verify(parser).handleUnknown(eq("foo")); + verify(parser).handleExtraArgs(eq(Arrays.asList("bar"))); + } + + @Test(expected=IllegalArgumentException.class) + public void testMissingArg() { + parser.parse(Arrays.asList(parser.MASTER)); + } + + @Test + public void testEqualSeparatedOption() { + List args = Arrays.asList(parser.MASTER + "=" + parser.MASTER); + parser.parse(args); + verify(parser).handle(eq(parser.MASTER), eq(parser.MASTER)); + verify(parser).handleExtraArgs(eq(Collections.emptyList())); + } + + private static class DummyParser extends SparkSubmitOptionParser { + + @Override + protected boolean handle(String opt, String value) { + return true; + } + + @Override + protected boolean handleUnknown(String opt) { + return false; + } + + @Override + protected void handleExtraArgs(List extra) { + + } + + } + +} diff --git a/launcher/src/test/resources/log4j.properties b/launcher/src/test/resources/log4j.properties new file mode 100644 index 000000000000..67a6a9821711 --- /dev/null +++ b/launcher/src/test/resources/log4j.properties @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file core/target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=false + +# Some tests will set "test.name" to avoid overwriting the main log file. +log4j.appender.file.file=target/unit-tests${test.name}.log + +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/make-distribution.sh b/make-distribution.sh index 4e2f400be305..cb65932b4abc 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -32,6 +32,10 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$SPARK_HOME/dist" SPARK_TACHYON=false +TACHYON_VERSION="0.6.4" +TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" +TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}" + MAKE_TGZ=false NAME=none MVN="$SPARK_HOME/build/mvn" @@ -93,10 +97,10 @@ done if [ -z "$JAVA_HOME" ]; then # Fall back on JAVA_HOME from rpm, if found - if which rpm &>/dev/null; then - RPM_JAVA_HOME=$(rpm -E %java_home 2>/dev/null) + if [ $(command -v rpm) ]; then + RPM_JAVA_HOME="$(rpm -E %java_home 2>/dev/null)" if [ "$RPM_JAVA_HOME" != "%java_home" ]; then - JAVA_HOME=$RPM_JAVA_HOME + JAVA_HOME="$RPM_JAVA_HOME" echo "No JAVA_HOME set, proceeding with '$JAVA_HOME' learned from rpm" fi fi @@ -107,25 +111,29 @@ if [ -z "$JAVA_HOME" ]; then exit -1 fi -if which git &>/dev/null; then +if [ $(command -v git) ]; then GITREV=$(git rev-parse --short HEAD 2>/dev/null || :) - if [ ! -z $GITREV ]; then + if [ ! -z "$GITREV" ]; then GITREVSTRING=" (git revision $GITREV)" fi unset GITREV fi -if ! which $MVN &>/dev/null; then + +if [ ! $(command -v "$MVN") ] ; then echo -e "Could not locate Maven command: '$MVN'." echo -e "Specify the Maven command with the --mvn flag" exit -1; fi -VERSION=$(mvn help:evaluate -Dexpression=project.version 2>/dev/null | grep -v "INFO" | tail -n 1) -SPARK_HADOOP_VERSION=$(mvn help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\ +VERSION=$("$MVN" help:evaluate -Dexpression=project.version 2>/dev/null | grep -v "INFO" | tail -n 1) +SCALA_VERSION=$("$MVN" help:evaluate -Dexpression=scala.binary.version $@ 2>/dev/null\ + | grep -v "INFO"\ + | tail -n 1) +SPARK_HADOOP_VERSION=$("$MVN" help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\ | grep -v "INFO"\ | tail -n 1) -SPARK_HIVE=$($MVN help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\ +SPARK_HIVE=$("$MVN" help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\ | grep -v "INFO"\ | fgrep --count "hive";\ # Reset exit status to 0, otherwise the script stops here if the last grep finds nothing\ @@ -142,7 +150,7 @@ if [[ ! "$JAVA_VERSION" =~ "1.6" && -z "$SKIP_JAVA_TEST" ]]; then echo "Output from 'java -version' was:" echo "$JAVA_VERSION" read -p "Would you like to continue anyways? [y,n]: " -r - if [[ ! $REPLY =~ ^[Yy]$ ]]; then + if [[ ! "$REPLY" =~ ^[Yy]$ ]]; then echo "Okay, exiting." exit 1 fi @@ -171,13 +179,16 @@ cd "$SPARK_HOME" export MAVEN_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" -BUILD_COMMAND="$MVN clean package -DskipTests $@" +# Store the command as an array because $MVN variable might have spaces in it. +# Normal quoting tricks don't work. +# See: http://mywiki.wooledge.org/BashFAQ/050 +BUILD_COMMAND=("$MVN" clean package -DskipTests $@) # Actually build the jar echo -e "\nBuilding with..." -echo -e "\$ $BUILD_COMMAND\n" +echo -e "\$ ${BUILD_COMMAND[@]}\n" -${BUILD_COMMAND} +"${BUILD_COMMAND[@]}" # Make directories rm -rf "$DISTDIR" @@ -222,16 +233,22 @@ cp -r "$SPARK_HOME/ec2" "$DISTDIR" # Download and copy in tachyon, if requested if [ "$SPARK_TACHYON" == "true" ]; then - TACHYON_VERSION="0.5.0" - TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/tachyon-${TACHYON_VERSION}-bin.tar.gz" - TMPD=`mktemp -d 2>/dev/null || mktemp -d -t 'disttmp'` - pushd $TMPD > /dev/null + pushd "$TMPD" > /dev/null echo "Fetching tachyon tgz" - wget "$TACHYON_URL" - tar xf "tachyon-${TACHYON_VERSION}-bin.tar.gz" + TACHYON_DL="${TACHYON_TGZ}.part" + if [ $(command -v curl) ]; then + curl --silent -k -L "${TACHYON_URL}" > "${TACHYON_DL}" && mv "${TACHYON_DL}" "${TACHYON_TGZ}" + elif [ $(command -v wget) ]; then + wget --quiet "${TACHYON_URL}" -O "${TACHYON_DL}" && mv "${TACHYON_DL}" "${TACHYON_TGZ}" + else + printf "You do not have curl or wget installed. please install Tachyon manually.\n" + exit -1 + fi + + tar xzf "${TACHYON_TGZ}" cp "tachyon-${TACHYON_VERSION}/core/target/tachyon-${TACHYON_VERSION}-jar-with-dependencies.jar" "$DISTDIR/lib" mkdir -p "$DISTDIR/tachyon/src/main/java/tachyon/web" cp -r "tachyon-${TACHYON_VERSION}"/{bin,conf,libexec} "$DISTDIR/tachyon" @@ -245,7 +262,7 @@ if [ "$SPARK_TACHYON" == "true" ]; then fi popd > /dev/null - rm -rf $TMPD + rm -rf "$TMPD" fi if [ "$MAKE_TGZ" == "true" ]; then diff --git a/mllib/pom.xml b/mllib/pom.xml index a0bda89ccaa7..5dfab36c7690 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT ../pom.xml @@ -50,15 +50,21 @@ spark-sql_${scala.binary.version} ${project.version}
    + + org.apache.spark + spark-graphx_${scala.binary.version} + ${project.version} + org.jblas jblas ${jblas.version} + test org.scalanlp breeze_${scala.binary.version} - 0.10 + 0.11.2 @@ -111,7 +117,7 @@ com.github.fommil.netlib all - 1.1.2 + ${netlib.java.version} pom @@ -125,6 +131,9 @@ ../python pyspark/mllib/*.py + pyspark/mllib/stat/*.py + pyspark/ml/*.py + pyspark/ml/param/*.py diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index 77d230eb4a12..d6b3503ebdd9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -21,7 +21,7 @@ import scala.annotation.varargs import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.{ParamMap, ParamPair, Params} -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame /** * :: AlphaComponent :: @@ -34,12 +34,13 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * Fits a single model to the input data with optional parameters. * * @param dataset input dataset - * @param paramPairs optional list of param pairs (overwrite embedded params) + * @param paramPairs Optional list of param pairs. + * These values override any specified in this Estimator's embedded ParamMap. * @return fitted model */ @varargs - def fit(dataset: SchemaRDD, paramPairs: ParamPair[_]*): M = { - val map = new ParamMap().put(paramPairs: _*) + def fit(dataset: DataFrame, paramPairs: ParamPair[_]*): M = { + val map = ParamMap(paramPairs: _*) fit(dataset, map) } @@ -47,10 +48,11 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * Fits a single model to the input data with provided parameter map. * * @param dataset input dataset - * @param paramMap parameter map + * @param paramMap Parameter map. + * These values override any specified in this Estimator's embedded ParamMap. * @return fitted model */ - def fit(dataset: SchemaRDD, paramMap: ParamMap): M + def fit(dataset: DataFrame, paramMap: ParamMap): M /** * Fits multiple models to the input data with multiple sets of parameters. @@ -58,10 +60,11 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * Subclasses could overwrite this to optimize multi-model training. * * @param dataset input dataset - * @param paramMaps an array of parameter maps + * @param paramMaps An array of parameter maps. + * These values override any specified in this Estimator's embedded ParamMap. * @return fitted models, matching the input parameter maps */ - def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = { + def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = { paramMaps.map(fit(dataset, _)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala index db563dd550e5..d2ca2e6871e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.ParamMap -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame /** * :: AlphaComponent :: @@ -35,5 +35,5 @@ abstract class Evaluator extends Identifiable { * @param paramMap parameter map that specifies the input columns and output metrics * @return metric */ - def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double + def evaluate(dataset: DataFrame, paramMap: ParamMap): Double } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala index cd84b05bfb49..a1d49095c24a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala @@ -25,9 +25,9 @@ import java.util.UUID private[ml] trait Identifiable extends Serializable { /** - * A unique id for the object. The default implementation concatenates the class name, "-", and 8 + * A unique id for the object. The default implementation concatenates the class name, "_", and 8 * random hex chars. */ private[ml] val uid: String = - this.getClass.getSimpleName + "-" + UUID.randomUUID().toString.take(8) + this.getClass.getSimpleName + "_" + UUID.randomUUID().toString.take(8) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index ad6fed178fae..8eddf79cdfe2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -20,9 +20,9 @@ package org.apache.spark.ml import scala.collection.mutable.ListBuffer import org.apache.spark.Logging -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType /** @@ -33,12 +33,23 @@ import org.apache.spark.sql.types.StructType abstract class PipelineStage extends Serializable with Logging { /** + * :: DeveloperApi :: + * * Derives the output schema from the input schema and parameters. + * The schema describes the columns and types of the data. + * + * @param schema Input schema to this stage + * @param paramMap Parameters passed to this stage + * @return Output schema from this stage */ - private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType + @DeveloperApi + def transformSchema(schema: StructType, paramMap: ParamMap): StructType /** * Derives the output schema from the input schema and parameters, optionally with logging. + * + * This should be optimistic. If it is unclear whether the schema will be valid, then it should + * be assumed valid until proven otherwise. */ protected def transformSchema( schema: StructType, @@ -58,11 +69,11 @@ abstract class PipelineStage extends Serializable with Logging { /** * :: AlphaComponent :: * A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each - * of which is either an [[Estimator]] or a [[Transformer]]. When [[Pipeline.fit]] is called, the - * stages are executed in order. If a stage is an [[Estimator]], its [[Estimator.fit]] method will + * of which is either an [[Estimator]] or a [[Transformer]]. When [[Pipeline#fit]] is called, the + * stages are executed in order. If a stage is an [[Estimator]], its [[Estimator#fit]] method will * be called on the input dataset to fit a model. Then the model, which is a transformer, will be * used to transform the dataset as the input to the next stage. If a stage is a [[Transformer]], - * its [[Transformer.transform]] method will be called to produce the dataset for the next stage. + * its [[Transformer#transform]] method will be called to produce the dataset for the next stage. * The fitted model from a [[Pipeline]] is an [[PipelineModel]], which consists of fitted models and * transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as * an identity transformer. @@ -73,13 +84,13 @@ class Pipeline extends Estimator[PipelineModel] { /** param for pipeline stages */ val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline") def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this } - def getStages: Array[PipelineStage] = get(stages) + def getStages: Array[PipelineStage] = getOrDefault(stages) /** * Fits the pipeline to the input dataset with additional parameters. If a stage is an - * [[Estimator]], its [[Estimator.fit]] method will be called on the input dataset to fit a model. + * [[Estimator]], its [[Estimator#fit]] method will be called on the input dataset to fit a model. * Then the model, which is a transformer, will be used to transform the dataset as the input to - * the next stage. If a stage is a [[Transformer]], its [[Transformer.transform]] method will be + * the next stage. If a stage is a [[Transformer]], its [[Transformer#transform]] method will be * called to produce the dataset for the next stage. The fitted model from a [[Pipeline]] is an * [[PipelineModel]], which consists of fitted models and transformers, corresponding to the * pipeline stages. If there are no stages, the output model acts as an identity transformer. @@ -88,9 +99,9 @@ class Pipeline extends Estimator[PipelineModel] { * @param paramMap parameter map * @return fitted pipeline */ - override def fit(dataset: SchemaRDD, paramMap: ParamMap): PipelineModel = { + override def fit(dataset: DataFrame, paramMap: ParamMap): PipelineModel = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val theStages = map(stages) // Search for the last estimator. var indexOfLastEstimator = -1 @@ -114,7 +125,9 @@ class Pipeline extends Estimator[PipelineModel] { throw new IllegalArgumentException( s"Do not support stage $stage of type ${stage.getClass}") } - curDataset = transformer.transform(curDataset, paramMap) + if (index < indexOfLastEstimator) { + curDataset = transformer.transform(curDataset, paramMap) + } transformers += transformer } else { transformers += stage.asInstanceOf[Transformer] @@ -124,8 +137,8 @@ class Pipeline extends Estimator[PipelineModel] { new PipelineModel(this, map, transformers.toArray) } - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = extractParamMap(paramMap) val theStages = map(stages) require(theStages.toSet.size == theStages.size, "Cannot have duplicate components in a pipeline.") @@ -162,16 +175,16 @@ class PipelineModel private[ml] ( } } - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap - val map = (fittingParamMap ++ this.paramMap) ++ paramMap + val map = fittingParamMap ++ extractParamMap(paramMap) transformSchema(dataset.schema, map, logging = true) stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map)) } - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap - val map = (fittingParamMap ++ this.paramMap) ++ paramMap + val map = fittingParamMap ++ extractParamMap(paramMap) stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index af56f9c43535..0acda71ec604 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -22,9 +22,9 @@ import scala.annotation.varargs import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ -import org.apache.spark.sql.SchemaRDD -import org.apache.spark.sql.catalyst.analysis.Star -import org.apache.spark.sql.catalyst.expressions.ScalaUdf +import org.apache.spark.ml.param.shared._ +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ /** @@ -41,7 +41,7 @@ abstract class Transformer extends PipelineStage with Params { * @return transformed dataset */ @varargs - def transform(dataset: SchemaRDD, paramPairs: ParamPair[_]*): SchemaRDD = { + def transform(dataset: DataFrame, paramPairs: ParamPair[_]*): DataFrame = { val map = new ParamMap() paramPairs.foreach(map.put(_)) transform(dataset, map) @@ -53,7 +53,7 @@ abstract class Transformer extends PipelineStage with Params { * @param paramMap additional parameters, overwrite embedded params * @return transformed dataset */ - def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD + def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame } /** @@ -63,7 +63,10 @@ abstract class Transformer extends PipelineStage with Params { private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] extends Transformer with HasInputCol with HasOutputCol with Logging { + /** @group setParam */ def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T] + + /** @group setParam */ def setOutputCol(value: String): T = set(outputCol, value).asInstanceOf[T] /** @@ -84,22 +87,21 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O protected def validateInputType(inputType: DataType): Unit = {} override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val inputType = schema(map(inputCol)).dataType validateInputType(inputType) if (schema.fieldNames.contains(map(outputCol))) { throw new IllegalArgumentException(s"Output column ${map(outputCol)} already exists.") } val outputFields = schema.fields :+ - StructField(map(outputCol), outputDataType, !outputDataType.isPrimitive) + StructField(map(outputCol), outputDataType, nullable = false) StructType(outputFields) } - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ - val map = this.paramMap ++ paramMap - val udf = ScalaUdf(this.createTransformFunc(map), outputDataType, Seq(map(inputCol).attr)) - dataset.select(Star(None), udf as map(outputCol)) + val map = extractParamMap(paramMap) + dataset.withColumn(map(outputCol), + callUDF(this.createTransformFunc(map), outputDataType, dataset(map(inputCol)))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala new file mode 100644 index 000000000000..d7dee8fed2a5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.mllib.linalg.VectorUDT +import org.apache.spark.sql.types.{Metadata, MetadataBuilder, StructField} + +/** + * Attributes that describe a vector ML column. + * + * @param name name of the attribute group (the ML column name) + * @param numAttributes optional number of attributes. At most one of `numAttributes` and `attrs` + * can be defined. + * @param attrs optional array of attributes. Attribute will be copied with their corresponding + * indices in the array. + */ +class AttributeGroup private ( + val name: String, + val numAttributes: Option[Int], + attrs: Option[Array[Attribute]]) extends Serializable { + + require(name.nonEmpty, "Cannot have an empty string for name.") + require(!(numAttributes.isDefined && attrs.isDefined), + "Cannot have both numAttributes and attrs defined.") + + /** + * Creates an attribute group without attribute info. + * @param name name of the attribute group + */ + def this(name: String) = this(name, None, None) + + /** + * Creates an attribute group knowing only the number of attributes. + * @param name name of the attribute group + * @param numAttributes number of attributes + */ + def this(name: String, numAttributes: Int) = this(name, Some(numAttributes), None) + + /** + * Creates an attribute group with attributes. + * @param name name of the attribute group + * @param attrs array of attributes. Attributes will be copied with their corresponding indices in + * the array. + */ + def this(name: String, attrs: Array[Attribute]) = this(name, None, Some(attrs)) + + /** + * Optional array of attributes. At most one of `numAttributes` and `attributes` can be defined. + */ + val attributes: Option[Array[Attribute]] = attrs.map(_.view.zipWithIndex.map { case (attr, i) => + attr.withIndex(i) + }.toArray) + + private lazy val nameToIndex: Map[String, Int] = { + attributes.map(_.view.flatMap { attr => + attr.name.map(_ -> attr.index.get) + }.toMap).getOrElse(Map.empty) + } + + /** Size of the attribute group. Returns -1 if the size is unknown. */ + def size: Int = { + if (numAttributes.isDefined) { + numAttributes.get + } else if (attributes.isDefined) { + attributes.get.length + } else { + -1 + } + } + + /** Test whether this attribute group contains a specific attribute. */ + def hasAttr(attrName: String): Boolean = nameToIndex.contains(attrName) + + /** Index of an attribute specified by name. */ + def indexOf(attrName: String): Int = nameToIndex(attrName) + + /** Gets an attribute by its name. */ + def apply(attrName: String): Attribute = { + attributes.get(indexOf(attrName)) + } + + /** Gets an attribute by its name. */ + def getAttr(attrName: String): Attribute = this(attrName) + + /** Gets an attribute by its index. */ + def apply(attrIndex: Int): Attribute = attributes.get(attrIndex) + + /** Gets an attribute by its index. */ + def getAttr(attrIndex: Int): Attribute = this(attrIndex) + + /** Converts to metadata without name. */ + private[attribute] def toMetadataImpl: Metadata = { + import AttributeKeys._ + val bldr = new MetadataBuilder() + if (attributes.isDefined) { + val numericMetadata = ArrayBuffer.empty[Metadata] + val nominalMetadata = ArrayBuffer.empty[Metadata] + val binaryMetadata = ArrayBuffer.empty[Metadata] + attributes.get.foreach { + case numeric: NumericAttribute => + // Skip default numeric attributes. + if (numeric.withoutIndex != NumericAttribute.defaultAttr) { + numericMetadata += numeric.toMetadataImpl(withType = false) + } + case nominal: NominalAttribute => + nominalMetadata += nominal.toMetadataImpl(withType = false) + case binary: BinaryAttribute => + binaryMetadata += binary.toMetadataImpl(withType = false) + } + val attrBldr = new MetadataBuilder + if (numericMetadata.nonEmpty) { + attrBldr.putMetadataArray(AttributeType.Numeric.name, numericMetadata.toArray) + } + if (nominalMetadata.nonEmpty) { + attrBldr.putMetadataArray(AttributeType.Nominal.name, nominalMetadata.toArray) + } + if (binaryMetadata.nonEmpty) { + attrBldr.putMetadataArray(AttributeType.Binary.name, binaryMetadata.toArray) + } + bldr.putMetadata(ATTRIBUTES, attrBldr.build()) + bldr.putLong(NUM_ATTRIBUTES, attributes.get.length) + } else if (numAttributes.isDefined) { + bldr.putLong(NUM_ATTRIBUTES, numAttributes.get) + } + bldr.build() + } + + /** Converts to ML metadata with some existing metadata. */ + def toMetadata(existingMetadata: Metadata): Metadata = { + new MetadataBuilder() + .withMetadata(existingMetadata) + .putMetadata(AttributeKeys.ML_ATTR, toMetadataImpl) + .build() + } + + /** Converts to ML metadata */ + def toMetadata(): Metadata = toMetadata(Metadata.empty) + + /** Converts to a StructField with some existing metadata. */ + def toStructField(existingMetadata: Metadata): StructField = { + StructField(name, new VectorUDT, nullable = false, toMetadata(existingMetadata)) + } + + /** Converts to a StructField. */ + def toStructField(): StructField = toStructField(Metadata.empty) + + override def equals(other: Any): Boolean = { + other match { + case o: AttributeGroup => + (name == o.name) && + (numAttributes == o.numAttributes) && + (attributes.map(_.toSeq) == o.attributes.map(_.toSeq)) + case _ => + false + } + } + + override def hashCode: Int = { + var sum = 17 + sum = 37 * sum + name.hashCode + sum = 37 * sum + numAttributes.hashCode + sum = 37 * sum + attributes.map(_.toSeq).hashCode + sum + } +} + +/** Factory methods to create attribute groups. */ +object AttributeGroup { + + import AttributeKeys._ + + /** Creates an attribute group from a [[Metadata]] instance with name. */ + private[attribute] def fromMetadata(metadata: Metadata, name: String): AttributeGroup = { + import org.apache.spark.ml.attribute.AttributeType._ + if (metadata.contains(ATTRIBUTES)) { + val numAttrs = metadata.getLong(NUM_ATTRIBUTES).toInt + val attributes = new Array[Attribute](numAttrs) + val attrMetadata = metadata.getMetadata(ATTRIBUTES) + if (attrMetadata.contains(Numeric.name)) { + attrMetadata.getMetadataArray(Numeric.name) + .map(NumericAttribute.fromMetadata) + .foreach { attr => + attributes(attr.index.get) = attr + } + } + if (attrMetadata.contains(Nominal.name)) { + attrMetadata.getMetadataArray(Nominal.name) + .map(NominalAttribute.fromMetadata) + .foreach { attr => + attributes(attr.index.get) = attr + } + } + if (attrMetadata.contains(Binary.name)) { + attrMetadata.getMetadataArray(Binary.name) + .map(BinaryAttribute.fromMetadata) + .foreach { attr => + attributes(attr.index.get) = attr + } + } + var i = 0 + while (i < numAttrs) { + if (attributes(i) == null) { + attributes(i) = NumericAttribute.defaultAttr + } + i += 1 + } + new AttributeGroup(name, attributes) + } else if (metadata.contains(NUM_ATTRIBUTES)) { + new AttributeGroup(name, metadata.getLong(NUM_ATTRIBUTES).toInt) + } else { + new AttributeGroup(name) + } + } + + /** Creates an attribute group from a [[StructField]] instance. */ + def fromStructField(field: StructField): AttributeGroup = { + require(field.dataType == new VectorUDT) + if (field.metadata.contains(ML_ATTR)) { + fromMetadata(field.metadata.getMetadata(ML_ATTR), field.name) + } else { + new AttributeGroup(field.name) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeKeys.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeKeys.scala new file mode 100644 index 000000000000..f714f7becc7e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeKeys.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute + +/** + * Keys used to store attributes. + */ +private[attribute] object AttributeKeys { + val ML_ATTR: String = "ml_attr" + val TYPE: String = "type" + val NAME: String = "name" + val INDEX: String = "idx" + val MIN: String = "min" + val MAX: String = "max" + val STD: String = "std" + val SPARSITY: String = "sparsity" + val ORDINAL: String = "ord" + val VALUES: String = "vals" + val NUM_VALUES: String = "num_vals" + val ATTRIBUTES: String = "attrs" + val NUM_ATTRIBUTES: String = "num_attrs" +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala new file mode 100644 index 000000000000..65e7e43d5a5b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute + +/** + * An enum-like type for attribute types: [[AttributeType$#Numeric]], [[AttributeType$#Nominal]], + * and [[AttributeType$#Binary]]. + */ +sealed abstract class AttributeType(val name: String) + +object AttributeType { + + /** Numeric type. */ + val Numeric: AttributeType = { + case object Numeric extends AttributeType("numeric") + Numeric + } + + /** Nominal type. */ + val Nominal: AttributeType = { + case object Nominal extends AttributeType("nominal") + Nominal + } + + /** Binary type. */ + val Binary: AttributeType = { + case object Binary extends AttributeType("binary") + Binary + } + + /** + * Gets the [[AttributeType]] object from its name. + * @param name attribute type name: "numeric", "nominal", or "binary" + */ + def fromName(name: String): AttributeType = { + if (name == Numeric.name) { + Numeric + } else if (name == Nominal.name) { + Nominal + } else if (name == Binary.name) { + Binary + } else { + throw new IllegalArgumentException(s"Cannot recognize type $name.") + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala new file mode 100644 index 000000000000..5717d6ec2eae --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala @@ -0,0 +1,537 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute + +import scala.annotation.varargs + +import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, StructField} + +/** + * Abstract class for ML attributes. + */ +sealed abstract class Attribute extends Serializable { + + name.foreach { n => + require(n.nonEmpty, "Cannot have an empty string for name.") + } + index.foreach { i => + require(i >= 0, s"Index cannot be negative but got $i") + } + + /** Attribute type. */ + def attrType: AttributeType + + /** Name of the attribute. None if it is not set. */ + def name: Option[String] + + /** Copy with a new name. */ + def withName(name: String): Attribute + + /** Copy without the name. */ + def withoutName: Attribute + + /** Index of the attribute. None if it is not set. */ + def index: Option[Int] + + /** Copy with a new index. */ + def withIndex(index: Int): Attribute + + /** Copy without the index. */ + def withoutIndex: Attribute + + /** + * Tests whether this attribute is numeric, true for [[NumericAttribute]] and [[BinaryAttribute]]. + */ + def isNumeric: Boolean + + /** + * Tests whether this attribute is nominal, true for [[NominalAttribute]] and [[BinaryAttribute]]. + */ + def isNominal: Boolean + + /** + * Converts this attribute to [[Metadata]]. + * @param withType whether to include the type info + */ + private[attribute] def toMetadataImpl(withType: Boolean): Metadata + + /** + * Converts this attribute to [[Metadata]]. For numeric attributes, the type info is excluded to + * save space, because numeric type is the default attribute type. For nominal and binary + * attributes, the type info is included. + */ + private[attribute] def toMetadataImpl(): Metadata = { + if (attrType == AttributeType.Numeric) { + toMetadataImpl(withType = false) + } else { + toMetadataImpl(withType = true) + } + } + + /** Converts to ML metadata with some existing metadata. */ + def toMetadata(existingMetadata: Metadata): Metadata = { + new MetadataBuilder() + .withMetadata(existingMetadata) + .putMetadata(AttributeKeys.ML_ATTR, toMetadataImpl()) + .build() + } + + /** Converts to ML metadata */ + def toMetadata(): Metadata = toMetadata(Metadata.empty) + + /** + * Converts to a [[StructField]] with some existing metadata. + * @param existingMetadata existing metadata to carry over + */ + def toStructField(existingMetadata: Metadata): StructField = { + val newMetadata = new MetadataBuilder() + .withMetadata(existingMetadata) + .putMetadata(AttributeKeys.ML_ATTR, withoutName.withoutIndex.toMetadataImpl()) + .build() + StructField(name.get, DoubleType, nullable = false, newMetadata) + } + + /** Converts to a [[StructField]]. */ + def toStructField(): StructField = toStructField(Metadata.empty) + + override def toString: String = toMetadataImpl(withType = true).toString +} + +/** Trait for ML attribute factories. */ +private[attribute] trait AttributeFactory { + + /** + * Creates an [[Attribute]] from a [[Metadata]] instance. + */ + private[attribute] def fromMetadata(metadata: Metadata): Attribute + + /** + * Creates an [[Attribute]] from a [[StructField]] instance. + */ + def fromStructField(field: StructField): Attribute = { + require(field.dataType == DoubleType) + fromMetadata(field.metadata.getMetadata(AttributeKeys.ML_ATTR)).withName(field.name) + } +} + +object Attribute extends AttributeFactory { + + private[attribute] override def fromMetadata(metadata: Metadata): Attribute = { + import org.apache.spark.ml.attribute.AttributeKeys._ + val attrType = if (metadata.contains(TYPE)) { + metadata.getString(TYPE) + } else { + AttributeType.Numeric.name + } + getFactory(attrType).fromMetadata(metadata) + } + + /** Gets the attribute factory given the attribute type name. */ + private def getFactory(attrType: String): AttributeFactory = { + if (attrType == AttributeType.Numeric.name) { + NumericAttribute + } else if (attrType == AttributeType.Nominal.name) { + NominalAttribute + } else if (attrType == AttributeType.Binary.name) { + BinaryAttribute + } else { + throw new IllegalArgumentException(s"Cannot recognize type $attrType.") + } + } +} + + +/** + * A numeric attribute with optional summary statistics. + * @param name optional name + * @param index optional index + * @param min optional min value + * @param max optional max value + * @param std optional standard deviation + * @param sparsity optional sparsity (ratio of zeros) + */ +class NumericAttribute private[ml] ( + override val name: Option[String] = None, + override val index: Option[Int] = None, + val min: Option[Double] = None, + val max: Option[Double] = None, + val std: Option[Double] = None, + val sparsity: Option[Double] = None) extends Attribute { + + std.foreach { s => + require(s >= 0.0, s"Standard deviation cannot be negative but got $s.") + } + sparsity.foreach { s => + require(s >= 0.0 && s <= 1.0, s"Sparsity must be in [0, 1] but got $s.") + } + + override def attrType: AttributeType = AttributeType.Numeric + + override def withName(name: String): NumericAttribute = copy(name = Some(name)) + override def withoutName: NumericAttribute = copy(name = None) + + override def withIndex(index: Int): NumericAttribute = copy(index = Some(index)) + override def withoutIndex: NumericAttribute = copy(index = None) + + /** Copy with a new min value. */ + def withMin(min: Double): NumericAttribute = copy(min = Some(min)) + + /** Copy without the min value. */ + def withoutMin: NumericAttribute = copy(min = None) + + + /** Copy with a new max value. */ + def withMax(max: Double): NumericAttribute = copy(max = Some(max)) + + /** Copy without the max value. */ + def withoutMax: NumericAttribute = copy(max = None) + + /** Copy with a new standard deviation. */ + def withStd(std: Double): NumericAttribute = copy(std = Some(std)) + + /** Copy without the standard deviation. */ + def withoutStd: NumericAttribute = copy(std = None) + + /** Copy with a new sparsity. */ + def withSparsity(sparsity: Double): NumericAttribute = copy(sparsity = Some(sparsity)) + + /** Copy without the sparsity. */ + def withoutSparsity: NumericAttribute = copy(sparsity = None) + + /** Copy without summary statistics. */ + def withoutSummary: NumericAttribute = copy(min = None, max = None, std = None, sparsity = None) + + override def isNumeric: Boolean = true + + override def isNominal: Boolean = false + + /** Convert this attribute to metadata. */ + override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = { + import org.apache.spark.ml.attribute.AttributeKeys._ + val bldr = new MetadataBuilder() + if (withType) bldr.putString(TYPE, attrType.name) + name.foreach(bldr.putString(NAME, _)) + index.foreach(bldr.putLong(INDEX, _)) + min.foreach(bldr.putDouble(MIN, _)) + max.foreach(bldr.putDouble(MAX, _)) + std.foreach(bldr.putDouble(STD, _)) + sparsity.foreach(bldr.putDouble(SPARSITY, _)) + bldr.build() + } + + /** Creates a copy of this attribute with optional changes. */ + private def copy( + name: Option[String] = name, + index: Option[Int] = index, + min: Option[Double] = min, + max: Option[Double] = max, + std: Option[Double] = std, + sparsity: Option[Double] = sparsity): NumericAttribute = { + new NumericAttribute(name, index, min, max, std, sparsity) + } + + override def equals(other: Any): Boolean = { + other match { + case o: NumericAttribute => + (name == o.name) && + (index == o.index) && + (min == o.min) && + (max == o.max) && + (std == o.std) && + (sparsity == o.sparsity) + case _ => + false + } + } + + override def hashCode: Int = { + var sum = 17 + sum = 37 * sum + name.hashCode + sum = 37 * sum + index.hashCode + sum = 37 * sum + min.hashCode + sum = 37 * sum + max.hashCode + sum = 37 * sum + std.hashCode + sum = 37 * sum + sparsity.hashCode + sum + } +} + +/** + * Factory methods for numeric attributes. + */ +object NumericAttribute extends AttributeFactory { + + /** The default numeric attribute. */ + val defaultAttr: NumericAttribute = new NumericAttribute + + private[attribute] override def fromMetadata(metadata: Metadata): NumericAttribute = { + import org.apache.spark.ml.attribute.AttributeKeys._ + val name = if (metadata.contains(NAME)) Some(metadata.getString(NAME)) else None + val index = if (metadata.contains(INDEX)) Some(metadata.getLong(INDEX).toInt) else None + val min = if (metadata.contains(MIN)) Some(metadata.getDouble(MIN)) else None + val max = if (metadata.contains(MAX)) Some(metadata.getDouble(MAX)) else None + val std = if (metadata.contains(STD)) Some(metadata.getDouble(STD)) else None + val sparsity = if (metadata.contains(SPARSITY)) Some(metadata.getDouble(SPARSITY)) else None + new NumericAttribute(name, index, min, max, std, sparsity) + } +} + +/** + * A nominal attribute. + * @param name optional name + * @param index optional index + * @param isOrdinal whether this attribute is ordinal (optional) + * @param numValues optional number of values. At most one of `numValues` and `values` can be + * defined. + * @param values optional values. At most one of `numValues` and `values` can be defined. + */ +class NominalAttribute private[ml] ( + override val name: Option[String] = None, + override val index: Option[Int] = None, + val isOrdinal: Option[Boolean] = None, + val numValues: Option[Int] = None, + val values: Option[Array[String]] = None) extends Attribute { + + numValues.foreach { n => + require(n >= 0, s"numValues cannot be negative but got $n.") + } + require(!(numValues.isDefined && values.isDefined), + "Cannot have both numValues and values defined.") + + override def attrType: AttributeType = AttributeType.Nominal + + override def isNumeric: Boolean = false + + override def isNominal: Boolean = true + + private lazy val valueToIndex: Map[String, Int] = { + values.map(_.zipWithIndex.toMap).getOrElse(Map.empty) + } + + /** Index of a specific value. */ + def indexOf(value: String): Int = { + valueToIndex(value) + } + + /** Tests whether this attribute contains a specific value. */ + def hasValue(value: String): Boolean = valueToIndex.contains(value) + + /** Gets a value given its index. */ + def getValue(index: Int): String = values.get(index) + + override def withName(name: String): NominalAttribute = copy(name = Some(name)) + override def withoutName: NominalAttribute = copy(name = None) + + override def withIndex(index: Int): NominalAttribute = copy(index = Some(index)) + override def withoutIndex: NominalAttribute = copy(index = None) + + /** Copy with new values and empty `numValues`. */ + def withValues(values: Array[String]): NominalAttribute = { + copy(numValues = None, values = Some(values)) + } + + /** Copy with new values and empty `numValues`. */ + @varargs + def withValues(first: String, others: String*): NominalAttribute = { + copy(numValues = None, values = Some((first +: others).toArray)) + } + + /** Copy without the values. */ + def withoutValues: NominalAttribute = { + copy(values = None) + } + + /** Copy with a new `numValues` and empty `values`. */ + def withNumValues(numValues: Int): NominalAttribute = { + copy(numValues = Some(numValues), values = None) + } + + /** Copy without the `numValues`. */ + def withoutNumValues: NominalAttribute = copy(numValues = None) + + /** + * Get the number of values, either from `numValues` or from `values`. + * Return None if unknown. + */ + def getNumValues: Option[Int] = { + if (numValues.nonEmpty) { + numValues + } else if (values.nonEmpty) { + Some(values.get.length) + } else { + None + } + } + + /** Creates a copy of this attribute with optional changes. */ + private def copy( + name: Option[String] = name, + index: Option[Int] = index, + isOrdinal: Option[Boolean] = isOrdinal, + numValues: Option[Int] = numValues, + values: Option[Array[String]] = values): NominalAttribute = { + new NominalAttribute(name, index, isOrdinal, numValues, values) + } + + override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = { + import org.apache.spark.ml.attribute.AttributeKeys._ + val bldr = new MetadataBuilder() + if (withType) bldr.putString(TYPE, attrType.name) + name.foreach(bldr.putString(NAME, _)) + index.foreach(bldr.putLong(INDEX, _)) + isOrdinal.foreach(bldr.putBoolean(ORDINAL, _)) + numValues.foreach(bldr.putLong(NUM_VALUES, _)) + values.foreach(v => bldr.putStringArray(VALUES, v)) + bldr.build() + } + + override def equals(other: Any): Boolean = { + other match { + case o: NominalAttribute => + (name == o.name) && + (index == o.index) && + (isOrdinal == o.isOrdinal) && + (numValues == o.numValues) && + (values.map(_.toSeq) == o.values.map(_.toSeq)) + case _ => + false + } + } + + override def hashCode: Int = { + var sum = 17 + sum = 37 * sum + name.hashCode + sum = 37 * sum + index.hashCode + sum = 37 * sum + isOrdinal.hashCode + sum = 37 * sum + numValues.hashCode + sum = 37 * sum + values.map(_.toSeq).hashCode + sum + } +} + +/** Factory methods for nominal attributes. */ +object NominalAttribute extends AttributeFactory { + + /** The default nominal attribute. */ + final val defaultAttr: NominalAttribute = new NominalAttribute + + private[attribute] override def fromMetadata(metadata: Metadata): NominalAttribute = { + import org.apache.spark.ml.attribute.AttributeKeys._ + val name = if (metadata.contains(NAME)) Some(metadata.getString(NAME)) else None + val index = if (metadata.contains(INDEX)) Some(metadata.getLong(INDEX).toInt) else None + val isOrdinal = if (metadata.contains(ORDINAL)) Some(metadata.getBoolean(ORDINAL)) else None + val numValues = + if (metadata.contains(NUM_VALUES)) Some(metadata.getLong(NUM_VALUES).toInt) else None + val values = + if (metadata.contains(VALUES)) Some(metadata.getStringArray(VALUES)) else None + new NominalAttribute(name, index, isOrdinal, numValues, values) + } +} + +/** + * A binary attribute. + * @param name optional name + * @param index optional index + * @param values optionla values. If set, its size must be 2. + */ +class BinaryAttribute private[ml] ( + override val name: Option[String] = None, + override val index: Option[Int] = None, + val values: Option[Array[String]] = None) + extends Attribute { + + values.foreach { v => + require(v.length == 2, s"Number of values must be 2 for a binary attribute but got ${v.toSeq}.") + } + + override def attrType: AttributeType = AttributeType.Binary + + override def isNumeric: Boolean = true + + override def isNominal: Boolean = true + + override def withName(name: String): BinaryAttribute = copy(name = Some(name)) + override def withoutName: BinaryAttribute = copy(name = None) + + override def withIndex(index: Int): BinaryAttribute = copy(index = Some(index)) + override def withoutIndex: BinaryAttribute = copy(index = None) + + /** + * Copy with new values. + * @param negative name for negative + * @param positive name for positive + */ + def withValues(negative: String, positive: String): BinaryAttribute = + copy(values = Some(Array(negative, positive))) + + /** Copy without the values. */ + def withoutValues: BinaryAttribute = copy(values = None) + + /** Creates a copy of this attribute with optional changes. */ + private def copy( + name: Option[String] = name, + index: Option[Int] = index, + values: Option[Array[String]] = values): BinaryAttribute = { + new BinaryAttribute(name, index, values) + } + + override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = { + import org.apache.spark.ml.attribute.AttributeKeys._ + val bldr = new MetadataBuilder + if (withType) bldr.putString(TYPE, attrType.name) + name.foreach(bldr.putString(NAME, _)) + index.foreach(bldr.putLong(INDEX, _)) + values.foreach(v => bldr.putStringArray(VALUES, v)) + bldr.build() + } + + override def equals(other: Any): Boolean = { + other match { + case o: BinaryAttribute => + (name == o.name) && + (index == o.index) && + (values.map(_.toSeq) == o.values.map(_.toSeq)) + case _ => + false + } + } + + override def hashCode: Int = { + var sum = 17 + sum = 37 * sum + name.hashCode + sum = 37 * sum + index.hashCode + sum = 37 * sum + values.map(_.toSeq).hashCode + sum + } +} + +/** Factory methods for binary attributes. */ +object BinaryAttribute extends AttributeFactory { + + /** The default binary attribute. */ + final val defaultAttr: BinaryAttribute = new BinaryAttribute + + private[attribute] override def fromMetadata(metadata: Metadata): BinaryAttribute = { + import org.apache.spark.ml.attribute.AttributeKeys._ + val name = if (metadata.contains(NAME)) Some(metadata.getString(NAME)) else None + val index = if (metadata.contains(INDEX)) Some(metadata.getLong(INDEX).toInt) else None + val values = + if (metadata.contains(VALUES)) Some(metadata.getStringArray(VALUES)) else None + new BinaryAttribute(name, index, values) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/attribute/package-info.java new file mode 100644 index 000000000000..e3474f3c1d3f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/package-info.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// The content here should be in sync with `package.scala`. + +/** + *

    ML attributes

    + * + * The ML pipeline API uses {@link org.apache.spark.sql.DataFrame}s as ML datasets. + * Each dataset consists of typed columns, e.g., string, double, vector, etc. + * However, knowing only the column type may not be sufficient to handle the data properly. + * For instance, a double column with values 0.0, 1.0, 2.0, ... may represent some label indices, + * which cannot be treated as numeric values in ML algorithms, and, for another instance, we may + * want to know the names and types of features stored in a vector column. + * ML attributes are used to provide additional information to describe columns in a dataset. + * + *

    ML columns

    + * + * A column with ML attributes attached is called an ML column. + * The data in ML columns are stored as double values, i.e., an ML column is either a scalar column + * of double values or a vector column. + * Columns of other types must be encoded into ML columns using transformers. + * We use {@link org.apache.spark.ml.attribute.Attribute} to describe a scalar ML column, and + * {@link org.apache.spark.ml.attribute.AttributeGroup} to describe a vector ML column. + * ML attributes are stored in the metadata field of the column schema. + */ +package org.apache.spark.ml.attribute; diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala new file mode 100644 index 000000000000..7ac21d7d563f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.sql.DataFrame +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup} + +/** + * ==ML attributes== + * + * The ML pipeline API uses [[DataFrame]]s as ML datasets. + * Each dataset consists of typed columns, e.g., string, double, vector, etc. + * However, knowing only the column type may not be sufficient to handle the data properly. + * For instance, a double column with values 0.0, 1.0, 2.0, ... may represent some label indices, + * which cannot be treated as numeric values in ML algorithms, and, for another instance, we may + * want to know the names and types of features stored in a vector column. + * ML attributes are used to provide additional information to describe columns in a dataset. + * + * ===ML columns=== + * + * A column with ML attributes attached is called an ML column. + * The data in ML columns are stored as double values, i.e., an ML column is either a scalar column + * of double values or a vector column. + * Columns of other types must be encoded into ML columns using transformers. + * We use [[Attribute]] to describe a scalar ML column, and [[AttributeGroup]] to describe a vector + * ML column. + * ML attributes are stored in the metadata field of the column schema. + */ +package object attribute diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala new file mode 100644 index 000000000000..29339c98f51c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} +import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams} +import org.apache.spark.ml.param.{ParamMap, Params} +import org.apache.spark.ml.param.shared.HasRawPredictionCol +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DataType, DoubleType, StructType} + + +/** + * :: DeveloperApi :: + * Params for classification. + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@DeveloperApi +private[spark] trait ClassifierParams extends PredictorParams + with HasRawPredictionCol { + + override protected def validateAndTransformSchema( + schema: StructType, + paramMap: ParamMap, + fitting: Boolean, + featuresDataType: DataType): StructType = { + val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType) + val map = extractParamMap(paramMap) + SchemaUtils.appendColumn(parentSchema, map(rawPredictionCol), new VectorUDT) + } +} + +/** + * :: AlphaComponent :: + * Single-label binary or multiclass classification. + * Classes are indexed {0, 1, ..., numClasses - 1}. + * + * @tparam FeaturesType Type of input features. E.g., [[Vector]] + * @tparam E Concrete Estimator type + * @tparam M Concrete Model type + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] abstract class Classifier[ + FeaturesType, + E <: Classifier[FeaturesType, E, M], + M <: ClassificationModel[FeaturesType, M]] + extends Predictor[FeaturesType, E, M] + with ClassifierParams { + + /** @group setParam */ + def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E] + + // TODO: defaultEvaluator (follow-up PR) +} + +/** + * :: AlphaComponent :: + * Model produced by a [[Classifier]]. + * Classes are indexed {0, 1, ..., numClasses - 1}. + * + * @tparam FeaturesType Type of input features. E.g., [[Vector]] + * @tparam M Concrete Model type + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] +abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[FeaturesType, M]] + extends PredictionModel[FeaturesType, M] with ClassifierParams { + + /** @group setParam */ + def setRawPredictionCol(value: String): M = set(rawPredictionCol, value).asInstanceOf[M] + + /** Number of classes (values which the label can take). */ + def numClasses: Int + + /** + * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by + * parameters: + * - predicted labels as [[predictionCol]] of type [[Double]] + * - raw predictions (confidences) as [[rawPredictionCol]] of type [[Vector]]. + * + * @param dataset input dataset + * @param paramMap additional parameters, overwrite embedded params + * @return transformed dataset + */ + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + // This default implementation should be overridden as needed. + + // Check schema + transformSchema(dataset.schema, paramMap, logging = true) + val map = extractParamMap(paramMap) + + // Prepare model + val tmpModel = if (paramMap.size != 0) { + val tmpModel = this.copy() + Params.inheritValues(paramMap, parent, tmpModel) + tmpModel + } else { + this + } + + val (numColsOutput, outputData) = + ClassificationModel.transformColumnsImpl[FeaturesType](dataset, tmpModel, map) + if (numColsOutput == 0) { + logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" + + " since no output columns were set.") + } + outputData + } + + /** + * :: DeveloperApi :: + * + * Predict label for the given features. + * This internal method is used to implement [[transform()]] and output [[predictionCol]]. + * + * This default implementation for classification predicts the index of the maximum value + * from [[predictRaw()]]. + */ + @DeveloperApi + override protected def predict(features: FeaturesType): Double = { + predictRaw(features).toArray.zipWithIndex.maxBy(_._1)._2 + } + + /** + * :: DeveloperApi :: + * + * Raw prediction for each possible label. + * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives + * a measure of confidence in each possible label (where larger = more confident). + * This internal method is used to implement [[transform()]] and output [[rawPredictionCol]]. + * + * @return vector where element i is the raw prediction for label i. + * This raw prediction may be any real number, where a larger value indicates greater + * confidence for that label. + */ + @DeveloperApi + protected def predictRaw(features: FeaturesType): Vector + +} + +private[ml] object ClassificationModel { + + /** + * Added prediction column(s). This is separated from [[ClassificationModel.transform()]] + * since it is used by [[org.apache.spark.ml.classification.ProbabilisticClassificationModel]]. + * @param dataset Input dataset + * @param map Parameter map. This will NOT be merged with the embedded paramMap; the merge + * should already be done. + * @return (number of columns added, transformed dataset) + */ + def transformColumnsImpl[FeaturesType]( + dataset: DataFrame, + model: ClassificationModel[FeaturesType, _], + map: ParamMap): (Int, DataFrame) = { + + // Output selected columns only. + // This is a bit complicated since it tries to avoid repeated computation. + var tmpData = dataset + var numColsOutput = 0 + if (map(model.rawPredictionCol) != "") { + // output raw prediction + val features2raw: FeaturesType => Vector = model.predictRaw + tmpData = tmpData.withColumn(map(model.rawPredictionCol), + callUDF(features2raw, new VectorUDT, col(map(model.featuresCol)))) + numColsOutput += 1 + if (map(model.predictionCol) != "") { + val raw2pred: Vector => Double = (rawPred) => { + rawPred.toArray.zipWithIndex.maxBy(_._1)._2 + } + tmpData = tmpData.withColumn(map(model.predictionCol), + callUDF(raw2pred, DoubleType, col(map(model.rawPredictionCol)))) + numColsOutput += 1 + } + } else if (map(model.predictionCol) != "") { + // output prediction + val features2pred: FeaturesType => Double = model.predict + tmpData = tmpData.withColumn(map(model.predictionCol), + callUDF(features2pred, DoubleType, col(map(model.featuresCol)))) + numColsOutput += 1 + } + (numColsOutput, tmpData) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala new file mode 100644 index 000000000000..3855e396b553 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.impl.estimator.{Predictor, PredictionModel} +import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.param.{Params, ParamMap} +import org.apache.spark.ml.tree.{DecisionTreeModel, Node} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm + * for classification. + * It supports both binary and multiclass labels, as well as both continuous and categorical + * features. + */ +@AlphaComponent +final class DecisionTreeClassifier + extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] + with DecisionTreeParams + with TreeClassifierParams { + + // Override parameter setters from parent trait for Java API compatibility. + + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setCacheNodeIds(value: Boolean): this.type = + super.setCacheNodeIds(value) + + override def setCheckpointInterval(value: Int): this.type = + super.setCheckpointInterval(value) + + override def setImpurity(value: String): this.type = super.setImpurity(value) + + override protected def train( + dataset: DataFrame, + paramMap: ParamMap): DecisionTreeClassificationModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol))) + val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match { + case Some(n: Int) => n + case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" + + s" with invalid label column, without the number of classes specified.") + // TODO: Automatically index labels. + } + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) + val strategy = getOldStrategy(categoricalFeatures, numClasses) + val oldModel = OldDecisionTree.train(oldDataset, strategy) + DecisionTreeClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures) + } + + /** (private[ml]) Create a Strategy instance to use with the old API. */ + override private[ml] def getOldStrategy( + categoricalFeatures: Map[Int, Int], + numClasses: Int): OldStrategy = { + val strategy = super.getOldStrategy(categoricalFeatures, numClasses) + strategy.algo = OldAlgo.Classification + strategy.setImpurity(getOldImpurity) + strategy + } +} + +object DecisionTreeClassifier { + /** Accessor for supported impurities */ + final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities +} + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for classification. + * It supports both binary and multiclass labels, as well as both continuous and categorical + * features. + */ +@AlphaComponent +final class DecisionTreeClassificationModel private[ml] ( + override val parent: DecisionTreeClassifier, + override val fittingParamMap: ParamMap, + override val rootNode: Node) + extends PredictionModel[Vector, DecisionTreeClassificationModel] + with DecisionTreeModel with Serializable { + + require(rootNode != null, + "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.") + + override protected def predict(features: Vector): Double = { + rootNode.predict(features) + } + + override protected def copy(): DecisionTreeClassificationModel = { + val m = new DecisionTreeClassificationModel(parent, fittingParamMap, rootNode) + Params.inheritValues(this.extractParamMap(), this, m) + m + } + + override def toString: String = { + s"DecisionTreeClassificationModel of depth $depth with $numNodes nodes" + } + + /** (private[ml]) Convert to a model in the old API */ + private[ml] def toOld: OldDecisionTreeModel = { + new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification) + } +} + +private[ml] object DecisionTreeClassificationModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldDecisionTreeModel, + parent: DecisionTreeClassifier, + fittingParamMap: ParamMap, + categoricalFeatures: Map[Int, Int]): DecisionTreeClassificationModel = { + require(oldModel.algo == OldAlgo.Classification, + s"Cannot convert non-classification DecisionTreeModel (old API) to" + + s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") + val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) + new DecisionTreeClassificationModel(parent, fittingParamMap, rootNode) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 8c570812f831..cc8b0721cf2b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -18,132 +18,184 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml._ import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS -import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.Star -import org.apache.spark.sql.catalyst.dsl._ -import org.apache.spark.sql.types.{DoubleType, StructField, StructType} +import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector, Vectors} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ import org.apache.spark.storage.StorageLevel + /** - * :: AlphaComponent :: * Params for logistic regression. */ -@AlphaComponent -private[classification] trait LogisticRegressionParams extends Params - with HasRegParam with HasMaxIter with HasLabelCol with HasThreshold with HasFeaturesCol - with HasScoreCol with HasPredictionCol { +private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams + with HasRegParam with HasMaxIter with HasFitIntercept with HasThreshold { - /** - * Validates and transforms the input schema with the provided param map. - * @param schema input schema - * @param paramMap additional parameters - * @param fitting whether this is in fitting - * @return output schema - */ - protected def validateAndTransformSchema( - schema: StructType, - paramMap: ParamMap, - fitting: Boolean): StructType = { - val map = this.paramMap ++ paramMap - val featuresType = schema(map(featuresCol)).dataType - // TODO: Support casting Array[Double] and Array[Float] to Vector. - require(featuresType.isInstanceOf[VectorUDT], - s"Features column ${map(featuresCol)} must be a vector column but got $featuresType.") - if (fitting) { - val labelType = schema(map(labelCol)).dataType - require(labelType == DoubleType, - s"Cannot convert label column ${map(labelCol)} of type $labelType to a double column.") - } - val fieldNames = schema.fieldNames - require(!fieldNames.contains(map(scoreCol)), s"Score column ${map(scoreCol)} already exists.") - require(!fieldNames.contains(map(predictionCol)), - s"Prediction column ${map(predictionCol)} already exists.") - val outputFields = schema.fields ++ Seq( - StructField(map(scoreCol), DoubleType, false), - StructField(map(predictionCol), DoubleType, false)) - StructType(outputFields) - } + setDefault(regParam -> 0.1, maxIter -> 100, threshold -> 0.5) } /** + * :: AlphaComponent :: + * * Logistic regression. + * Currently, this class only supports binary classification. */ -class LogisticRegression extends Estimator[LogisticRegressionModel] with LogisticRegressionParams { - - setRegParam(0.1) - setMaxIter(100) - setThreshold(0.5) +@AlphaComponent +class LogisticRegression + extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] + with LogisticRegressionParams { + /** @group setParam */ def setRegParam(value: Double): this.type = set(regParam, value) + + /** @group setParam */ def setMaxIter(value: Int): this.type = set(maxIter, value) - def setLabelCol(value: String): this.type = set(labelCol, value) + + /** @group setParam */ + def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) + + /** @group setParam */ def setThreshold(value: Double): this.type = set(threshold, value) - def setFeaturesCol(value: String): this.type = set(featuresCol, value) - def setScoreCol(value: String): this.type = set(scoreCol, value) - def setPredictionCol(value: String): this.type = set(predictionCol, value) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = { - transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ - val map = this.paramMap ++ paramMap - val instances = dataset.select(map(labelCol).attr, map(featuresCol).attr) - .map { case Row(label: Double, features: Vector) => - LabeledPoint(label, features) - }.persist(StorageLevel.MEMORY_AND_DISK) - val lr = new LogisticRegressionWithLBFGS + override protected def train(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = { + // Extract columns from data. If dataset is persisted, do not persist oldDataset. + val oldDataset = extractLabeledPoints(dataset, paramMap) + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) { + oldDataset.persist(StorageLevel.MEMORY_AND_DISK) + } + + // Train model + val lr = new LogisticRegressionWithLBFGS() + .setIntercept(paramMap(fitIntercept)) lr.optimizer - .setRegParam(map(regParam)) - .setNumIterations(map(maxIter)) - val lrm = new LogisticRegressionModel(this, map, lr.run(instances).weights) - instances.unpersist() - // copy model params - Params.inheritValues(map, this, lrm) - lrm - } + .setRegParam(paramMap(regParam)) + .setNumIterations(paramMap(maxIter)) + val oldModel = lr.run(oldDataset) + val lrm = new LogisticRegressionModel(this, paramMap, oldModel.weights, oldModel.intercept) - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - validateAndTransformSchema(schema, paramMap, fitting = true) + if (handlePersistence) { + oldDataset.unpersist() + } + lrm } } + /** * :: AlphaComponent :: + * * Model produced by [[LogisticRegression]]. */ @AlphaComponent class LogisticRegressionModel private[ml] ( override val parent: LogisticRegression, override val fittingParamMap: ParamMap, - weights: Vector) - extends Model[LogisticRegressionModel] with LogisticRegressionParams { + val weights: Vector, + val intercept: Double) + extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] + with LogisticRegressionParams { + /** @group setParam */ def setThreshold(value: Double): this.type = set(threshold, value) - def setFeaturesCol(value: String): this.type = set(featuresCol, value) - def setScoreCol(value: String): this.type = set(scoreCol, value) - def setPredictionCol(value: String): this.type = set(predictionCol, value) - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - validateAndTransformSchema(schema, paramMap, fitting = false) + private val margin: Vector => Double = (features) => { + BLAS.dot(features, weights) + intercept } - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + private val score: Vector => Double = (features) => { + val m = margin(features) + 1.0 / (1.0 + math.exp(-m)) + } + + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + // This is overridden (a) to be more efficient (avoiding re-computing values when creating + // multiple output columns) and (b) to handle threshold, which the abstractions do not use. + // TODO: We should abstract away the steps defined by UDFs below so that the abstractions + // can call whichever UDFs are needed to create the output columns. + + // Check schema transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ - val map = this.paramMap ++ paramMap - val score: Vector => Double = (v) => { - val margin = BLAS.dot(v, weights) - 1.0 / (1.0 + math.exp(-margin)) + + val map = extractParamMap(paramMap) + + // Output selected columns only. + // This is a bit complicated since it tries to avoid repeated computation. + // rawPrediction (-margin, margin) + // probability (1.0-score, score) + // prediction (max margin) + var tmpData = dataset + var numColsOutput = 0 + if (map(rawPredictionCol) != "") { + val features2raw: Vector => Vector = (features) => predictRaw(features) + tmpData = tmpData.withColumn(map(rawPredictionCol), + callUDF(features2raw, new VectorUDT, col(map(featuresCol)))) + numColsOutput += 1 + } + if (map(probabilityCol) != "") { + if (map(rawPredictionCol) != "") { + val raw2prob = udf { (rawPreds: Vector) => + val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1))) + Vectors.dense(1.0 - prob1, prob1): Vector + } + tmpData = tmpData.withColumn(map(probabilityCol), raw2prob(col(map(rawPredictionCol)))) + } else { + val features2prob = udf { (features: Vector) => predictProbabilities(features) : Vector } + tmpData = tmpData.withColumn(map(probabilityCol), features2prob(col(map(featuresCol)))) + } + numColsOutput += 1 } - val t = map(threshold) - val predict: Double => Double = (score) => { - if (score > t) 1.0 else 0.0 + if (map(predictionCol) != "") { + val t = map(threshold) + if (map(probabilityCol) != "") { + val predict = udf { probs: Vector => + if (probs(1) > t) 1.0 else 0.0 + } + tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(probabilityCol)))) + } else if (map(rawPredictionCol) != "") { + val predict = udf { rawPreds: Vector => + val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1))) + if (prob1 > t) 1.0 else 0.0 + } + tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(rawPredictionCol)))) + } else { + val predict = udf { features: Vector => this.predict(features) } + tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(featuresCol)))) + } + numColsOutput += 1 } - dataset.select(Star(None), score.call(map(featuresCol).attr) as map(scoreCol)) - .select(Star(None), predict.call(map(scoreCol).attr) as map(predictionCol)) + if (numColsOutput == 0) { + this.logWarning(s"$uid: LogisticRegressionModel.transform() was called as NOOP" + + " since no output columns were set.") + } + tmpData + } + + override val numClasses: Int = 2 + + /** + * Predict label for the given feature vector. + * The behavior of this can be adjusted using [[threshold]]. + */ + override protected def predict(features: Vector): Double = { + if (score(features) > getThreshold) 1 else 0 + } + + override protected def predictProbabilities(features: Vector): Vector = { + val s = score(features) + Vectors.dense(1.0 - s, s) + } + + override protected def predictRaw(features: Vector): Vector = { + val m = margin(features) + Vectors.dense(0.0, m) + } + + override protected def copy(): LogisticRegressionModel = { + val m = new LogisticRegressionModel(parent, fittingParamMap, weights, intercept) + Params.inheritValues(this.extractParamMap(), this, m) + m } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala new file mode 100644 index 000000000000..10404548ccfd --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} +import org.apache.spark.ml.param.{ParamMap, Params} +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DataType, StructType} + +/** + * Params for probabilistic classification. + */ +private[classification] trait ProbabilisticClassifierParams + extends ClassifierParams with HasProbabilityCol { + + override protected def validateAndTransformSchema( + schema: StructType, + paramMap: ParamMap, + fitting: Boolean, + featuresDataType: DataType): StructType = { + val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType) + val map = extractParamMap(paramMap) + SchemaUtils.appendColumn(parentSchema, map(probabilityCol), new VectorUDT) + } +} + + +/** + * :: AlphaComponent :: + * + * Single-label binary or multiclass classifier which can output class conditional probabilities. + * + * @tparam FeaturesType Type of input features. E.g., [[Vector]] + * @tparam E Concrete Estimator type + * @tparam M Concrete Model type + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] abstract class ProbabilisticClassifier[ + FeaturesType, + E <: ProbabilisticClassifier[FeaturesType, E, M], + M <: ProbabilisticClassificationModel[FeaturesType, M]] + extends Classifier[FeaturesType, E, M] with ProbabilisticClassifierParams { + + /** @group setParam */ + def setProbabilityCol(value: String): E = set(probabilityCol, value).asInstanceOf[E] +} + + +/** + * :: AlphaComponent :: + * + * Model produced by a [[ProbabilisticClassifier]]. + * Classes are indexed {0, 1, ..., numClasses - 1}. + * + * @tparam FeaturesType Type of input features. E.g., [[Vector]] + * @tparam M Concrete Model type + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] abstract class ProbabilisticClassificationModel[ + FeaturesType, + M <: ProbabilisticClassificationModel[FeaturesType, M]] + extends ClassificationModel[FeaturesType, M] with ProbabilisticClassifierParams { + + /** @group setParam */ + def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M] + + /** + * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by + * parameters: + * - predicted labels as [[predictionCol]] of type [[Double]] + * - raw predictions (confidences) as [[rawPredictionCol]] of type [[Vector]] + * - probability of each class as [[probabilityCol]] of type [[Vector]]. + * + * @param dataset input dataset + * @param paramMap additional parameters, overwrite embedded params + * @return transformed dataset + */ + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + // This default implementation should be overridden as needed. + + // Check schema + transformSchema(dataset.schema, paramMap, logging = true) + val map = extractParamMap(paramMap) + + // Prepare model + val tmpModel = if (paramMap.size != 0) { + val tmpModel = this.copy() + Params.inheritValues(paramMap, parent, tmpModel) + tmpModel + } else { + this + } + + val (numColsOutput, outputData) = + ClassificationModel.transformColumnsImpl[FeaturesType](dataset, tmpModel, map) + + // Output selected columns only. + if (map(probabilityCol) != "") { + // output probabilities + val features2probs: FeaturesType => Vector = (features) => { + tmpModel.predictProbabilities(features) + } + outputData.withColumn(map(probabilityCol), + callUDF(features2probs, new VectorUDT, col(map(featuresCol)))) + } else { + if (numColsOutput == 0) { + this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" + + " since no output columns were set.") + } + outputData + } + } + + /** + * :: DeveloperApi :: + * + * Predict the probability of each class given the features. + * These predictions are also called class conditional probabilities. + * + * WARNING: Not all models output well-calibrated probability estimates! These probabilities + * should be treated as confidences, not precise probabilities. + * + * This internal method is used to implement [[transform()]] and output [[probabilityCol]]. + */ + @DeveloperApi + protected def predictProbabilities(features: FeaturesType): Vector +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 12473cb2b571..c865eb9fe092 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -18,44 +18,56 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml._ +import org.apache.spark.ml.Evaluator import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.sql.{Row, SchemaRDD} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.DoubleType /** * :: AlphaComponent :: + * * Evaluator for binary classification, which expects two input columns: score and label. */ @AlphaComponent class BinaryClassificationEvaluator extends Evaluator with Params - with HasScoreCol with HasLabelCol { + with HasRawPredictionCol with HasLabelCol { - /** param for metric name in evaluation */ + /** + * param for metric name in evaluation + * @group param + */ val metricName: Param[String] = new Param(this, "metricName", - "metric name in evaluation (areaUnderROC|areaUnderPR)", Some("areaUnderROC")) - def getMetricName: String = get(metricName) + "metric name in evaluation (areaUnderROC|areaUnderPR)") + + /** @group getParam */ + def getMetricName: String = getOrDefault(metricName) + + /** @group setParam */ def setMetricName(value: String): this.type = set(metricName, value) - def setScoreCol(value: String): this.type = set(scoreCol, value) + /** @group setParam */ + def setScoreCol(value: String): this.type = set(rawPredictionCol, value) + + /** @group setParam */ def setLabelCol(value: String): this.type = set(labelCol, value) - override def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double = { - val map = this.paramMap ++ paramMap + setDefault(metricName -> "areaUnderROC") + + override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { + val map = extractParamMap(paramMap) val schema = dataset.schema - val scoreType = schema(map(scoreCol)).dataType - require(scoreType == DoubleType, - s"Score column ${map(scoreCol)} must be double type but found $scoreType") - val labelType = schema(map(labelCol)).dataType - require(labelType == DoubleType, - s"Label column ${map(labelCol)} must be double type but found $labelType") + SchemaUtils.checkColumnType(schema, map(rawPredictionCol), new VectorUDT) + SchemaUtils.checkColumnType(schema, map(labelCol), DoubleType) - import dataset.sqlContext._ - val scoreAndLabels = dataset.select(map(scoreCol).attr, map(labelCol).attr) - .map { case Row(score: Double, label: Double) => - (score, label) + // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2. + val scoreAndLabels = dataset.select(map(rawPredictionCol), map(labelCol)) + .map { case Row(rawPrediction: Vector, label: Double) => + (rawPrediction(1), label) } val metrics = new BinaryClassificationMetrics(scoreAndLabels) val metric = map(metricName) match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 0956062643f2..b20f2fc49a8f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -31,10 +31,19 @@ import org.apache.spark.sql.types.DataType @AlphaComponent class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { - /** number of features */ - val numFeatures = new IntParam(this, "numFeatures", "number of features", Some(1 << 18)) - def setNumFeatures(value: Int) = set(numFeatures, value) - def getNumFeatures: Int = get(numFeatures) + /** + * number of features + * @group param + */ + val numFeatures = new IntParam(this, "numFeatures", "number of features") + + /** @group getParam */ + def getNumFeatures: Int = getOrDefault(numFeatures) + + /** @group setParam */ + def setNumFeatures(value: Int): this.type = set(numFeatures, value) + + setDefault(numFeatures -> (1 << 18)) override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = { val hashingTF = new feature.HashingTF(paramMap(numFeatures)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala new file mode 100644 index 000000000000..e6a62d998bb9 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.StructType + +/** + * Params for [[IDF]] and [[IDFModel]]. + */ +private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol { + + /** + * The minimum of documents in which a term should appear. + * @group param + */ + final val minDocFreq = new IntParam( + this, "minDocFreq", "minimum of documents in which a term should appear for filtering") + + setDefault(minDocFreq -> 0) + + /** @group getParam */ + def getMinDocFreq: Int = getOrDefault(minDocFreq) + + /** @group setParam */ + def setMinDocFreq(value: Int): this.type = set(minDocFreq, value) + + /** + * Validate and transform the input schema. + */ + protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = extractParamMap(paramMap) + SchemaUtils.checkColumnType(schema, map(inputCol), new VectorUDT) + SchemaUtils.appendColumn(schema, map(outputCol), new VectorUDT) + } +} + +/** + * :: AlphaComponent :: + * Compute the Inverse Document Frequency (IDF) given a collection of documents. + */ +@AlphaComponent +final class IDF extends Estimator[IDFModel] with IDFBase { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def fit(dataset: DataFrame, paramMap: ParamMap): IDFModel = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = extractParamMap(paramMap) + val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } + val idf = new feature.IDF(map(minDocFreq)).fit(input) + val model = new IDFModel(this, map, idf) + Params.inheritValues(map, this, model) + model + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} + +/** + * :: AlphaComponent :: + * Model fitted by [[IDF]]. + */ +@AlphaComponent +class IDFModel private[ml] ( + override val parent: IDF, + override val fittingParamMap: ParamMap, + idfModel: feature.IDFModel) + extends Model[IDFModel] with IDFBase { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = extractParamMap(paramMap) + val idf = udf { vec: Vector => idfModel.transform(vec) } + dataset.withColumn(map(outputCol), idf(col(map(inputCol)))) + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala new file mode 100644 index 000000000000..decaeb0da624 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.{DoubleParam, ParamMap} +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{VectorUDT, Vector} +import org.apache.spark.sql.types.DataType + +/** + * :: AlphaComponent :: + * Normalize a vector to have unit norm using the given p-norm. + */ +@AlphaComponent +class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] { + + /** + * Normalization in L^p^ space, p = 2 by default. + * @group param + */ + val p = new DoubleParam(this, "p", "the p norm value") + + /** @group getParam */ + def getP: Double = getOrDefault(p) + + /** @group setParam */ + def setP(value: Double): this.type = set(p, value) + + setDefault(p -> 2.0) + + override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = { + val normalizer = new feature.Normalizer(paramMap(p)) + normalizer.transform + } + + override protected def outputDataType: DataType = new VectorUDT() +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala new file mode 100644 index 000000000000..d855f04799ae --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.{IntParam, ParamMap} +import org.apache.spark.mllib.linalg._ +import org.apache.spark.sql.types.DataType + +/** + * :: AlphaComponent :: + * Perform feature expansion in a polynomial space. As said in wikipedia of Polynomial Expansion, + * which is available at [[http://en.wikipedia.org/wiki/Polynomial_expansion]], "In mathematics, an + * expansion of a product of sums expresses it as a sum of products by using the fact that + * multiplication distributes over addition". Take a 2-variable feature vector as an example: + * `(x, y)`, if we want to expand it with degree 2, then we get `(x, y, x * x, x * y, y * y)`. + */ +@AlphaComponent +class PolynomialExpansion extends UnaryTransformer[Vector, Vector, PolynomialExpansion] { + + /** + * The polynomial degree to expand, which should be larger than 1. + * @group param + */ + val degree = new IntParam(this, "degree", "the polynomial degree to expand") + setDefault(degree -> 2) + + /** @group getParam */ + def getDegree: Int = getOrDefault(degree) + + /** @group setParam */ + def setDegree(value: Int): this.type = set(degree, value) + + override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = { v => + val d = paramMap(degree) + PolynomialExpansion.expand(v, d) + } + + override protected def outputDataType: DataType = new VectorUDT() +} + +/** + * The expansion is done via recursion. Given n features and degree d, the size after expansion is + * (n + d choose d) (including 1 and first-order values). For example, let f([a, b, c], 3) be the + * function that expands [a, b, c] to their monomials of degree 3. We have the following recursion: + * + * {{{ + * f([a, b, c], 3) = f([a, b], 3) ++ f([a, b], 2) * c ++ f([a, b], 1) * c^2 ++ [c^3] + * }}} + * + * To handle sparsity, if c is zero, we can skip all monomials that contain it. We remember the + * current index and increment it properly for sparse input. + */ +object PolynomialExpansion { + + private def choose(n: Int, k: Int): Int = { + Range(n, n - k, -1).product / Range(k, 1, -1).product + } + + private def getPolySize(numFeatures: Int, degree: Int): Int = choose(numFeatures + degree, degree) + + private def expandDense( + values: Array[Double], + lastIdx: Int, + degree: Int, + multiplier: Double, + polyValues: Array[Double], + curPolyIdx: Int): Int = { + if (multiplier == 0.0) { + // do nothing + } else if (degree == 0 || lastIdx < 0) { + if (curPolyIdx >= 0) { // skip the very first 1 + polyValues(curPolyIdx) = multiplier + } + } else { + val v = values(lastIdx) + val lastIdx1 = lastIdx - 1 + var alpha = multiplier + var i = 0 + var curStart = curPolyIdx + while (i <= degree && alpha != 0.0) { + curStart = expandDense(values, lastIdx1, degree - i, alpha, polyValues, curStart) + i += 1 + alpha *= v + } + } + curPolyIdx + getPolySize(lastIdx + 1, degree) + } + + private def expandSparse( + indices: Array[Int], + values: Array[Double], + lastIdx: Int, + lastFeatureIdx: Int, + degree: Int, + multiplier: Double, + polyIndices: mutable.ArrayBuilder[Int], + polyValues: mutable.ArrayBuilder[Double], + curPolyIdx: Int): Int = { + if (multiplier == 0.0) { + // do nothing + } else if (degree == 0 || lastIdx < 0) { + if (curPolyIdx >= 0) { // skip the very first 1 + polyIndices += curPolyIdx + polyValues += multiplier + } + } else { + // Skip all zeros at the tail. + val v = values(lastIdx) + val lastIdx1 = lastIdx - 1 + val lastFeatureIdx1 = indices(lastIdx) - 1 + var alpha = multiplier + var curStart = curPolyIdx + var i = 0 + while (i <= degree && alpha != 0.0) { + curStart = expandSparse(indices, values, lastIdx1, lastFeatureIdx1, degree - i, alpha, + polyIndices, polyValues, curStart) + i += 1 + alpha *= v + } + } + curPolyIdx + getPolySize(lastFeatureIdx + 1, degree) + } + + private def expand(dv: DenseVector, degree: Int): DenseVector = { + val n = dv.size + val polySize = getPolySize(n, degree) + val polyValues = new Array[Double](polySize - 1) + expandDense(dv.values, n - 1, degree, 1.0, polyValues, -1) + new DenseVector(polyValues) + } + + private def expand(sv: SparseVector, degree: Int): SparseVector = { + val polySize = getPolySize(sv.size, degree) + val nnz = sv.values.length + val nnzPolySize = getPolySize(nnz, degree) + val polyIndices = mutable.ArrayBuilder.make[Int] + polyIndices.sizeHint(nnzPolySize - 1) + val polyValues = mutable.ArrayBuilder.make[Double] + polyValues.sizeHint(nnzPolySize - 1) + expandSparse( + sv.indices, sv.values, nnz - 1, sv.size - 1, degree, 1.0, polyIndices, polyValues, -1) + new SparseVector(polySize - 1, polyIndices.result(), polyValues.result()) + } + + def expand(v: Vector, degree: Int): Vector = { + v match { + case dv: DenseVector => expand(dv, degree) + case sv: SparseVector => expand(sv, degree) + case _ => throw new IllegalArgumentException + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 72825f6e0218..447851ec034d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -20,17 +20,32 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.Star -import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StructField, StructType} /** * Params for [[StandardScaler]] and [[StandardScalerModel]]. */ -private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol +private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol { + + /** + * False by default. Centers the data with mean before scaling. + * It will build a dense output, so this does not work on sparse input + * and will raise an exception. + * @group param + */ + val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean") + + /** + * True by default. Scales the data to unit standard deviation. + * @group param + */ + val withStd: BooleanParam = new BooleanParam(this, "withStd", "Scale to unit standard deviation") +} /** * :: AlphaComponent :: @@ -40,25 +55,33 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with @AlphaComponent class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams { + setDefault(withMean -> false, withStd -> true) + + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) - def setOutputCol(value: String): this.type = set(outputCol, value) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): StandardScalerModel = { + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setWithMean(value: Boolean): this.type = set(withMean, value) + + /** @group setParam */ + def setWithStd(value: Boolean): this.type = set(withStd, value) + + override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ - val map = this.paramMap ++ paramMap - val input = dataset.select(map(inputCol).attr) - .map { case Row(v: Vector) => - v - } - val scaler = new feature.StandardScaler().fit(input) - val model = new StandardScalerModel(this, map, scaler) + val map = extractParamMap(paramMap) + val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } + val scaler = new feature.StandardScaler(withMean = map(withMean), withStd = map(withStd)) + val scalerModel = scaler.fit(input) + val model = new StandardScalerModel(this, map, scalerModel) Params.inheritValues(map, this, model) model } - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = extractParamMap(paramMap) val inputType = schema(map(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${map(inputCol)} must be a vector column") @@ -80,21 +103,21 @@ class StandardScalerModel private[ml] ( scaler: feature.StandardScalerModel) extends Model[StandardScalerModel] with StandardScalerParams { + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ - val map = this.paramMap ++ paramMap - val scale: (Vector) => Vector = (v) => { - scaler.transform(v) - } - dataset.select(Star(None), scale.call(map(inputCol).attr) as map(outputCol)) + val map = extractParamMap(paramMap) + val scale = udf((v: Vector) => { scaler.transform(v) } : Vector) + dataset.withColumn(map(outputCol), scale(col(map(inputCol)))) } - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = extractParamMap(paramMap) val inputType = schema(map(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${map(inputCol)} must be a vector column") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala new file mode 100644 index 000000000000..23956c512c8a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkException +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.attribute.NominalAttribute +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.util.collection.OpenHashMap + +/** + * Base trait for [[StringIndexer]] and [[StringIndexerModel]]. + */ +private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { + + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = extractParamMap(paramMap) + SchemaUtils.checkColumnType(schema, map(inputCol), StringType) + val inputFields = schema.fields + val outputColName = map(outputCol) + require(inputFields.forall(_.name != outputColName), + s"Output column $outputColName already exists.") + val attr = NominalAttribute.defaultAttr.withName(map(outputCol)) + val outputFields = inputFields :+ attr.toStructField() + StructType(outputFields) + } +} + +/** + * :: AlphaComponent :: + * A label indexer that maps a string column of labels to an ML column of label indices. + * The indices are in [0, numLabels), ordered by label frequencies. + * So the most frequent label gets index 0. + */ +@AlphaComponent +class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + // TODO: handle unseen labels + + override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = { + val map = extractParamMap(paramMap) + val counts = dataset.select(map(inputCol)).map(_.getString(0)).countByValue() + val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray + val model = new StringIndexerModel(this, map, labels) + Params.inheritValues(map, this, model) + model + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} + +/** + * :: AlphaComponent :: + * Model fitted by [[StringIndexer]]. + */ +@AlphaComponent +class StringIndexerModel private[ml] ( + override val parent: StringIndexer, + override val fittingParamMap: ParamMap, + labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { + + private val labelToIndex: OpenHashMap[String, Double] = { + val n = labels.length + val map = new OpenHashMap[String, Double](n) + var i = 0 + while (i < n) { + map.update(labels(i), i) + i += 1 + } + map + } + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + val map = extractParamMap(paramMap) + val indexer = udf { label: String => + if (labelToIndex.contains(label)) { + labelToIndex(label) + } else { + // TODO: handle unseen labels + throw new SparkException(s"Unseen label: $label.") + } + } + val outputColName = map(outputCol) + val metadata = NominalAttribute.defaultAttr + .withName(outputColName).withValues(labels).toMetadata() + dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata)) + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index e622a5cf9e6f..376a004858b4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.{ParamMap, IntParam, BooleanParam, Param} import org.apache.spark.sql.types.{DataType, StringType, ArrayType} /** @@ -29,11 +29,75 @@ import org.apache.spark.sql.types.{DataType, StringType, ArrayType} @AlphaComponent class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] { - protected override def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { + override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { _.toLowerCase.split("\\s") } - protected override def validateInputType(inputType: DataType): Unit = { + override protected def validateInputType(inputType: DataType): Unit = { + require(inputType == StringType, s"Input type must be string type but got $inputType.") + } + + override protected def outputDataType: DataType = new ArrayType(StringType, false) +} + +/** + * :: AlphaComponent :: + * A regex based tokenizer that extracts tokens either by repeatedly matching the regex(default) + * or using it to split the text (set matching to false). Optional parameters also allow to fold + * the text to lowercase prior to it being tokenized and to filer tokens using a minimal length. + * It returns an array of strings that can be empty. + * The default parameters are regex = "\\p{L}+|[^\\p{L}\\s]+", matching = true, + * lowercase = false, minTokenLength = 1 + */ +@AlphaComponent +class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenizer] { + + /** + * param for minimum token length, default is one to avoid returning empty strings + * @group param + */ + val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length") + + /** @group setParam */ + def setMinTokenLength(value: Int): this.type = set(minTokenLength, value) + + /** @group getParam */ + def getMinTokenLength: Int = getOrDefault(minTokenLength) + + /** + * param sets regex as splitting on gaps (true) or matching tokens (false) + * @group param + */ + val gaps: BooleanParam = new BooleanParam(this, "gaps", "Set regex to match gaps or tokens") + + /** @group setParam */ + def setGaps(value: Boolean): this.type = set(gaps, value) + + /** @group getParam */ + def getGaps: Boolean = getOrDefault(gaps) + + /** + * param sets regex pattern used by tokenizer + * @group param + */ + val pattern: Param[String] = new Param(this, "pattern", "regex pattern used for tokenizing") + + /** @group setParam */ + def setPattern(value: String): this.type = set(pattern, value) + + /** @group getParam */ + def getPattern: String = getOrDefault(pattern) + + setDefault(minTokenLength -> 1, gaps -> false, pattern -> "\\p{L}+|[^\\p{L}\\s]+") + + override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { str => + val re = paramMap(pattern).r + val tokens = if (paramMap(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq + val minLength = paramMap(minTokenLength) + tokens.filter(_.length >= minLength) + } + + override protected def validateInputType(inputType: DataType): Unit = { require(inputType == StringType, s"Input type must be string type but got $inputType.") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala new file mode 100644 index 000000000000..7b2a451ca5ee --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable.ArrayBuilder + +import org.apache.spark.SparkException +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared._ +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.sql.{Column, DataFrame, Row} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, CreateStruct} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** + * :: AlphaComponent :: + * A feature transformer than merge multiple columns into a vector column. + */ +@AlphaComponent +class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { + + /** @group setParam */ + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + val map = extractParamMap(paramMap) + val assembleFunc = udf { r: Row => + VectorAssembler.assemble(r.toSeq: _*) + } + val schema = dataset.schema + val inputColNames = map(inputCols) + val args = inputColNames.map { c => + schema(c).dataType match { + case DoubleType => UnresolvedAttribute(c) + case t if t.isInstanceOf[VectorUDT] => UnresolvedAttribute(c) + case _: NumericType | BooleanType => + Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")() + } + } + dataset.select(col("*"), assembleFunc(new Column(CreateStruct(args))).as(map(outputCol))) + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = extractParamMap(paramMap) + val inputColNames = map(inputCols) + val outputColName = map(outputCol) + val inputDataTypes = inputColNames.map(name => schema(name).dataType) + inputDataTypes.foreach { + case _: NumericType | BooleanType => + case t if t.isInstanceOf[VectorUDT] => + case other => + throw new IllegalArgumentException(s"Data type $other is not supported.") + } + if (schema.fieldNames.contains(outputColName)) { + throw new IllegalArgumentException(s"Output column $outputColName already exists.") + } + StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false)) + } +} + +@AlphaComponent +object VectorAssembler { + + private[feature] def assemble(vv: Any*): Vector = { + val indices = ArrayBuilder.make[Int] + val values = ArrayBuilder.make[Double] + var cur = 0 + vv.foreach { + case v: Double => + if (v != 0.0) { + indices += cur + values += v + } + cur += 1 + case vec: Vector => + vec.foreachActive { case (i, v) => + if (v != 0.0) { + indices += cur + i + values += v + } + } + cur += vec.size + case null => + // TODO: output Double.NaN? + throw new SparkException("Values to assemble cannot be null.") + case o => + throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") + } + Vectors.sparse(cur, indices.result(), values.result()) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala new file mode 100644 index 000000000000..452faa06e202 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -0,0 +1,396 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.attribute.{BinaryAttribute, NumericAttribute, NominalAttribute, + Attribute, AttributeGroup} +import org.apache.spark.ml.param.{IntParam, ParamMap, Params} +import org.apache.spark.ml.param.shared._ +import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, VectorUDT} +import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.functions.callUDF +import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.util.collection.OpenHashSet + + +/** Private trait for params for VectorIndexer and VectorIndexerModel */ +private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOutputCol { + + /** + * Threshold for the number of values a categorical feature can take. + * If a feature is found to have > maxCategories values, then it is declared continuous. + * + * (default = 20) + */ + val maxCategories = new IntParam(this, "maxCategories", + "Threshold for the number of values a categorical feature can take." + + " If a feature is found to have > maxCategories values, then it is declared continuous.") + + /** @group getParam */ + def getMaxCategories: Int = getOrDefault(maxCategories) + + setDefault(maxCategories -> 20) +} + +/** + * :: AlphaComponent :: + * + * Class for indexing categorical feature columns in a dataset of [[Vector]]. + * + * This has 2 usage modes: + * - Automatically identify categorical features (default behavior) + * - This helps process a dataset of unknown vectors into a dataset with some continuous + * features and some categorical features. The choice between continuous and categorical + * is based upon a maxCategories parameter. + * - Set maxCategories to the maximum number of categorical any categorical feature should have. + * - E.g.: Feature 0 has unique values {-1.0, 0.0}, and feature 1 values {1.0, 3.0, 5.0}. + * If maxCategories = 2, then feature 0 will be declared categorical and use indices {0, 1}, + * and feature 1 will be declared continuous. + * - Index all features, if all features are categorical + * - If maxCategories is set to be very large, then this will build an index of unique + * values for all features. + * - Warning: This can cause problems if features are continuous since this will collect ALL + * unique values to the driver. + * - E.g.: Feature 0 has unique values {-1.0, 0.0}, and feature 1 values {1.0, 3.0, 5.0}. + * If maxCategories >= 3, then both features will be declared categorical. + * + * This returns a model which can transform categorical features to use 0-based indices. + * + * Index stability: + * - This is not guaranteed to choose the same category index across multiple runs. + * - If a categorical feature includes value 0, then this is guaranteed to map value 0 to index 0. + * This maintains vector sparsity. + * - More stability may be added in the future. + * + * TODO: Future extensions: The following functionality is planned for the future: + * - Preserve metadata in transform; if a feature's metadata is already present, do not recompute. + * - Specify certain features to not index, either via a parameter or via existing metadata. + * - Add warning if a categorical feature has only 1 category. + * - Add option for allowing unknown categories. + */ +@AlphaComponent +class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerParams { + + /** @group setParam */ + def setMaxCategories(value: Int): this.type = { + require(value > 1, + s"DatasetIndexer given maxCategories = value, but requires maxCategories > 1.") + set(maxCategories, value) + } + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def fit(dataset: DataFrame, paramMap: ParamMap): VectorIndexerModel = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = extractParamMap(paramMap) + val firstRow = dataset.select(map(inputCol)).take(1) + require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty dataset.") + val numFeatures = firstRow(0).getAs[Vector](0).size + val vectorDataset = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } + val maxCats = map(maxCategories) + val categoryStats: VectorIndexer.CategoryStats = vectorDataset.mapPartitions { iter => + val localCatStats = new VectorIndexer.CategoryStats(numFeatures, maxCats) + iter.foreach(localCatStats.addVector) + Iterator(localCatStats) + }.reduce((stats1, stats2) => stats1.merge(stats2)) + val model = new VectorIndexerModel(this, map, numFeatures, categoryStats.getCategoryMaps) + Params.inheritValues(map, this, model) + model + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + // We do not transfer feature metadata since we do not know what types of features we will + // produce in transform(). + val map = extractParamMap(paramMap) + val dataType = new VectorUDT + require(map.contains(inputCol), s"VectorIndexer requires input column parameter: $inputCol") + require(map.contains(outputCol), s"VectorIndexer requires output column parameter: $outputCol") + SchemaUtils.checkColumnType(schema, map(inputCol), dataType) + SchemaUtils.appendColumn(schema, map(outputCol), dataType) + } +} + +private object VectorIndexer { + + /** + * Helper class for tracking unique values for each feature. + * + * TODO: Track which features are known to be continuous already; do not update counts for them. + * + * @param numFeatures This class fails if it encounters a Vector whose length is not numFeatures. + * @param maxCategories This class caps the number of unique values collected at maxCategories. + */ + class CategoryStats(private val numFeatures: Int, private val maxCategories: Int) + extends Serializable { + + /** featureValueSets[feature index] = set of unique values */ + private val featureValueSets = + Array.fill[OpenHashSet[Double]](numFeatures)(new OpenHashSet[Double]()) + + /** Merge with another instance, modifying this instance. */ + def merge(other: CategoryStats): CategoryStats = { + featureValueSets.zip(other.featureValueSets).foreach { case (thisValSet, otherValSet) => + otherValSet.iterator.foreach { x => + // Once we have found > maxCategories values, we know the feature is continuous + // and do not need to collect more values for it. + if (thisValSet.size <= maxCategories) thisValSet.add(x) + } + } + this + } + + /** Add a new vector to this index, updating sets of unique feature values */ + def addVector(v: Vector): Unit = { + require(v.size == numFeatures, s"VectorIndexer expected $numFeatures features but" + + s" found vector of size ${v.size}.") + v match { + case dv: DenseVector => addDenseVector(dv) + case sv: SparseVector => addSparseVector(sv) + } + } + + /** + * Based on stats collected, decide which features are categorical, + * and choose indices for categories. + * + * Sparsity: This tries to maintain sparsity by treating value 0.0 specially. + * If a categorical feature takes value 0.0, then value 0.0 is given index 0. + * + * @return Feature value index. Keys are categorical feature indices (column indices). + * Values are mappings from original features values to 0-based category indices. + */ + def getCategoryMaps: Map[Int, Map[Double, Int]] = { + // Filter out features which are declared continuous. + featureValueSets.zipWithIndex.filter(_._1.size <= maxCategories).map { + case (featureValues: OpenHashSet[Double], featureIndex: Int) => + var sortedFeatureValues = featureValues.iterator.filter(_ != 0.0).toArray.sorted + val zeroExists = sortedFeatureValues.length + 1 == featureValues.size + if (zeroExists) { + sortedFeatureValues = 0.0 +: sortedFeatureValues + } + val categoryMap: Map[Double, Int] = sortedFeatureValues.zipWithIndex.toMap + (featureIndex, categoryMap) + }.toMap + } + + private def addDenseVector(dv: DenseVector): Unit = { + var i = 0 + while (i < dv.size) { + if (featureValueSets(i).size <= maxCategories) { + featureValueSets(i).add(dv(i)) + } + i += 1 + } + } + + private def addSparseVector(sv: SparseVector): Unit = { + // TODO: This might be able to handle 0's more efficiently. + var vecIndex = 0 // index into vector + var k = 0 // index into non-zero elements + while (vecIndex < sv.size) { + val featureValue = if (k < sv.indices.length && vecIndex == sv.indices(k)) { + k += 1 + sv.values(k - 1) + } else { + 0.0 + } + if (featureValueSets(vecIndex).size <= maxCategories) { + featureValueSets(vecIndex).add(featureValue) + } + vecIndex += 1 + } + } + } +} + +/** + * :: AlphaComponent :: + * + * Transform categorical features to use 0-based indices instead of their original values. + * - Categorical features are mapped to indices. + * - Continuous features (columns) are left unchanged. + * This also appends metadata to the output column, marking features as Numeric (continuous), + * Nominal (categorical), or Binary (either continuous or categorical). + * + * This maintains vector sparsity. + * + * @param numFeatures Number of features, i.e., length of Vectors which this transforms + * @param categoryMaps Feature value index. Keys are categorical feature indices (column indices). + * Values are maps from original features values to 0-based category indices. + * If a feature is not in this map, it is treated as continuous. + */ +@AlphaComponent +class VectorIndexerModel private[ml] ( + override val parent: VectorIndexer, + override val fittingParamMap: ParamMap, + val numFeatures: Int, + val categoryMaps: Map[Int, Map[Double, Int]]) + extends Model[VectorIndexerModel] with VectorIndexerParams { + + /** + * Pre-computed feature attributes, with some missing info. + * In transform(), set attribute name and other info, if available. + */ + private val partialFeatureAttributes: Array[Attribute] = { + val attrs = new Array[Attribute](numFeatures) + var categoricalFeatureCount = 0 // validity check for numFeatures, categoryMaps + var featureIndex = 0 + while (featureIndex < numFeatures) { + if (categoryMaps.contains(featureIndex)) { + // categorical feature + val featureValues: Array[String] = + categoryMaps(featureIndex).toArray.sortBy(_._1).map(_._1).map(_.toString) + if (featureValues.length == 2) { + attrs(featureIndex) = new BinaryAttribute(index = Some(featureIndex), + values = Some(featureValues)) + } else { + attrs(featureIndex) = new NominalAttribute(index = Some(featureIndex), + isOrdinal = Some(false), values = Some(featureValues)) + } + categoricalFeatureCount += 1 + } else { + // continuous feature + attrs(featureIndex) = new NumericAttribute(index = Some(featureIndex)) + } + featureIndex += 1 + } + require(categoricalFeatureCount == categoryMaps.size, "VectorIndexerModel given categoryMaps" + + s" with keys outside expected range [0,...,numFeatures), where numFeatures=$numFeatures") + attrs + } + + // TODO: Check more carefully about whether this whole class will be included in a closure. + + private val transformFunc: Vector => Vector = { + val sortedCategoricalFeatureIndices = categoryMaps.keys.toArray.sorted + val localVectorMap = categoryMaps + val f: Vector => Vector = { + case dv: DenseVector => + val tmpv = dv.copy + localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) => + tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex)) + } + tmpv + case sv: SparseVector => + // We use the fact that categorical value 0 is always mapped to index 0. + val tmpv = sv.copy + var catFeatureIdx = 0 // index into sortedCategoricalFeatureIndices + var k = 0 // index into non-zero elements of sparse vector + while (catFeatureIdx < sortedCategoricalFeatureIndices.length && k < tmpv.indices.length) { + val featureIndex = sortedCategoricalFeatureIndices(catFeatureIdx) + if (featureIndex < tmpv.indices(k)) { + catFeatureIdx += 1 + } else if (featureIndex > tmpv.indices(k)) { + k += 1 + } else { + tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k)) + catFeatureIdx += 1 + k += 1 + } + } + tmpv + } + f + } + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = extractParamMap(paramMap) + val newField = prepOutputField(dataset.schema, map) + val newCol = callUDF(transformFunc, new VectorUDT, dataset(map(inputCol))) + // For now, just check the first row of inputCol for vector length. + val firstRow = dataset.select(map(inputCol)).take(1) + if (firstRow.length != 0) { + val actualNumFeatures = firstRow(0).getAs[Vector](0).size + require(numFeatures == actualNumFeatures, "VectorIndexerModel expected vector of length" + + s" $numFeatures but found length $actualNumFeatures") + } + dataset.withColumn(map(outputCol), newCol.as(map(outputCol), newField.metadata)) + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = extractParamMap(paramMap) + val dataType = new VectorUDT + require(map.contains(inputCol), + s"VectorIndexerModel requires input column parameter: $inputCol") + require(map.contains(outputCol), + s"VectorIndexerModel requires output column parameter: $outputCol") + SchemaUtils.checkColumnType(schema, map(inputCol), dataType) + + val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol))) + val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) { + Some(origAttrGroup.attributes.get.length) + } else { + origAttrGroup.numAttributes + } + require(origNumFeatures.forall(_ == numFeatures), "VectorIndexerModel expected" + + s" $numFeatures features, but input column ${map(inputCol)} had metadata specifying" + + s" ${origAttrGroup.numAttributes.get} features.") + + val newField = prepOutputField(schema, map) + val outputFields = schema.fields :+ newField + StructType(outputFields) + } + + /** + * Prepare the output column field, including per-feature metadata. + * @param schema Input schema + * @param map Parameter map (with this class' embedded parameter map folded in) + * @return Output column field + */ + private def prepOutputField(schema: StructType, map: ParamMap): StructField = { + val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol))) + val featureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) { + // Convert original attributes to modified attributes + val origAttrs: Array[Attribute] = origAttrGroup.attributes.get + origAttrs.zip(partialFeatureAttributes).map { + case (origAttr: Attribute, featAttr: BinaryAttribute) => + if (origAttr.name.nonEmpty) { + featAttr.withName(origAttr.name.get) + } else { + featAttr + } + case (origAttr: Attribute, featAttr: NominalAttribute) => + if (origAttr.name.nonEmpty) { + featAttr.withName(origAttr.name.get) + } else { + featAttr + } + case (origAttr: Attribute, featAttr: NumericAttribute) => + origAttr.withIndex(featAttr.index.get) + } + } else { + partialFeatureAttributes + } + val newAttributeGroup = new AttributeGroup(map(outputCol), featureAttributes) + newAttributeGroup.toStructField(schema(map(inputCol)).metadata) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala new file mode 100644 index 000000000000..195333a5cc47 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.impl.estimator + +import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.mllib.linalg.{VectorUDT, Vector} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DataType, DoubleType, StructType} + + +/** + * :: DeveloperApi :: + * + * Trait for parameters for prediction (regression and classification). + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@DeveloperApi +private[spark] trait PredictorParams extends Params + with HasLabelCol with HasFeaturesCol with HasPredictionCol { + + /** + * Validates and transforms the input schema with the provided param map. + * @param schema input schema + * @param paramMap additional parameters + * @param fitting whether this is in fitting + * @param featuresDataType SQL DataType for FeaturesType. + * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. + * @return output schema + */ + protected def validateAndTransformSchema( + schema: StructType, + paramMap: ParamMap, + fitting: Boolean, + featuresDataType: DataType): StructType = { + val map = extractParamMap(paramMap) + // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector + SchemaUtils.checkColumnType(schema, map(featuresCol), featuresDataType) + if (fitting) { + // TODO: Allow other numeric types + SchemaUtils.checkColumnType(schema, map(labelCol), DoubleType) + } + SchemaUtils.appendColumn(schema, map(predictionCol), DoubleType) + } +} + +/** + * :: AlphaComponent :: + * + * Abstraction for prediction problems (regression and classification). + * + * @tparam FeaturesType Type of features. + * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. + * @tparam Learner Specialization of this class. If you subclass this type, use this type + * parameter to specify the concrete type. + * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type + * parameter to specify the concrete type for the corresponding model. + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] abstract class Predictor[ + FeaturesType, + Learner <: Predictor[FeaturesType, Learner, M], + M <: PredictionModel[FeaturesType, M]] + extends Estimator[M] with PredictorParams { + + /** @group setParam */ + def setLabelCol(value: String): Learner = set(labelCol, value).asInstanceOf[Learner] + + /** @group setParam */ + def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner] + + /** @group setParam */ + def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner] + + override def fit(dataset: DataFrame, paramMap: ParamMap): M = { + // This handles a few items such as schema validation. + // Developers only need to implement train(). + transformSchema(dataset.schema, paramMap, logging = true) + val map = extractParamMap(paramMap) + val model = train(dataset, map) + Params.inheritValues(map, this, model) // copy params to model + model + } + + /** + * :: DeveloperApi :: + * + * Train a model using the given dataset and parameters. + * Developers can implement this instead of [[fit()]] to avoid dealing with schema validation + * and copying parameters into the model. + * + * @param dataset Training dataset + * @param paramMap Parameter map. Unlike [[fit()]]'s paramMap, this paramMap has already + * been combined with the embedded ParamMap. + * @return Fitted model + */ + @DeveloperApi + protected def train(dataset: DataFrame, paramMap: ParamMap): M + + /** + * :: DeveloperApi :: + * + * Returns the SQL DataType corresponding to the FeaturesType type parameter. + * + * This is used by [[validateAndTransformSchema()]]. + * This workaround is needed since SQL has different APIs for Scala and Java. + * + * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector. + */ + @DeveloperApi + protected def featuresDataType: DataType = new VectorUDT + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap, fitting = true, featuresDataType) + } + + /** + * Extract [[labelCol]] and [[featuresCol]] from the given dataset, + * and put it in an RDD with strong types. + */ + protected def extractLabeledPoints(dataset: DataFrame, paramMap: ParamMap): RDD[LabeledPoint] = { + val map = extractParamMap(paramMap) + dataset.select(map(labelCol), map(featuresCol)) + .map { case Row(label: Double, features: Vector) => + LabeledPoint(label, features) + } + } +} + +/** + * :: AlphaComponent :: + * + * Abstraction for a model for prediction tasks (regression and classification). + * + * @tparam FeaturesType Type of features. + * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. + * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type + * parameter to specify the concrete type for the corresponding model. + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]] + extends Model[M] with PredictorParams { + + /** @group setParam */ + def setFeaturesCol(value: String): M = set(featuresCol, value).asInstanceOf[M] + + /** @group setParam */ + def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M] + + /** + * :: DeveloperApi :: + * + * Returns the SQL DataType corresponding to the FeaturesType type parameter. + * + * This is used by [[validateAndTransformSchema()]]. + * This workaround is needed since SQL has different APIs for Scala and Java. + * + * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector. + */ + @DeveloperApi + protected def featuresDataType: DataType = new VectorUDT + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap, fitting = false, featuresDataType) + } + + /** + * Transforms dataset by reading from [[featuresCol]], calling [[predict()]], and storing + * the predictions as a new column [[predictionCol]]. + * + * @param dataset input dataset + * @param paramMap additional parameters, overwrite embedded params + * @return transformed dataset with [[predictionCol]] of type [[Double]] + */ + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + // This default implementation should be overridden as needed. + + // Check schema + transformSchema(dataset.schema, paramMap, logging = true) + val map = extractParamMap(paramMap) + + // Prepare model + val tmpModel = if (paramMap.size != 0) { + val tmpModel = this.copy() + Params.inheritValues(paramMap, parent, tmpModel) + tmpModel + } else { + this + } + + if (map(predictionCol) != "") { + val pred: FeaturesType => Double = (features) => { + tmpModel.predict(features) + } + dataset.withColumn(map(predictionCol), callUDF(pred, DoubleType, col(map(featuresCol)))) + } else { + this.logWarning(s"$uid: Predictor.transform() was called as NOOP" + + " since no output columns were set.") + dataset + } + } + + /** + * :: DeveloperApi :: + * + * Predict label for the given features. + * This internal method is used to implement [[transform()]] and output [[predictionCol]]. + */ + @DeveloperApi + protected def predict(features: FeaturesType): Double + + /** + * Create a copy of the model. + * The copy is shallow, except for the embedded paramMap, which gets a deep copy. + */ + protected def copy(): M +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala new file mode 100644 index 000000000000..eb2609faef05 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.impl.tree + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.impl.estimator.PredictorParams +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.impurity.{Gini => OldGini, Entropy => OldEntropy, + Impurity => OldImpurity, Variance => OldVariance} + + +/** + * :: DeveloperApi :: + * Parameters for Decision Tree-based algorithms. + * + * Note: Marked as private and DeveloperApi since this may be made public in the future. + */ +@DeveloperApi +private[ml] trait DecisionTreeParams extends PredictorParams { + + /** + * Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * (default = 5) + * @group param + */ + final val maxDepth: IntParam = + new IntParam(this, "maxDepth", "Maximum depth of the tree." + + " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.") + + /** + * Maximum number of bins used for discretizing continuous features and for choosing how to split + * on features at each node. More bins give higher granularity. + * Must be >= 2 and >= number of categories in any categorical feature. + * (default = 32) + * @group param + */ + final val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for" + + " discretizing continuous features. Must be >=2 and >= number of categories for any" + + " categorical feature.") + + /** + * Minimum number of instances each child must have after split. + * If a split causes the left or right child to have fewer than minInstancesPerNode, + * the split will be discarded as invalid. + * Should be >= 1. + * (default = 1) + * @group param + */ + final val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum" + + " number of instances each child must have after split. If a split causes the left or right" + + " child to have fewer than minInstancesPerNode, the split will be discarded as invalid." + + " Should be >= 1.") + + /** + * Minimum information gain for a split to be considered at a tree node. + * (default = 0.0) + * @group param + */ + final val minInfoGain: DoubleParam = new DoubleParam(this, "minInfoGain", + "Minimum information gain for a split to be considered at a tree node.") + + /** + * Maximum memory in MB allocated to histogram aggregation. + * (default = 256 MB) + * @group expertParam + */ + final val maxMemoryInMB: IntParam = new IntParam(this, "maxMemoryInMB", + "Maximum memory in MB allocated to histogram aggregation.") + + /** + * If false, the algorithm will pass trees to executors to match instances with nodes. + * If true, the algorithm will cache node IDs for each instance. + * Caching can speed up training of deeper trees. + * (default = false) + * @group expertParam + */ + final val cacheNodeIds: BooleanParam = new BooleanParam(this, "cacheNodeIds", "If false, the" + + " algorithm will pass trees to executors to match instances with nodes. If true, the" + + " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" + + " trees.") + + /** + * Specifies how often to checkpoint the cached node IDs. + * E.g. 10 means that the cache will get checkpointed every 10 iterations. + * This is only used if cacheNodeIds is true and if the checkpoint directory is set in + * [[org.apache.spark.SparkContext]]. + * Must be >= 1. + * (default = 10) + * @group expertParam + */ + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies" + + " how often to checkpoint the cached node IDs. E.g. 10 means that the cache will get" + + " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" + + " checkpoint directory is set in the SparkContext. Must be >= 1.") + + setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, + maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) + + /** @group setParam */ + def setMaxDepth(value: Int): this.type = { + require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value") + set(maxDepth, value) + this + } + + /** @group getParam */ + def getMaxDepth: Int = getOrDefault(maxDepth) + + /** @group setParam */ + def setMaxBins(value: Int): this.type = { + require(value >= 2, s"maxBins parameter must be >= 2. Given bad value: $value") + set(maxBins, value) + this + } + + /** @group getParam */ + def getMaxBins: Int = getOrDefault(maxBins) + + /** @group setParam */ + def setMinInstancesPerNode(value: Int): this.type = { + require(value >= 1, s"minInstancesPerNode parameter must be >= 1. Given bad value: $value") + set(minInstancesPerNode, value) + this + } + + /** @group getParam */ + def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode) + + /** @group setParam */ + def setMinInfoGain(value: Double): this.type = { + set(minInfoGain, value) + this + } + + /** @group getParam */ + def getMinInfoGain: Double = getOrDefault(minInfoGain) + + /** @group expertSetParam */ + def setMaxMemoryInMB(value: Int): this.type = { + require(value > 0, s"maxMemoryInMB parameter must be > 0. Given bad value: $value") + set(maxMemoryInMB, value) + this + } + + /** @group expertGetParam */ + def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB) + + /** @group expertSetParam */ + def setCacheNodeIds(value: Boolean): this.type = { + set(cacheNodeIds, value) + this + } + + /** @group expertGetParam */ + def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds) + + /** @group expertSetParam */ + def setCheckpointInterval(value: Int): this.type = { + require(value >= 1, s"checkpointInterval parameter must be >= 1. Given bad value: $value") + set(checkpointInterval, value) + this + } + + /** @group expertGetParam */ + def getCheckpointInterval: Int = getOrDefault(checkpointInterval) + + /** + * Create a Strategy instance to use with the old API. + * NOTE: The caller should set impurity and subsamplingRate (which is set to 1.0, + * the default for single trees). + */ + private[ml] def getOldStrategy( + categoricalFeatures: Map[Int, Int], + numClasses: Int): OldStrategy = { + val strategy = OldStrategy.defaultStategy(OldAlgo.Classification) + strategy.checkpointInterval = getCheckpointInterval + strategy.maxBins = getMaxBins + strategy.maxDepth = getMaxDepth + strategy.maxMemoryInMB = getMaxMemoryInMB + strategy.minInfoGain = getMinInfoGain + strategy.minInstancesPerNode = getMinInstancesPerNode + strategy.useNodeIdCache = getCacheNodeIds + strategy.numClasses = numClasses + strategy.categoricalFeaturesInfo = categoricalFeatures + strategy.subsamplingRate = 1.0 // default for individual trees + strategy + } +} + +/** + * (private trait) Parameters for Decision Tree-based classification algorithms. + */ +private[ml] trait TreeClassifierParams extends Params { + + /** + * Criterion used for information gain calculation (case-insensitive). + * Supported: "entropy" and "gini". + * (default = gini) + * @group param + */ + val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + + " information gain calculation (case-insensitive). Supported options:" + + s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}") + + setDefault(impurity -> "gini") + + /** @group setParam */ + def setImpurity(value: String): this.type = { + val impurityStr = value.toLowerCase + require(TreeClassifierParams.supportedImpurities.contains(impurityStr), + s"Tree-based classifier was given unrecognized impurity: $value." + + s" Supported options: ${TreeClassifierParams.supportedImpurities.mkString(", ")}") + set(impurity, impurityStr) + this + } + + /** @group getParam */ + def getImpurity: String = getOrDefault(impurity) + + /** Convert new impurity to old impurity. */ + private[ml] def getOldImpurity: OldImpurity = { + getImpurity match { + case "entropy" => OldEntropy + case "gini" => OldGini + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException( + s"TreeClassifierParams was given unrecognized impurity: $impurity.") + } + } +} + +private[ml] object TreeClassifierParams { + // These options should be lowercase. + val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase) +} + +/** + * (private trait) Parameters for Decision Tree-based regression algorithms. + */ +private[ml] trait TreeRegressorParams extends Params { + + /** + * Criterion used for information gain calculation (case-insensitive). + * Supported: "variance". + * (default = variance) + * @group param + */ + val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + + " information gain calculation (case-insensitive). Supported options:" + + s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}") + + setDefault(impurity -> "variance") + + /** @group setParam */ + def setImpurity(value: String): this.type = { + val impurityStr = value.toLowerCase + require(TreeRegressorParams.supportedImpurities.contains(impurityStr), + s"Tree-based regressor was given unrecognized impurity: $value." + + s" Supported options: ${TreeRegressorParams.supportedImpurities.mkString(", ")}") + set(impurity, impurityStr) + this + } + + /** @group getParam */ + def getImpurity: String = getOrDefault(impurity) + + /** Convert new impurity to old impurity. */ + private[ml] def getOldImpurity: OldImpurity = { + getImpurity match { + case "variance" => OldVariance + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException( + s"TreeRegressorParams was given unrecognized impurity: $impurity") + } + } +} + +private[ml] object TreeRegressorParams { + // These options should be lowercase. + val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala index 51cd48c90432..ac75e9de1a8f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/package.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala @@ -20,5 +20,31 @@ package org.apache.spark /** * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly * assemble and configure practical machine learning pipelines. + * + * @groupname param Parameters + * @groupdesc param A list of (hyper-)parameter keys this algorithm can take. Users can set and get + * the parameter values through setters and getters, respectively. + * @groupprio param -5 + * + * @groupname setParam Parameter setters + * @groupprio setParam 5 + * + * @groupname getParam Parameter getters + * @groupprio getParam 6 + * + * @groupname expertParam (expert-only) Parameters + * @groupdesc expertParam A list of advanced, expert-only (hyper-)parameter keys this algorithm can + * take. Users can set and get the parameter values through setters and getters, + * respectively. + * @groupprio expertParam 7 + * + * @groupname expertSetParam (expert-only) Parameter setters + * @groupprio expertSetParam 8 + * + * @groupname expertGetParam (expert-only) Parameter getters + * @groupprio expertGetParam 9 + * + * @groupname Ungrouped Members + * @groupprio Ungrouped 0 */ package object ml diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 04f9cfb1bfc2..ddc5907e7fac 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -17,12 +17,13 @@ package org.apache.spark.ml.param +import java.lang.reflect.Modifier +import java.util.NoSuchElementException + import scala.annotation.varargs import scala.collection.mutable -import java.lang.reflect.Modifier - -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} import org.apache.spark.ml.Identifiable /** @@ -36,12 +37,7 @@ import org.apache.spark.ml.Identifiable * @tparam T param value type */ @AlphaComponent -class Param[T] ( - val parent: Params, - val name: String, - val doc: String, - val defaultValue: Option[T] = None) - extends Serializable { +class Param[T] (val parent: Params, val name: String, val doc: String) extends Serializable { /** * Creates a param pair with the given value (for Java). @@ -53,48 +49,55 @@ class Param[T] ( */ def ->(value: T): ParamPair[T] = ParamPair(this, value) + /** + * Converts this param's name, doc, and optionally its default value and the user-supplied + * value in its parent to string. + */ override def toString: String = { - if (defaultValue.isDefined) { - s"$name: $doc (default: ${defaultValue.get})" + val valueStr = if (parent.isDefined(this)) { + val defaultValueStr = parent.getDefault(this).map("default: " + _) + val currentValueStr = parent.get(this).map("current: " + _) + (defaultValueStr ++ currentValueStr).mkString("(", ", ", ")") } else { - s"$name: $doc" + "(undefined)" } + s"$name: $doc $valueStr" } } // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... /** Specialized version of [[Param[Double]]] for Java. */ -class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double] = None) - extends Param[Double](parent, name, doc, defaultValue) { +class DoubleParam(parent: Params, name: String, doc: String) + extends Param[Double](parent, name, doc) { override def w(value: Double): ParamPair[Double] = super.w(value) } /** Specialized version of [[Param[Int]]] for Java. */ -class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int] = None) - extends Param[Int](parent, name, doc, defaultValue) { +class IntParam(parent: Params, name: String, doc: String) + extends Param[Int](parent, name, doc) { override def w(value: Int): ParamPair[Int] = super.w(value) } /** Specialized version of [[Param[Float]]] for Java. */ -class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float] = None) - extends Param[Float](parent, name, doc, defaultValue) { +class FloatParam(parent: Params, name: String, doc: String) + extends Param[Float](parent, name, doc) { override def w(value: Float): ParamPair[Float] = super.w(value) } /** Specialized version of [[Param[Long]]] for Java. */ -class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long] = None) - extends Param[Long](parent, name, doc, defaultValue) { +class LongParam(parent: Params, name: String, doc: String) + extends Param[Long](parent, name, doc) { override def w(value: Long): ParamPair[Long] = super.w(value) } /** Specialized version of [[Param[Boolean]]] for Java. */ -class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean] = None) - extends Param[Boolean](parent, name, doc, defaultValue) { +class BooleanParam(parent: Params, name: String, doc: String) + extends Param[Boolean](parent, name, doc) { override def w(value: Boolean): ParamPair[Boolean] = super.w(value) } @@ -112,8 +115,11 @@ case class ParamPair[T](param: Param[T], value: T) @AlphaComponent trait Params extends Identifiable with Serializable { - /** Returns all params. */ - def params: Array[Param[_]] = { + /** + * Returns all params sorted by their names. The default implementation uses Java reflection to + * list all public methods that have no arguments and return [[Param]]. + */ + lazy val params: Array[Param[_]] = { val methods = this.getClass.getMethods methods.filter { m => Modifier.isPublic(m.getModifiers) && @@ -141,44 +147,144 @@ trait Params extends Identifiable with Serializable { def explainParams(): String = params.mkString("\n") /** Checks whether a param is explicitly set. */ - def isSet(param: Param[_]): Boolean = { - require(param.parent.eq(this)) + final def isSet(param: Param[_]): Boolean = { + shouldOwn(param) paramMap.contains(param) } + /** Checks whether a param is explicitly set or has a default value. */ + final def isDefined(param: Param[_]): Boolean = { + shouldOwn(param) + defaultParamMap.contains(param) || paramMap.contains(param) + } + /** Gets a param by its name. */ - private[ml] def getParam(paramName: String): Param[Any] = { - val m = this.getClass.getMethod(paramName) - assert(Modifier.isPublic(m.getModifiers) && - classOf[Param[_]].isAssignableFrom(m.getReturnType) && - m.getParameterTypes.isEmpty) - m.invoke(this).asInstanceOf[Param[Any]] + def getParam(paramName: String): Param[Any] = { + params.find(_.name == paramName).getOrElse { + throw new NoSuchElementException(s"Param $paramName does not exist.") + }.asInstanceOf[Param[Any]] } /** * Sets a parameter in the embedded param map. */ - private[ml] def set[T](param: Param[T], value: T): this.type = { - require(param.parent.eq(this)) + protected final def set[T](param: Param[T], value: T): this.type = { + shouldOwn(param) paramMap.put(param.asInstanceOf[Param[Any]], value) this } /** - * Gets the value of a parameter in the embedded param map. + * Sets a parameter (by name) in the embedded param map. + */ + protected final def set(param: String, value: Any): this.type = { + set(getParam(param), value) + } + + /** + * Optionally returns the user-supplied value of a param. + */ + final def get[T](param: Param[T]): Option[T] = { + shouldOwn(param) + paramMap.get(param) + } + + /** + * Clears the user-supplied value for the input param. + */ + protected final def clear(param: Param[_]): this.type = { + shouldOwn(param) + paramMap.remove(param) + this + } + + /** + * Gets the value of a param in the embedded param map or its default value. Throws an exception + * if neither is set. + */ + final def getOrDefault[T](param: Param[T]): T = { + shouldOwn(param) + get(param).orElse(getDefault(param)).get + } + + /** + * Sets a default value for a param. + * @param param param to set the default value. Make sure that this param is initialized before + * this method gets called. + * @param value the default value + */ + protected final def setDefault[T](param: Param[T], value: T): this.type = { + shouldOwn(param) + defaultParamMap.put(param, value) + this + } + + /** + * Sets default values for a list of params. + * @param paramPairs a list of param pairs that specify params and their default values to set + * respectively. Make sure that the params are initialized before this method + * gets called. */ - private[ml] def get[T](param: Param[T]): T = { - require(param.parent.eq(this)) - paramMap(param) + protected final def setDefault(paramPairs: ParamPair[_]*): this.type = { + paramPairs.foreach { p => + setDefault(p.param.asInstanceOf[Param[Any]], p.value) + } + this } /** - * Internal param map. + * Gets the default value of a parameter. */ - protected val paramMap: ParamMap = ParamMap.empty + final def getDefault[T](param: Param[T]): Option[T] = { + shouldOwn(param) + defaultParamMap.get(param) + } + + /** + * Tests whether the input param has a default value set. + */ + final def hasDefault[T](param: Param[T]): Boolean = { + shouldOwn(param) + defaultParamMap.contains(param) + } + + /** + * Extracts the embedded default param values and user-supplied values, and then merges them with + * extra values from input into a flat param map, where the latter value is used if there exist + * conflicts, i.e., with ordering: default param values < user-supplied values < extraParamMap. + */ + protected final def extractParamMap(extraParamMap: ParamMap): ParamMap = { + defaultParamMap ++ paramMap ++ extraParamMap + } + + /** + * [[extractParamMap]] with no extra values. + */ + protected final def extractParamMap(): ParamMap = { + extractParamMap(ParamMap.empty) + } + + /** Internal param map for user-supplied values. */ + private val paramMap: ParamMap = ParamMap.empty + + /** Internal param map for default values. */ + private val defaultParamMap: ParamMap = ParamMap.empty + + /** Validates that the input param belongs to this instance. */ + private def shouldOwn(param: Param[_]): Unit = { + require(param.parent.eq(this), s"Param $param does not belong to $this.") + } } -private[ml] object Params { +/** + * :: DeveloperApi :: + * + * Helper functionality for developers. + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@DeveloperApi +private[spark] object Params { /** * Copies parameter values from the parent estimator to the child model it produced. @@ -190,8 +296,9 @@ private[ml] object Params { paramMap: ParamMap, parent: E, child: M): Unit = { + val childParams = child.params.map(_.name).toSet parent.params.foreach { param => - if (paramMap.contains(param)) { + if (paramMap.contains(param) && childParams.contains(param.name)) { child.set(child.getParam(param.name), paramMap(param)) } } @@ -203,12 +310,13 @@ private[ml] object Params { * A param to value map. */ @AlphaComponent -class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) extends Serializable { +final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) + extends Serializable { /** * Creates an empty param map. */ - def this() = this(mutable.Map.empty[Param[Any], Any]) + def this() = this(mutable.Map.empty) /** * Puts a (param, value) pair (overwrites if the input param exists). @@ -230,12 +338,17 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten } /** - * Optionally returns the value associated with a param or its default. + * Optionally returns the value associated with a param. */ def get[T](param: Param[T]): Option[T] = { - map.get(param.asInstanceOf[Param[Any]]) - .orElse(param.defaultValue) - .asInstanceOf[Option[T]] + map.get(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]] + } + + /** + * Returns the value associated with a param or a default value. + */ + def getOrElse[T](param: Param[T], default: T): T = { + get(param).getOrElse(default) } /** @@ -243,10 +356,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten * Raises a NoSuchElementException if there is no value associated with the input param. */ def apply[T](param: Param[T]): T = { - val value = get(param) - if (value.isDefined) { - value.get - } else { + get(param).getOrElse { throw new NoSuchElementException(s"Cannot find param ${param.name}.") } } @@ -258,6 +368,13 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten map.contains(param.asInstanceOf[Param[Any]]) } + /** + * Removes a key from this map and returns its value associated previously as an option. + */ + def remove[T](param: Param[T]): Option[T] = { + map.remove(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]] + } + /** * Filters this param map for the given parent. */ @@ -267,26 +384,25 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten } /** - * Make a copy of this param map. + * Creates a copy of this param map. */ def copy: ParamMap = new ParamMap(map.clone()) override def toString: String = { - map.map { case (param, value) => + map.toSeq.sortBy(_._1.name).map { case (param, value) => s"\t${param.parent.uid}-${param.name}: $value" }.mkString("{\n", ",\n", "\n}") } /** * Returns a new param map that contains parameters in this map and the given map, - * where the latter overwrites this if there exists conflicts. + * where the latter overwrites this if there exist conflicts. */ def ++(other: ParamMap): ParamMap = { // TODO: Provide a better method name for Java users. new ParamMap(this.map ++ other.map) } - /** * Adds all parameters from the input param map into this param map. */ @@ -304,6 +420,11 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten ParamPair(param, value) } } + + /** + * Number of param pairs in this map. + */ + def size: Int = map.size } object ParamMap { diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala new file mode 100644 index 000000000000..95d7e64790c7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param.shared + +import java.io.PrintWriter + +import scala.reflect.ClassTag + +/** + * Code generator for shared params (sharedParams.scala). Run under the Spark folder with + * {{{ + * build/sbt "mllib/runMain org.apache.spark.ml.param.shared.SharedParamsCodeGen" + * }}} + */ +private[shared] object SharedParamsCodeGen { + + def main(args: Array[String]): Unit = { + val params = Seq( + ParamDesc[Double]("regParam", "regularization parameter"), + ParamDesc[Int]("maxIter", "max number of iterations"), + ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")), + ParamDesc[String]("labelCol", "label column name", Some("\"label\"")), + ParamDesc[String]("predictionCol", "prediction column name", Some("\"prediction\"")), + ParamDesc[String]("rawPredictionCol", "raw prediction (a.k.a. confidence) column name", + Some("\"rawPrediction\"")), + ParamDesc[String]("probabilityCol", + "column name for predicted class conditional probabilities", Some("\"probability\"")), + ParamDesc[Double]("threshold", "threshold in binary classification prediction"), + ParamDesc[String]("inputCol", "input column name"), + ParamDesc[Array[String]]("inputCols", "input column names"), + ParamDesc[String]("outputCol", "output column name"), + ParamDesc[Int]("checkpointInterval", "checkpoint interval"), + ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true"))) + + val code = genSharedParams(params) + val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" + val writer = new PrintWriter(file) + writer.write(code) + writer.close() + } + + /** Description of a param. */ + private case class ParamDesc[T: ClassTag]( + name: String, + doc: String, + defaultValueStr: Option[String] = None) { + + require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.") + require(doc.nonEmpty) // TODO: more rigorous on doc + + def paramTypeName: String = { + val c = implicitly[ClassTag[T]].runtimeClass + c match { + case _ if c == classOf[Int] => "IntParam" + case _ if c == classOf[Long] => "LongParam" + case _ if c == classOf[Float] => "FloatParam" + case _ if c == classOf[Double] => "DoubleParam" + case _ if c == classOf[Boolean] => "BooleanParam" + case _ => s"Param[${getTypeString(c)}]" + } + } + + def valueTypeName: String = { + val c = implicitly[ClassTag[T]].runtimeClass + getTypeString(c) + } + + private def getTypeString(c: Class[_]): String = { + c match { + case _ if c == classOf[Int] => "Int" + case _ if c == classOf[Long] => "Long" + case _ if c == classOf[Float] => "Float" + case _ if c == classOf[Double] => "Double" + case _ if c == classOf[Boolean] => "Boolean" + case _ if c == classOf[String] => "String" + case _ if c.isArray => s"Array[${getTypeString(c.getComponentType)}]" + } + } + } + + /** Generates the HasParam trait code for the input param. */ + private def genHasParamTrait(param: ParamDesc[_]): String = { + val name = param.name + val Name = name(0).toUpper +: name.substring(1) + val Param = param.paramTypeName + val T = param.valueTypeName + val doc = param.doc + val defaultValue = param.defaultValueStr + val defaultValueDoc = defaultValue.map { v => + s" (default: $v)" + }.getOrElse("") + val setDefault = defaultValue.map { v => + s""" + | setDefault($name, $v) + |""".stripMargin + }.getOrElse("") + + s""" + |/** + | * :: DeveloperApi :: + | * Trait for shared param $name$defaultValueDoc. + | */ + |@DeveloperApi + |trait Has$Name extends Params { + | + | /** + | * Param for $doc. + | * @group param + | */ + | final val $name: $Param = new $Param(this, "$name", "$doc") + |$setDefault + | /** @group getParam */ + | final def get$Name: $T = getOrDefault($name) + |} + |""".stripMargin + } + + /** Generates Scala source code for the input params with header. */ + private def genSharedParams(params: Seq[ParamDesc[_]]): String = { + val header = + """/* + | * Licensed to the Apache Software Foundation (ASF) under one or more + | * contributor license agreements. See the NOTICE file distributed with + | * this work for additional information regarding copyright ownership. + | * The ASF licenses this file to You under the Apache License, Version 2.0 + | * (the "License"); you may not use this file except in compliance with + | * the License. You may obtain a copy of the License at + | * + | * http://www.apache.org/licenses/LICENSE-2.0 + | * + | * Unless required by applicable law or agreed to in writing, software + | * distributed under the License is distributed on an "AS IS" BASIS, + | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + | * See the License for the specific language governing permissions and + | * limitations under the License. + | */ + | + |package org.apache.spark.ml.param.shared + | + |import org.apache.spark.annotation.DeveloperApi + |import org.apache.spark.ml.param._ + | + |// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. + | + |// scalastyle:off + |""".stripMargin + + val footer = "// scalastyle:on\n" + + val traits = params.map(genHasParamTrait).mkString + + header + traits + footer + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala new file mode 100644 index 000000000000..72b08bf27648 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param.shared + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.param._ + +// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. + +// scalastyle:off + +/** + * :: DeveloperApi :: + * Trait for shared param regParam. + */ +@DeveloperApi +trait HasRegParam extends Params { + + /** + * Param for regularization parameter. + * @group param + */ + final val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter") + + /** @group getParam */ + final def getRegParam: Double = getOrDefault(regParam) +} + +/** + * :: DeveloperApi :: + * Trait for shared param maxIter. + */ +@DeveloperApi +trait HasMaxIter extends Params { + + /** + * Param for max number of iterations. + * @group param + */ + final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") + + /** @group getParam */ + final def getMaxIter: Int = getOrDefault(maxIter) +} + +/** + * :: DeveloperApi :: + * Trait for shared param featuresCol (default: "features"). + */ +@DeveloperApi +trait HasFeaturesCol extends Params { + + /** + * Param for features column name. + * @group param + */ + final val featuresCol: Param[String] = new Param[String](this, "featuresCol", "features column name") + + setDefault(featuresCol, "features") + + /** @group getParam */ + final def getFeaturesCol: String = getOrDefault(featuresCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param labelCol (default: "label"). + */ +@DeveloperApi +trait HasLabelCol extends Params { + + /** + * Param for label column name. + * @group param + */ + final val labelCol: Param[String] = new Param[String](this, "labelCol", "label column name") + + setDefault(labelCol, "label") + + /** @group getParam */ + final def getLabelCol: String = getOrDefault(labelCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param predictionCol (default: "prediction"). + */ +@DeveloperApi +trait HasPredictionCol extends Params { + + /** + * Param for prediction column name. + * @group param + */ + final val predictionCol: Param[String] = new Param[String](this, "predictionCol", "prediction column name") + + setDefault(predictionCol, "prediction") + + /** @group getParam */ + final def getPredictionCol: String = getOrDefault(predictionCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param rawPredictionCol (default: "rawPrediction"). + */ +@DeveloperApi +trait HasRawPredictionCol extends Params { + + /** + * Param for raw prediction (a.k.a. confidence) column name. + * @group param + */ + final val rawPredictionCol: Param[String] = new Param[String](this, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name") + + setDefault(rawPredictionCol, "rawPrediction") + + /** @group getParam */ + final def getRawPredictionCol: String = getOrDefault(rawPredictionCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param probabilityCol (default: "probability"). + */ +@DeveloperApi +trait HasProbabilityCol extends Params { + + /** + * Param for column name for predicted class conditional probabilities. + * @group param + */ + final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "column name for predicted class conditional probabilities") + + setDefault(probabilityCol, "probability") + + /** @group getParam */ + final def getProbabilityCol: String = getOrDefault(probabilityCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param threshold. + */ +@DeveloperApi +trait HasThreshold extends Params { + + /** + * Param for threshold in binary classification prediction. + * @group param + */ + final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction") + + /** @group getParam */ + final def getThreshold: Double = getOrDefault(threshold) +} + +/** + * :: DeveloperApi :: + * Trait for shared param inputCol. + */ +@DeveloperApi +trait HasInputCol extends Params { + + /** + * Param for input column name. + * @group param + */ + final val inputCol: Param[String] = new Param[String](this, "inputCol", "input column name") + + /** @group getParam */ + final def getInputCol: String = getOrDefault(inputCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param inputCols. + */ +@DeveloperApi +trait HasInputCols extends Params { + + /** + * Param for input column names. + * @group param + */ + final val inputCols: Param[Array[String]] = new Param[Array[String]](this, "inputCols", "input column names") + + /** @group getParam */ + final def getInputCols: Array[String] = getOrDefault(inputCols) +} + +/** + * :: DeveloperApi :: + * Trait for shared param outputCol. + */ +@DeveloperApi +trait HasOutputCol extends Params { + + /** + * Param for output column name. + * @group param + */ + final val outputCol: Param[String] = new Param[String](this, "outputCol", "output column name") + + /** @group getParam */ + final def getOutputCol: String = getOrDefault(outputCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param checkpointInterval. + */ +@DeveloperApi +trait HasCheckpointInterval extends Params { + + /** + * Param for checkpoint interval. + * @group param + */ + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval") + + /** @group getParam */ + final def getCheckpointInterval: Int = getOrDefault(checkpointInterval) +} + +/** + * :: DeveloperApi :: + * Trait for shared param fitIntercept (default: true). + */ +@DeveloperApi +trait HasFitIntercept extends Params { + + /** + * Param for whether to fit an intercept term. + * @group param + */ + final val fitIntercept: BooleanParam = new BooleanParam(this, "fitIntercept", "whether to fit an intercept term") + + setDefault(fitIntercept, true) + + /** @group getParam */ + final def getFitIntercept: Boolean = getOrDefault(fitIntercept) +} +// scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala deleted file mode 100644 index ef141d3eb2b0..000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.param - -private[ml] trait HasRegParam extends Params { - /** param for regularization parameter */ - val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter") - def getRegParam: Double = get(regParam) -} - -private[ml] trait HasMaxIter extends Params { - /** param for max number of iterations */ - val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") - def getMaxIter: Int = get(maxIter) -} - -private[ml] trait HasFeaturesCol extends Params { - /** param for features column name */ - val featuresCol: Param[String] = - new Param(this, "featuresCol", "features column name", Some("features")) - def getFeaturesCol: String = get(featuresCol) -} - -private[ml] trait HasLabelCol extends Params { - /** param for label column name */ - val labelCol: Param[String] = new Param(this, "labelCol", "label column name", Some("label")) - def getLabelCol: String = get(labelCol) -} - -private[ml] trait HasScoreCol extends Params { - /** param for score column name */ - val scoreCol: Param[String] = new Param(this, "scoreCol", "score column name", Some("score")) - def getScoreCol: String = get(scoreCol) -} - -private[ml] trait HasPredictionCol extends Params { - /** param for prediction column name */ - val predictionCol: Param[String] = - new Param(this, "predictionCol", "prediction column name", Some("prediction")) - def getPredictionCol: String = get(predictionCol) -} - -private[ml] trait HasThreshold extends Params { - /** param for threshold in (binary) prediction */ - val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction") - def getThreshold: Double = get(threshold) -} - -private[ml] trait HasInputCol extends Params { - /** param for input column name */ - val inputCol: Param[String] = new Param(this, "inputCol", "input column name") - def getInputCol: String = get(inputCol) -} - -private[ml] trait HasOutputCol extends Params { - /** param for output column name */ - val outputCol: Param[String] = new Param(this, "outputCol", "output column name") - def getOutputCol: String = get(outputCol) -} diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 2d89e76a4c8b..bd793beba35b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -18,22 +18,29 @@ package org.apache.spark.ml.recommendation import java.{util => ju} +import java.io.IOException import scala.collection.mutable +import scala.reflect.ClassTag +import scala.util.Sorting +import scala.util.hashing.byteswap64 import com.github.fommil.netlib.BLAS.{getInstance => blas} import com.github.fommil.netlib.LAPACK.{getInstance => lapack} +import org.apache.hadoop.fs.{FileSystem, Path} import org.netlib.util.intW -import org.apache.spark.{HashPartitioner, Logging, Partitioner} +import org.apache.spark.{Logging, Partitioner} +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SchemaRDD -import org.apache.spark.sql.catalyst.dsl._ -import org.apache.spark.sql.catalyst.expressions.Cast -import org.apache.spark.sql.catalyst.plans.LeftOuter +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType} +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} import org.apache.spark.util.random.XORShiftRandom @@ -42,42 +49,94 @@ import org.apache.spark.util.random.XORShiftRandom * Common params for ALS. */ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam - with HasPredictionCol { + with HasPredictionCol with HasCheckpointInterval { - /** Param for rank of the matrix factorization. */ - val rank = new IntParam(this, "rank", "rank of the factorization", Some(10)) - def getRank: Int = get(rank) + /** + * Param for rank of the matrix factorization. + * @group param + */ + val rank = new IntParam(this, "rank", "rank of the factorization") - /** Param for number of user blocks. */ - val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks", Some(10)) - def getNumUserBlocks: Int = get(numUserBlocks) + /** @group getParam */ + def getRank: Int = getOrDefault(rank) - /** Param for number of item blocks. */ + /** + * Param for number of user blocks. + * @group param + */ + val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks") + + /** @group getParam */ + def getNumUserBlocks: Int = getOrDefault(numUserBlocks) + + /** + * Param for number of item blocks. + * @group param + */ val numItemBlocks = - new IntParam(this, "numItemBlocks", "number of item blocks", Some(10)) - def getNumItemBlocks: Int = get(numItemBlocks) + new IntParam(this, "numItemBlocks", "number of item blocks") - /** Param to decide whether to use implicit preference. */ - val implicitPrefs = - new BooleanParam(this, "implicitPrefs", "whether to use implicit preference", Some(false)) - def getImplicitPrefs: Boolean = get(implicitPrefs) + /** @group getParam */ + def getNumItemBlocks: Int = getOrDefault(numItemBlocks) - /** Param for the alpha parameter in the implicit preference formulation. */ - val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference", Some(1.0)) - def getAlpha: Double = get(alpha) + /** + * Param to decide whether to use implicit preference. + * @group param + */ + val implicitPrefs = new BooleanParam(this, "implicitPrefs", "whether to use implicit preference") - /** Param for the column name for user ids. */ - val userCol = new Param[String](this, "userCol", "column name for user ids", Some("user")) - def getUserCol: String = get(userCol) + /** @group getParam */ + def getImplicitPrefs: Boolean = getOrDefault(implicitPrefs) - /** Param for the column name for item ids. */ - val itemCol = - new Param[String](this, "itemCol", "column name for item ids", Some("item")) - def getItemCol: String = get(itemCol) + /** + * Param for the alpha parameter in the implicit preference formulation. + * @group param + */ + val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference") - /** Param for the column name for ratings. */ - val ratingCol = new Param[String](this, "ratingCol", "column name for ratings", Some("rating")) - def getRatingCol: String = get(ratingCol) + /** @group getParam */ + def getAlpha: Double = getOrDefault(alpha) + + /** + * Param for the column name for user ids. + * @group param + */ + val userCol = new Param[String](this, "userCol", "column name for user ids") + + /** @group getParam */ + def getUserCol: String = getOrDefault(userCol) + + /** + * Param for the column name for item ids. + * @group param + */ + val itemCol = new Param[String](this, "itemCol", "column name for item ids") + + /** @group getParam */ + def getItemCol: String = getOrDefault(itemCol) + + /** + * Param for the column name for ratings. + * @group param + */ + val ratingCol = new Param[String](this, "ratingCol", "column name for ratings") + + /** @group getParam */ + def getRatingCol: String = getOrDefault(ratingCol) + + /** + * Param for whether to apply nonnegativity constraints. + * @group param + */ + val nonnegative = new BooleanParam( + this, "nonnegative", "whether to use nonnegative constraint for least squares") + + /** @group getParam */ + def getNonnegative: Boolean = getOrDefault(nonnegative) + + setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10, + implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item", + ratingCol -> "rating", nonnegative -> false) /** * Validates and transforms the input schema. @@ -86,7 +145,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR * @return output schema */ protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) assert(schema(map(userCol)).dataType == IntegerType) assert(schema(map(itemCol)).dataType== IntegerType) val ratingType = schema(map(ratingCol)).dataType @@ -110,49 +169,35 @@ class ALSModel private[ml] ( itemFactors: RDD[(Int, Array[Float])]) extends Model[ALSModel] with ALSParams { + /** @group setParam */ def setPredictionCol(value: String): this.type = set(predictionCol, value) - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { - import dataset.sqlContext._ - import org.apache.spark.ml.recommendation.ALSModel.Factor - val map = this.paramMap ++ paramMap - // TODO: Add DSL to simplify the code here. - val instanceTable = s"instance_$uid" - val userTable = s"user_$uid" - val itemTable = s"item_$uid" - val instances = dataset.as(Symbol(instanceTable)) - val users = userFactors.map { case (id, features) => - Factor(id, features) - }.as(Symbol(userTable)) - val items = itemFactors.map { case (id, features) => - Factor(id, features) - }.as(Symbol(itemTable)) - val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, itemFeatures) => { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + import dataset.sqlContext.implicits._ + val map = extractParamMap(paramMap) + val users = userFactors.toDF("id", "features") + val items = itemFactors.toDF("id", "features") + + // Register a UDF for DataFrame, and then + // create a new column named map(predictionCol) by running the predict UDF. + val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => if (userFeatures != null && itemFeatures != null) { blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1) } else { Float.NaN } } - val inputColumns = dataset.schema.fieldNames - val prediction = - predict.call(s"$userTable.features".attr, s"$itemTable.features".attr) as map(predictionCol) - val outputColumns = inputColumns.map(f => s"$instanceTable.$f".attr as f) :+ prediction - instances - .join(users, LeftOuter, Some(map(userCol).attr === s"$userTable.id".attr)) - .join(items, LeftOuter, Some(map(itemCol).attr === s"$itemTable.id".attr)) - .select(outputColumns: _*) + dataset + .join(users, dataset(map(userCol)) === users("id"), "left") + .join(items, dataset(map(itemCol)) === items("id"), "left") + .select(dataset("*"), predict(users("features"), items("features")).as(map(predictionCol))) } - override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { validateAndTransformSchema(schema, paramMap) } } -private object ALSModel { - /** Case class to convert factors to SchemaRDDs */ - private case class Factor(id: Int, features: Seq[Float]) -} /** * Alternating Least Squares (ALS) matrix factorization. @@ -187,19 +232,49 @@ class ALS extends Estimator[ALSModel] with ALSParams { import org.apache.spark.ml.recommendation.ALS.Rating + /** @group setParam */ def setRank(value: Int): this.type = set(rank, value) + + /** @group setParam */ def setNumUserBlocks(value: Int): this.type = set(numUserBlocks, value) + + /** @group setParam */ def setNumItemBlocks(value: Int): this.type = set(numItemBlocks, value) + + /** @group setParam */ def setImplicitPrefs(value: Boolean): this.type = set(implicitPrefs, value) + + /** @group setParam */ def setAlpha(value: Double): this.type = set(alpha, value) + + /** @group setParam */ def setUserCol(value: String): this.type = set(userCol, value) + + /** @group setParam */ def setItemCol(value: String): this.type = set(itemCol, value) + + /** @group setParam */ def setRatingCol(value: String): this.type = set(ratingCol, value) + + /** @group setParam */ def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ def setRegParam(value: Double): this.type = set(regParam, value) - /** Sets both numUserBlocks and numItemBlocks to the specific value. */ + /** @group setParam */ + def setNonnegative(value: Boolean): this.type = set(nonnegative, value) + + /** @group setParam */ + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + + /** + * Sets both numUserBlocks and numItemBlocks to the specific value. + * @group setParam + */ def setNumBlocks(value: Int): this.type = { setNumUserBlocks(value) setNumItemBlocks(value) @@ -208,60 +283,75 @@ class ALS extends Estimator[ALSModel] with ALSParams { setMaxIter(20) setRegParam(1.0) - - override def fit(dataset: SchemaRDD, paramMap: ParamMap): ALSModel = { - import dataset.sqlContext._ - val map = this.paramMap ++ paramMap - val ratings = - dataset.select(map(userCol).attr, map(itemCol).attr, Cast(map(ratingCol).attr, FloatType)) - .map { row => - new Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) - } + setCheckpointInterval(10) + + override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = { + val map = extractParamMap(paramMap) + val ratings = dataset + .select(col(map(userCol)), col(map(itemCol)), col(map(ratingCol)).cast(FloatType)) + .map { row => + Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) + } val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank), numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks), maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs), - alpha = map(alpha)) + alpha = map(alpha), nonnegative = map(nonnegative), + checkpointInterval = map(checkpointInterval)) val model = new ALSModel(this, map, map(rank), userFactors, itemFactors) Params.inheritValues(map, this, model) model } - override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { validateAndTransformSchema(schema, paramMap) } } -private[recommendation] object ALS extends Logging { +/** + * :: DeveloperApi :: + * An implementation of ALS that supports generic ID types, specialized for Int and Long. This is + * exposed as a developer API for users who do need other ID types. But it is not recommended + * because it increases the shuffle size and memory requirement during training. For simplicity, + * users and items must have the same type. The number of distinct users/items should be smaller + * than 2 billion. + */ +@DeveloperApi +object ALS extends Logging { /** Rating class for better code readability. */ - private[recommendation] case class Rating(user: Int, item: Int, rating: Float) + case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float) + + /** Trait for least squares solvers applied to the normal equation. */ + private[recommendation] trait LeastSquaresNESolver extends Serializable { + /** Solves a least squares problem with regularization (possibly with other constraints). */ + def solve(ne: NormalEquation, lambda: Double): Array[Float] + } /** Cholesky solver for least square problems. */ - private[recommendation] class CholeskySolver { + private[recommendation] class CholeskySolver extends LeastSquaresNESolver { private val upper = "U" - private val info = new intW(0) /** * Solves a least squares problem with L2 regularization: * - * min norm(A x - b)^2^ + lambda * n * norm(x)^2^ + * min norm(A x - b)^2^ + lambda * norm(x)^2^ * * @param ne a [[NormalEquation]] instance that contains AtA, Atb, and n (number of instances) - * @param lambda regularization constant, which will be scaled by n + * @param lambda regularization constant * @return the solution x */ - def solve(ne: NormalEquation, lambda: Double): Array[Float] = { + override def solve(ne: NormalEquation, lambda: Double): Array[Float] = { val k = ne.k // Add scaled lambda to the diagonals of AtA. - val scaledlambda = lambda * ne.n var i = 0 var j = 2 while (i < ne.triK) { - ne.ata(i) += scaledlambda + ne.ata(i) += lambda i += j j += 1 } + val info = new intW(0) lapack.dppsv(upper, k, 1, ne.ata, ne.atb, k, info) val code = info.`val` assert(code == 0, s"lapack.dppsv returned $code.") @@ -276,7 +366,71 @@ private[recommendation] object ALS extends Logging { } } - /** Representing a normal equation (ALS' subproblem). */ + /** NNLS solver. */ + private[recommendation] class NNLSSolver extends LeastSquaresNESolver { + private var rank: Int = -1 + private var workspace: NNLS.Workspace = _ + private var ata: Array[Double] = _ + private var initialized: Boolean = false + + private def initialize(rank: Int): Unit = { + if (!initialized) { + this.rank = rank + workspace = NNLS.createWorkspace(rank) + ata = new Array[Double](rank * rank) + initialized = true + } else { + require(this.rank == rank) + } + } + + /** + * Solves a nonnegative least squares problem with L2 regularizatin: + * + * min_x_ norm(A x - b)^2^ + lambda * n * norm(x)^2^ + * subject to x >= 0 + */ + override def solve(ne: NormalEquation, lambda: Double): Array[Float] = { + val rank = ne.k + initialize(rank) + fillAtA(ne.ata, lambda) + val x = NNLS.solve(ata, ne.atb, workspace) + ne.reset() + x.map(x => x.toFloat) + } + + /** + * Given a triangular matrix in the order of fillXtX above, compute the full symmetric square + * matrix that it represents, storing it into destMatrix. + */ + private def fillAtA(triAtA: Array[Double], lambda: Double) { + var i = 0 + var pos = 0 + var a = 0.0 + while (i < rank) { + var j = 0 + while (j <= i) { + a = triAtA(pos) + ata(i * rank + j) = a + ata(j * rank + i) = a + pos += 1 + j += 1 + } + ata(i * rank + i) += lambda + i += 1 + } + } + } + + /** + * Representing a normal equation to solve the following weighted least squares problem: + * + * minimize \sum,,i,, c,,i,, (a,,i,,^T^ x - b,,i,,)^2^ + lambda * x^T^ x. + * + * Its normal equation is given by + * + * \sum,,i,, c,,i,, (a,,i,, a,,i,,^T^ x - b,,i,, a,,i,,) + lambda * x = 0. + */ private[recommendation] class NormalEquation(val k: Int) extends Serializable { /** Number of entries in the upper triangular part of a k-by-k matrix. */ @@ -285,8 +439,6 @@ private[recommendation] object ALS extends Logging { val ata = new Array[Double](triK) /** A^T^ * b */ val atb = new Array[Double](k) - /** Number of observations. */ - var n = 0 private val da = new Array[Double](k) private val upper = "U" @@ -300,28 +452,13 @@ private[recommendation] object ALS extends Logging { } /** Adds an observation. */ - def add(a: Array[Float], b: Float): this.type = { - require(a.size == k) - copyToDouble(a) - blas.dspr(upper, k, 1.0, da, 1, ata) - blas.daxpy(k, b.toDouble, da, 1, atb, 1) - n += 1 - this - } - - /** - * Adds an observation with implicit feedback. Note that this does not increment the counter. - */ - def addImplicit(a: Array[Float], b: Float, alpha: Double): this.type = { - require(a.size == k) - // Extension to the original paper to handle b < 0. confidence is a function of |b| instead - // so that it is never negative. - val confidence = 1.0 + alpha * math.abs(b) + def add(a: Array[Float], b: Double, c: Double = 1.0): this.type = { + require(c >= 0.0) + require(a.length == k) copyToDouble(a) - blas.dspr(upper, k, confidence - 1.0, da, 1, ata) - // For b <= 0, the corresponding preference is 0. So the term below is only added for b > 0. - if (b > 0) { - blas.daxpy(k, confidence, da, 1, atb, 1) + blas.dspr(upper, k, c, da, 1, ata) + if (b != 0.0) { + blas.daxpy(k, c * b, da, 1, atb, 1) } this } @@ -329,9 +466,8 @@ private[recommendation] object ALS extends Logging { /** Merges another normal equation object. */ def merge(other: NormalEquation): this.type = { require(other.k == k) - blas.daxpy(ata.size, 1.0, other.ata, 1, ata, 1) - blas.daxpy(atb.size, 1.0, other.atb, 1, atb, 1) - n += other.n + blas.daxpy(ata.length, 1.0, other.ata, 1, ata, 1) + blas.daxpy(atb.length, 1.0, other.atb, 1, atb, 1) this } @@ -339,87 +475,132 @@ private[recommendation] object ALS extends Logging { def reset(): Unit = { ju.Arrays.fill(ata, 0.0) ju.Arrays.fill(atb, 0.0) - n = 0 } } /** * Implementation of the ALS algorithm. */ - private def train( - ratings: RDD[Rating], + def train[ID: ClassTag]( // scalastyle:ignore + ratings: RDD[Rating[ID]], rank: Int = 10, numUserBlocks: Int = 10, numItemBlocks: Int = 10, maxIter: Int = 10, regParam: Double = 1.0, implicitPrefs: Boolean = false, - alpha: Double = 1.0): (RDD[(Int, Array[Float])], RDD[(Int, Array[Float])]) = { - val userPart = new HashPartitioner(numUserBlocks) - val itemPart = new HashPartitioner(numItemBlocks) + alpha: Double = 1.0, + nonnegative: Boolean = false, + intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK, + finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK, + checkpointInterval: Int = 10, + seed: Long = 0L)( + implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = { + require(intermediateRDDStorageLevel != StorageLevel.NONE, + "ALS is not designed to run without persisting intermediate RDDs.") + val sc = ratings.sparkContext + val userPart = new ALSPartitioner(numUserBlocks) + val itemPart = new ALSPartitioner(numItemBlocks) val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions) val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions) - val blockRatings = partitionRatings(ratings, userPart, itemPart).cache() - val (userInBlocks, userOutBlocks) = makeBlocks("user", blockRatings, userPart, itemPart) + val solver = if (nonnegative) new NNLSSolver else new CholeskySolver + val blockRatings = partitionRatings(ratings, userPart, itemPart) + .persist(intermediateRDDStorageLevel) + val (userInBlocks, userOutBlocks) = + makeBlocks("user", blockRatings, userPart, itemPart, intermediateRDDStorageLevel) // materialize blockRatings and user blocks userOutBlocks.count() val swappedBlockRatings = blockRatings.map { case ((userBlockId, itemBlockId), RatingBlock(userIds, itemIds, localRatings)) => ((itemBlockId, userBlockId), RatingBlock(itemIds, userIds, localRatings)) } - val (itemInBlocks, itemOutBlocks) = makeBlocks("item", swappedBlockRatings, itemPart, userPart) + val (itemInBlocks, itemOutBlocks) = + makeBlocks("item", swappedBlockRatings, itemPart, userPart, intermediateRDDStorageLevel) // materialize item blocks itemOutBlocks.count() - var userFactors = initialize(userInBlocks, rank) - var itemFactors = initialize(itemInBlocks, rank) + val seedGen = new XORShiftRandom(seed) + var userFactors = initialize(userInBlocks, rank, seedGen.nextLong()) + var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong()) + var previousCheckpointFile: Option[String] = None + val shouldCheckpoint: Int => Boolean = (iter) => + sc.checkpointDir.isDefined && (iter % checkpointInterval == 0) + val deletePreviousCheckpointFile: () => Unit = () => + previousCheckpointFile.foreach { file => + try { + FileSystem.get(sc.hadoopConfiguration).delete(new Path(file), true) + } catch { + case e: IOException => + logWarning(s"Cannot delete checkpoint file $file:", e) + } + } if (implicitPrefs) { for (iter <- 1 to maxIter) { - userFactors.setName(s"userFactors-$iter").persist() + userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel) val previousItemFactors = itemFactors itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam, - userLocalIndexEncoder, implicitPrefs, alpha) + userLocalIndexEncoder, implicitPrefs, alpha, solver) previousItemFactors.unpersist() - itemFactors.setName(s"itemFactors-$iter").persist() + itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel) + // TODO: Generalize PeriodicGraphCheckpointer and use it here. + if (shouldCheckpoint(iter)) { + itemFactors.checkpoint() // itemFactors gets materialized in computeFactors. + } val previousUserFactors = userFactors userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam, - itemLocalIndexEncoder, implicitPrefs, alpha) + itemLocalIndexEncoder, implicitPrefs, alpha, solver) + if (shouldCheckpoint(iter)) { + deletePreviousCheckpointFile() + previousCheckpointFile = itemFactors.getCheckpointFile + } previousUserFactors.unpersist() } } else { for (iter <- 0 until maxIter) { itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam, - userLocalIndexEncoder) + userLocalIndexEncoder, solver = solver) + if (shouldCheckpoint(iter)) { + itemFactors.checkpoint() + itemFactors.count() // checkpoint item factors and cut lineage + deletePreviousCheckpointFile() + previousCheckpointFile = itemFactors.getCheckpointFile + } userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam, - itemLocalIndexEncoder) + itemLocalIndexEncoder, solver = solver) } } val userIdAndFactors = userInBlocks .mapValues(_.srcIds) .join(userFactors) - .values + .mapPartitions({ items => + items.flatMap { case (_, (ids, factors)) => + ids.view.zip(factors) + } + // Preserve the partitioning because IDs are consistent with the partitioners in userInBlocks + // and userFactors. + }, preservesPartitioning = true) .setName("userFactors") - .cache() - userIdAndFactors.count() - itemFactors.unpersist() + .persist(finalRDDStorageLevel) val itemIdAndFactors = itemInBlocks .mapValues(_.srcIds) .join(itemFactors) - .values + .mapPartitions({ items => + items.flatMap { case (_, (ids, factors)) => + ids.view.zip(factors) + } + }, preservesPartitioning = true) .setName("itemFactors") - .cache() - itemIdAndFactors.count() - userInBlocks.unpersist() - userOutBlocks.unpersist() - itemInBlocks.unpersist() - itemOutBlocks.unpersist() - blockRatings.unpersist() - val userOutput = userIdAndFactors.flatMap { case (ids, factors) => - ids.view.zip(factors) - } - val itemOutput = itemIdAndFactors.flatMap { case (ids, factors) => - ids.view.zip(factors) - } - (userOutput, itemOutput) + .persist(finalRDDStorageLevel) + if (finalRDDStorageLevel != StorageLevel.NONE) { + userIdAndFactors.count() + itemFactors.unpersist() + itemIdAndFactors.count() + userInBlocks.unpersist() + userOutBlocks.unpersist() + itemInBlocks.unpersist() + itemOutBlocks.unpersist() + blockRatings.unpersist() + } + (userIdAndFactors, itemIdAndFactors) } /** @@ -457,16 +638,15 @@ private[recommendation] object ALS extends Logging { * * @see [[LocalIndexEncoder]] */ - private[recommendation] case class InBlock( - srcIds: Array[Int], + private[recommendation] case class InBlock[@specialized(Int, Long) ID: ClassTag]( + srcIds: Array[ID], dstPtrs: Array[Int], dstEncodedIndices: Array[Int], ratings: Array[Float]) { /** Size of the block. */ - val size: Int = ratings.size - - require(dstEncodedIndices.size == size) - require(dstPtrs.size == srcIds.size + 1) + def size: Int = ratings.length + require(dstEncodedIndices.length == size) + require(dstPtrs.length == srcIds.length + 1) } /** @@ -476,15 +656,18 @@ private[recommendation] object ALS extends Logging { * @param rank rank * @return initialized factor blocks */ - private def initialize(inBlocks: RDD[(Int, InBlock)], rank: Int): RDD[(Int, FactorBlock)] = { + private def initialize[ID]( + inBlocks: RDD[(Int, InBlock[ID])], + rank: Int, + seed: Long): RDD[(Int, FactorBlock)] = { // Choose a unit vector uniformly at random from the unit sphere, but from the // "first quadrant" where all elements are nonnegative. This can be done by choosing // elements distributed as Normal(0,1) and taking the absolute value, and then normalizing. // This appears to create factorizations that have a slightly better reconstruction // (<1%) compared picking elements uniformly at random in [0,1]. inBlocks.map { case (srcBlockId, inBlock) => - val random = new XORShiftRandom(srcBlockId) - val factors = Array.fill(inBlock.srcIds.size) { + val random = new XORShiftRandom(byteswap64(seed ^ srcBlockId)) + val factors = Array.fill(inBlock.srcIds.length) { val factor = Array.fill(rank)(random.nextGaussian().toFloat) val nrm = blas.snrm2(rank, factor, 1) blas.sscal(rank, 1.0f / nrm, factor, 1) @@ -497,26 +680,29 @@ private[recommendation] object ALS extends Logging { /** * A rating block that contains src IDs, dst IDs, and ratings, stored in primitive arrays. */ - private[recommendation] - case class RatingBlock(srcIds: Array[Int], dstIds: Array[Int], ratings: Array[Float]) { + private[recommendation] case class RatingBlock[@specialized(Int, Long) ID: ClassTag]( + srcIds: Array[ID], + dstIds: Array[ID], + ratings: Array[Float]) { /** Size of the block. */ - val size: Int = srcIds.size - require(dstIds.size == size) - require(ratings.size == size) + def size: Int = srcIds.length + require(dstIds.length == srcIds.length) + require(ratings.length == srcIds.length) } /** * Builder for [[RatingBlock]]. [[mutable.ArrayBuilder]] is used to avoid boxing/unboxing. */ - private[recommendation] class RatingBlockBuilder extends Serializable { + private[recommendation] class RatingBlockBuilder[@specialized(Int, Long) ID: ClassTag] + extends Serializable { - private val srcIds = mutable.ArrayBuilder.make[Int] - private val dstIds = mutable.ArrayBuilder.make[Int] + private val srcIds = mutable.ArrayBuilder.make[ID] + private val dstIds = mutable.ArrayBuilder.make[ID] private val ratings = mutable.ArrayBuilder.make[Float] var size = 0 /** Adds a rating. */ - def add(r: Rating): this.type = { + def add(r: Rating[ID]): this.type = { size += 1 srcIds += r.user dstIds += r.item @@ -525,8 +711,8 @@ private[recommendation] object ALS extends Logging { } /** Merges another [[RatingBlockBuilder]]. */ - def merge(other: RatingBlock): this.type = { - size += other.srcIds.size + def merge(other: RatingBlock[ID]): this.type = { + size += other.srcIds.length srcIds ++= other.srcIds dstIds ++= other.dstIds ratings ++= other.ratings @@ -534,8 +720,8 @@ private[recommendation] object ALS extends Logging { } /** Builds a [[RatingBlock]]. */ - def build(): RatingBlock = { - RatingBlock(srcIds.result(), dstIds.result(), ratings.result()) + def build(): RatingBlock[ID] = { + RatingBlock[ID](srcIds.result(), dstIds.result(), ratings.result()) } } @@ -548,10 +734,10 @@ private[recommendation] object ALS extends Logging { * * @return an RDD of rating blocks in the form of ((srcBlockId, dstBlockId), ratingBlock) */ - private def partitionRatings( - ratings: RDD[Rating], + private def partitionRatings[ID: ClassTag]( + ratings: RDD[Rating[ID]], srcPart: Partitioner, - dstPart: Partitioner): RDD[((Int, Int), RatingBlock)] = { + dstPart: Partitioner): RDD[((Int, Int), RatingBlock[ID])] = { /* The implementation produces the same result as the following but generates less objects. @@ -565,7 +751,7 @@ private[recommendation] object ALS extends Logging { val numPartitions = srcPart.numPartitions * dstPart.numPartitions ratings.mapPartitions { iter => - val builders = Array.fill(numPartitions)(new RatingBlockBuilder) + val builders = Array.fill(numPartitions)(new RatingBlockBuilder[ID]) iter.flatMap { r => val srcBlockId = srcPart.getPartition(r.user) val dstBlockId = dstPart.getPartition(r.item) @@ -586,7 +772,7 @@ private[recommendation] object ALS extends Logging { } } }.groupByKey().mapValues { blocks => - val builder = new RatingBlockBuilder + val builder = new RatingBlockBuilder[ID] blocks.foreach(builder.merge) builder.build() }.setName("ratingBlocks") @@ -596,9 +782,11 @@ private[recommendation] object ALS extends Logging { * Builder for uncompressed in-blocks of (srcId, dstEncodedIndex, rating) tuples. * @param encoder encoder for dst indices */ - private[recommendation] class UncompressedInBlockBuilder(encoder: LocalIndexEncoder) { + private[recommendation] class UncompressedInBlockBuilder[@specialized(Int, Long) ID: ClassTag]( + encoder: LocalIndexEncoder)( + implicit ord: Ordering[ID]) { - private val srcIds = mutable.ArrayBuilder.make[Int] + private val srcIds = mutable.ArrayBuilder.make[ID] private val dstEncodedIndices = mutable.ArrayBuilder.make[Int] private val ratings = mutable.ArrayBuilder.make[Float] @@ -612,12 +800,12 @@ private[recommendation] object ALS extends Logging { */ def add( dstBlockId: Int, - srcIds: Array[Int], + srcIds: Array[ID], dstLocalIndices: Array[Int], ratings: Array[Float]): this.type = { - val sz = srcIds.size - require(dstLocalIndices.size == sz) - require(ratings.size == sz) + val sz = srcIds.length + require(dstLocalIndices.length == sz) + require(ratings.length == sz) this.srcIds ++= srcIds this.ratings ++= ratings var j = 0 @@ -629,7 +817,7 @@ private[recommendation] object ALS extends Logging { } /** Builds a [[UncompressedInBlock]]. */ - def build(): UncompressedInBlock = { + def build(): UncompressedInBlock[ID] = { new UncompressedInBlock(srcIds.result(), dstEncodedIndices.result(), ratings.result()) } } @@ -637,24 +825,25 @@ private[recommendation] object ALS extends Logging { /** * A block of (srcId, dstEncodedIndex, rating) tuples stored in primitive arrays. */ - private[recommendation] class UncompressedInBlock( - val srcIds: Array[Int], + private[recommendation] class UncompressedInBlock[@specialized(Int, Long) ID: ClassTag]( + val srcIds: Array[ID], val dstEncodedIndices: Array[Int], - val ratings: Array[Float]) { + val ratings: Array[Float])( + implicit ord: Ordering[ID]) { /** Size the of block. */ - def size: Int = srcIds.size + def length: Int = srcIds.length /** * Compresses the block into an [[InBlock]]. The algorithm is the same as converting a * sparse matrix from coordinate list (COO) format into compressed sparse column (CSC) format. * Sorting is done using Spark's built-in Timsort to avoid generating too many objects. */ - def compress(): InBlock = { - val sz = size + def compress(): InBlock[ID] = { + val sz = length assert(sz > 0, "Empty in-link block should not exist.") sort() - val uniqueSrcIdsBuilder = mutable.ArrayBuilder.make[Int] + val uniqueSrcIdsBuilder = mutable.ArrayBuilder.make[ID] val dstCountsBuilder = mutable.ArrayBuilder.make[Int] var preSrcId = srcIds(0) uniqueSrcIdsBuilder += preSrcId @@ -675,7 +864,7 @@ private[recommendation] object ALS extends Logging { } dstCountsBuilder += curCount val uniqueSrcIds = uniqueSrcIdsBuilder.result() - val numUniqueSrdIds = uniqueSrcIds.size + val numUniqueSrdIds = uniqueSrcIds.length val dstCounts = dstCountsBuilder.result() val dstPtrs = new Array[Int](numUniqueSrdIds + 1) var sum = 0 @@ -689,51 +878,61 @@ private[recommendation] object ALS extends Logging { } private def sort(): Unit = { - val sz = size + val sz = length // Since there might be interleaved log messages, we insert a unique id for easy pairing. val sortId = Utils.random.nextInt() logDebug(s"Start sorting an uncompressed in-block of size $sz. (sortId = $sortId)") val start = System.nanoTime() - val sorter = new Sorter(new UncompressedInBlockSort) - sorter.sort(this, 0, size, Ordering[IntWrapper]) + val sorter = new Sorter(new UncompressedInBlockSort[ID]) + sorter.sort(this, 0, length, Ordering[KeyWrapper[ID]]) val duration = (System.nanoTime() - start) / 1e9 logDebug(s"Sorting took $duration seconds. (sortId = $sortId)") } } /** - * A wrapper that holds a primitive integer key. + * A wrapper that holds a primitive key. * * @see [[UncompressedInBlockSort]] */ - private class IntWrapper(var key: Int = 0) extends Ordered[IntWrapper] { - override def compare(that: IntWrapper): Int = { - key.compare(that.key) + private class KeyWrapper[@specialized(Int, Long) ID: ClassTag]( + implicit ord: Ordering[ID]) extends Ordered[KeyWrapper[ID]] { + + var key: ID = _ + + override def compare(that: KeyWrapper[ID]): Int = { + ord.compare(key, that.key) + } + + def setKey(key: ID): this.type = { + this.key = key + this } } /** * [[SortDataFormat]] of [[UncompressedInBlock]] used by [[Sorter]]. */ - private class UncompressedInBlockSort extends SortDataFormat[IntWrapper, UncompressedInBlock] { + private class UncompressedInBlockSort[@specialized(Int, Long) ID: ClassTag]( + implicit ord: Ordering[ID]) + extends SortDataFormat[KeyWrapper[ID], UncompressedInBlock[ID]] { - override def newKey(): IntWrapper = new IntWrapper() + override def newKey(): KeyWrapper[ID] = new KeyWrapper() override def getKey( - data: UncompressedInBlock, + data: UncompressedInBlock[ID], pos: Int, - reuse: IntWrapper): IntWrapper = { + reuse: KeyWrapper[ID]): KeyWrapper[ID] = { if (reuse == null) { - new IntWrapper(data.srcIds(pos)) + new KeyWrapper().setKey(data.srcIds(pos)) } else { - reuse.key = data.srcIds(pos) - reuse + reuse.setKey(data.srcIds(pos)) } } override def getKey( - data: UncompressedInBlock, - pos: Int): IntWrapper = { + data: UncompressedInBlock[ID], + pos: Int): KeyWrapper[ID] = { getKey(data, pos, null) } @@ -746,16 +945,16 @@ private[recommendation] object ALS extends Logging { data(pos1) = tmp } - override def swap(data: UncompressedInBlock, pos0: Int, pos1: Int): Unit = { + override def swap(data: UncompressedInBlock[ID], pos0: Int, pos1: Int): Unit = { swapElements(data.srcIds, pos0, pos1) swapElements(data.dstEncodedIndices, pos0, pos1) swapElements(data.ratings, pos0, pos1) } override def copyRange( - src: UncompressedInBlock, + src: UncompressedInBlock[ID], srcPos: Int, - dst: UncompressedInBlock, + dst: UncompressedInBlock[ID], dstPos: Int, length: Int): Unit = { System.arraycopy(src.srcIds, srcPos, dst.srcIds, dstPos, length) @@ -763,15 +962,15 @@ private[recommendation] object ALS extends Logging { System.arraycopy(src.ratings, srcPos, dst.ratings, dstPos, length) } - override def allocate(length: Int): UncompressedInBlock = { + override def allocate(length: Int): UncompressedInBlock[ID] = { new UncompressedInBlock( - new Array[Int](length), new Array[Int](length), new Array[Float](length)) + new Array[ID](length), new Array[Int](length), new Array[Float](length)) } override def copyElement( - src: UncompressedInBlock, + src: UncompressedInBlock[ID], srcPos: Int, - dst: UncompressedInBlock, + dst: UncompressedInBlock[ID], dstPos: Int): Unit = { dst.srcIds(dstPos) = src.srcIds(srcPos) dst.dstEncodedIndices(dstPos) = src.dstEncodedIndices(srcPos) @@ -787,19 +986,21 @@ private[recommendation] object ALS extends Logging { * @param dstPart partitioner for dst IDs * @return (in-blocks, out-blocks) */ - private def makeBlocks( + private def makeBlocks[ID: ClassTag]( prefix: String, - ratingBlocks: RDD[((Int, Int), RatingBlock)], + ratingBlocks: RDD[((Int, Int), RatingBlock[ID])], srcPart: Partitioner, - dstPart: Partitioner): (RDD[(Int, InBlock)], RDD[(Int, OutBlock)]) = { + dstPart: Partitioner, + storageLevel: StorageLevel)( + implicit srcOrd: Ordering[ID]): (RDD[(Int, InBlock[ID])], RDD[(Int, OutBlock)]) = { val inBlocks = ratingBlocks.map { case ((srcBlockId, dstBlockId), RatingBlock(srcIds, dstIds, ratings)) => // The implementation is a faster version of // val dstIdToLocalIndex = dstIds.toSet.toSeq.sorted.zipWithIndex.toMap val start = System.nanoTime() - val dstIdSet = new OpenHashSet[Int](1 << 20) + val dstIdSet = new OpenHashSet[ID](1 << 20) dstIds.foreach(dstIdSet.add) - val sortedDstIds = new Array[Int](dstIdSet.size) + val sortedDstIds = new Array[ID](dstIdSet.size) var i = 0 var pos = dstIdSet.nextPos(0) while (pos != -1) { @@ -808,10 +1009,10 @@ private[recommendation] object ALS extends Logging { i += 1 } assert(i == dstIdSet.size) - ju.Arrays.sort(sortedDstIds) - val dstIdToLocalIndex = new OpenHashMap[Int, Int](sortedDstIds.size) + Sorting.quickSort(sortedDstIds) + val dstIdToLocalIndex = new OpenHashMap[ID, Int](sortedDstIds.length) i = 0 - while (i < sortedDstIds.size) { + while (i < sortedDstIds.length) { dstIdToLocalIndex.update(sortedDstIds(i), i) i += 1 } @@ -819,21 +1020,22 @@ private[recommendation] object ALS extends Logging { "Converting to local indices took " + (System.nanoTime() - start) / 1e9 + " seconds.") val dstLocalIndices = dstIds.map(dstIdToLocalIndex.apply) (srcBlockId, (dstBlockId, srcIds, dstLocalIndices, ratings)) - }.groupByKey(new HashPartitioner(srcPart.numPartitions)) - .mapValues { iter => - val builder = - new UncompressedInBlockBuilder(new LocalIndexEncoder(dstPart.numPartitions)) - iter.foreach { case (dstBlockId, srcIds, dstLocalIndices, ratings) => - builder.add(dstBlockId, srcIds, dstLocalIndices, ratings) - } - builder.build().compress() - }.setName(prefix + "InBlocks").cache() + }.groupByKey(new ALSPartitioner(srcPart.numPartitions)) + .mapValues { iter => + val builder = + new UncompressedInBlockBuilder[ID](new LocalIndexEncoder(dstPart.numPartitions)) + iter.foreach { case (dstBlockId, srcIds, dstLocalIndices, ratings) => + builder.add(dstBlockId, srcIds, dstLocalIndices, ratings) + } + builder.build().compress() + }.setName(prefix + "InBlocks") + .persist(storageLevel) val outBlocks = inBlocks.mapValues { case InBlock(srcIds, dstPtrs, dstEncodedIndices, _) => val encoder = new LocalIndexEncoder(dstPart.numPartitions) val activeIds = Array.fill(dstPart.numPartitions)(mutable.ArrayBuilder.make[Int]) var i = 0 val seen = new Array[Boolean](dstPart.numPartitions) - while (i < srcIds.size) { + while (i < srcIds.length) { var j = dstPtrs(i) ju.Arrays.fill(seen, false) while (j < dstPtrs(i + 1)) { @@ -849,7 +1051,8 @@ private[recommendation] object ALS extends Logging { activeIds.map { x => x.result() } - }.setName(prefix + "OutBlocks").cache() + }.setName(prefix + "OutBlocks") + .persist(storageLevel) (inBlocks, outBlocks) } @@ -864,19 +1067,21 @@ private[recommendation] object ALS extends Logging { * @param srcEncoder encoder for src local indices * @param implicitPrefs whether to use implicit preference * @param alpha the alpha constant in the implicit preference formulation + * @param solver solver for least squares problems * * @return dst factors */ - private def computeFactors( + private def computeFactors[ID]( srcFactorBlocks: RDD[(Int, FactorBlock)], srcOutBlocks: RDD[(Int, OutBlock)], - dstInBlocks: RDD[(Int, InBlock)], + dstInBlocks: RDD[(Int, InBlock[ID])], rank: Int, regParam: Double, srcEncoder: LocalIndexEncoder, implicitPrefs: Boolean = false, - alpha: Double = 1.0): RDD[(Int, FactorBlock)] = { - val numSrcBlocks = srcFactorBlocks.partitions.size + alpha: Double = 1.0, + solver: LeastSquaresNESolver): RDD[(Int, FactorBlock)] = { + val numSrcBlocks = srcFactorBlocks.partitions.length val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap { case (srcBlockId, (srcOutBlock, srcFactors)) => @@ -884,23 +1089,23 @@ private[recommendation] object ALS extends Logging { (dstBlockId, (srcBlockId, activeIndices.map(idx => srcFactors(idx)))) } } - val merged = srcOut.groupByKey(new HashPartitioner(dstInBlocks.partitions.size)) + val merged = srcOut.groupByKey(new ALSPartitioner(dstInBlocks.partitions.length)) dstInBlocks.join(merged).mapValues { case (InBlock(dstIds, srcPtrs, srcEncodedIndices, ratings), srcFactors) => val sortedSrcFactors = new Array[FactorBlock](numSrcBlocks) srcFactors.foreach { case (srcBlockId, factors) => sortedSrcFactors(srcBlockId) = factors } - val dstFactors = new Array[Array[Float]](dstIds.size) + val dstFactors = new Array[Array[Float]](dstIds.length) var j = 0 val ls = new NormalEquation(rank) - val solver = new CholeskySolver // TODO: add NNLS solver - while (j < dstIds.size) { + while (j < dstIds.length) { ls.reset() if (implicitPrefs) { ls.merge(YtY.get) } var i = srcPtrs(j) + var numExplicits = 0 while (i < srcPtrs(j + 1)) { val encoded = srcEncodedIndices(i) val blockId = srcEncoder.blockId(encoded) @@ -908,13 +1113,23 @@ private[recommendation] object ALS extends Logging { val srcFactor = sortedSrcFactors(blockId)(localIndex) val rating = ratings(i) if (implicitPrefs) { - ls.addImplicit(srcFactor, rating, alpha) + // Extension to the original paper to handle b < 0. confidence is a function of |b| + // instead so that it is never negative. c1 is confidence - 1.0. + val c1 = alpha * math.abs(rating) + // For rating <= 0, the corresponding preference is 0. So the term below is only added + // for rating > 0. Because YtY is already added, we need to adjust the scaling here. + if (rating > 0) { + numExplicits += 1 + ls.add(srcFactor, (c1 + 1.0) / c1, c1) + } } else { ls.add(srcFactor, rating) + numExplicits += 1 } i += 1 } - dstFactors(j) = solver.solve(ls, regParam) + // Weight lambda by the number of explicit ratings based on the ALS-WR paper. + dstFactors(j) = solver.solve(ls, numExplicits * regParam) j += 1 } dstFactors @@ -928,7 +1143,7 @@ private[recommendation] object ALS extends Logging { private def computeYtY(factorBlocks: RDD[(Int, FactorBlock)], rank: Int): NormalEquation = { factorBlocks.values.aggregate(new NormalEquation(rank))( seqOp = (ne, factors) => { - factors.foreach(ne.add(_, 0.0f)) + factors.foreach(ne.add(_, 0.0)) ne }, combOp = (ne1, ne2) => ne1.merge(ne2)) @@ -970,4 +1185,11 @@ private[recommendation] object ALS extends Logging { encoded & localIndexMask } } + + /** + * Partitioner used by ALS. We requires that getPartition is a projection. That is, for any key k, + * we have getPartition(getPartition(k)) = getPartition(k). Since the the default HashPartitioner + * satisfies this requirement, we simply use a type alias here. + */ + private[recommendation] type ALSPartitioner = org.apache.spark.HashPartitioner } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala new file mode 100644 index 000000000000..49a8b77acf96 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} +import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.param.{Params, ParamMap} +import org.apache.spark.ml.tree.{DecisionTreeModel, Node} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm + * for regression. + * It supports both continuous and categorical features. + */ +@AlphaComponent +final class DecisionTreeRegressor + extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] + with DecisionTreeParams + with TreeRegressorParams { + + // Override parameter setters from parent trait for Java API compatibility. + + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + + override def setCheckpointInterval(value: Int): this.type = + super.setCheckpointInterval(value) + + override def setImpurity(value: String): this.type = super.setImpurity(value) + + override protected def train( + dataset: DataFrame, + paramMap: ParamMap): DecisionTreeRegressionModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol))) + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) + val strategy = getOldStrategy(categoricalFeatures) + val oldModel = OldDecisionTree.train(oldDataset, strategy) + DecisionTreeRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures) + } + + /** (private[ml]) Create a Strategy instance to use with the old API. */ + private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = { + val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0) + strategy.algo = OldAlgo.Regression + strategy.setImpurity(getOldImpurity) + strategy + } +} + +object DecisionTreeRegressor { + /** Accessor for supported impurities */ + final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities +} + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for regression. + * It supports both continuous and categorical features. + * @param rootNode Root of the decision tree + */ +@AlphaComponent +final class DecisionTreeRegressionModel private[ml] ( + override val parent: DecisionTreeRegressor, + override val fittingParamMap: ParamMap, + override val rootNode: Node) + extends PredictionModel[Vector, DecisionTreeRegressionModel] + with DecisionTreeModel with Serializable { + + require(rootNode != null, + "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.") + + override protected def predict(features: Vector): Double = { + rootNode.predict(features) + } + + override protected def copy(): DecisionTreeRegressionModel = { + val m = new DecisionTreeRegressionModel(parent, fittingParamMap, rootNode) + Params.inheritValues(this.extractParamMap(), this, m) + m + } + + override def toString: String = { + s"DecisionTreeRegressionModel of depth $depth with $numNodes nodes" + } + + /** Convert to a model in the old API */ + private[ml] def toOld: OldDecisionTreeModel = { + new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression) + } +} + +private[ml] object DecisionTreeRegressionModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldDecisionTreeModel, + parent: DecisionTreeRegressor, + fittingParamMap: ParamMap, + categoricalFeatures: Map[Int, Int]): DecisionTreeRegressionModel = { + require(oldModel.algo == OldAlgo.Regression, + s"Cannot convert non-regression DecisionTreeModel (old API) to" + + s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}") + val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) + new DecisionTreeRegressionModel(parent, fittingParamMap, rootNode) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala new file mode 100644 index 000000000000..26ca7459c4fd --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param.{Params, ParamMap} +import org.apache.spark.ml.param.shared._ +import org.apache.spark.mllib.linalg.{BLAS, Vector} +import org.apache.spark.mllib.regression.LinearRegressionWithSGD +import org.apache.spark.sql.DataFrame +import org.apache.spark.storage.StorageLevel + + +/** + * Params for linear regression. + */ +private[regression] trait LinearRegressionParams extends RegressorParams + with HasRegParam with HasMaxIter + + +/** + * :: AlphaComponent :: + * + * Linear regression. + */ +@AlphaComponent +class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel] + with LinearRegressionParams { + + setDefault(regParam -> 0.1, maxIter -> 100) + + /** @group setParam */ + def setRegParam(value: Double): this.type = set(regParam, value) + + /** @group setParam */ + def setMaxIter(value: Int): this.type = set(maxIter, value) + + override protected def train(dataset: DataFrame, paramMap: ParamMap): LinearRegressionModel = { + // Extract columns from data. If dataset is persisted, do not persist oldDataset. + val oldDataset = extractLabeledPoints(dataset, paramMap) + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) { + oldDataset.persist(StorageLevel.MEMORY_AND_DISK) + } + + // Train model + val lr = new LinearRegressionWithSGD() + lr.optimizer + .setRegParam(paramMap(regParam)) + .setNumIterations(paramMap(maxIter)) + val model = lr.run(oldDataset) + val lrm = new LinearRegressionModel(this, paramMap, model.weights, model.intercept) + + if (handlePersistence) { + oldDataset.unpersist() + } + lrm + } +} + +/** + * :: AlphaComponent :: + * + * Model produced by [[LinearRegression]]. + */ +@AlphaComponent +class LinearRegressionModel private[ml] ( + override val parent: LinearRegression, + override val fittingParamMap: ParamMap, + val weights: Vector, + val intercept: Double) + extends RegressionModel[Vector, LinearRegressionModel] + with LinearRegressionParams { + + override protected def predict(features: Vector): Double = { + BLAS.dot(features, weights) + intercept + } + + override protected def copy(): LinearRegressionModel = { + val m = new LinearRegressionModel(parent, fittingParamMap, weights, intercept) + Params.inheritValues(extractParamMap(), this, m) + m + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala new file mode 100644 index 000000000000..d679085eeafe --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.annotation.{DeveloperApi, AlphaComponent} +import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams} + +/** + * :: DeveloperApi :: + * Params for regression. + * Currently empty, but may add functionality later. + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@DeveloperApi +private[spark] trait RegressorParams extends PredictorParams + +/** + * :: AlphaComponent :: + * + * Single-label regression + * + * @tparam FeaturesType Type of input features. E.g., [[org.apache.spark.mllib.linalg.Vector]] + * @tparam Learner Concrete Estimator type + * @tparam M Concrete Model type + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] abstract class Regressor[ + FeaturesType, + Learner <: Regressor[FeaturesType, Learner, M], + M <: RegressionModel[FeaturesType, M]] + extends Predictor[FeaturesType, Learner, M] + with RegressorParams { + + // TODO: defaultEvaluator (follow-up PR) +} + +/** + * :: AlphaComponent :: + * + * Model produced by a [[Regressor]]. + * + * @tparam FeaturesType Type of input features. E.g., [[org.apache.spark.mllib.linalg.Vector]] + * @tparam M Concrete Model type. + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] abstract class RegressionModel[FeaturesType, M <: RegressionModel[FeaturesType, M]] + extends PredictionModel[FeaturesType, M] with RegressorParams { + + /** + * :: DeveloperApi :: + * + * Predict real-valued label for the given features. + * This internal method is used to implement [[transform()]] and output [[predictionCol]]. + */ + @DeveloperApi + protected def predict(features: FeaturesType): Double + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala new file mode 100644 index 000000000000..d6e2203d9f93 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree + +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats, + Node => OldNode, Predict => OldPredict} + + +/** + * Decision tree node interface. + */ +sealed abstract class Node extends Serializable { + + // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree + // code into the new API and deprecate the old API. + + /** Prediction this node makes (or would make, if it is an internal node) */ + def prediction: Double + + /** Impurity measure at this node (for training data) */ + def impurity: Double + + /** Recursive prediction helper method */ + private[ml] def predict(features: Vector): Double = prediction + + /** + * Get the number of nodes in tree below this node, including leaf nodes. + * E.g., if this is a leaf, returns 0. If both children are leaves, returns 2. + */ + private[tree] def numDescendants: Int + + /** + * Recursive print function. + * @param indentFactor The number of spaces to add to each level of indentation. + */ + private[tree] def subtreeToString(indentFactor: Int = 0): String + + /** + * Get depth of tree from this node. + * E.g.: Depth 0 means this is a leaf node. Depth 1 means 1 internal and 2 leaf nodes. + */ + private[tree] def subtreeDepth: Int + + /** + * Create a copy of this node in the old Node format, recursively creating child nodes as needed. + * @param id Node ID using old format IDs + */ + private[ml] def toOld(id: Int): OldNode +} + +private[ml] object Node { + + /** + * Create a new Node from the old Node format, recursively creating child nodes as needed. + */ + def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node = { + if (oldNode.isLeaf) { + // TODO: Once the implementation has been moved to this API, then include sufficient + // statistics here. + new LeafNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity) + } else { + val gain = if (oldNode.stats.nonEmpty) { + oldNode.stats.get.gain + } else { + 0.0 + } + new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity, + gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures), + rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures), + split = Split.fromOld(oldNode.split.get, categoricalFeatures)) + } + } +} + +/** + * Decision tree leaf node. + * @param prediction Prediction this node makes + * @param impurity Impurity measure at this node (for training data) + */ +final class LeafNode private[ml] ( + override val prediction: Double, + override val impurity: Double) extends Node { + + override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)" + + override private[ml] def predict(features: Vector): Double = prediction + + override private[tree] def numDescendants: Int = 0 + + override private[tree] def subtreeToString(indentFactor: Int = 0): String = { + val prefix: String = " " * indentFactor + prefix + s"Predict: $prediction\n" + } + + override private[tree] def subtreeDepth: Int = 0 + + override private[ml] def toOld(id: Int): OldNode = { + // NOTE: We do NOT store 'prob' in the new API currently. + new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = true, + None, None, None, None) + } +} + +/** + * Internal Decision Tree node. + * @param prediction Prediction this node would make if it were a leaf node + * @param impurity Impurity measure at this node (for training data) + * @param gain Information gain value. + * Values < 0 indicate missing values; this quirk will be removed with future updates. + * @param leftChild Left-hand child node + * @param rightChild Right-hand child node + * @param split Information about the test used to split to the left or right child. + */ +final class InternalNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + val gain: Double, + val leftChild: Node, + val rightChild: Node, + val split: Split) extends Node { + + override def toString: String = { + s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)" + } + + override private[ml] def predict(features: Vector): Double = { + if (split.shouldGoLeft(features)) { + leftChild.predict(features) + } else { + rightChild.predict(features) + } + } + + override private[tree] def numDescendants: Int = { + 2 + leftChild.numDescendants + rightChild.numDescendants + } + + override private[tree] def subtreeToString(indentFactor: Int = 0): String = { + val prefix: String = " " * indentFactor + prefix + s"If (${InternalNode.splitToString(split, left=true)})\n" + + leftChild.subtreeToString(indentFactor + 1) + + prefix + s"Else (${InternalNode.splitToString(split, left=false)})\n" + + rightChild.subtreeToString(indentFactor + 1) + } + + override private[tree] def subtreeDepth: Int = { + 1 + math.max(leftChild.subtreeDepth, rightChild.subtreeDepth) + } + + override private[ml] def toOld(id: Int): OldNode = { + assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API" + + " since the old API does not support deep trees.") + // NOTE: We do NOT store 'prob' in the new API currently. + new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = false, + Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))), + Some(rightChild.toOld(OldNode.rightChildIndex(id))), + Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity, + new OldPredict(leftChild.prediction, prob = 0.0), + new OldPredict(rightChild.prediction, prob = 0.0)))) + } +} + +private object InternalNode { + + /** + * Helper method for [[Node.subtreeToString()]]. + * @param split Split to print + * @param left Indicates whether this is the part of the split going to the left, + * or that going to the right. + */ + private def splitToString(split: Split, left: Boolean): String = { + val featureStr = s"feature ${split.featureIndex}" + split match { + case contSplit: ContinuousSplit => + if (left) { + s"$featureStr <= ${contSplit.threshold}" + } else { + s"$featureStr > ${contSplit.threshold}" + } + case catSplit: CategoricalSplit => + val categoriesStr = catSplit.getLeftCategories.mkString("{", ",", "}") + if (left) { + s"$featureStr in $categoriesStr" + } else { + s"$featureStr not in $categoriesStr" + } + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala new file mode 100644 index 000000000000..708c769087dd --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree + +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType} +import org.apache.spark.mllib.tree.model.{Split => OldSplit} + + +/** + * Interface for a "Split," which specifies a test made at a decision tree node + * to choose the left or right path. + */ +sealed trait Split extends Serializable { + + /** Index of feature which this split tests */ + def featureIndex: Int + + /** Return true (split to left) or false (split to right) */ + private[ml] def shouldGoLeft(features: Vector): Boolean + + /** Convert to old Split format */ + private[tree] def toOld: OldSplit +} + +private[tree] object Split { + + def fromOld(oldSplit: OldSplit, categoricalFeatures: Map[Int, Int]): Split = { + oldSplit.featureType match { + case OldFeatureType.Categorical => + new CategoricalSplit(featureIndex = oldSplit.feature, + leftCategories = oldSplit.categories.toArray, categoricalFeatures(oldSplit.feature)) + case OldFeatureType.Continuous => + new ContinuousSplit(featureIndex = oldSplit.feature, threshold = oldSplit.threshold) + } + } +} + +/** + * Split which tests a categorical feature. + * @param featureIndex Index of the feature to test + * @param leftCategories If the feature value is in this set of categories, then the split goes + * left. Otherwise, it goes right. + * @param numCategories Number of categories for this feature. + */ +final class CategoricalSplit private[ml] ( + override val featureIndex: Int, + leftCategories: Array[Double], + private val numCategories: Int) + extends Split { + + require(leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" + + s" (should be in range [0, $numCategories)): ${leftCategories.mkString(",")}") + + /** + * If true, then "categories" is the set of categories for splitting to the left, and vice versa. + */ + private val isLeft: Boolean = leftCategories.length <= numCategories / 2 + + /** Set of categories determining the splitting rule, along with [[isLeft]]. */ + private val categories: Set[Double] = { + if (isLeft) { + leftCategories.toSet + } else { + setComplement(leftCategories.toSet) + } + } + + override private[ml] def shouldGoLeft(features: Vector): Boolean = { + if (isLeft) { + categories.contains(features(featureIndex)) + } else { + !categories.contains(features(featureIndex)) + } + } + + override def equals(o: Any): Boolean = { + o match { + case other: CategoricalSplit => featureIndex == other.featureIndex && + isLeft == other.isLeft && categories == other.categories + case _ => false + } + } + + override private[tree] def toOld: OldSplit = { + val oldCats = if (isLeft) { + categories + } else { + setComplement(categories) + } + OldSplit(featureIndex, threshold = 0.0, OldFeatureType.Categorical, oldCats.toList) + } + + /** Get sorted categories which split to the left */ + def getLeftCategories: Array[Double] = { + val cats = if (isLeft) categories else setComplement(categories) + cats.toArray.sorted + } + + /** Get sorted categories which split to the right */ + def getRightCategories: Array[Double] = { + val cats = if (isLeft) setComplement(categories) else categories + cats.toArray.sorted + } + + /** [0, numCategories) \ cats */ + private def setComplement(cats: Set[Double]): Set[Double] = { + Range(0, numCategories).map(_.toDouble).filter(cat => !cats.contains(cat)).toSet + } +} + +/** + * Split which tests a continuous feature. + * @param featureIndex Index of the feature to test + * @param threshold If the feature value is <= this threshold, then the split goes left. + * Otherwise, it goes right. + */ +final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double) + extends Split { + + override private[ml] def shouldGoLeft(features: Vector): Boolean = { + features(featureIndex) <= threshold + } + + override def equals(o: Any): Boolean = { + o match { + case other: ContinuousSplit => + featureIndex == other.featureIndex && threshold == other.threshold + case _ => + false + } + } + + override private[tree] def toOld: OldSplit = { + OldSplit(featureIndex, threshold, OldFeatureType.Continuous, List.empty[Double]) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala new file mode 100644 index 000000000000..8e3bc3849dcf --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree + +import org.apache.spark.annotation.AlphaComponent + + +/** + * :: AlphaComponent :: + * + * Abstraction for Decision Tree models. + * + * TODO: Add support for predicting probabilities and raw predictions + */ +@AlphaComponent +trait DecisionTreeModel { + + /** Root of the decision tree */ + def rootNode: Node + + /** Number of nodes in tree, including leaf nodes. */ + def numNodes: Int = { + 1 + rootNode.numDescendants + } + + /** + * Depth of the tree. + * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. + */ + lazy val depth: Int = { + rootNode.subtreeDepth + } + + /** Summary of the model */ + override def toString: String = { + // Implementing classes should generally override this method to be more descriptive. + s"DecisionTreeModel of depth $depth with $numNodes nodes" + } + + /** Full description of model */ + def toDebugString: String = { + val header = toString + "\n" + header + rootNode.subtreeToString(2) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 08fe99176424..4bb4ed813c00 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -24,30 +24,52 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType /** * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ private[ml] trait CrossValidatorParams extends Params { - /** param for the estimator to be cross-validated */ + + /** + * param for the estimator to be cross-validated + * @group param + */ val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection") - def getEstimator: Estimator[_] = get(estimator) - /** param for estimator param maps */ + /** @group getParam */ + def getEstimator: Estimator[_] = getOrDefault(estimator) + + /** + * param for estimator param maps + * @group param + */ val estimatorParamMaps: Param[Array[ParamMap]] = new Param(this, "estimatorParamMaps", "param maps for the estimator") - def getEstimatorParamMaps: Array[ParamMap] = get(estimatorParamMaps) - /** param for the evaluator for selection */ + /** @group getParam */ + def getEstimatorParamMaps: Array[ParamMap] = getOrDefault(estimatorParamMaps) + + /** + * param for the evaluator for selection + * @group param + */ val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection") - def getEvaluator: Evaluator = get(evaluator) - /** param for number of folds for cross validation */ - val numFolds: IntParam = - new IntParam(this, "numFolds", "number of folds for cross validation", Some(3)) - def getNumFolds: Int = get(numFolds) + /** @group getParam */ + def getEvaluator: Evaluator = getOrDefault(evaluator) + + /** + * param for number of folds for cross validation + * @group param + */ + val numFolds: IntParam = new IntParam(this, "numFolds", "number of folds for cross validation") + + /** @group getParam */ + def getNumFolds: Int = getOrDefault(numFolds) + + setDefault(numFolds -> 3) } /** @@ -59,13 +81,20 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP private val f2jBLAS = new F2jBLAS + /** @group setParam */ def setEstimator(value: Estimator[_]): this.type = set(estimator, value) + + /** @group setParam */ def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value) + + /** @group setParam */ def setEvaluator(value: Evaluator): this.type = set(evaluator, value) + + /** @group setParam */ def setNumFolds(value: Int): this.type = set(numFolds, value) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): CrossValidatorModel = { - val map = this.paramMap ++ paramMap + override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = { + val map = extractParamMap(paramMap) val schema = dataset.schema transformSchema(dataset.schema, paramMap, logging = true) val sqlCtx = dataset.sqlContext @@ -74,13 +103,14 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP val epm = map(estimatorParamMaps) val numModels = epm.size val metrics = new Array[Double](epm.size) - val splits = MLUtils.kFold(dataset, map(numFolds), 0) + val splits = MLUtils.kFold(dataset.rdd, map(numFolds), 0) splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => - val trainingDataset = sqlCtx.applySchema(training, schema).cache() - val validationDataset = sqlCtx.applySchema(validation, schema).cache() + val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() + val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() // multi-model training logDebug(s"Train split $splitIndex with multiple sets of parameters.") val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] + trainingDataset.unpersist() var i = 0 while (i < numModels) { val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)), map) @@ -88,6 +118,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP metrics(i) += metric i += 1 } + validationDataset.unpersist() } f2jBLAS.dscal(numModels, 1.0 / map(numFolds), metrics, 1) logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") @@ -100,8 +131,8 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP cvModel } - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = extractParamMap(paramMap) map(estimator).transformSchema(schema, paramMap) } } @@ -117,11 +148,11 @@ class CrossValidatorModel private[ml] ( val bestModel: Model[_]) extends Model[CrossValidatorModel] with CrossValidatorParams { - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { bestModel.transform(dataset, paramMap) } - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { bestModel.transformSchema(schema, paramMap) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala new file mode 100644 index 000000000000..c84c8b4eb744 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import scala.collection.immutable.HashMap + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, BinaryAttribute, NominalAttribute, + NumericAttribute} +import org.apache.spark.sql.types.StructField + + +/** + * :: Experimental :: + * + * Helper utilities for tree-based algorithms + */ +@Experimental +object MetadataUtils { + + /** + * Examine a schema to identify the number of classes in a label column. + * Returns None if the number of labels is not specified, or if the label column is continuous. + */ + def getNumClasses(labelSchema: StructField): Option[Int] = { + Attribute.fromStructField(labelSchema) match { + case numAttr: NumericAttribute => None + case binAttr: BinaryAttribute => Some(2) + case nomAttr: NominalAttribute => nomAttr.getNumValues + } + } + + /** + * Examine a schema to identify categorical (Binary and Nominal) features. + * + * @param featuresSchema Schema of the features column. + * If a feature does not have metadata, it is assumed to be continuous. + * If a feature is Nominal, then it must have the number of values + * specified. + * @return Map: feature index --> number of categories. + * The map's set of keys will be the set of categorical feature indices. + */ + def getCategoricalFeatures(featuresSchema: StructField): Map[Int, Int] = { + val metadata = AttributeGroup.fromStructField(featuresSchema) + if (metadata.attributes.isEmpty) { + HashMap.empty[Int, Int] + } else { + metadata.attributes.get.zipWithIndex.flatMap { case (attr, idx) => + if (attr == null) { + Iterator() + } else { + attr match { + case numAttr: NumericAttribute => Iterator() + case binAttr: BinaryAttribute => Iterator(idx -> 2) + case nomAttr: NominalAttribute => + nomAttr.getNumValues match { + case Some(numValues: Int) => Iterator(idx -> numValues) + case None => throw new IllegalArgumentException(s"Feature $idx is marked as" + + " Nominal (categorical), but it does not have the number of values specified.") + } + } + } + }.toMap + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala new file mode 100644 index 000000000000..0383bf0b382b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.types.{DataType, StructField, StructType} + +/** + * :: DeveloperApi :: + * Utils for handling schemas. + */ +@DeveloperApi +object SchemaUtils { + + // TODO: Move the utility methods to SQL. + + /** + * Check whether the given schema contains a column of the required data type. + * @param colName column name + * @param dataType required column data type + */ + def checkColumnType(schema: StructType, colName: String, dataType: DataType): Unit = { + val actualDataType = schema(colName).dataType + require(actualDataType.equals(dataType), + s"Column $colName must be of type $dataType but was actually $actualDataType.") + } + + /** + * Appends a new column to the input schema. This fails if the given output column already exists. + * @param schema input schema + * @param colName new column name. If this column name is an empty string "", this method returns + * the input schema unchanged. This allows users to disable output columns. + * @param dataType new column data type + * @return new schema with the input column appended + */ + def appendColumn( + schema: StructType, + colName: String, + dataType: DataType): StructType = { + if (colName.isEmpty) return schema + val fieldNames = schema.fieldNames + require(!fieldNames.contains(colName), s"Column $colName already exists.") + val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false) + StructType(outputFields) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala new file mode 100644 index 000000000000..ee933f4cfcaf --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python + +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel} +import org.apache.spark.rdd.RDD + +/** + * A Wrapper of FPGrowthModel to provide helper method for Python + */ +private[python] class FPGrowthModelWrapper(model: FPGrowthModel[Any]) + extends FPGrowthModel(model.freqItemsets) { + + def getFreqItemsets: RDD[Array[Any]] = { + SerDe.fromTuple2RDD(model.freqItemsets.map(x => (x.javaItems, x.freq))) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala new file mode 100644 index 000000000000..534edac56bc5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python + +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.recommendation.{MatrixFactorizationModel, Rating} +import org.apache.spark.rdd.RDD + +/** + * A Wrapper of MatrixFactorizationModel to provide helper method for Python. + */ +private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorizationModel) + extends MatrixFactorizationModel(model.rank, model.userFeatures, model.productFeatures) { + + def predict(userAndProducts: JavaRDD[Array[Any]]): RDD[Rating] = + predict(SerDe.asTupleRDD(userAndProducts.rdd)) + + def getUserFeatures: RDD[Array[Any]] = { + SerDe.fromTuple2RDD(userFeatures.map { + case (user, feature) => (user, Vectors.dense(feature)) + }.asInstanceOf[RDD[(Any, Any)]]) + } + + def getProductFeatures: RDD[Array[Any]] = { + SerDe.fromTuple2RDD(productFeatures.map { + case (product, feature) => (product, Vectors.dense(feature)) + }.asInstanceOf[RDD[(Any, Any)]]) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 430d763ef7ca..6237b64c8f98 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -22,41 +22,41 @@ import java.nio.{ByteBuffer, ByteOrder} import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.language.existentials import scala.reflect.ClassTag import net.razorvine.pickle._ -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.python.SerDeUtil import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ import org.apache.spark.mllib.feature._ +import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.random.{RandomRDDs => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames +import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.stat.test.ChiSqTestResult -import org.apache.spark.mllib.tree.{RandomForest, DecisionTree} -import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} +import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} +import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy, Strategy} import org.apache.spark.mllib.tree.impurity._ -import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel} +import org.apache.spark.mllib.tree.loss.Losses +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel, RandomForestModel} +import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils /** - * :: DeveloperApi :: - * The Java stubs necessary for the Python mllib bindings. + * The Java stubs necessary for the Python mllib bindings. It is called by Py4J on the Python side. */ -@DeveloperApi -class PythonMLLibAPI extends Serializable { - +private[python] class PythonMLLibAPI extends Serializable { /** * Loads and serializes labeled points saved with `RDD#saveAsTextFile`. @@ -77,7 +77,13 @@ class PythonMLLibAPI extends Serializable { initialWeights: Vector): JList[Object] = { try { val model = learner.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), initialWeights) - List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava + if (model.isInstanceOf[LogisticRegressionModel]) { + val lrModel = model.asInstanceOf[LogisticRegressionModel] + List(lrModel.weights, lrModel.intercept, lrModel.numFeatures, lrModel.numClasses) + .map(_.asInstanceOf[Object]).asJava + } else { + List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava + } } finally { data.rdd.unpersist(blocking = false) } @@ -110,9 +116,11 @@ class PythonMLLibAPI extends Serializable { initialWeights: Vector, regParam: Double, regType: String, - intercept: Boolean): JList[Object] = { + intercept: Boolean, + validateData: Boolean): JList[Object] = { val lrAlg = new LinearRegressionWithSGD() lrAlg.setIntercept(intercept) + .setValidateData(validateData) lrAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) @@ -134,8 +142,12 @@ class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeights: Vector): JList[Object] = { + initialWeights: Vector, + intercept: Boolean, + validateData: Boolean): JList[Object] = { val lassoAlg = new LassoWithSGD() + lassoAlg.setIntercept(intercept) + .setValidateData(validateData) lassoAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) @@ -156,8 +168,12 @@ class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeights: Vector): JList[Object] = { + initialWeights: Vector, + intercept: Boolean, + validateData: Boolean): JList[Object] = { val ridgeAlg = new RidgeRegressionWithSGD() + ridgeAlg.setIntercept(intercept) + .setValidateData(validateData) ridgeAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) @@ -180,9 +196,11 @@ class PythonMLLibAPI extends Serializable { miniBatchFraction: Double, initialWeights: Vector, regType: String, - intercept: Boolean): JList[Object] = { + intercept: Boolean, + validateData: Boolean): JList[Object] = { val SVMAlg = new SVMWithSGD() SVMAlg.setIntercept(intercept) + .setValidateData(validateData) SVMAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) @@ -206,9 +224,11 @@ class PythonMLLibAPI extends Serializable { initialWeights: Vector, regParam: Double, regType: String, - intercept: Boolean): JList[Object] = { + intercept: Boolean, + validateData: Boolean): JList[Object] = { val LogRegAlg = new LogisticRegressionWithSGD() LogRegAlg.setIntercept(intercept) + .setValidateData(validateData) LogRegAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) @@ -232,9 +252,13 @@ class PythonMLLibAPI extends Serializable { regType: String, intercept: Boolean, corrections: Int, - tolerance: Double): JList[Object] = { + tolerance: Double, + validateData: Boolean, + numClasses: Int): JList[Object] = { val LogRegAlg = new LogisticRegressionWithLBFGS() LogRegAlg.setIntercept(intercept) + .setValidateData(validateData) + .setNumClasses(numClasses) LogRegAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) @@ -254,12 +278,12 @@ class PythonMLLibAPI extends Serializable { data: JavaRDD[LabeledPoint], lambda: Double): JList[Object] = { val model = NaiveBayes.train(data.rdd, lambda) - List(Vectors.dense(model.labels), Vectors.dense(model.pi), model.theta). + List(Vectors.dense(model.labels), Vectors.dense(model.pi), model.theta.map(Vectors.dense)). map(_.asInstanceOf[Object]).asJava } /** - * Java stub for Python mllib KMeans.train() + * Java stub for Python mllib KMeans.run() */ def trainKMeansModel( data: JavaRDD[Vector], @@ -284,18 +308,55 @@ class PythonMLLibAPI extends Serializable { } /** - * A Wrapper of MatrixFactorizationModel to provide helpfer method for Python + * Java stub for Python mllib GaussianMixture.run() + * Returns a list containing weights, mean and covariance of each mixture component. */ - private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorizationModel) - extends MatrixFactorizationModel(model.rank, model.userFeatures, model.productFeatures) { - - def predict(userAndProducts: JavaRDD[Array[Any]]): RDD[Rating] = - predict(SerDe.asTupleRDD(userAndProducts.rdd)) + def trainGaussianMixture( + data: JavaRDD[Vector], + k: Int, + convergenceTol: Double, + maxIterations: Int, + seed: java.lang.Long): JList[Object] = { + val gmmAlg = new GaussianMixture() + .setK(k) + .setConvergenceTol(convergenceTol) + .setMaxIterations(maxIterations) - def getUserFeatures = SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]]) + if (seed != null) gmmAlg.setSeed(seed) - def getProductFeatures = SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]]) + try { + val model = gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)) + var wt = ArrayBuffer.empty[Double] + var mu = ArrayBuffer.empty[Vector] + var sigma = ArrayBuffer.empty[Matrix] + for (i <- 0 until model.k) { + wt += model.weights(i) + mu += model.gaussians(i).mu + sigma += model.gaussians(i).sigma + } + List(Vectors.dense(wt.toArray), mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava + } finally { + data.rdd.unpersist(blocking = false) + } + } + /** + * Java stub for Python mllib GaussianMixtureModel.predictSoft() + */ + def predictSoftGMM( + data: JavaRDD[Vector], + wt: Vector, + mu: Array[Object], + si: Array[Object]): RDD[Vector] = { + + val weight = wt.toArray + val mean = mu.map(_.asInstanceOf[DenseVector]) + val sigma = si.map(_.asInstanceOf[DenseMatrix]) + val gaussians = Array.tabulate(weight.length){ + i => new MultivariateGaussian(mean(i), sigma(i)) + } + val model = new GaussianMixtureModel(weight, gaussians) + model.predictSoft(data).map(Vectors.dense) } /** @@ -357,6 +418,24 @@ class PythonMLLibAPI extends Serializable { new MatrixFactorizationModelWrapper(model) } + /** + * Java stub for Python mllib FPGrowth.train(). This stub returns a handle + * to the Java object instead of the content of the Java object. Extra care + * needs to be taken in the Python code to ensure it gets freed on exit; see + * the Py4J documentation. + */ + def trainFPGrowthModel( + data: JavaRDD[java.lang.Iterable[Any]], + minSupport: Double, + numPartitions: Int): FPGrowthModel[Any] = { + val fpg = new FPGrowth() + .setMinSupport(minSupport) + .setNumPartitions(numPartitions) + + val model = fpg.run(data.rdd.map(_.asScala.toArray)) + new FPGrowthModelWrapper(model) + } + /** * Java stub for Normalizer.transform() */ @@ -370,9 +449,9 @@ class PythonMLLibAPI extends Serializable { def normalizeVector(p: Double, rdd: JavaRDD[Vector]): JavaRDD[Vector] = { new Normalizer(p).transform(rdd) } - + /** - * Java stub for IDF.fit(). This stub returns a + * Java stub for StandardScaler.fit(). This stub returns a * handle to the Java object instead of the content of the Java object. * Extra care needs to be taken in the Python code to ensure it gets freed on * exit; see the Py4J documentation. @@ -413,13 +492,15 @@ class PythonMLLibAPI extends Serializable { learningRate: Double, numPartitions: Int, numIterations: Int, - seed: Long): Word2VecModelWrapper = { + seed: Long, + minCount: Int): Word2VecModelWrapper = { val word2vec = new Word2Vec() .setVectorSize(vectorSize) .setLearningRate(learningRate) .setNumPartitions(numPartitions) .setNumIterations(numIterations) .setSeed(seed) + .setMinCount(minCount) try { val model = word2vec.fit(dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER)) new Word2VecModelWrapper(model) @@ -453,6 +534,10 @@ class PythonMLLibAPI extends Serializable { val words = result.map(_._1) List(words, similarity).map(_.asInstanceOf[Object]).asJava } + + def getVectors: JMap[String, JList[Float]] = { + model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava + } } /** @@ -532,6 +617,35 @@ class PythonMLLibAPI extends Serializable { } } + /** + * Java stub for Python mllib GradientBoostedTrees.train(). + * This stub returns a handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on exit; + * see the Py4J documentation. + */ + def trainGradientBoostedTreesModel( + data: JavaRDD[LabeledPoint], + algoStr: String, + categoricalFeaturesInfo: JMap[Int, Int], + lossStr: String, + numIterations: Int, + learningRate: Double, + maxDepth: Int): GradientBoostedTreesModel = { + val boostingStrategy = BoostingStrategy.defaultParams(algoStr) + boostingStrategy.setLoss(Losses.fromString(lossStr)) + boostingStrategy.setNumIterations(numIterations) + boostingStrategy.setLearningRate(learningRate) + boostingStrategy.treeStrategy.setMaxDepth(maxDepth) + boostingStrategy.treeStrategy.categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap + + val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK) + try { + GradientBoostedTrees.train(cached, boostingStrategy) + } finally { + cached.unpersist(blocking = false) + } + } + /** * Java stub for mllib Statistics.colStats(X: RDD[Vector]). * TODO figure out return type. @@ -821,13 +935,21 @@ private[spark] object SerDe extends Serializable { out.write(code) } + protected def getBytes(obj: Object): Array[Byte] = { + if (obj.getClass.isArray) { + obj.asInstanceOf[Array[Byte]] + } else { + obj.asInstanceOf[String].getBytes(LATIN1) + } + } + private[python] def saveState(obj: Object, out: OutputStream, pickler: Pickler) } // Pickler for DenseVector private[python] class DenseVectorPickler extends BasePickler[DenseVector] { - def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { + def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = { val vector: DenseVector = obj.asInstanceOf[DenseVector] val bytes = new Array[Byte](8 * vector.size) val bb = ByteBuffer.wrap(bytes) @@ -846,7 +968,7 @@ private[spark] object SerDe extends Serializable { if (args.length != 1) { throw new PickleException("should be 1") } - val bytes = args(0).asInstanceOf[String].getBytes(LATIN1) + val bytes = getBytes(args(0)) val bb = ByteBuffer.wrap(bytes, 0, bytes.length) bb.order(ByteOrder.nativeOrder()) val db = bb.asDoubleBuffer() @@ -859,12 +981,14 @@ private[spark] object SerDe extends Serializable { // Pickler for DenseMatrix private[python] class DenseMatrixPickler extends BasePickler[DenseMatrix] { - def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { + def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = { val m: DenseMatrix = obj.asInstanceOf[DenseMatrix] val bytes = new Array[Byte](8 * m.values.size) val order = ByteOrder.nativeOrder() + val isTransposed = if (m.isTransposed) 1 else 0 ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().put(m.values) + out.write(Opcodes.MARK) out.write(Opcodes.BININT) out.write(PickleUtils.integer_to_bytes(m.numRows)) out.write(Opcodes.BININT) @@ -872,26 +996,29 @@ private[spark] object SerDe extends Serializable { out.write(Opcodes.BINSTRING) out.write(PickleUtils.integer_to_bytes(bytes.length)) out.write(bytes) - out.write(Opcodes.TUPLE3) + out.write(Opcodes.BININT) + out.write(PickleUtils.integer_to_bytes(isTransposed)) + out.write(Opcodes.TUPLE) } def construct(args: Array[Object]): Object = { - if (args.length != 3) { - throw new PickleException("should be 3") + if (args.length != 4) { + throw new PickleException("should be 4") } - val bytes = args(2).asInstanceOf[String].getBytes(LATIN1) + val bytes = getBytes(args(2)) val n = bytes.length / 8 val values = new Array[Double](n) val order = ByteOrder.nativeOrder() ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().get(values) - new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values) + val isTransposed = args(3).asInstanceOf[Int] == 1 + new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values, isTransposed) } } // Pickler for SparseVector private[python] class SparseVectorPickler extends BasePickler[SparseVector] { - def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { + def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = { val v: SparseVector = obj.asInstanceOf[SparseVector] val n = v.indices.size val indiceBytes = new Array[Byte](4 * n) @@ -916,8 +1043,8 @@ private[spark] object SerDe extends Serializable { throw new PickleException("should be 3") } val size = args(0).asInstanceOf[Int] - val indiceBytes = args(1).asInstanceOf[String].getBytes(LATIN1) - val valueBytes = args(2).asInstanceOf[String].getBytes(LATIN1) + val indiceBytes = getBytes(args(1)) + val valueBytes = getBytes(args(2)) val n = indiceBytes.length / 4 val indices = new Array[Int](n) val values = new Array[Double](n) @@ -933,7 +1060,7 @@ private[spark] object SerDe extends Serializable { // Pickler for LabeledPoint private[python] class LabeledPointPickler extends BasePickler[LabeledPoint] { - def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { + def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = { val point: LabeledPoint = obj.asInstanceOf[LabeledPoint] saveObjects(out, pickler, point.label, point.features) } @@ -949,7 +1076,7 @@ private[spark] object SerDe extends Serializable { // Pickler for Rating private[python] class RatingPickler extends BasePickler[Rating] { - def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { + def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = { val rating: Rating = obj.asInstanceOf[Rating] saveObjects(out, pickler, rating.user, rating.product, rating.rating) } @@ -1021,7 +1148,10 @@ private[spark] object SerDe extends Serializable { iter.flatMap { row => val obj = unpickle.loads(row) if (batched) { - obj.asInstanceOf[JArrayList[_]].asScala + obj match { + case list: JArrayList[_] => list.asScala + case arr: Array[_] => arr + } } else { Seq(obj) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala index b7a1d90d24d7..35a0db76f3a8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.classification +import org.json4s.{DefaultFormats, JValue} + import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector @@ -53,3 +55,15 @@ trait ClassificationModel extends Serializable { def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } + +private[mllib] object ClassificationModel { + + /** + * Helper method for loading GLM classification model metadata. + * @return (numFeatures, numClasses) + */ + def getNumFeaturesClasses(metadata: JValue): (Int, Int) = { + implicit val formats = DefaultFormats + ((metadata \ "numFeatures").extract[Int], (metadata \ "numClasses").extract[Int]) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 94d757bc317a..057b628c6a58 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -17,31 +17,73 @@ package org.apache.spark.mllib.classification +import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.classification.impl.GLMClassificationModel +import org.apache.spark.mllib.linalg.BLAS.dot +import org.apache.spark.mllib.linalg.{DenseVector, Vector} import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.DataValidators +import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader} import org.apache.spark.rdd.RDD + /** - * Classification model trained using Logistic Regression. + * Classification model trained using Multinomial/Binary Logistic Regression. * * @param weights Weights computed for every feature. - * @param intercept Intercept computed for this model. + * @param intercept Intercept computed for this model. (Only used in Binary Logistic Regression. + * In Multinomial Logistic Regression, the intercepts will not be a single value, + * so the intercepts will be part of the weights.) + * @param numFeatures the dimension of the features. + * @param numClasses the number of possible outcomes for k classes classification problem in + * Multinomial Logistic Regression. By default, it is binary logistic regression + * so numClasses will be set to 2. */ class LogisticRegressionModel ( override val weights: Vector, - override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { + override val intercept: Double, + val numFeatures: Int, + val numClasses: Int) + extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable + with Saveable { + + if (numClasses == 2) { + require(weights.size == numFeatures, + s"LogisticRegressionModel with numClasses = 2 was given non-matching values:" + + s" numFeatures = $numFeatures, but weights.size = ${weights.size}") + } else { + val weightsSizeWithoutIntercept = (numClasses - 1) * numFeatures + val weightsSizeWithIntercept = (numClasses - 1) * (numFeatures + 1) + require(weights.size == weightsSizeWithoutIntercept || weights.size == weightsSizeWithIntercept, + s"LogisticRegressionModel.load with numClasses = $numClasses and numFeatures = $numFeatures" + + s" expected weights of length $weightsSizeWithoutIntercept (without intercept)" + + s" or $weightsSizeWithIntercept (with intercept)," + + s" but was given weights of length ${weights.size}") + } + + private val dataWithBiasSize: Int = weights.size / (numClasses - 1) + + private val weightsArray: Array[Double] = weights match { + case dv: DenseVector => dv.values + case _ => + throw new IllegalArgumentException( + s"weights only supports dense vector but got type ${weights.getClass}.") + } + + /** + * Constructs a [[LogisticRegressionModel]] with weights and intercept for binary classification. + */ + def this(weights: Vector, intercept: Double) = this(weights, intercept, weights.size, 2) private var threshold: Option[Double] = Some(0.5) /** * :: Experimental :: - * Sets the threshold that separates positive predictions from negative predictions. An example - * with prediction score greater than or equal to this threshold is identified as an positive, - * and negative otherwise. The default value is 0.5. + * Sets the threshold that separates positive predictions from negative predictions + * in Binary Logistic Regression. An example with prediction score greater than or equal to + * this threshold is identified as an positive, and negative otherwise. The default value is 0.5. + * It is only used for binary classification. */ @Experimental def setThreshold(threshold: Double): this.type = { @@ -49,9 +91,18 @@ class LogisticRegressionModel ( this } + /** + * :: Experimental :: + * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. + * It is only used for binary classification. + */ + @Experimental + def getThreshold: Option[Double] = threshold + /** * :: Experimental :: * Clears the threshold so that `predict` will output raw prediction scores. + * It is only used for binary classification. */ @Experimental def clearThreshold(): this.type = { @@ -59,25 +110,97 @@ class LogisticRegressionModel ( this } - override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, + override protected def predictPoint( + dataMatrix: Vector, + weightMatrix: Vector, intercept: Double) = { - val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept - val score = 1.0 / (1.0 + math.exp(-margin)) - threshold match { - case Some(t) => if (score > t) 1.0 else 0.0 - case None => score + require(dataMatrix.size == numFeatures) + + // If dataMatrix and weightMatrix have the same dimension, it's binary logistic regression. + if (numClasses == 2) { + val margin = dot(weightMatrix, dataMatrix) + intercept + val score = 1.0 / (1.0 + math.exp(-margin)) + threshold match { + case Some(t) => if (score > t) 1.0 else 0.0 + case None => score + } + } else { + /** + * Compute and find the one with maximum margins. If the maxMargin is negative, then the + * prediction result will be the first class. + * + * PS, if you want to compute the probabilities for each outcome instead of the outcome + * with maximum probability, remember to subtract the maxMargin from margins if maxMargin + * is positive to prevent overflow. + */ + var bestClass = 0 + var maxMargin = 0.0 + val withBias = dataMatrix.size + 1 == dataWithBiasSize + (0 until numClasses - 1).foreach { i => + var margin = 0.0 + dataMatrix.foreachActive { (index, value) => + if (value != 0.0) margin += value * weightsArray((i * dataWithBiasSize) + index) + } + // Intercept is required to be added into margin. + if (withBias) { + margin += weightsArray((i * dataWithBiasSize) + dataMatrix.size) + } + if (margin > maxMargin) { + maxMargin = margin + bestClass = i + 1 + } + } + bestClass.toDouble + } + } + + override def save(sc: SparkContext, path: String): Unit = { + GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, + numFeatures, numClasses, weights, intercept, threshold) + } + + override protected def formatVersion: String = "1.0" + + override def toString: String = { + s"${super.toString}, numClasses = ${numClasses}, threshold = ${threshold.get}" + } +} + +object LogisticRegressionModel extends Loader[LogisticRegressionModel] { + + override def load(sc: SparkContext, path: String): LogisticRegressionModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future + val classNameV1_0 = "org.apache.spark.mllib.classification.LogisticRegressionModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) + val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0) + // numFeatures, numClasses, weights are checked in model initialization + val model = + new LogisticRegressionModel(data.weights, data.intercept, numFeatures, numClasses) + data.threshold match { + case Some(t) => model.setThreshold(t) + case None => model.clearThreshold() + } + model + case _ => throw new Exception( + s"LogisticRegressionModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") } } } /** - * Train a classification model for Logistic Regression using Stochastic Gradient Descent. By - * default L2 regularization is used, which can be changed via - * [[LogisticRegressionWithSGD.optimizer]]. - * NOTE: Labels used in Logistic Regression should be {0, 1}. + * Train a classification model for Binary Logistic Regression + * using Stochastic Gradient Descent. By default L2 regularization is used, + * which can be changed via [[LogisticRegressionWithSGD.optimizer]]. + * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1} + * for k classes multi-label classification problem. * Using [[LogisticRegressionWithLBFGS]] is recommended over this. */ -class LogisticRegressionWithSGD private ( +class LogisticRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, private var regParam: Double, @@ -99,7 +222,7 @@ class LogisticRegressionWithSGD private ( */ def this() = this(1.0, 100, 0.01, 1.0) - override protected def createModel(weights: Vector, intercept: Double) = { + override protected[mllib] def createModel(weights: Vector, intercept: Double) = { new LogisticRegressionModel(weights, intercept) } } @@ -194,9 +317,10 @@ object LogisticRegressionWithSGD { } /** - * Train a classification model for Logistic Regression using Limited-memory BFGS. - * Standard feature scaling and L2 regularization are used by default. - * NOTE: Labels used in Logistic Regression should be {0, 1} + * Train a classification model for Multinomial/Binary Logistic Regression using + * Limited-memory BFGS. Standard feature scaling and L2 regularization are used by default. + * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1} + * for k classes multi-label classification problem. */ class LogisticRegressionWithLBFGS extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable { @@ -205,9 +329,37 @@ class LogisticRegressionWithLBFGS override val optimizer = new LBFGS(new LogisticGradient, new SquaredL2Updater) - override protected val validators = List(DataValidators.binaryLabelValidator) + override protected val validators = List(multiLabelValidator) + + private def multiLabelValidator: RDD[LabeledPoint] => Boolean = { data => + if (numOfLinearPredictor > 1) { + DataValidators.multiLabelValidator(numOfLinearPredictor + 1)(data) + } else { + DataValidators.binaryLabelValidator(data) + } + } + + /** + * :: Experimental :: + * Set the number of possible outcomes for k classes classification problem in + * Multinomial Logistic Regression. + * By default, it is binary logistic regression so k will be set to 2. + */ + @Experimental + def setNumClasses(numClasses: Int): this.type = { + require(numClasses > 1) + numOfLinearPredictor = numClasses - 1 + if (numClasses > 2) { + optimizer.setGradient(new LogisticGradient(numClasses)) + } + this + } override protected def createModel(weights: Vector, intercept: Double) = { - new LogisticRegressionModel(weights, intercept) + if (numOfLinearPredictor == 1) { + new LogisticRegressionModel(weights, intercept) + } else { + new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor + 1) + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index a967df857bed..c9b3ff0172e2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -17,13 +17,24 @@ package org.apache.spark.mllib.classification -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} +import java.lang.{Iterable => JIterable} -import org.apache.spark.{SparkException, Logging} -import org.apache.spark.SparkContext._ +import scala.collection.JavaConverters._ + +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis} +import breeze.numerics.{exp => brzExp, log => brzLog} + +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ +import org.json4s.{DefaultFormats, JValue} + +import org.apache.spark.{Logging, SparkContext, SparkException} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, SQLContext} + /** * Model for Naive Bayes Classifiers. @@ -32,26 +43,39 @@ import org.apache.spark.rdd.RDD * @param pi log of class priors, whose dimension is C, number of labels * @param theta log of class conditional probabilities, whose dimension is C-by-D, * where D is number of features + * @param modelType The type of NB model to fit can be "Multinomial" or "Bernoulli" */ class NaiveBayesModel private[mllib] ( val labels: Array[Double], val pi: Array[Double], - val theta: Array[Array[Double]]) extends ClassificationModel with Serializable { + val theta: Array[Array[Double]], + val modelType: String) + extends ClassificationModel with Serializable with Saveable { + + private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) = + this(labels, pi, theta, "Multinomial") + + /** A Java-friendly constructor that takes three Iterable parameters. */ + private[mllib] def this( + labels: JIterable[Double], + pi: JIterable[Double], + theta: JIterable[JIterable[Double]]) = + this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray)) private val brzPi = new BDV[Double](pi) - private val brzTheta = new BDM[Double](theta.length, theta(0).length) + private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t - { - // Need to put an extra pair of braces to prevent Scala treating `i` as a member. - var i = 0 - while (i < theta.length) { - var j = 0 - while (j < theta(i).length) { - brzTheta(i, j) = theta(i)(j) - j += 1 - } - i += 1 - } + // Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0. + // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra + // application of this condition (in predict function). + private val (brzNegTheta, brzNegThetaSum) = modelType match { + case "Multinomial" => (None, None) + case "Bernoulli" => + val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x)) + (Option(negTheta), Option(brzSum(negTheta, Axis._1))) + case _ => + // This should never happen. + throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") } override def predict(testData: RDD[Vector]): RDD[Double] = { @@ -63,7 +87,150 @@ class NaiveBayesModel private[mllib] ( } override def predict(testData: Vector): Double = { - labels(brzArgmax(brzPi + brzTheta * testData.toBreeze)) + modelType match { + case "Multinomial" => + labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) ) + case "Bernoulli" => + labels (brzArgmax (brzPi + + (brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get)) + case _ => + // This should never happen. + throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") + } + } + + override def save(sc: SparkContext, path: String): Unit = { + val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType) + NaiveBayesModel.SaveLoadV2_0.save(sc, path, data) + } + + override protected def formatVersion: String = "2.0" +} + +object NaiveBayesModel extends Loader[NaiveBayesModel] { + + import org.apache.spark.mllib.util.Loader._ + + private[mllib] object SaveLoadV2_0 { + + def thisFormatVersion: String = "2.0" + + /** Hard-code class name string in case it changes in the future */ + def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel" + + /** Model data for model import/export */ + case class Data( + labels: Array[Double], + pi: Array[Double], + theta: Array[Array[Double]], + modelType: String) + + def save(sc: SparkContext, path: String, data: Data): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Create JSON metadata. + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) + + // Create Parquet data. + val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() + dataRDD.saveAsParquetFile(dataPath(path)) + } + + def load(sc: SparkContext, path: String): NaiveBayesModel = { + val sqlContext = new SQLContext(sc) + // Load Parquet data. + val dataRDD = sqlContext.parquetFile(dataPath(path)) + // Check schema explicitly since erasure makes it hard to use match-case for checking. + checkSchema[Data](dataRDD.schema) + val dataArray = dataRDD.select("labels", "pi", "theta", "modelType").take(1) + assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") + val data = dataArray(0) + val labels = data.getAs[Seq[Double]](0).toArray + val pi = data.getAs[Seq[Double]](1).toArray + val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray + val modelType = data.getString(3) + new NaiveBayesModel(labels, pi, theta, modelType) + } + + } + + private[mllib] object SaveLoadV1_0 { + + def thisFormatVersion: String = "1.0" + + /** Hard-code class name string in case it changes in the future */ + def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel" + + /** Model data for model import/export */ + case class Data( + labels: Array[Double], + pi: Array[Double], + theta: Array[Array[Double]]) + + def save(sc: SparkContext, path: String, data: Data): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Create JSON metadata. + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) + + // Create Parquet data. + val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() + dataRDD.saveAsParquetFile(dataPath(path)) + } + + def load(sc: SparkContext, path: String): NaiveBayesModel = { + val sqlContext = new SQLContext(sc) + // Load Parquet data. + val dataRDD = sqlContext.parquetFile(dataPath(path)) + // Check schema explicitly since erasure makes it hard to use match-case for checking. + checkSchema[Data](dataRDD.schema) + val dataArray = dataRDD.select("labels", "pi", "theta").take(1) + assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") + val data = dataArray(0) + val labels = data.getAs[Seq[Double]](0).toArray + val pi = data.getAs[Seq[Double]](1).toArray + val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray + new NaiveBayesModel(labels, pi, theta) + } + } + + override def load(sc: SparkContext, path: String): NaiveBayesModel = { + val (loadedClassName, version, metadata) = loadMetadata(sc, path) + val classNameV1_0 = SaveLoadV1_0.thisClassName + val classNameV2_0 = SaveLoadV2_0.thisClassName + val (model, numFeatures, numClasses) = (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) + val model = SaveLoadV1_0.load(sc, path) + (model, numFeatures, numClasses) + case (className, "2.0") if className == classNameV2_0 => + val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) + val model = SaveLoadV2_0.load(sc, path) + (model, numFeatures, numClasses) + case _ => throw new Exception( + s"NaiveBayesModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + assert(model.pi.size == numClasses, + s"NaiveBayesModel.load expected $numClasses classes," + + s" but class priors vector pi had ${model.pi.size} elements") + assert(model.theta.size == numClasses, + s"NaiveBayesModel.load expected $numClasses classes," + + s" but class conditionals array theta had ${model.theta.size} elements") + assert(model.theta.forall(_.size == numFeatures), + s"NaiveBayesModel.load expected $numFeatures features," + + s" but class conditionals array theta had elements of size:" + + s" ${model.theta.map(_.size).mkString(",")}") + model } } @@ -75,9 +242,14 @@ class NaiveBayesModel private[mllib] ( * document classification. By making every vector a 0-1 vector, it can also be used as * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative. */ -class NaiveBayes private (private var lambda: Double) extends Serializable with Logging { - def this() = this(1.0) +class NaiveBayes private ( + private var lambda: Double, + private var modelType: String) extends Serializable with Logging { + + def this(lambda: Double) = this(lambda, "Multinomial") + + def this() = this(1.0, "Multinomial") /** Set the smoothing parameter. Default: 1.0. */ def setLambda(lambda: Double): NaiveBayes = { @@ -85,12 +257,30 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with this } + /** Get the smoothing parameter. */ + def getLambda: Double = lambda + + /** + * Set the model type using a string (case-sensitive). + * Supported options: "Multinomial" and "Bernoulli". + * (default: Multinomial) + */ + def setModelType(modelType:String): NaiveBayes = { + require(NaiveBayes.supportedModelTypes.contains(modelType), + s"NaiveBayes was created with an unknown ModelType: $modelType") + this.modelType = modelType + this + } + + /** Get the model type. */ + def getModelType: String = this.modelType + /** * Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries. * * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. */ - def run(data: RDD[LabeledPoint]) = { + def run(data: RDD[LabeledPoint]): NaiveBayesModel = { val requireNonnegativeValues: Vector => Unit = (v: Vector) => { val values = v match { case SparseVector(size, indices, values) => @@ -118,21 +308,30 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) => (c1._1 + c2._1, c1._2 += c2._2) ).collect() + val numLabels = aggregated.length var numDocuments = 0L aggregated.foreach { case (_, (n, _)) => numDocuments += n } val numFeatures = aggregated.head match { case (_, (_, v)) => v.size } + val labels = new Array[Double](numLabels) val pi = new Array[Double](numLabels) val theta = Array.fill(numLabels)(new Array[Double](numFeatures)) + val piLogDenom = math.log(numDocuments + numLabels * lambda) var i = 0 aggregated.foreach { case (label, (n, sumTermFreqs)) => labels(i) = label - val thetaLogDenom = math.log(brzSum(sumTermFreqs) + numFeatures * lambda) pi(i) = math.log(n + lambda) - piLogDenom + val thetaLogDenom = modelType match { + case "Multinomial" => math.log(brzSum(sumTermFreqs) + numFeatures * lambda) + case "Bernoulli" => math.log(n + 2.0 * lambda) + case _ => + // This should never happen. + throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType") + } var j = 0 while (j < numFeatures) { theta(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom @@ -141,7 +340,7 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with i += 1 } - new NaiveBayesModel(labels, pi, theta) + new NaiveBayesModel(labels, pi, theta, modelType) } } @@ -149,13 +348,16 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with * Top-level methods for calling naive Bayes. */ object NaiveBayes { + + /* Set of modelTypes that NaiveBayes supports */ + private[mllib] val supportedModelTypes = Set("Multinomial", "Bernoulli") + /** * Trains a Naive Bayes model given an RDD of `(label, features)` pairs. * - * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of - * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for - * document classification. By making every vector a 0-1 vector, it can also be used as - * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). + * This is the default Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all + * kinds of discrete data. For example, by converting documents into TF-IDF vectors, it + * can be used for document classification. * * This version of the method uses a default smoothing parameter of 1.0. * @@ -169,16 +371,40 @@ object NaiveBayes { /** * Trains a Naive Bayes model given an RDD of `(label, features)` pairs. * - * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of - * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for - * document classification. By making every vector a 0-1 vector, it can also be used as - * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). + * This is the default Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all + * kinds of discrete data. For example, by converting documents into TF-IDF vectors, it + * can be used for document classification. * * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency * vector or a count vector. * @param lambda The smoothing parameter */ def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = { - new NaiveBayes(lambda).run(input) + new NaiveBayes(lambda, "Multinomial").run(input) + } + + /** + * Trains a Naive Bayes model given an RDD of `(label, features)` pairs. + * + * The model type can be set to either Multinomial NB ([[http://tinyurl.com/lsdw6p]]) + * or Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The Multinomial NB can handle + * discrete count data and can be called by setting the model type to "multinomial". + * For example, it can be used with word counts or TF_IDF vectors of documents. + * The Bernoulli model fits presence or absence (0-1) counts. By making every vector a + * 0-1 vector and setting the model type to "bernoulli", the fits and predicts as + * Bernoulli NB. + * + * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency + * vector or a count vector. + * @param lambda The smoothing parameter + * + * @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be + * multinomial or bernoulli + */ + def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = { + require(supportedModelTypes.contains(modelType), + s"NaiveBayes was created with an unknown ModelType: $modelType") + new NaiveBayes(lambda, modelType).run(input) } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index dd514ff8a37f..52fb62dcff1b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -17,11 +17,13 @@ package org.apache.spark.mllib.classification +import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.DataValidators +import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable} import org.apache.spark.rdd.RDD /** @@ -33,7 +35,8 @@ import org.apache.spark.rdd.RDD class SVMModel ( override val weights: Vector, override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { + extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable + with Saveable { private var threshold: Option[Double] = Some(0.0) @@ -49,6 +52,13 @@ class SVMModel ( this } + /** + * :: Experimental :: + * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. + */ + @Experimental + def getThreshold: Option[Double] = threshold + /** * :: Experimental :: * Clears the threshold so that `predict` will output raw prediction scores. @@ -69,6 +79,45 @@ class SVMModel ( case None => margin } } + + override def save(sc: SparkContext, path: String): Unit = { + GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, + numFeatures = weights.size, numClasses = 2, weights, intercept, threshold) + } + + override protected def formatVersion: String = "1.0" + + override def toString: String = { + s"${super.toString}, numClasses = 2, threshold = ${threshold.get}" + } +} + +object SVMModel extends Loader[SVMModel] { + + override def load(sc: SparkContext, path: String): SVMModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future + val classNameV1_0 = "org.apache.spark.mllib.classification.SVMModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) + val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0) + val model = new SVMModel(data.weights, data.intercept) + assert(model.weights.size == numFeatures, s"SVMModel.load with numFeatures=$numFeatures" + + s" was given non-matching weights vector of size ${model.weights.size}") + assert(numClasses == 2, + s"SVMModel.load was given numClasses=$numClasses but only supports 2 classes") + data.threshold match { + case Some(t) => model.setThreshold(t) + case None => model.clearThreshold() + } + model + case _ => throw new Exception( + s"SVMModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala new file mode 100644 index 000000000000..7d33df3221fb --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.StreamingLinearAlgorithm + +/** + * :: Experimental :: + * Train or predict a logistic regression model on streaming data. Training uses + * Stochastic Gradient Descent to update the model based on each new batch of + * incoming data from a DStream (see `LogisticRegressionWithSGD` for model equation) + * + * Each batch of data is assumed to be an RDD of LabeledPoints. + * The number of data points per batch can vary, but the number + * of features must be constant. An initial weight + * vector must be provided. + * + * Use a builder pattern to construct a streaming logistic regression + * analysis in an application, like: + * + * {{{ + * val model = new StreamingLogisticRegressionWithSGD() + * .setStepSize(0.5) + * .setNumIterations(10) + * .setInitialWeights(Vectors.dense(...)) + * .trainOn(DStream) + * }}} + */ +@Experimental +class StreamingLogisticRegressionWithSGD private[mllib] ( + private var stepSize: Double, + private var numIterations: Int, + private var miniBatchFraction: Double, + private var regParam: Double) + extends StreamingLinearAlgorithm[LogisticRegressionModel, LogisticRegressionWithSGD] + with Serializable { + + /** + * Construct a StreamingLogisticRegression object with default parameters: + * {stepSize: 0.1, numIterations: 50, miniBatchFraction: 1.0, regParam: 0.0}. + * Initial weights must be set before using trainOn or predictOn + * (see `StreamingLinearAlgorithm`) + */ + def this() = this(0.1, 50, 1.0, 0.0) + + protected val algorithm = new LogisticRegressionWithSGD( + stepSize, numIterations, regParam, miniBatchFraction) + + protected var model: Option[LogisticRegressionModel] = None + + /** Set the step size for gradient descent. Default: 0.1. */ + def setStepSize(stepSize: Double): this.type = { + this.algorithm.optimizer.setStepSize(stepSize) + this + } + + /** Set the number of iterations of gradient descent to run per update. Default: 50. */ + def setNumIterations(numIterations: Int): this.type = { + this.algorithm.optimizer.setNumIterations(numIterations) + this + } + + /** Set the fraction of each batch to use for updates. Default: 1.0. */ + def setMiniBatchFraction(miniBatchFraction: Double): this.type = { + this.algorithm.optimizer.setMiniBatchFraction(miniBatchFraction) + this + } + + /** Set the regularization parameter. Default: 0.0. */ + def setRegParam(regParam: Double): this.type = { + this.algorithm.optimizer.setRegParam(regParam) + this + } + + /** Set the initial weights. Default: [0.0, 0.0]. */ + def setInitialWeights(initialWeights: Vector): this.type = { + this.model = Some(algorithm.createModel(initialWeights, 0.0)) + this + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala new file mode 100644 index 000000000000..3b6790cce47c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification.impl + +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.util.Loader +import org.apache.spark.sql.{Row, SQLContext} + +/** + * Helper class for import/export of GLM classification models. + */ +private[classification] object GLMClassificationModel { + + object SaveLoadV1_0 { + + def thisFormatVersion: String = "1.0" + + /** Model data for import/export */ + case class Data(weights: Vector, intercept: Double, threshold: Option[Double]) + + /** + * Helper method for saving GLM classification model metadata and data. + * @param modelClass String name for model class, to be saved with metadata + * @param numClasses Number of classes label can take, to be saved with metadata + */ + def save( + sc: SparkContext, + path: String, + modelClass: String, + numFeatures: Int, + numClasses: Int, + weights: Vector, + intercept: Double, + threshold: Option[Double]): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Create JSON metadata. + val metadata = compact(render( + ("class" -> modelClass) ~ ("version" -> thisFormatVersion) ~ + ("numFeatures" -> numFeatures) ~ ("numClasses" -> numClasses))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val data = Data(weights, intercept, threshold) + sc.parallelize(Seq(data), 1).toDF().saveAsParquetFile(Loader.dataPath(path)) + } + + /** + * Helper method for loading GLM classification model data. + * + * NOTE: Callers of this method should check numClasses, numFeatures on their own. + * + * @param modelClass String name for model class (used for error messages) + */ + def loadData(sc: SparkContext, path: String, modelClass: String): Data = { + val datapath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + val dataRDD = sqlContext.parquetFile(datapath) + val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1) + assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") + val data = dataArray(0) + assert(data.size == 3, s"Unable to load $modelClass data from: $datapath") + val (weights, intercept) = data match { + case Row(weights: Vector, intercept: Double, _) => + (weights, intercept) + } + val threshold = if (data.isNullAt(2)) { + None + } else { + Some(data.getDouble(2)) + } + Data(weights, intercept, threshold) + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala similarity index 85% rename from mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala rename to mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index 899fe5e9e9cf..568b65305649 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -19,15 +19,18 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable.IndexedSeq -import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix, diag, Transpose} +import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, Vector => BV} -import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors, DenseVector, DenseMatrix, BLAS} +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, Matrices, Vector, Vectors} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils /** + * :: Experimental :: + * * This class performs expectation maximization for multivariate Gaussian * Mixture Models (GMMs). A GMM represents a composite distribution of * independent Gaussian distributions with associated "mixing" weights @@ -38,19 +41,27 @@ import org.apache.spark.util.Utils * less than convergenceTol, or until it has reached the max number of iterations. * While this process is generally guaranteed to converge, it is not guaranteed * to find a global optimum. - * + * + * Note: For high-dimensional data (with many features), this algorithm may perform poorly. + * This is due to high-dimensional data (a) making it difficult to cluster at all (based + * on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. + * * @param k The number of independent Gaussians in the mixture model * @param convergenceTol The maximum change in log-likelihood at which convergence * is considered to have occurred. * @param maxIterations The maximum number of iterations to perform */ -class GaussianMixtureEM private ( +@Experimental +class GaussianMixture private ( private var k: Int, private var convergenceTol: Double, private var maxIterations: Int, private var seed: Long) extends Serializable { - /** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */ + /** + * Constructs a default instance. The default parameters are {k: 2, convergenceTol: 0.01, + * maxIterations: 100, seed: random}. + */ def this() = this(2, 0.01, 100, Utils.random.nextLong()) // number of samples per cluster to use when initializing Gaussians @@ -123,7 +134,7 @@ class GaussianMixtureEM private ( val sc = data.sparkContext // we will operate on the data as breeze data - val breezeData = data.map(u => u.toBreeze.toDenseVector).cache() + val breezeData = data.map(_.toBreeze).cache() // Get length of the input vectors val d = breezeData.first().length @@ -141,7 +152,7 @@ class GaussianMixtureEM private ( (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => val slice = samples.view(i * nSamples, (i + 1) * nSamples) new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) - }) + }) } } @@ -162,7 +173,7 @@ class GaussianMixtureEM private ( var i = 0 while (i < k) { val mu = sums.means(i) / sums.weights(i) - BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu).asInstanceOf[DenseVector], + BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu), Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix]) weights(i) = sums.weights(i) / sumWeights gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i)) @@ -178,8 +189,8 @@ class GaussianMixtureEM private ( } /** Average of dense breeze vectors */ - private def vectorMean(x: IndexedSeq[BreezeVector[Double]]): BreezeVector[Double] = { - val v = BreezeVector.zeros[Double](x(0).length) + private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = { + val v = BDV.zeros[Double](x(0).length) x.foreach(xi => v += xi) v / x.length.toDouble } @@ -188,10 +199,10 @@ class GaussianMixtureEM private ( * Construct matrix where diagonal entries are element-wise * variance of input vectors (computes biased variance) */ - private def initCovariance(x: IndexedSeq[BreezeVector[Double]]): BreezeMatrix[Double] = { + private def initCovariance(x: IndexedSeq[BV[Double]]): BreezeMatrix[Double] = { val mu = vectorMean(x) - val ss = BreezeVector.zeros[Double](x(0).length) - x.map(xi => (xi - mu) :^ 2.0).foreach(u => ss += u) + val ss = BDV.zeros[Double](x(0).length) + x.foreach(xi => ss += (xi - mu) :^ 2.0) diag(ss / x.length.toDouble) } } @@ -200,7 +211,7 @@ class GaussianMixtureEM private ( private object ExpectationSum { def zero(k: Int, d: Int): ExpectationSum = { new ExpectationSum(0.0, Array.fill(k)(0.0), - Array.fill(k)(BreezeVector.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d))) + Array.fill(k)(BDV.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d))) } // compute cluster contributions for each input point @@ -208,19 +219,18 @@ private object ExpectationSum { def add( weights: Array[Double], dists: Array[MultivariateGaussian]) - (sums: ExpectationSum, x: BreezeVector[Double]): ExpectationSum = { + (sums: ExpectationSum, x: BV[Double]): ExpectationSum = { val p = weights.zip(dists).map { case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(x) } val pSum = p.sum sums.logLikelihood += math.log(pSum) - val xxt = x * new Transpose(x) var i = 0 while (i < sums.k) { p(i) /= pSum sums.weights(i) += p(i) sums.means(i) += x * p(i) - BLAS.syr(p(i), Vectors.fromBreeze(x).asInstanceOf[DenseVector], + BLAS.syr(p(i), Vectors.fromBreeze(x), Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix]) i = i + 1 } @@ -232,7 +242,7 @@ private object ExpectationSum { private class ExpectationSum( var logLikelihood: Double, val weights: Array[Double], - val means: Array[BreezeVector[Double]], + val means: Array[BDV[Double]], val sigmas: Array[BreezeMatrix[Double]]) extends Serializable { val k = weights.length diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 1a2178ee7f71..ec65a3da689d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -19,12 +19,21 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseVector => BreezeVector} -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.Vector +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian -import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.util.{MLUtils, Loader, Saveable} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{SQLContext, Row} /** + * :: Experimental :: + * * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are * the respective mean and covariance for each Gaussian distribution i=1..k. @@ -35,12 +44,19 @@ import org.apache.spark.mllib.util.MLUtils * @param sigma Covariance maxtrix for each Gaussian in the mixture, where sigma(i) is the * covariance matrix for Gaussian i */ +@Experimental class GaussianMixtureModel( val weights: Array[Double], - val gaussians: Array[MultivariateGaussian]) extends Serializable { + val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable{ require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match") - + + override protected def formatVersion = "1.0" + + override def save(sc: SparkContext, path: String): Unit = { + GaussianMixtureModel.SaveLoadV1_0.save(sc, path, weights, gaussians) + } + /** Number of gaussians in mixture */ def k: Int = weights.length @@ -79,5 +95,79 @@ class GaussianMixtureModel( p(i) /= pSum } p - } + } +} + +@Experimental +object GaussianMixtureModel extends Loader[GaussianMixtureModel] { + + private object SaveLoadV1_0 { + + case class Data(weight: Double, mu: Vector, sigma: Matrix) + + val formatVersionV1_0 = "1.0" + + val classNameV1_0 = "org.apache.spark.mllib.clustering.GaussianMixtureModel" + + def save( + sc: SparkContext, + path: String, + weights: Array[Double], + gaussians: Array[MultivariateGaussian]): Unit = { + + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Create JSON metadata. + val metadata = compact(render + (("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ ("k" -> weights.length))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val dataArray = Array.tabulate(weights.length) { i => + Data(weights(i), gaussians(i).mu, gaussians(i).sigma) + } + sc.parallelize(dataArray, 1).toDF().saveAsParquetFile(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): GaussianMixtureModel = { + val dataPath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + val dataFrame = sqlContext.parquetFile(dataPath) + val dataArray = dataFrame.select("weight", "mu", "sigma").collect() + + // Check schema explicitly since erasure makes it hard to use match-case for checking. + Loader.checkSchema[Data](dataFrame.schema) + + val (weights, gaussians) = dataArray.map { + case Row(weight: Double, mu: Vector, sigma: Matrix) => + (weight, new MultivariateGaussian(mu, sigma)) + }.unzip + + return new GaussianMixtureModel(weights.toArray, gaussians.toArray) + } + } + + override def load(sc: SparkContext, path: String) : GaussianMixtureModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + implicit val formats = DefaultFormats + val k = (metadata \ "k").extract[Int] + val classNameV1_0 = SaveLoadV1_0.classNameV1_0 + (loadedClassName, version) match { + case (classNameV1_0, "1.0") => { + val model = SaveLoadV1_0.load(sc, path) + require(model.weights.length == k, + s"GaussianMixtureModel requires weights of length $k " + + s"got weights of length ${model.weights.length}") + require(model.gaussians.length == k, + s"GaussianMixtureModel requires gaussians of length $k" + + s"got gaussians of length ${model.gaussians.length}") + model + } + case _ => throw new Exception( + s"GaussianMixtureModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 11633e824231..0f8d6a399682 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -52,18 +52,33 @@ class KMeans private ( */ def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4, Utils.random.nextLong()) + /** + * Number of clusters to create (k). + */ + def getK: Int = k + /** Set the number of clusters to create (k). Default: 2. */ def setK(k: Int): this.type = { this.k = k this } + /** + * Maximum number of iterations to run. + */ + def getMaxIterations: Int = maxIterations + /** Set maximum number of iterations to run. Default: 20. */ def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this } + /** + * The initialization algorithm. This can be either "random" or "k-means||". + */ + def getInitializationMode: String = initializationMode + /** * Set the initialization algorithm. This can be either "random" to choose random points as * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ @@ -77,6 +92,13 @@ class KMeans private ( this } + /** + * :: Experimental :: + * Number of runs of the algorithm to execute in parallel. + */ + @Experimental + def getRuns: Int = runs + /** * :: Experimental :: * Set the number of runs of the algorithm to execute in parallel. We initialize the algorithm @@ -92,6 +114,11 @@ class KMeans private ( this } + /** + * Number of steps for the k-means|| initialization mode + */ + def getInitializationSteps: Int = initializationSteps + /** * Set the number of steps for the k-means|| initialization mode. This is an advanced * setting -- the default of 5 is almost always enough. Default: 5. @@ -104,6 +131,11 @@ class KMeans private ( this } + /** + * The distance threshold within which we've consider centers to have converged. + */ + def getEpsilon: Double = epsilon + /** * Set the distance threshold within which we've consider centers to have converged. * If all centers move less than this Euclidean distance, we stop iterating one run. @@ -113,6 +145,11 @@ class KMeans private ( this } + /** + * The random seed for cluster initialization. + */ + def getSeed: Long = seed + /** Set the random seed for cluster initialization. */ def setSeed(seed: Long): this.type = { this.seed = seed @@ -499,5 +536,5 @@ class VectorWithNorm(val vector: Vector, val norm: Double) extends Serializable def this(array: Array[Double]) = this(Vectors.dense(array)) /** Converts the vector to a dense vector. */ - def toDense = new VectorWithNorm(Vectors.dense(vector.toArray), norm) + def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 3b95a9e6936e..e4e411a3c8b4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -17,15 +17,27 @@ package org.apache.spark.mllib.clustering +import scala.collection.JavaConverters._ + +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.api.java.JavaRDD -import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.util.{Loader, Saveable} +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.Row /** * A clustering model for K-means. Each point belongs to the cluster with the closest center. */ -class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable { +class KMeansModel (val clusterCenters: Array[Vector]) extends Saveable with Serializable { + + /** A Java-friendly constructor that takes an Iterable of Vectors. */ + def this(centers: java.lang.Iterable[Vector]) = this(centers.asScala.toArray) /** Total number of clusters. */ def k: Int = clusterCenters.length @@ -58,4 +70,59 @@ class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable { private def clusterCentersWithNorm: Iterable[VectorWithNorm] = clusterCenters.map(new VectorWithNorm(_)) + + override def save(sc: SparkContext, path: String): Unit = { + KMeansModel.SaveLoadV1_0.save(sc, this, path) + } + + override protected def formatVersion: String = "1.0" +} + +object KMeansModel extends Loader[KMeansModel] { + override def load(sc: SparkContext, path: String): KMeansModel = { + KMeansModel.SaveLoadV1_0.load(sc, path) + } + + private case class Cluster(id: Int, point: Vector) + + private object Cluster { + def apply(r: Row): Cluster = { + Cluster(r.getInt(0), r.getAs[Vector](1)) + } + } + + private[clustering] + object SaveLoadV1_0 { + + private val thisFormatVersion = "1.0" + + private[clustering] + val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel" + + def save(sc: SparkContext, model: KMeansModel, path: String): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { case (point, id) => + Cluster(id, point) + }.toDF() + dataRDD.saveAsParquetFile(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): KMeansModel = { + implicit val formats = DefaultFormats + val sqlContext = new SQLContext(sc) + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + val k = (metadata \ "k").extract[Int] + val centriods = sqlContext.parquetFile(Loader.dataPath(path)) + Loader.checkSchema[Cluster](centriods.schema) + val localCentriods = centriods.map(Cluster.apply).collect() + assert(k == localCentriods.size) + new KMeansModel(localCentriods.sortBy(_.id).map(_.point)) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala new file mode 100644 index 000000000000..d006b39acb21 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -0,0 +1,475 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import java.util.Random + +import breeze.linalg.{DenseVector => BDV, normalize} + +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaPairRDD +import org.apache.spark.graphx._ +import org.apache.spark.graphx.impl.GraphImpl +import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils + + +/** + * :: Experimental :: + * + * Latent Dirichlet Allocation (LDA), a topic model designed for text documents. + * + * Terminology: + * - "word" = "term": an element of the vocabulary + * - "token": instance of a term appearing in a document + * - "topic": multinomial distribution over words representing some concept + * + * Currently, the underlying implementation uses Expectation-Maximization (EM), implemented + * according to the Asuncion et al. (2009) paper referenced below. + * + * References: + * - Original LDA paper (journal version): + * Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. + * - This class implements their "smoothed" LDA model. + * - Paper which clearly explains several algorithms, including EM: + * Asuncion, Welling, Smyth, and Teh. + * "On Smoothing and Inference for Topic Models." UAI, 2009. + * + * @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation + * (Wikipedia)]] + */ +@Experimental +class LDA private ( + private var k: Int, + private var maxIterations: Int, + private var docConcentration: Double, + private var topicConcentration: Double, + private var seed: Long, + private var checkpointInterval: Int) extends Logging { + + def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1, + seed = Utils.random.nextLong(), checkpointInterval = 10) + + /** + * Number of topics to infer. I.e., the number of soft cluster centers. + */ + def getK: Int = k + + /** + * Number of topics to infer. I.e., the number of soft cluster centers. + * (default = 10) + */ + def setK(k: Int): this.type = { + require(k > 0, s"LDA k (number of clusters) must be > 0, but was set to $k") + this.k = k + this + } + + /** + * Concentration parameter (commonly named "alpha") for the prior placed on documents' + * distributions over topics ("theta"). + * + * This is the parameter to a symmetric Dirichlet distribution. + */ + def getDocConcentration: Double = { + if (this.docConcentration == -1) { + (50.0 / k) + 1.0 + } else { + this.docConcentration + } + } + + /** + * Concentration parameter (commonly named "alpha") for the prior placed on documents' + * distributions over topics ("theta"). + * + * This is the parameter to a symmetric Dirichlet distribution. + * + * This value should be > 1.0, where larger values mean more smoothing (more regularization). + * If set to -1, then docConcentration is set automatically. + * (default = -1 = automatic) + * + * Automatic setting of parameter: + * - For EM: default = (50 / k) + 1. + * - The 50/k is common in LDA libraries. + * - The +1 follows Asuncion et al. (2009), who recommend a +1 adjustment for EM. + * + * Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions), + * but values in (0,1) are not yet supported. + */ + def setDocConcentration(docConcentration: Double): this.type = { + require(docConcentration > 1.0 || docConcentration == -1.0, + s"LDA docConcentration must be > 1.0 (or -1 for auto), but was set to $docConcentration") + this.docConcentration = docConcentration + this + } + + /** Alias for [[getDocConcentration]] */ + def getAlpha: Double = getDocConcentration + + /** Alias for [[setDocConcentration()]] */ + def setAlpha(alpha: Double): this.type = setDocConcentration(alpha) + + /** + * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics' + * distributions over terms. + * + * This is the parameter to a symmetric Dirichlet distribution. + * + * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. + */ + def getTopicConcentration: Double = { + if (this.topicConcentration == -1) { + 1.1 + } else { + this.topicConcentration + } + } + + /** + * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics' + * distributions over terms. + * + * This is the parameter to a symmetric Dirichlet distribution. + * + * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. + * + * This value should be > 0.0. + * If set to -1, then topicConcentration is set automatically. + * (default = -1 = automatic) + * + * Automatic setting of parameter: + * - For EM: default = 0.1 + 1. + * - The 0.1 gives a small amount of smoothing. + * - The +1 follows Asuncion et al. (2009), who recommend a +1 adjustment for EM. + * + * Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions), + * but values in (0,1) are not yet supported. + */ + def setTopicConcentration(topicConcentration: Double): this.type = { + require(topicConcentration > 1.0 || topicConcentration == -1.0, + s"LDA topicConcentration must be > 1.0 (or -1 for auto), but was set to $topicConcentration") + this.topicConcentration = topicConcentration + this + } + + /** Alias for [[getTopicConcentration]] */ + def getBeta: Double = getTopicConcentration + + /** Alias for [[setTopicConcentration()]] */ + def setBeta(beta: Double): this.type = setTopicConcentration(beta) + + /** + * Maximum number of iterations for learning. + */ + def getMaxIterations: Int = maxIterations + + /** + * Maximum number of iterations for learning. + * (default = 20) + */ + def setMaxIterations(maxIterations: Int): this.type = { + this.maxIterations = maxIterations + this + } + + /** Random seed */ + def getSeed: Long = seed + + /** Random seed */ + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } + + /** + * Period (in iterations) between checkpoints. + */ + def getCheckpointInterval: Int = checkpointInterval + + /** + * Period (in iterations) between checkpoints (default = 10). Checkpointing helps with recovery + * (when nodes fail). It also helps with eliminating temporary shuffle files on disk, which can be + * important when LDA is run for many iterations. If the checkpoint directory is not set in + * [[org.apache.spark.SparkContext]], this setting is ignored. + * + * @see [[org.apache.spark.SparkContext#setCheckpointDir]] + */ + def setCheckpointInterval(checkpointInterval: Int): this.type = { + this.checkpointInterval = checkpointInterval + this + } + + /** + * Learn an LDA model using the given dataset. + * + * @param documents RDD of documents, which are term (word) count vectors paired with IDs. + * The term count vectors are "bags of words" with a fixed-size vocabulary + * (where the vocabulary size is the length of the vector). + * Document IDs must be unique and >= 0. + * @return Inferred LDA model + */ + def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = { + val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed, + checkpointInterval) + var iter = 0 + val iterationTimes = Array.fill[Double](maxIterations)(0) + while (iter < maxIterations) { + val start = System.nanoTime() + state.next() + val elapsedSeconds = (System.nanoTime() - start) / 1e9 + iterationTimes(iter) = elapsedSeconds + iter += 1 + } + state.graphCheckpointer.deleteAllCheckpoints() + new DistributedLDAModel(state, iterationTimes) + } + + /** Java-friendly version of [[run()]] */ + def run(documents: JavaPairRDD[java.lang.Long, Vector]): DistributedLDAModel = { + run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) + } +} + + +private[clustering] object LDA { + + /* + DEVELOPERS NOTE: + + This implementation uses GraphX, where the graph is bipartite with 2 types of vertices: + - Document vertices + - indexed with unique indices >= 0 + - Store vectors of length k (# topics). + - Term vertices + - indexed {-1, -2, ..., -vocabSize} + - Store vectors of length k (# topics). + - Edges correspond to terms appearing in documents. + - Edges are directed Document -> Term. + - Edges are partitioned by documents. + + Info on EM implementation. + - We follow Section 2.2 from Asuncion et al., 2009. We use some of their notation. + - In this implementation, there is one edge for every unique term appearing in a document, + i.e., for every unique (document, term) pair. + - Notation: + - N_{wkj} = count of tokens of term w currently assigned to topic k in document j + - N_{*} where * is missing a subscript w/k/j is the count summed over missing subscript(s) + - gamma_{wjk} = P(z_i = k | x_i = w, d_i = j), + the probability of term x_i in document d_i having topic z_i. + - Data graph + - Document vertices store N_{kj} + - Term vertices store N_{wk} + - Edges store N_{wj}. + - Global data N_k + - Algorithm + - Initial state: + - Document and term vertices store random counts N_{wk}, N_{kj}. + - E-step: For each (document,term) pair i, compute P(z_i | x_i, d_i). + - Aggregate N_k from term vertices. + - Compute gamma_{wjk} for each possible topic k, from each triplet. + using inputs N_{wk}, N_{kj}, N_k. + - M-step: Compute sufficient statistics for hidden parameters phi and theta + (counts N_{wk}, N_{kj}, N_k). + - Document update: + - N_{kj} <- sum_w N_{wj} gamma_{wjk} + - N_j <- sum_k N_{kj} (only needed to output predictions) + - Term update: + - N_{wk} <- sum_j N_{wj} gamma_{wjk} + - N_k <- sum_w N_{wk} + + TODO: Add simplex constraints to allow alpha in (0,1). + See: Vorontsov and Potapenko. "Tutorial on Probabilistic Topic Modeling : Additive + Regularization for Stochastic Matrix Factorization." 2014. + */ + + /** + * Vector over topics (length k) of token counts. + * The meaning of these counts can vary, and it may or may not be normalized to be a distribution. + */ + private[clustering] type TopicCounts = BDV[Double] + + private[clustering] type TokenCount = Double + + /** Term vertex IDs are {-1, -2, ..., -vocabSize} */ + private[clustering] def term2index(term: Int): Long = -(1 + term.toLong) + + private[clustering] def index2term(termIndex: Long): Int = -(1 + termIndex).toInt + + private[clustering] def isDocumentVertex(v: (VertexId, _)): Boolean = v._1 >= 0 + + private[clustering] def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0 + + /** + * Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters. + * + * @param graph EM graph, storing current parameter estimates in vertex descriptors and + * data (token counts) in edge descriptors. + * @param k Number of topics + * @param vocabSize Number of unique terms + * @param docConcentration "alpha" + * @param topicConcentration "beta" or "eta" + */ + private[clustering] class EMOptimizer( + var graph: Graph[TopicCounts, TokenCount], + val k: Int, + val vocabSize: Int, + val docConcentration: Double, + val topicConcentration: Double, + checkpointInterval: Int) { + + private[LDA] val graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( + graph, checkpointInterval) + + def next(): EMOptimizer = { + val eta = topicConcentration + val W = vocabSize + val alpha = docConcentration + + val N_k = globalTopicTotals + val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit = + (edgeContext) => { + // Compute N_{wj} gamma_{wjk} + val N_wj = edgeContext.attr + // E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count + // N_{wj}. + val scaledTopicDistribution: TopicCounts = + computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj + edgeContext.sendToDst((false, scaledTopicDistribution)) + edgeContext.sendToSrc((false, scaledTopicDistribution)) + } + // This is a hack to detect whether we could modify the values in-place. + // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438) + val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) = + (m0, m1) => { + val sum = + if (m0._1) { + m0._2 += m1._2 + } else if (m1._1) { + m1._2 += m0._2 + } else { + m0._2 + m1._2 + } + (true, sum) + } + // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts. + val docTopicDistributions: VertexRDD[TopicCounts] = + graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg) + .mapValues(_._2) + // Update the vertex descriptors with the new counts. + val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges) + graph = newGraph + graphCheckpointer.updateGraph(newGraph) + globalTopicTotals = computeGlobalTopicTotals() + this + } + + /** + * Aggregate distributions over topics from all term vertices. + * + * Note: This executes an action on the graph RDDs. + */ + var globalTopicTotals: TopicCounts = computeGlobalTopicTotals() + + private def computeGlobalTopicTotals(): TopicCounts = { + val numTopics = k + graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _) + } + + } + + /** + * Compute gamma_{wjk}, a distribution over topics k. + */ + private def computePTopic( + docTopicCounts: TopicCounts, + termTopicCounts: TopicCounts, + totalTopicCounts: TopicCounts, + vocabSize: Int, + eta: Double, + alpha: Double): TopicCounts = { + val K = docTopicCounts.length + val N_j = docTopicCounts.data + val N_w = termTopicCounts.data + val N = totalTopicCounts.data + val eta1 = eta - 1.0 + val alpha1 = alpha - 1.0 + val Weta1 = vocabSize * eta1 + var sum = 0.0 + val gamma_wj = new Array[Double](K) + var k = 0 + while (k < K) { + val gamma_wjk = (N_w(k) + eta1) * (N_j(k) + alpha1) / (N(k) + Weta1) + gamma_wj(k) = gamma_wjk + sum += gamma_wjk + k += 1 + } + // normalize + BDV(gamma_wj) /= sum + } + + /** + * Compute bipartite term/doc graph. + */ + private def initialState( + docs: RDD[(Long, Vector)], + k: Int, + docConcentration: Double, + topicConcentration: Double, + randomSeed: Long, + checkpointInterval: Int): EMOptimizer = { + // For each document, create an edge (Document -> Term) for each unique term in the document. + val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) => + // Add edges for terms with non-zero counts. + termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) => + Edge(docID, term2index(term), cnt) + } + } + + val vocabSize = docs.take(1).head._2.size + + // Create vertices. + // Initially, we use random soft assignments of tokens to topics (random gamma). + def createVertices(): RDD[(VertexId, TopicCounts)] = { + val verticesTMP: RDD[(VertexId, TopicCounts)] = + edges.mapPartitionsWithIndex { case (partIndex, partEdges) => + val random = new Random(partIndex + randomSeed) + partEdges.flatMap { edge => + val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0) + val sum = gamma * edge.attr + Seq((edge.srcId, sum), (edge.dstId, sum)) + } + } + verticesTMP.reduceByKey(_ + _) + } + + val docTermVertices = createVertices() + + // Partition such that edges are grouped by document + val graph = Graph(docTermVertices, edges) + .partitionBy(PartitionStrategy.EdgePartition1D) + + new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointInterval) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala new file mode 100644 index 000000000000..0a3f21ecee0d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -0,0 +1,351 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.graphx.{VertexId, EdgeContext, Graph} +import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix} +import org.apache.spark.rdd.RDD +import org.apache.spark.util.BoundedPriorityQueue + +/** + * :: Experimental :: + * + * Latent Dirichlet Allocation (LDA) model. + * + * This abstraction permits for different underlying representations, + * including local and distributed data structures. + */ +@Experimental +abstract class LDAModel private[clustering] { + + /** Number of topics */ + def k: Int + + /** Vocabulary size (number of terms or terms in the vocabulary) */ + def vocabSize: Int + + /** + * Inferred topics, where each topic is represented by a distribution over terms. + * This is a matrix of size vocabSize x k, where each column is a topic. + * No guarantees are given about the ordering of the topics. + */ + def topicsMatrix: Matrix + + /** + * Return the topics described by weighted terms. + * + * This limits the number of terms per topic. + * This is approximate; it may not return exactly the top-weighted terms for each topic. + * To get a more precise set of top terms, increase maxTermsPerTopic. + * + * @param maxTermsPerTopic Maximum number of terms to collect for each topic. + * @return Array over topics. Each topic is represented as a pair of matching arrays: + * (term indices, term weights in topic). + * Each topic's terms are sorted in order of decreasing weight. + */ + def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] + + /** + * Return the topics described by weighted terms. + * + * WARNING: If vocabSize and k are large, this can return a large object! + * + * @return Array over topics. Each topic is represented as a pair of matching arrays: + * (term indices, term weights in topic). + * Each topic's terms are sorted in order of decreasing weight. + */ + def describeTopics(): Array[(Array[Int], Array[Double])] = describeTopics(vocabSize) + + /* TODO (once LDA can be trained with Strings or given a dictionary) + * Return the topics described by weighted terms. + * + * This is similar to [[describeTopics()]] but returns String values for terms. + * If this model was trained using Strings or was given a dictionary, then this method returns + * terms as text. Otherwise, this method returns terms as term indices. + * + * This limits the number of terms per topic. + * This is approximate; it may not return exactly the top-weighted terms for each topic. + * To get a more precise set of top terms, increase maxTermsPerTopic. + * + * @param maxTermsPerTopic Maximum number of terms to collect for each topic. + * @return Array over topics. Each topic is represented as a pair of matching arrays: + * (terms, term weights in topic) where terms are either the actual term text + * (if available) or the term indices. + * Each topic's terms are sorted in order of decreasing weight. + */ + // def describeTopicsAsStrings(maxTermsPerTopic: Int): Array[(Array[Double], Array[String])] + + /* TODO (once LDA can be trained with Strings or given a dictionary) + * Return the topics described by weighted terms. + * + * This is similar to [[describeTopics()]] but returns String values for terms. + * If this model was trained using Strings or was given a dictionary, then this method returns + * terms as text. Otherwise, this method returns terms as term indices. + * + * WARNING: If vocabSize and k are large, this can return a large object! + * + * @return Array over topics. Each topic is represented as a pair of matching arrays: + * (terms, term weights in topic) where terms are either the actual term text + * (if available) or the term indices. + * Each topic's terms are sorted in order of decreasing weight. + */ + // def describeTopicsAsStrings(): Array[(Array[Double], Array[String])] = + // describeTopicsAsStrings(vocabSize) + + /* TODO + * Compute the log likelihood of the observed tokens, given the current parameter estimates: + * log P(docs | topics, topic distributions for docs, alpha, eta) + * + * Note: + * - This excludes the prior. + * - Even with the prior, this is NOT the same as the data log likelihood given the + * hyperparameters. + * + * @param documents RDD of documents, which are term (word) count vectors paired with IDs. + * The term count vectors are "bags of words" with a fixed-size vocabulary + * (where the vocabulary size is the length of the vector). + * This must use the same vocabulary (ordering of term counts) as in training. + * Document IDs must be unique and >= 0. + * @return Estimated log likelihood of the data under this model + */ + // def logLikelihood(documents: RDD[(Long, Vector)]): Double + + /* TODO + * Compute the estimated topic distribution for each document. + * This is often called 'theta' in the literature. + * + * @param documents RDD of documents, which are term (word) count vectors paired with IDs. + * The term count vectors are "bags of words" with a fixed-size vocabulary + * (where the vocabulary size is the length of the vector). + * This must use the same vocabulary (ordering of term counts) as in training. + * Document IDs must be unique and >= 0. + * @return Estimated topic distribution for each document. + * The returned RDD may be zipped with the given RDD, where each returned vector + * is a multinomial distribution over topics. + */ + // def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] + +} + +/** + * :: Experimental :: + * + * Local LDA model. + * This model stores only the inferred topics. + * It may be used for computing topics for new documents, but it may give less accurate answers + * than the [[DistributedLDAModel]]. + * + * @param topics Inferred topics (vocabSize x k matrix). + */ +@Experimental +class LocalLDAModel private[clustering] ( + private val topics: Matrix) extends LDAModel with Serializable { + + override def k: Int = topics.numCols + + override def vocabSize: Int = topics.numRows + + override def topicsMatrix: Matrix = topics + + override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = { + val brzTopics = topics.toBreeze.toDenseMatrix + Range(0, k).map { topicIndex => + val topic = normalize(brzTopics(::, topicIndex), 1.0) + val (termWeights, terms) = + topic.toArray.zipWithIndex.sortBy(-_._1).take(maxTermsPerTopic).unzip + (terms.toArray, termWeights.toArray) + }.toArray + } + + // TODO + // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? + + // TODO: + // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? + +} + +/** + * :: Experimental :: + * + * Distributed LDA model. + * This model stores the inferred topics, the full training dataset, and the topic distributions. + * When computing topics for new documents, it may give more accurate answers + * than the [[LocalLDAModel]]. + */ +@Experimental +class DistributedLDAModel private ( + private val graph: Graph[LDA.TopicCounts, LDA.TokenCount], + private val globalTopicTotals: LDA.TopicCounts, + val k: Int, + val vocabSize: Int, + private val docConcentration: Double, + private val topicConcentration: Double, + private[spark] val iterationTimes: Array[Double]) extends LDAModel { + + import LDA._ + + private[clustering] def this(state: LDA.EMOptimizer, iterationTimes: Array[Double]) = { + this(state.graph, state.globalTopicTotals, state.k, state.vocabSize, state.docConcentration, + state.topicConcentration, iterationTimes) + } + + /** + * Convert model to a local model. + * The local model stores the inferred topics but not the topic distributions for training + * documents. + */ + def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix) + + /** + * Inferred topics, where each topic is represented by a distribution over terms. + * This is a matrix of size vocabSize x k, where each column is a topic. + * No guarantees are given about the ordering of the topics. + * + * WARNING: This matrix is collected from an RDD. Beware memory usage when vocabSize, k are large. + */ + override lazy val topicsMatrix: Matrix = { + // Collect row-major topics + val termTopicCounts: Array[(Int, TopicCounts)] = + graph.vertices.filter(_._1 < 0).map { case (termIndex, cnts) => + (index2term(termIndex), cnts) + }.collect() + // Convert to Matrix + val brzTopics = BDM.zeros[Double](vocabSize, k) + termTopicCounts.foreach { case (term, cnts) => + var j = 0 + while (j < k) { + brzTopics(term, j) = cnts(j) + j += 1 + } + } + Matrices.fromBreeze(brzTopics) + } + + override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = { + val numTopics = k + // Note: N_k is not needed to find the top terms, but it is needed to normalize weights + // to a distribution over terms. + val N_k: TopicCounts = globalTopicTotals + val topicsInQueues: Array[BoundedPriorityQueue[(Double, Int)]] = + graph.vertices.filter(isTermVertex) + .mapPartitions { termVertices => + // For this partition, collect the most common terms for each topic in queues: + // queues(topic) = queue of (term weight, term index). + // Term weights are N_{wk} / N_k. + val queues = + Array.fill(numTopics)(new BoundedPriorityQueue[(Double, Int)](maxTermsPerTopic)) + for ((termId, n_wk) <- termVertices) { + var topic = 0 + while (topic < numTopics) { + queues(topic) += (n_wk(topic) / N_k(topic) -> index2term(termId.toInt)) + topic += 1 + } + } + Iterator(queues) + }.reduce { (q1, q2) => + q1.zip(q2).foreach { case (a, b) => a ++= b} + q1 + } + topicsInQueues.map { q => + val (termWeights, terms) = q.toArray.sortBy(-_._1).unzip + (terms.toArray, termWeights.toArray) + } + } + + // TODO + // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? + + /** + * Log likelihood of the observed tokens in the training set, + * given the current parameter estimates: + * log P(docs | topics, topic distributions for docs, alpha, eta) + * + * Note: + * - This excludes the prior; for that, use [[logPrior]]. + * - Even with [[logPrior]], this is NOT the same as the data log likelihood given the + * hyperparameters. + */ + lazy val logLikelihood: Double = { + val eta = topicConcentration + val alpha = docConcentration + assert(eta > 1.0) + assert(alpha > 1.0) + val N_k = globalTopicTotals + val smoothed_N_k: TopicCounts = N_k + (vocabSize * (eta - 1.0)) + // Edges: Compute token log probability from phi_{wk}, theta_{kj}. + val sendMsg: EdgeContext[TopicCounts, TokenCount, Double] => Unit = (edgeContext) => { + val N_wj = edgeContext.attr + val smoothed_N_wk: TopicCounts = edgeContext.dstAttr + (eta - 1.0) + val smoothed_N_kj: TopicCounts = edgeContext.srcAttr + (alpha - 1.0) + val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k + val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0) + val tokenLogLikelihood = N_wj * math.log(phi_wk.dot(theta_kj)) + edgeContext.sendToDst(tokenLogLikelihood) + } + graph.aggregateMessages[Double](sendMsg, _ + _) + .map(_._2).fold(0.0)(_ + _) + } + + /** + * Log probability of the current parameter estimate: + * log P(topics, topic distributions for docs | alpha, eta) + */ + lazy val logPrior: Double = { + val eta = topicConcentration + val alpha = docConcentration + // Term vertices: Compute phi_{wk}. Use to compute prior log probability. + // Doc vertex: Compute theta_{kj}. Use to compute prior log probability. + val N_k = globalTopicTotals + val smoothed_N_k: TopicCounts = N_k + (vocabSize * (eta - 1.0)) + val seqOp: (Double, (VertexId, TopicCounts)) => Double = { + case (sumPrior: Double, vertex: (VertexId, TopicCounts)) => + if (isTermVertex(vertex)) { + val N_wk = vertex._2 + val smoothed_N_wk: TopicCounts = N_wk + (eta - 1.0) + val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k + (eta - 1.0) * brzSum(phi_wk.map(math.log)) + } else { + val N_kj = vertex._2 + val smoothed_N_kj: TopicCounts = N_kj + (alpha - 1.0) + val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0) + (alpha - 1.0) * brzSum(theta_kj.map(math.log)) + } + } + graph.vertices.aggregate(0.0)(seqOp, _ + _) + } + + /** + * For each document in the training set, return the distribution over topics for that document + * ("theta_doc"). + * + * @return RDD of (document ID, topic distribution) pairs + */ + def topicDistributions: RDD[(Long, Vector)] = { + graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) => + (docID.toLong, Vectors.fromBreeze(normalize(topicCounts, 1.0))) + } + } + + // TODO: + // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala new file mode 100644 index 000000000000..aa53e88d5985 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.json4s.JsonDSL._ +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.graphx._ +import org.apache.spark.graphx.impl.GraphImpl +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.{Loader, MLUtils, Saveable} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.util.random.XORShiftRandom +import org.apache.spark.{Logging, SparkContext, SparkException} + +/** + * :: Experimental :: + * + * Model produced by [[PowerIterationClustering]]. + * + * @param k number of clusters + * @param assignments an RDD of clustering [[PowerIterationClustering#Assignment]]s + */ +@Experimental +class PowerIterationClusteringModel( + val k: Int, + val assignments: RDD[PowerIterationClustering.Assignment]) extends Saveable with Serializable { + + override def save(sc: SparkContext, path: String): Unit = { + PowerIterationClusteringModel.SaveLoadV1_0.save(sc, this, path) + } + + override protected def formatVersion: String = "1.0" +} + +object PowerIterationClusteringModel extends Loader[PowerIterationClusteringModel] { + override def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { + PowerIterationClusteringModel.SaveLoadV1_0.load(sc, path) + } + + private[clustering] + object SaveLoadV1_0 { + + private val thisFormatVersion = "1.0" + + private[clustering] + val thisClassName = "org.apache.spark.mllib.clustering.PowerIterationClusteringModel" + + def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val dataRDD = model.assignments.toDF() + dataRDD.saveAsParquetFile(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { + implicit val formats = DefaultFormats + val sqlContext = new SQLContext(sc) + + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + + val k = (metadata \ "k").extract[Int] + val assignments = sqlContext.parquetFile(Loader.dataPath(path)) + Loader.checkSchema[PowerIterationClustering.Assignment](assignments.schema) + + val assignmentsRDD = assignments.map { + case Row(id: Long, cluster: Int) => PowerIterationClustering.Assignment(id, cluster) + } + + new PowerIterationClusteringModel(k, assignmentsRDD) + } + } +} + +/** + * :: Experimental :: + * + * Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by + * [[http://www.icml2010.org/papers/387.pdf Lin and Cohen]]. From the abstract: PIC finds a very + * low-dimensional embedding of a dataset using truncated power iteration on a normalized pair-wise + * similarity matrix of the data. + * + * @param k Number of clusters. + * @param maxIterations Maximum number of iterations of the PIC algorithm. + * @param initMode Initialization mode. + * + * @see [[http://en.wikipedia.org/wiki/Spectral_clustering Spectral clustering (Wikipedia)]] + */ +@Experimental +class PowerIterationClustering private[clustering] ( + private var k: Int, + private var maxIterations: Int, + private var initMode: String) extends Serializable { + + import org.apache.spark.mllib.clustering.PowerIterationClustering._ + + /** Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100, + * initMode: "random"}. + */ + def this() = this(k = 2, maxIterations = 100, initMode = "random") + + /** + * Set the number of clusters. + */ + def setK(k: Int): this.type = { + this.k = k + this + } + + /** + * Set maximum number of iterations of the power iteration loop + */ + def setMaxIterations(maxIterations: Int): this.type = { + this.maxIterations = maxIterations + this + } + + /** + * Set the initialization mode. This can be either "random" to use a random vector + * as vertex properties, or "degree" to use normalized sum similarities. Default: random. + */ + def setInitializationMode(mode: String): this.type = { + this.initMode = mode match { + case "random" | "degree" => mode + case _ => throw new IllegalArgumentException("Invalid initialization mode: " + mode) + } + this + } + + /** + * Run the PIC algorithm. + * + * @param similarities an RDD of (i, j, s,,ij,,) tuples representing the affinity matrix, which is + * the matrix A in the PIC paper. The similarity s,,ij,, must be nonnegative. + * This is a symmetric matrix and hence s,,ij,, = s,,ji,,. For any (i, j) with + * nonzero similarity, there should be either (i, j, s,,ij,,) or + * (j, i, s,,ji,,) in the input. Tuples with i = j are ignored, because we + * assume s,,ij,, = 0.0. + * + * @return a [[PowerIterationClusteringModel]] that contains the clustering result + */ + def run(similarities: RDD[(Long, Long, Double)]): PowerIterationClusteringModel = { + val w = normalize(similarities) + val w0 = initMode match { + case "random" => randomInit(w) + case "degree" => initDegreeVector(w) + } + pic(w0) + } + + /** + * A Java-friendly version of [[PowerIterationClustering.run]]. + */ + def run(similarities: JavaRDD[(java.lang.Long, java.lang.Long, java.lang.Double)]) + : PowerIterationClusteringModel = { + run(similarities.rdd.asInstanceOf[RDD[(Long, Long, Double)]]) + } + + /** + * Runs the PIC algorithm. + * + * @param w The normalized affinity matrix, which is the matrix W in the PIC paper with + * w,,ij,, = a,,ij,, / d,,ii,, as its edge properties and the initial vector of the power + * iteration as its vertex properties. + */ + private def pic(w: Graph[Double, Double]): PowerIterationClusteringModel = { + val v = powerIter(w, maxIterations) + val assignments = kMeans(v, k).mapPartitions({ iter => + iter.map { case (id, cluster) => + Assignment(id, cluster) + } + }, preservesPartitioning = true) + new PowerIterationClusteringModel(k, assignments) + } +} + +@Experimental +object PowerIterationClustering extends Logging { + + /** + * :: Experimental :: + * Cluster assignment. + * @param id node id + * @param cluster assigned cluster id + */ + @Experimental + case class Assignment(id: Long, cluster: Int) + + /** + * Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W). + */ + private[clustering] + def normalize(similarities: RDD[(Long, Long, Double)]) + : Graph[Double, Double] = { + val edges = similarities.flatMap { case (i, j, s) => + if (s < 0.0) { + throw new SparkException("Similarity must be nonnegative but found s($i, $j) = $s.") + } + if (i != j) { + Seq(Edge(i, j, s), Edge(j, i, s)) + } else { + None + } + } + val gA = Graph.fromEdges(edges, 0.0) + val vD = gA.aggregateMessages[Double]( + sendMsg = ctx => { + ctx.sendToSrc(ctx.attr) + }, + mergeMsg = _ + _, + TripletFields.EdgeOnly) + GraphImpl.fromExistingRDDs(vD, gA.edges) + .mapTriplets( + e => e.attr / math.max(e.srcAttr, MLUtils.EPSILON), + TripletFields.Src) + } + + /** + * Generates random vertex properties (v0) to start power iteration. + * + * @param g a graph representing the normalized affinity matrix (W) + * @return a graph with edges representing W and vertices representing a random vector + * with unit 1-norm + */ + private[clustering] + def randomInit(g: Graph[Double, Double]): Graph[Double, Double] = { + val r = g.vertices.mapPartitionsWithIndex( + (part, iter) => { + val random = new XORShiftRandom(part) + iter.map { case (id, _) => + (id, random.nextGaussian()) + } + }, preservesPartitioning = true).cache() + val sum = r.values.map(math.abs).sum() + val v0 = r.mapValues(x => x / sum) + GraphImpl.fromExistingRDDs(VertexRDD(v0), g.edges) + } + + /** + * Generates the degree vector as the vertex properties (v0) to start power iteration. + * It is not exactly the node degrees but just the normalized sum similarities. Call it + * as degree vector because it is used in the PIC paper. + * + * @param g a graph representing the normalized affinity matrix (W) + * @return a graph with edges representing W and vertices representing the degree vector + */ + private[clustering] + def initDegreeVector(g: Graph[Double, Double]): Graph[Double, Double] = { + val sum = g.vertices.values.sum() + val v0 = g.vertices.mapValues(_ / sum) + GraphImpl.fromExistingRDDs(VertexRDD(v0), g.edges) + } + + /** + * Runs power iteration. + * @param g input graph with edges representing the normalized affinity matrix (W) and vertices + * representing the initial vector of the power iterations. + * @param maxIterations maximum number of iterations + * @return a [[VertexRDD]] representing the pseudo-eigenvector + */ + private[clustering] + def powerIter( + g: Graph[Double, Double], + maxIterations: Int): VertexRDD[Double] = { + // the default tolerance used in the PIC paper, with a lower bound 1e-8 + val tol = math.max(1e-5 / g.vertices.count(), 1e-8) + var prevDelta = Double.MaxValue + var diffDelta = Double.MaxValue + var curG = g + for (iter <- 0 until maxIterations if math.abs(diffDelta) > tol) { + val msgPrefix = s"Iteration $iter" + // multiply W by vt + val v = curG.aggregateMessages[Double]( + sendMsg = ctx => ctx.sendToSrc(ctx.attr * ctx.dstAttr), + mergeMsg = _ + _, + TripletFields.Dst).cache() + // normalize v + val norm = v.values.map(math.abs).sum() + logInfo(s"$msgPrefix: norm(v) = $norm.") + val v1 = v.mapValues(x => x / norm) + // compare difference + val delta = curG.joinVertices(v1) { case (_, x, y) => + math.abs(x - y) + }.vertices.values.sum() + logInfo(s"$msgPrefix: delta = $delta.") + diffDelta = math.abs(delta - prevDelta) + logInfo(s"$msgPrefix: diff(delta) = $diffDelta.") + // update v + curG = GraphImpl.fromExistingRDDs(VertexRDD(v1), g.edges) + prevDelta = delta + } + curG.vertices + } + + /** + * Runs k-means clustering. + * @param v a [[VertexRDD]] representing the pseudo-eigenvector + * @param k number of clusters + * @return a [[VertexRDD]] representing the clustering assignments + */ + private[clustering] + def kMeans(v: VertexRDD[Double], k: Int): VertexRDD[Int] = { + val points = v.mapValues(x => Vectors.dense(x)).cache() + val model = new KMeans() + .setK(k) + .setRuns(5) + .setSeed(0L) + .run(points.values) + points.mapValues(p => model.predict(p)).cache() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 7752c1988fdd..812014a04171 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -20,8 +20,7 @@ package org.apache.spark.mllib.clustering import scala.reflect.ClassTag import org.apache.spark.Logging -import org.apache.spark.SparkContext._ -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.dstream.DStream @@ -29,7 +28,8 @@ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom /** - * :: DeveloperApi :: + * :: Experimental :: + * * StreamingKMeansModel extends MLlib's KMeansModel for streaming * algorithms, so it can keep track of a continuously updated weight * associated with each cluster, and also update the model by @@ -39,8 +39,10 @@ import org.apache.spark.util.random.XORShiftRandom * generalized to incorporate forgetfullness (i.e. decay). * The update rule (for each cluster) is: * + * {{{ * c_t+1 = [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t] * n_t+t = n_t * a + m_t + * }}} * * Where c_t is the previously estimated centroid for that cluster, * n_t is the number of points assigned to it thus far, x_t is the centroid @@ -61,7 +63,7 @@ import org.apache.spark.util.random.XORShiftRandom * as batches or points. * */ -@DeveloperApi +@Experimental class StreamingKMeansModel( override val clusterCenters: Array[Vector], val clusterWeights: Array[Double]) extends KMeansModel(clusterCenters) with Logging { @@ -140,7 +142,8 @@ class StreamingKMeansModel( } /** - * :: DeveloperApi :: + * :: Experimental :: + * * StreamingKMeans provides methods for configuring a * streaming k-means analysis, training the model on streaming, * and using the model to make predictions on streaming data. @@ -149,17 +152,19 @@ class StreamingKMeansModel( * Use a builder pattern to construct a streaming k-means analysis * in an application, like: * + * {{{ * val model = new StreamingKMeans() * .setDecayFactor(0.5) * .setK(3) * .setRandomCenters(5, 100.0) * .trainOn(DStream) + * }}} */ -@DeveloperApi +@Experimental class StreamingKMeans( var k: Int, var decayFactor: Double, - var timeUnit: String) extends Logging { + var timeUnit: String) extends Logging with Serializable { def this() = this(2, 1.0, StreamingKMeans.BATCHES) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala index ced042e2f96c..c1d1a224817e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala @@ -22,6 +22,7 @@ import org.apache.spark.Logging import org.apache.spark.SparkContext._ import org.apache.spark.mllib.evaluation.binary._ import org.apache.spark.rdd.{RDD, UnionRDD} +import org.apache.spark.sql.DataFrame /** * :: Experimental :: @@ -53,6 +54,13 @@ class BinaryClassificationMetrics( */ def this(scoreAndLabels: RDD[(Double, Double)]) = this(scoreAndLabels, 0) + /** + * An auxiliary constructor taking a DataFrame. + * @param scoreAndLabels a DataFrame with two double columns: score and label + */ + private[mllib] def this(scoreAndLabels: DataFrame) = + this(scoreAndLabels.map(r => (r.getDouble(0), r.getDouble(1)))) + /** Unpersist intermediate RDDs used in the computation. */ def unpersist() { cumulativeCounts.unpersist() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala index ea10bde5fa25..a8378a76d20a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala @@ -96,30 +96,30 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns precision for a given label (category) * @param label the label. */ - def precision(label: Double) = { + def precision(label: Double): Double = { val tp = tpPerClass(label) val fp = fpPerClass.getOrElse(label, 0L) - if (tp + fp == 0) 0 else tp.toDouble / (tp + fp) + if (tp + fp == 0) 0.0 else tp.toDouble / (tp + fp) } /** * Returns recall for a given label (category) * @param label the label. */ - def recall(label: Double) = { + def recall(label: Double): Double = { val tp = tpPerClass(label) val fn = fnPerClass.getOrElse(label, 0L) - if (tp + fn == 0) 0 else tp.toDouble / (tp + fn) + if (tp + fn == 0) 0.0 else tp.toDouble / (tp + fn) } /** * Returns f1-measure for a given label (category) * @param label the label. */ - def f1Measure(label: Double) = { + def f1Measure(label: Double): Double = { val p = precision(label) val r = recall(label) - if((p + r) == 0) 0 else 2 * p * r / (p + r) + if((p + r) == 0) 0.0 else 2 * p * r / (p + r) } private lazy val sumTp = tpPerClass.foldLeft(0L) { case (sum, (_, tp)) => sum + tp } @@ -130,7 +130,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns micro-averaged label-based precision * (equals to micro-averaged document-based precision) */ - lazy val microPrecision = { + lazy val microPrecision: Double = { val sumFp = fpPerClass.foldLeft(0L){ case(cum, (_, fp)) => cum + fp} sumTp.toDouble / (sumTp + sumFp) } @@ -139,7 +139,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns micro-averaged label-based recall * (equals to micro-averaged document-based recall) */ - lazy val microRecall = { + lazy val microRecall: Double = { val sumFn = fnPerClass.foldLeft(0.0){ case(cum, (_, fn)) => cum + fn} sumTp.toDouble / (sumTp + sumFn) } @@ -148,7 +148,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns micro-averaged label-based f1-measure * (equals to micro-averaged document-based f1-measure) */ - lazy val microF1Measure = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass) + lazy val microF1Measure: Double = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass) /** * Returns the sequence of labels in ascending order diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala new file mode 100644 index 000000000000..c6057c7f837b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import scala.collection.mutable.ArrayBuilder + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.rdd.RDD + +/** + * :: Experimental :: + * Chi Squared selector model. + * + * @param selectedFeatures list of indices to select (filter). Must be ordered asc + */ +@Experimental +class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransformer { + + require(isSorted(selectedFeatures), "Array has to be sorted asc") + + protected def isSorted(array: Array[Int]): Boolean = { + var i = 1 + while (i < array.length) { + if (array(i) < array(i-1)) return false + i += 1 + } + true + } + + /** + * Applies transformation on a vector. + * + * @param vector vector to be transformed. + * @return transformed vector. + */ + override def transform(vector: Vector): Vector = { + compress(vector, selectedFeatures) + } + + /** + * Returns a vector with features filtered. + * Preserves the order of filtered features the same as their indices are stored. + * Might be moved to Vector as .slice + * @param features vector + * @param filterIndices indices of features to filter, must be ordered asc + */ + private def compress(features: Vector, filterIndices: Array[Int]): Vector = { + features match { + case SparseVector(size, indices, values) => + val newSize = filterIndices.length + val newValues = new ArrayBuilder.ofDouble + val newIndices = new ArrayBuilder.ofInt + var i = 0 + var j = 0 + var indicesIdx = 0 + var filterIndicesIdx = 0 + while (i < indices.length && j < filterIndices.length) { + indicesIdx = indices(i) + filterIndicesIdx = filterIndices(j) + if (indicesIdx == filterIndicesIdx) { + newIndices += j + newValues += values(i) + j += 1 + i += 1 + } else { + if (indicesIdx > filterIndicesIdx) { + j += 1 + } else { + i += 1 + } + } + } + // TODO: Sparse representation might be ineffective if (newSize ~= newValues.size) + Vectors.sparse(newSize, newIndices.result(), newValues.result()) + case DenseVector(values) => + val values = features.toArray + Vectors.dense(filterIndices.map(i => values(i))) + case other => + throw new UnsupportedOperationException( + s"Only sparse and dense vectors are supported but got ${other.getClass}.") + } + } +} + +/** + * :: Experimental :: + * Creates a ChiSquared feature selector. + * @param numTopFeatures number of features that selector will select + * (ordered by statistic value descending) + */ +@Experimental +class ChiSqSelector (val numTopFeatures: Int) { + + /** + * Returns a ChiSquared feature selector. + * + * @param data an `RDD[LabeledPoint]` containing the labeled dataset with categorical features. + * Real-valued features will be treated as categorical for each distinct value. + * Apply feature discretizer before using this function. + */ + def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = { + val indices = Statistics.chiSqTest(data) + .zipWithIndex.sortBy { case (res, _) => -res.statistic } + .take(numTopFeatures) + .map { case (_, indices) => indices } + .sorted + new ChiSqSelectorModel(indices) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index 3260f27513c7..a89eea0e21be 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -22,7 +22,6 @@ import breeze.linalg.{DenseVector => BDV} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.rdd.RDD /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index 3c2091732f9b..6ae6917eae59 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -18,15 +18,14 @@ package org.apache.spark.mllib.feature import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD /** * :: Experimental :: - * Standardizes features by removing the mean and scaling to unit variance using column summary + * Standardizes features by removing the mean and scaling to unit std using column summary * statistics on the samples in the training set. * * @param withMean False by default. Centers the data with mean before scaling. It will build a @@ -53,7 +52,11 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { val summary = data.treeAggregate(new MultivariateOnlineSummarizer)( (aggregator, data) => aggregator.add(data), (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) - new StandardScalerModel(withMean, withStd, summary.mean, summary.variance) + new StandardScalerModel( + Vectors.dense(summary.variance.toArray.map(v => math.sqrt(v))), + summary.mean, + withStd, + withMean) } } @@ -61,28 +64,43 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { * :: Experimental :: * Represents a StandardScaler model that can transform vectors. * - * @param withMean whether to center the data before scaling - * @param withStd whether to scale the data to have unit standard deviation + * @param std column standard deviation values * @param mean column mean values - * @param variance column variance values + * @param withStd whether to scale the data to have unit standard deviation + * @param withMean whether to center the data before scaling */ @Experimental -class StandardScalerModel private[mllib] ( - val withMean: Boolean, - val withStd: Boolean, +class StandardScalerModel ( + val std: Vector, val mean: Vector, - val variance: Vector) extends VectorTransformer { - - require(mean.size == variance.size) + var withStd: Boolean, + var withMean: Boolean) extends VectorTransformer { - private lazy val factor: Array[Double] = { - val f = Array.ofDim[Double](variance.size) - var i = 0 - while (i < f.size) { - f(i) = if (variance(i) != 0.0) 1.0 / math.sqrt(variance(i)) else 0.0 - i += 1 + def this(std: Vector, mean: Vector) { + this(std, mean, withStd = std != null, withMean = mean != null) + require(this.withStd || this.withMean, + "at least one of std or mean vectors must be provided") + if (this.withStd && this.withMean) { + require(mean.size == std.size, + "mean and std vectors must have equal size if both are provided") } - f + } + + def this(std: Vector) = this(std, null) + + @DeveloperApi + def setWithMean(withMean: Boolean): this.type = { + require(!(withMean && this.mean == null),"cannot set withMean to true while mean is null") + this.withMean = withMean + this + } + + @DeveloperApi + def setWithStd(withStd: Boolean): this.type = { + require(!(withStd && this.std == null), + "cannot set withStd to true while std is null") + this.withStd = withStd + this } // Since `shift` will be only used in `withMean` branch, we have it as @@ -94,8 +112,8 @@ class StandardScalerModel private[mllib] ( * Applies standardization transformation on a vector. * * @param vector Vector to be standardized. - * @return Standardized vector. If the variance of a column is zero, it will return default `0.0` - * for the column with zero variance. + * @return Standardized vector. If the std of a column is zero, it will return default `0.0` + * for the column with zero std. */ override def transform(vector: Vector): Vector = { require(mean.size == vector.size) @@ -109,11 +127,9 @@ class StandardScalerModel private[mllib] ( val values = vs.clone() val size = values.size if (withStd) { - // Having a local reference of `factor` to avoid overhead as the comment before. - val localFactor = factor var i = 0 while (i < size) { - values(i) = (values(i) - localShift(i)) * localFactor(i) + values(i) = if (std(i) != 0.0) (values(i) - localShift(i)) * (1.0 / std(i)) else 0.0 i += 1 } } else { @@ -127,15 +143,13 @@ class StandardScalerModel private[mllib] ( case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } } else if (withStd) { - // Having a local reference of `factor` to avoid overhead as the comment before. - val localFactor = factor vector match { case DenseVector(vs) => val values = vs.clone() val size = values.size var i = 0 while(i < size) { - values(i) *= localFactor(i) + values(i) *= (if (std(i) != 0.0) 1.0 / std(i) else 0.0) i += 1 } Vectors.dense(values) @@ -146,7 +160,7 @@ class StandardScalerModel private[mllib] ( val nnz = values.size var i = 0 while (i < nnz) { - values(i) *= localFactor(indices(i)) + values(i) *= (if (std(indices(i)) != 0.0) 1.0 / std(indices(i)) else 0.0) i += 1 } Vectors.sparse(size, indices, values) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index d25a7cd5b439..98e83112f52a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -21,18 +21,25 @@ import java.lang.{Iterable => JavaIterable} import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.ArrayBuilder import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.Logging +import org.apache.spark.SparkContext import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, BLAS, DenseVector} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom +import org.apache.spark.sql.{SQLContext, Row} /** * Entry in vocabulary @@ -272,7 +279,7 @@ class Word2Vec extends Serializable with Logging { def hasNext: Boolean = iter.hasNext def next(): Array[Int] = { - var sentence = new ArrayBuffer[Int] + val sentence = ArrayBuilder.make[Int] var sentenceLength = 0 while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) { val word = bcVocabHash.value.get(iter.next()) @@ -283,13 +290,20 @@ class Word2Vec extends Serializable with Logging { case None => } } - sentence.toArray + sentence.result() } } } val newSentences = sentences.repartition(numPartitions).cache() val initRandom = new XORShiftRandom(seed) + + if (vocabSize.toLong * vectorSize * 8 >= Int.MaxValue) { + throw new RuntimeException("Please increase minCount or decrease vectorSize in Word2Vec" + + " to avoid an OOM. You are highly recommended to make your vocabSize*vectorSize, " + + "which is " + vocabSize + "*" + vectorSize + " for now, less than `Int.MaxValue/8`.") + } + val syn0Global = Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) val syn1Global = new Array[Float](vocabSize * vectorSize) @@ -415,7 +429,36 @@ class Word2Vec extends Serializable with Logging { */ @Experimental class Word2VecModel private[mllib] ( - private val model: Map[String, Array[Float]]) extends Serializable { + model: Map[String, Array[Float]]) extends Serializable with Saveable { + + // wordList: Ordered list of words obtained from model. + private val wordList: Array[String] = model.keys.toArray + + // wordIndex: Maps each word to an index, which can retrieve the corresponding + // vector from wordVectors (see below). + private val wordIndex: Map[String, Int] = wordList.zip(0 until model.size).toMap + + // vectorSize: Dimension of each word's vector. + private val vectorSize = model.head._2.size + private val numWords = wordIndex.size + + // wordVectors: Array of length numWords * vectorSize, vector corresponding to the word + // mapped with index i can be retrieved by the slice + // (ind * vectorSize, ind * vectorSize + vectorSize) + // wordVecNorms: Array of length numWords, each value being the Euclidean norm + // of the wordVector. + private val (wordVectors: Array[Float], wordVecNorms: Array[Double]) = { + val wordVectors = new Array[Float](vectorSize * numWords) + val wordVecNorms = new Array[Double](numWords) + var i = 0 + while (i < numWords) { + val vec = model.get(wordList(i)).get + Array.copy(vec, 0, wordVectors, i * vectorSize, vectorSize) + wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1) + i += 1 + } + (wordVectors, wordVecNorms) + } private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { require(v1.length == v2.length, "Vectors should have the same length") @@ -425,7 +468,13 @@ class Word2VecModel private[mllib] ( if (norm1 == 0 || norm2 == 0) return 0.0 blas.sdot(n, v1, 1, v2,1) / norm1 / norm2 } - + + override protected def formatVersion = "1.0" + + def save(sc: SparkContext, path: String): Unit = { + Word2VecModel.SaveLoadV1_0.save(sc, path, getVectors) + } + /** * Transforms a word to its vector representation * @param word a word @@ -459,20 +508,104 @@ class Word2VecModel private[mllib] ( */ def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") - // TODO: optimize top-k + val fVector = vector.toArray.map(_.toFloat) - model.mapValues(vec => cosineSimilarity(fVector, vec)) + val cosineVec = Array.fill[Float](numWords)(0) + val alpha: Float = 1 + val beta: Float = 0 + + blas.sgemv( + "T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1) + + // Need not divide with the norm of the given vector since it is constant. + val updatedCosines = new Array[Double](numWords) + var ind = 0 + while (ind < numWords) { + updatedCosines(ind) = cosineVec(ind) / wordVecNorms(ind) + ind += 1 + } + wordList.zip(updatedCosines) .toSeq .sortBy(- _._2) .take(num + 1) .tail .toArray } - + /** * Returns a map of words to their vector representations. */ def getVectors: Map[String, Array[Float]] = { - model + wordIndex.map { case (word, ind) => + (word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize)) + } + } +} + +@Experimental +object Word2VecModel extends Loader[Word2VecModel] { + + private object SaveLoadV1_0 { + + val formatVersionV1_0 = "1.0" + + val classNameV1_0 = "org.apache.spark.mllib.feature.Word2VecModel" + + case class Data(word: String, vector: Array[Float]) + + def load(sc: SparkContext, path: String): Word2VecModel = { + val dataPath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + val dataFrame = sqlContext.parquetFile(dataPath) + + val dataArray = dataFrame.select("word", "vector").collect() + + // Check schema explicitly since erasure makes it hard to use match-case for checking. + Loader.checkSchema[Data](dataFrame.schema) + + val word2VecMap = dataArray.map(i => (i.getString(0), i.getSeq[Float](1).toArray)).toMap + new Word2VecModel(word2VecMap) + } + + def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = { + + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val vectorSize = model.values.head.size + val numWords = model.size + val metadata = compact(render + (("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ + ("vectorSize" -> vectorSize) ~ ("numWords" -> numWords))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val dataArray = model.toSeq.map { case (w, v) => Data(w, v) } + sc.parallelize(dataArray.toSeq, 1).toDF().saveAsParquetFile(Loader.dataPath(path)) + } + } + + override def load(sc: SparkContext, path: String): Word2VecModel = { + + val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path) + implicit val formats = DefaultFormats + val expectedVectorSize = (metadata \ "vectorSize").extract[Int] + val expectedNumWords = (metadata \ "numWords").extract[Int] + val classNameV1_0 = SaveLoadV1_0.classNameV1_0 + (loadedClassName, loadedVersion) match { + case (classNameV1_0, "1.0") => + val model = SaveLoadV1_0.load(sc, path) + val vectorSize = model.getVectors.values.head.size + val numWords = model.getVectors.size + require(expectedVectorSize == vectorSize, + s"Word2VecModel requires each word to be mapped to a vector of size " + + s"$expectedVectorSize, got vector of size $vectorSize") + require(expectedNumWords == numWords, + s"Word2VecModel requires $expectedNumWords words, but got $numWords") + model + case _ => throw new Exception( + s"Word2VecModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $loadedVersion). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala new file mode 100644 index 000000000000..efa8459d3cdb --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.fpm + +import java.{util => ju} +import java.lang.{Iterable => JavaIterable} + +import scala.collection.mutable +import scala.collection.JavaConverters._ +import scala.reflect.ClassTag + +import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException} +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + +/** + * :: Experimental :: + * + * Model trained by [[FPGrowth]], which holds frequent itemsets. + * @param freqItemsets frequent itemset, which is an RDD of [[FreqItemset]] + * @tparam Item item type + */ +@Experimental +class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable + +/** + * :: Experimental :: + * + * A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in + * [[http://dx.doi.org/10.1145/1454008.1454027 Li et al., PFP: Parallel FP-Growth for Query + * Recommendation]]. PFP distributes computation in such a way that each worker executes an + * independent group of mining tasks. The FP-Growth algorithm is described in + * [[http://dx.doi.org/10.1145/335191.335372 Han et al., Mining frequent patterns without candidate + * generation]]. + * + * @param minSupport the minimal support level of the frequent pattern, any pattern appears + * more than (minSupport * size-of-the-dataset) times will be output + * @param numPartitions number of partitions used by parallel FP-growth + * + * @see [[http://en.wikipedia.org/wiki/Association_rule_learning Association rule learning + * (Wikipedia)]] + */ +@Experimental +class FPGrowth private ( + private var minSupport: Double, + private var numPartitions: Int) extends Logging with Serializable { + + /** + * Constructs a default instance with default parameters {minSupport: `0.3`, numPartitions: same + * as the input data}. + */ + def this() = this(0.3, -1) + + /** + * Sets the minimal support level (default: `0.3`). + */ + def setMinSupport(minSupport: Double): this.type = { + this.minSupport = minSupport + this + } + + /** + * Sets the number of partitions used by parallel FP-growth (default: same as input data). + */ + def setNumPartitions(numPartitions: Int): this.type = { + this.numPartitions = numPartitions + this + } + + /** + * Computes an FP-Growth model that contains frequent itemsets. + * @param data input data set, each element contains a transaction + * @return an [[FPGrowthModel]] + */ + def run[Item: ClassTag](data: RDD[Array[Item]]): FPGrowthModel[Item] = { + if (data.getStorageLevel == StorageLevel.NONE) { + logWarning("Input data is not cached.") + } + val count = data.count() + val minCount = math.ceil(minSupport * count).toLong + val numParts = if (numPartitions > 0) numPartitions else data.partitions.length + val partitioner = new HashPartitioner(numParts) + val freqItems = genFreqItems(data, minCount, partitioner) + val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner) + new FPGrowthModel(freqItemsets) + } + + def run[Item, Basket <: JavaIterable[Item]](data: JavaRDD[Basket]): FPGrowthModel[Item] = { + implicit val tag = fakeClassTag[Item] + run(data.rdd.map(_.asScala.toArray)) + } + + /** + * Generates frequent items by filtering the input data using minimal support level. + * @param minCount minimum count for frequent itemsets + * @param partitioner partitioner used to distribute items + * @return array of frequent pattern ordered by their frequencies + */ + private def genFreqItems[Item: ClassTag]( + data: RDD[Array[Item]], + minCount: Long, + partitioner: Partitioner): Array[Item] = { + data.flatMap { t => + val uniq = t.toSet + if (t.size != uniq.size) { + throw new SparkException(s"Items in a transaction must be unique but got ${t.toSeq}.") + } + t + }.map(v => (v, 1L)) + .reduceByKey(partitioner, _ + _) + .filter(_._2 >= minCount) + .collect() + .sortBy(-_._2) + .map(_._1) + } + + /** + * Generate frequent itemsets by building FP-Trees, the extraction is done on each partition. + * @param data transactions + * @param minCount minimum count for frequent itemsets + * @param freqItems frequent items + * @param partitioner partitioner used to distribute transactions + * @return an RDD of (frequent itemset, count) + */ + private def genFreqItemsets[Item: ClassTag]( + data: RDD[Array[Item]], + minCount: Long, + freqItems: Array[Item], + partitioner: Partitioner): RDD[FreqItemset[Item]] = { + val itemToRank = freqItems.zipWithIndex.toMap + data.flatMap { transaction => + genCondTransactions(transaction, itemToRank, partitioner) + }.aggregateByKey(new FPTree[Int], partitioner.numPartitions)( + (tree, transaction) => tree.add(transaction, 1L), + (tree1, tree2) => tree1.merge(tree2)) + .flatMap { case (part, tree) => + tree.extract(minCount, x => partitioner.getPartition(x) == part) + }.map { case (ranks, count) => + new FreqItemset(ranks.map(i => freqItems(i)).toArray, count) + } + } + + /** + * Generates conditional transactions. + * @param transaction a transaction + * @param itemToRank map from item to their rank + * @param partitioner partitioner used to distribute transactions + * @return a map of (target partition, conditional transaction) + */ + private def genCondTransactions[Item: ClassTag]( + transaction: Array[Item], + itemToRank: Map[Item, Int], + partitioner: Partitioner): mutable.Map[Int, Array[Int]] = { + val output = mutable.Map.empty[Int, Array[Int]] + // Filter the basket by frequent items pattern and sort their ranks. + val filtered = transaction.flatMap(itemToRank.get) + ju.Arrays.sort(filtered) + val n = filtered.length + var i = n - 1 + while (i >= 0) { + val item = filtered(i) + val part = partitioner.getPartition(item) + if (!output.contains(part)) { + output(part) = filtered.slice(0, i + 1) + } + i -= 1 + } + output + } +} + +/** + * :: Experimental :: + */ +@Experimental +object FPGrowth { + + /** + * Frequent itemset. + * @param items items in this itemset. Java users should call [[FreqItemset#javaItems]] instead. + * @param freq frequency + * @tparam Item item type + */ + class FreqItemset[Item](val items: Array[Item], val freq: Long) extends Serializable { + + /** + * Returns items in a Java List. + */ + def javaItems: java.util.List[Item] = { + items.toList.asJava + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPTree.scala new file mode 100644 index 000000000000..1d2d777c0079 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPTree.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.fpm + +import scala.collection.mutable +import scala.collection.mutable.ListBuffer + +/** + * FP-Tree data structure used in FP-Growth. + * @tparam T item type + */ +private[fpm] class FPTree[T] extends Serializable { + + import FPTree._ + + val root: Node[T] = new Node(null) + + private val summaries: mutable.Map[T, Summary[T]] = mutable.Map.empty + + /** Adds a transaction with count. */ + def add(t: Iterable[T], count: Long = 1L): this.type = { + require(count > 0) + var curr = root + curr.count += count + t.foreach { item => + val summary = summaries.getOrElseUpdate(item, new Summary) + summary.count += count + val child = curr.children.getOrElseUpdate(item, { + val newNode = new Node(curr) + newNode.item = item + summary.nodes += newNode + newNode + }) + child.count += count + curr = child + } + this + } + + /** Merges another FP-Tree. */ + def merge(other: FPTree[T]): this.type = { + other.transactions.foreach { case (t, c) => + add(t, c) + } + this + } + + /** Gets a subtree with the suffix. */ + private def project(suffix: T): FPTree[T] = { + val tree = new FPTree[T] + if (summaries.contains(suffix)) { + val summary = summaries(suffix) + summary.nodes.foreach { node => + var t = List.empty[T] + var curr = node.parent + while (!curr.isRoot) { + t = curr.item :: t + curr = curr.parent + } + tree.add(t, node.count) + } + } + tree + } + + /** Returns all transactions in an iterator. */ + def transactions: Iterator[(List[T], Long)] = getTransactions(root) + + /** Returns all transactions under this node. */ + private def getTransactions(node: Node[T]): Iterator[(List[T], Long)] = { + var count = node.count + node.children.iterator.flatMap { case (item, child) => + getTransactions(child).map { case (t, c) => + count -= c + (item :: t, c) + } + } ++ { + if (count > 0) { + Iterator.single((Nil, count)) + } else { + Iterator.empty + } + } + } + + /** Extracts all patterns with valid suffix and minimum count. */ + def extract( + minCount: Long, + validateSuffix: T => Boolean = _ => true): Iterator[(List[T], Long)] = { + summaries.iterator.flatMap { case (item, summary) => + if (validateSuffix(item) && summary.count >= minCount) { + Iterator.single((item :: Nil, summary.count)) ++ + project(item).extract(minCount).map { case (t, c) => + (item :: t, c) + } + } else { + Iterator.empty + } + } + } +} + +private[fpm] object FPTree { + + /** Representing a node in an FP-Tree. */ + class Node[T](val parent: Node[T]) extends Serializable { + var item: T = _ + var count: Long = 0L + val children: mutable.Map[T, Node[T]] = mutable.Map.empty + + def isRoot: Boolean = parent == null + } + + /** Summary of a item in an FP-Tree. */ + private class Summary[T] extends Serializable { + var count: Long = 0L + val nodes: ListBuffer[Node[T]] = ListBuffer.empty + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala new file mode 100644 index 000000000000..6e5dd119dd65 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import scala.collection.mutable + +import org.apache.hadoop.fs.{Path, FileSystem} + +import org.apache.spark.Logging +import org.apache.spark.graphx.Graph +import org.apache.spark.storage.StorageLevel + + +/** + * This class helps with persisting and checkpointing Graphs. + * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as + * unpersisting and removing checkpoint files. + * + * Users should call [[PeriodicGraphCheckpointer.updateGraph()]] when a new graph has been created, + * before the graph has been materialized. After updating [[PeriodicGraphCheckpointer]], users are + * responsible for materializing the graph to ensure that persisting and checkpointing actually + * occur. + * + * When [[PeriodicGraphCheckpointer.updateGraph()]] is called, this does the following: + * - Persist new graph (if not yet persisted), and put in queue of persisted graphs. + * - Unpersist graphs from queue until there are at most 3 persisted graphs. + * - If using checkpointing and the checkpoint interval has been reached, + * - Checkpoint the new graph, and put in a queue of checkpointed graphs. + * - Remove older checkpoints. + * + * WARNINGS: + * - This class should NOT be copied (since copies may conflict on which Graphs should be + * checkpointed). + * - This class removes checkpoint files once later graphs have been checkpointed. + * However, references to the older graphs will still return isCheckpointed = true. + * + * Example usage: + * {{{ + * val (graph1, graph2, graph3, ...) = ... + * val cp = new PeriodicGraphCheckpointer(graph1, dir, 2) + * graph1.vertices.count(); graph1.edges.count() + * // persisted: graph1 + * cp.updateGraph(graph2) + * graph2.vertices.count(); graph2.edges.count() + * // persisted: graph1, graph2 + * // checkpointed: graph2 + * cp.updateGraph(graph3) + * graph3.vertices.count(); graph3.edges.count() + * // persisted: graph1, graph2, graph3 + * // checkpointed: graph2 + * cp.updateGraph(graph4) + * graph4.vertices.count(); graph4.edges.count() + * // persisted: graph2, graph3, graph4 + * // checkpointed: graph4 + * cp.updateGraph(graph5) + * graph5.vertices.count(); graph5.edges.count() + * // persisted: graph3, graph4, graph5 + * // checkpointed: graph4 + * }}} + * + * @param currentGraph Initial graph + * @param checkpointInterval Graphs will be checkpointed at this interval + * @tparam VD Vertex descriptor type + * @tparam ED Edge descriptor type + * + * TODO: Generalize this for Graphs and RDDs, and move it out of MLlib. + */ +private[mllib] class PeriodicGraphCheckpointer[VD, ED]( + var currentGraph: Graph[VD, ED], + val checkpointInterval: Int) extends Logging { + + /** FIFO queue of past checkpointed RDDs */ + private val checkpointQueue = mutable.Queue[Graph[VD, ED]]() + + /** FIFO queue of past persisted RDDs */ + private val persistedQueue = mutable.Queue[Graph[VD, ED]]() + + /** Number of times [[updateGraph()]] has been called */ + private var updateCount = 0 + + /** + * Spark Context for the Graphs given to this checkpointer. + * NOTE: This code assumes that only one SparkContext is used for the given graphs. + */ + private val sc = currentGraph.vertices.sparkContext + + updateGraph(currentGraph) + + /** + * Update [[currentGraph]] with a new graph. Handle persistence and checkpointing as needed. + * Since this handles persistence and checkpointing, this should be called before the graph + * has been materialized. + * + * @param newGraph New graph created from previous graphs in the lineage. + */ + def updateGraph(newGraph: Graph[VD, ED]): Unit = { + if (newGraph.vertices.getStorageLevel == StorageLevel.NONE) { + newGraph.persist() + } + persistedQueue.enqueue(newGraph) + // We try to maintain 2 Graphs in persistedQueue to support the semantics of this class: + // Users should call [[updateGraph()]] when a new graph has been created, + // before the graph has been materialized. + while (persistedQueue.size > 3) { + val graphToUnpersist = persistedQueue.dequeue() + graphToUnpersist.unpersist(blocking = false) + } + updateCount += 1 + + // Handle checkpointing (after persisting) + if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) { + // Add new checkpoint before removing old checkpoints. + newGraph.checkpoint() + checkpointQueue.enqueue(newGraph) + // Remove checkpoints before the latest one. + var canDelete = true + while (checkpointQueue.size > 1 && canDelete) { + // Delete the oldest checkpoint only if the next checkpoint exists. + if (checkpointQueue.get(1).get.isCheckpointed) { + removeCheckpointFile() + } else { + canDelete = false + } + } + } + } + + /** + * Call this at the end to delete any remaining checkpoint files. + */ + def deleteAllCheckpoints(): Unit = { + while (checkpointQueue.size > 0) { + removeCheckpointFile() + } + } + + /** + * Dequeue the oldest checkpointed Graph, and remove its checkpoint files. + * This prints a warning but does not fail if the files cannot be removed. + */ + private def removeCheckpointFile(): Unit = { + val old = checkpointQueue.dequeue() + // Since the old checkpoint is not deleted by Spark, we manually delete it. + val fs = FileSystem.get(sc.hadoopConfiguration) + old.getCheckpointFiles.foreach { checkpointFile => + try { + fs.delete(new Path(checkpointFile), true) + } catch { + case e: Exception => + logWarning("PeriodicGraphCheckpointer could not remove old checkpoint file: " + + checkpointFile) + } + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 3414daccd7ca..87052e1ba853 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -235,12 +235,24 @@ private[spark] object BLAS extends Serializable with Logging { * @param x the vector x that contains the n elements. * @param A the symmetric matrix A. Size of n x n. */ - def syr(alpha: Double, x: DenseVector, A: DenseMatrix) { + def syr(alpha: Double, x: Vector, A: DenseMatrix) { val mA = A.numRows val nA = A.numCols - require(mA == nA, s"A is not a symmetric matrix. A: $mA x $nA") + require(mA == nA, s"A is not a square matrix (and hence is not symmetric). A: $mA x $nA") require(mA == x.size, s"The size of x doesn't match the rank of A. A: $mA x $nA, x: ${x.size}") + x match { + case dv: DenseVector => syr(alpha, dv, A) + case sv: SparseVector => syr(alpha, sv, A) + case _ => + throw new IllegalArgumentException(s"syr doesn't support vector type ${x.getClass}.") + } + } + + private def syr(alpha: Double, x: DenseVector, A: DenseMatrix) { + val nA = A.numRows + val mA = A.numCols + nativeBLAS.dsyr("U", x.size, alpha, x.values, 1, A.values, nA) // Fill lower triangular part of A @@ -255,82 +267,77 @@ private[spark] object BLAS extends Serializable with Logging { } } + private def syr(alpha: Double, x: SparseVector, A: DenseMatrix) { + val mA = A.numCols + val xIndices = x.indices + val xValues = x.values + val nnz = xValues.length + val Avalues = A.values + + var i = 0 + while (i < nnz) { + val multiplier = alpha * xValues(i) + val offset = xIndices(i) * mA + var j = 0 + while (j < nnz) { + Avalues(xIndices(j) + offset) += multiplier * xValues(j) + j += 1 + } + i += 1 + } + } + /** * C := alpha * A * B + beta * C - * @param transA whether to use the transpose of matrix A (true), or A itself (false). - * @param transB whether to use the transpose of matrix B (true), or B itself (false). * @param alpha a scalar to scale the multiplication A * B. * @param A the matrix A that will be left multiplied to B. Size of m x k. * @param B the matrix B that will be left multiplied by A. Size of k x n. * @param beta a scalar that can be used to scale matrix C. - * @param C the resulting matrix C. Size of m x n. + * @param C the resulting matrix C. Size of m x n. C.isTransposed must be false. */ def gemm( - transA: Boolean, - transB: Boolean, alpha: Double, A: Matrix, B: DenseMatrix, beta: Double, C: DenseMatrix): Unit = { + require(!C.isTransposed, + "The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.") if (alpha == 0.0) { logDebug("gemm: alpha is equal to 0. Returning C.") } else { A match { - case sparse: SparseMatrix => - gemm(transA, transB, alpha, sparse, B, beta, C) - case dense: DenseMatrix => - gemm(transA, transB, alpha, dense, B, beta, C) + case sparse: SparseMatrix => gemm(alpha, sparse, B, beta, C) + case dense: DenseMatrix => gemm(alpha, dense, B, beta, C) case _ => throw new IllegalArgumentException(s"gemm doesn't support matrix type ${A.getClass}.") } } } - /** - * C := alpha * A * B + beta * C - * - * @param alpha a scalar to scale the multiplication A * B. - * @param A the matrix A that will be left multiplied to B. Size of m x k. - * @param B the matrix B that will be left multiplied by A. Size of k x n. - * @param beta a scalar that can be used to scale matrix C. - * @param C the resulting matrix C. Size of m x n. - */ - def gemm( - alpha: Double, - A: Matrix, - B: DenseMatrix, - beta: Double, - C: DenseMatrix): Unit = { - gemm(false, false, alpha, A, B, beta, C) - } - /** * C := alpha * A * B + beta * C * For `DenseMatrix` A. */ private def gemm( - transA: Boolean, - transB: Boolean, alpha: Double, A: DenseMatrix, B: DenseMatrix, beta: Double, C: DenseMatrix): Unit = { - val mA: Int = if (!transA) A.numRows else A.numCols - val nB: Int = if (!transB) B.numCols else B.numRows - val kA: Int = if (!transA) A.numCols else A.numRows - val kB: Int = if (!transB) B.numRows else B.numCols - val tAstr = if (!transA) "N" else "T" - val tBstr = if (!transB) "N" else "T" - - require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB") - require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA") - require(nB == C.numCols, - s"The columns of C don't match the columns of B. C: ${C.numCols}, A: $nB") - - nativeBLAS.dgemm(tAstr, tBstr, mA, nB, kA, alpha, A.values, A.numRows, B.values, B.numRows, - beta, C.values, C.numRows) + val tAstr = if (A.isTransposed) "T" else "N" + val tBstr = if (B.isTransposed) "T" else "N" + val lda = if (!A.isTransposed) A.numRows else A.numCols + val ldb = if (!B.isTransposed) B.numRows else B.numCols + + require(A.numCols == B.numRows, + s"The columns of A don't match the rows of B. A: ${A.numCols}, B: ${B.numRows}") + require(A.numRows == C.numRows, + s"The rows of C don't match the rows of A. C: ${C.numRows}, A: ${A.numRows}") + require(B.numCols == C.numCols, + s"The columns of C don't match the columns of B. C: ${C.numCols}, A: ${B.numCols}") + nativeBLAS.dgemm(tAstr, tBstr, A.numRows, B.numCols, A.numCols, alpha, A.values, lda, + B.values, ldb, beta, C.values, C.numRows) } /** @@ -338,17 +345,15 @@ private[spark] object BLAS extends Serializable with Logging { * For `SparseMatrix` A. */ private def gemm( - transA: Boolean, - transB: Boolean, alpha: Double, A: SparseMatrix, B: DenseMatrix, beta: Double, C: DenseMatrix): Unit = { - val mA: Int = if (!transA) A.numRows else A.numCols - val nB: Int = if (!transB) B.numCols else B.numRows - val kA: Int = if (!transA) A.numCols else A.numRows - val kB: Int = if (!transB) B.numRows else B.numCols + val mA: Int = A.numRows + val nB: Int = B.numCols + val kA: Int = A.numCols + val kB: Int = B.numRows require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB") require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA") @@ -358,23 +363,23 @@ private[spark] object BLAS extends Serializable with Logging { val Avals = A.values val Bvals = B.values val Cvals = C.values - val Arows = if (!transA) A.rowIndices else A.colPtrs - val Acols = if (!transA) A.colPtrs else A.rowIndices + val ArowIndices = A.rowIndices + val AcolPtrs = A.colPtrs // Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices - if (transA){ + if (A.isTransposed){ var colCounterForB = 0 - if (!transB) { // Expensive to put the check inside the loop + if (!B.isTransposed) { // Expensive to put the check inside the loop while (colCounterForB < nB) { var rowCounterForA = 0 val Cstart = colCounterForB * mA val Bstart = colCounterForB * kA while (rowCounterForA < mA) { - var i = Arows(rowCounterForA) - val indEnd = Arows(rowCounterForA + 1) + var i = AcolPtrs(rowCounterForA) + val indEnd = AcolPtrs(rowCounterForA + 1) var sum = 0.0 while (i < indEnd) { - sum += Avals(i) * Bvals(Bstart + Acols(i)) + sum += Avals(i) * Bvals(Bstart + ArowIndices(i)) i += 1 } val Cindex = Cstart + rowCounterForA @@ -385,19 +390,19 @@ private[spark] object BLAS extends Serializable with Logging { } } else { while (colCounterForB < nB) { - var rowCounter = 0 + var rowCounterForA = 0 val Cstart = colCounterForB * mA - while (rowCounter < mA) { - var i = Arows(rowCounter) - val indEnd = Arows(rowCounter + 1) + while (rowCounterForA < mA) { + var i = AcolPtrs(rowCounterForA) + val indEnd = AcolPtrs(rowCounterForA + 1) var sum = 0.0 while (i < indEnd) { - sum += Avals(i) * B(colCounterForB, Acols(i)) + sum += Avals(i) * B(ArowIndices(i), colCounterForB) i += 1 } - val Cindex = Cstart + rowCounter + val Cindex = Cstart + rowCounterForA Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha - rowCounter += 1 + rowCounterForA += 1 } colCounterForB += 1 } @@ -410,17 +415,17 @@ private[spark] object BLAS extends Serializable with Logging { // Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of // B, and added to C. var colCounterForB = 0 // the column to be updated in C - if (!transB) { // Expensive to put the check inside the loop + if (!B.isTransposed) { // Expensive to put the check inside the loop while (colCounterForB < nB) { var colCounterForA = 0 // The column of A to multiply with the row of B val Bstart = colCounterForB * kB val Cstart = colCounterForB * mA while (colCounterForA < kA) { - var i = Acols(colCounterForA) - val indEnd = Acols(colCounterForA + 1) + var i = AcolPtrs(colCounterForA) + val indEnd = AcolPtrs(colCounterForA + 1) val Bval = Bvals(Bstart + colCounterForA) * alpha while (i < indEnd) { - Cvals(Cstart + Arows(i)) += Avals(i) * Bval + Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval i += 1 } colCounterForA += 1 @@ -432,11 +437,11 @@ private[spark] object BLAS extends Serializable with Logging { var colCounterForA = 0 // The column of A to multiply with the row of B val Cstart = colCounterForB * mA while (colCounterForA < kA) { - var i = Acols(colCounterForA) - val indEnd = Acols(colCounterForA + 1) - val Bval = B(colCounterForB, colCounterForA) * alpha + var i = AcolPtrs(colCounterForA) + val indEnd = AcolPtrs(colCounterForA + 1) + val Bval = B(colCounterForA, colCounterForB) * alpha while (i < indEnd) { - Cvals(Cstart + Arows(i)) += Avals(i) * Bval + Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval i += 1 } colCounterForA += 1 @@ -449,7 +454,6 @@ private[spark] object BLAS extends Serializable with Logging { /** * y := alpha * A * x + beta * y - * @param trans whether to use the transpose of matrix A (true), or A itself (false). * @param alpha a scalar to scale the multiplication A * x. * @param A the matrix A that will be left multiplied to x. Size of m x n. * @param x the vector x that will be left multiplied by A. Size of n x 1. @@ -457,65 +461,43 @@ private[spark] object BLAS extends Serializable with Logging { * @param y the resulting vector y. Size of m x 1. */ def gemv( - trans: Boolean, alpha: Double, A: Matrix, x: DenseVector, beta: Double, y: DenseVector): Unit = { - - val mA: Int = if (!trans) A.numRows else A.numCols - val nx: Int = x.size - val nA: Int = if (!trans) A.numCols else A.numRows - - require(nA == nx, s"The columns of A don't match the number of elements of x. A: $nA, x: $nx") - require(mA == y.size, - s"The rows of A don't match the number of elements of y. A: $mA, y:${y.size}}") + require(A.numCols == x.size, + s"The columns of A don't match the number of elements of x. A: ${A.numCols}, x: ${x.size}") + require(A.numRows == y.size, + s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}}") if (alpha == 0.0) { logDebug("gemv: alpha is equal to 0. Returning y.") } else { A match { case sparse: SparseMatrix => - gemv(trans, alpha, sparse, x, beta, y) + gemv(alpha, sparse, x, beta, y) case dense: DenseMatrix => - gemv(trans, alpha, dense, x, beta, y) + gemv(alpha, dense, x, beta, y) case _ => throw new IllegalArgumentException(s"gemv doesn't support matrix type ${A.getClass}.") } } } - /** - * y := alpha * A * x + beta * y - * - * @param alpha a scalar to scale the multiplication A * x. - * @param A the matrix A that will be left multiplied to x. Size of m x n. - * @param x the vector x that will be left multiplied by A. Size of n x 1. - * @param beta a scalar that can be used to scale vector y. - * @param y the resulting vector y. Size of m x 1. - */ - def gemv( - alpha: Double, - A: Matrix, - x: DenseVector, - beta: Double, - y: DenseVector): Unit = { - gemv(false, alpha, A, x, beta, y) - } - /** * y := alpha * A * x + beta * y * For `DenseMatrix` A. */ private def gemv( - trans: Boolean, alpha: Double, A: DenseMatrix, x: DenseVector, beta: Double, y: DenseVector): Unit = { - val tStrA = if (!trans) "N" else "T" - nativeBLAS.dgemv(tStrA, A.numRows, A.numCols, alpha, A.values, A.numRows, x.values, 1, beta, + val tStrA = if (A.isTransposed) "T" else "N" + val mA = if (!A.isTransposed) A.numRows else A.numCols + val nA = if (!A.isTransposed) A.numCols else A.numRows + nativeBLAS.dgemv(tStrA, mA, nA, alpha, A.values, mA, x.values, 1, beta, y.values, 1) } @@ -524,24 +506,21 @@ private[spark] object BLAS extends Serializable with Logging { * For `SparseMatrix` A. */ private def gemv( - trans: Boolean, alpha: Double, A: SparseMatrix, x: DenseVector, beta: Double, y: DenseVector): Unit = { - val xValues = x.values val yValues = y.values - - val mA: Int = if (!trans) A.numRows else A.numCols - val nA: Int = if (!trans) A.numCols else A.numRows + val mA: Int = A.numRows + val nA: Int = A.numCols val Avals = A.values - val Arows = if (!trans) A.rowIndices else A.colPtrs - val Acols = if (!trans) A.colPtrs else A.rowIndices + val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs + val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices // Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices - if (trans) { + if (A.isTransposed) { var rowCounter = 0 while (rowCounter < mA) { var i = Arows(rowCounter) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala index 3515461b5249..866936aa4f11 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala @@ -79,6 +79,9 @@ private[mllib] object EigenValueDecomposition { // Mode 1: A*x = lambda*x, A symmetric iparam(6) = 1 + require(n * ncv.toLong <= Integer.MAX_VALUE && ncv * (ncv.toLong + 8) <= Integer.MAX_VALUE, + s"k = $k and/or n = $n are too large to compute an eigendecomposition") + var ido = new intW(0) var info = new intW(0) var resid = new Array[Double](n) @@ -114,7 +117,7 @@ private[mllib] object EigenValueDecomposition { info.`val` match { case 1 => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` + " Maximum number of iterations taken. (Refer ARPACK user guide for details)") - case 2 => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` + + case 3 => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` + " No shifts could be applied. Try to increase NCV. " + "(Refer ARPACK user guide for details)") case _ => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 5a7281ec6dc3..3fa5e068d16d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -23,9 +23,15 @@ import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHash import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow + /** * Trait for a local matrix. */ +@SQLUserDefinedType(udt = classOf[MatrixUDT]) sealed trait Matrix extends Serializable { /** Number of rows. */ @@ -34,14 +40,23 @@ sealed trait Matrix extends Serializable { /** Number of columns. */ def numCols: Int + /** Flag that keeps track whether the matrix is transposed or not. False by default. */ + val isTransposed: Boolean = false + /** Converts to a dense array in column major. */ - def toArray: Array[Double] + def toArray: Array[Double] = { + val newArray = new Array[Double](numRows * numCols) + foreachActive { (i, j, v) => + newArray(j * numRows + i) = v + } + newArray + } /** Converts to a breeze matrix. */ private[mllib] def toBreeze: BM[Double] /** Gets the (i, j)-th element. */ - private[mllib] def apply(i: Int, j: Int): Double + def apply(i: Int, j: Int): Double /** Return the index for the (i, j)-th element in the backing array. */ private[mllib] def index(i: Int, j: Int): Int @@ -52,10 +67,13 @@ sealed trait Matrix extends Serializable { /** Get a deep copy of the matrix. */ def copy: Matrix + /** Transpose the Matrix. Returns a new `Matrix` instance sharing the same underlying data. */ + def transpose: Matrix + /** Convenience method for `Matrix`-`DenseMatrix` multiplication. */ def multiply(y: DenseMatrix): DenseMatrix = { - val C: DenseMatrix = Matrices.zeros(numRows, y.numCols).asInstanceOf[DenseMatrix] - BLAS.gemm(false, false, 1.0, this, y, 0.0, C) + val C: DenseMatrix = DenseMatrix.zeros(numRows, y.numCols) + BLAS.gemm(1.0, this, y, 0.0, C) C } @@ -66,23 +84,12 @@ sealed trait Matrix extends Serializable { output } - /** Convenience method for `Matrix`^T^-`DenseMatrix` multiplication. */ - private[mllib] def transposeMultiply(y: DenseMatrix): DenseMatrix = { - val C: DenseMatrix = Matrices.zeros(numCols, y.numCols).asInstanceOf[DenseMatrix] - BLAS.gemm(true, false, 1.0, this, y, 0.0, C) - C - } - - /** Convenience method for `Matrix`^T^-`DenseVector` multiplication. */ - private[mllib] def transposeMultiply(y: DenseVector): DenseVector = { - val output = new DenseVector(new Array[Double](numCols)) - BLAS.gemv(true, 1.0, this, y, 0.0, output) - output - } - /** A human readable representation of the matrix */ override def toString: String = toBreeze.toString() + /** A human readable representation of the matrix with maximum lines and width */ + def toString(maxLines: Int, maxLineWidth: Int): String = toBreeze.toString(maxLines, maxLineWidth) + /** Map the values of this matrix using a function. Generates a new matrix. Performs the * function on only the backing array. For example, an operation such as addition or * subtraction will only be performed on the non-zero values in a `SparseMatrix`. */ @@ -92,6 +99,100 @@ sealed trait Matrix extends Serializable { * backing array. For example, an operation such as addition or subtraction will only be * performed on the non-zero values in a `SparseMatrix`. */ private[mllib] def update(f: Double => Double): Matrix + + /** + * Applies a function `f` to all the active elements of dense and sparse matrix. The ordering + * of the elements are not defined. + * + * @param f the function takes three parameters where the first two parameters are the row + * and column indices respectively with the type `Int`, and the final parameter is the + * corresponding value in the matrix with type `Double`. + */ + private[spark] def foreachActive(f: (Int, Int, Double) => Unit) +} + +@DeveloperApi +private[spark] class MatrixUDT extends UserDefinedType[Matrix] { + + override def sqlType: StructType = { + // type: 0 = sparse, 1 = dense + // the dense matrix is built by numRows, numCols, values and isTransposed, all of which are + // set as not nullable, except values since in the future, support for binary matrices might + // be added for which values are not needed. + // the sparse matrix needs colPtrs and rowIndices, which are set as + // null, while building the dense matrix. + StructType(Seq( + StructField("type", ByteType, nullable = false), + StructField("numRows", IntegerType, nullable = false), + StructField("numCols", IntegerType, nullable = false), + StructField("colPtrs", ArrayType(IntegerType, containsNull = false), nullable = true), + StructField("rowIndices", ArrayType(IntegerType, containsNull = false), nullable = true), + StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true), + StructField("isTransposed", BooleanType, nullable = false) + )) + } + + override def serialize(obj: Any): Row = { + val row = new GenericMutableRow(7) + obj match { + case sm: SparseMatrix => + row.setByte(0, 0) + row.setInt(1, sm.numRows) + row.setInt(2, sm.numCols) + row.update(3, sm.colPtrs.toSeq) + row.update(4, sm.rowIndices.toSeq) + row.update(5, sm.values.toSeq) + row.setBoolean(6, sm.isTransposed) + + case dm: DenseMatrix => + row.setByte(0, 1) + row.setInt(1, dm.numRows) + row.setInt(2, dm.numCols) + row.setNullAt(3) + row.setNullAt(4) + row.update(5, dm.values.toSeq) + row.setBoolean(6, dm.isTransposed) + } + row + } + + override def deserialize(datum: Any): Matrix = { + datum match { + // TODO: something wrong with UDT serialization, should never happen. + case m: Matrix => m + case row: Row => + require(row.length == 7, + s"MatrixUDT.deserialize given row with length ${row.length} but requires length == 7") + val tpe = row.getByte(0) + val numRows = row.getInt(1) + val numCols = row.getInt(2) + val values = row.getAs[Iterable[Double]](5).toArray + val isTransposed = row.getBoolean(6) + tpe match { + case 0 => + val colPtrs = row.getAs[Iterable[Int]](3).toArray + val rowIndices = row.getAs[Iterable[Int]](4).toArray + new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed) + case 1 => + new DenseMatrix(numRows, numCols, values, isTransposed) + } + } + } + + override def userClass: Class[Matrix] = classOf[Matrix] + + override def equals(o: Any): Boolean = { + o match { + case v: MatrixUDT => true + case _ => false + } + } + + override def hashCode(): Int = 1994 + + override def typeName: String = "matrix" + + private[spark] override def asNullable: MatrixUDT = this } /** @@ -107,34 +208,70 @@ sealed trait Matrix extends Serializable { * * @param numRows number of rows * @param numCols number of columns - * @param values matrix entries in column major + * @param values matrix entries in column major if not transposed or in row major otherwise + * @param isTransposed whether the matrix is transposed. If true, `values` stores the matrix in + * row major. */ -class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) extends Matrix { +@SQLUserDefinedType(udt = classOf[MatrixUDT]) +class DenseMatrix( + val numRows: Int, + val numCols: Int, + val values: Array[Double], + override val isTransposed: Boolean) extends Matrix { require(values.length == numRows * numCols, "The number of values supplied doesn't match the " + s"size of the matrix! values.length: ${values.length}, numRows * numCols: ${numRows * numCols}") - override def toArray: Array[Double] = values + /** + * Column-major dense matrix. + * The entry values are stored in a single array of doubles with columns listed in sequence. + * For example, the following matrix + * {{{ + * 1.0 2.0 + * 3.0 4.0 + * 5.0 6.0 + * }}} + * is stored as `[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]`. + * + * @param numRows number of rows + * @param numCols number of columns + * @param values matrix entries in column major + */ + def this(numRows: Int, numCols: Int, values: Array[Double]) = + this(numRows, numCols, values, false) - override def equals(o: Any) = o match { + override def equals(o: Any): Boolean = o match { case m: DenseMatrix => m.numRows == numRows && m.numCols == numCols && Arrays.equals(toArray, m.toArray) case _ => false } - private[mllib] def toBreeze: BM[Double] = new BDM[Double](numRows, numCols, values) + override def hashCode: Int = { + com.google.common.base.Objects.hashCode(numRows : Integer, numCols: Integer, toArray) + } + + private[mllib] def toBreeze: BM[Double] = { + if (!isTransposed) { + new BDM[Double](numRows, numCols, values) + } else { + val breezeMatrix = new BDM[Double](numCols, numRows, values) + breezeMatrix.t + } + } private[mllib] def apply(i: Int): Double = values(i) - private[mllib] def apply(i: Int, j: Int): Double = values(index(i, j)) + override def apply(i: Int, j: Int): Double = values(index(i, j)) - private[mllib] def index(i: Int, j: Int): Int = i + numRows * j + private[mllib] def index(i: Int, j: Int): Int = { + if (!isTransposed) i + numRows * j else j + numCols * i + } private[mllib] def update(i: Int, j: Int, v: Double): Unit = { values(index(i, j)) = v } - override def copy = new DenseMatrix(numRows, numCols, values.clone()) + override def copy: DenseMatrix = new DenseMatrix(numRows, numCols, values.clone()) private[mllib] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f)) @@ -148,8 +285,41 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) this } - /** Generate a `SparseMatrix` from the given `DenseMatrix`. */ - def toSparse(): SparseMatrix = { + override def transpose: DenseMatrix = new DenseMatrix(numCols, numRows, values, !isTransposed) + + private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { + if (!isTransposed) { + // outer loop over columns + var j = 0 + while (j < numCols) { + var i = 0 + val indStart = j * numRows + while (i < numRows) { + f(i, j, values(indStart + i)) + i += 1 + } + j += 1 + } + } else { + // outer loop over rows + var i = 0 + while (i < numRows) { + var j = 0 + val indStart = i * numCols + while (j < numCols) { + f(i, j, values(indStart + j)) + j += 1 + } + i += 1 + } + } + } + + /** + * Generate a `SparseMatrix` from the given `DenseMatrix`. The new matrix will have isTransposed + * set to false. + */ + def toSparse: SparseMatrix = { val spVals: MArrayBuilder[Double] = new MArrayBuilder.ofDouble val colPtrs: Array[Int] = new Array[Int](numCols + 1) val rowIndices: MArrayBuilder[Int] = new MArrayBuilder.ofInt @@ -157,9 +327,8 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) var j = 0 while (j < numCols) { var i = 0 - val indStart = j * numRows while (i < numRows) { - val v = values(indStart + i) + val v = values(index(i, j)) if (v != 0.0) { rowIndices += i spVals += v @@ -185,8 +354,11 @@ object DenseMatrix { * @param numCols number of columns of the matrix * @return `DenseMatrix` with size `numRows` x `numCols` and values of zeros */ - def zeros(numRows: Int, numCols: Int): DenseMatrix = + def zeros(numRows: Int, numCols: Int): DenseMatrix = { + require(numRows.toLong * numCols <= Int.MaxValue, + s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, new Array[Double](numRows * numCols)) + } /** * Generate a `DenseMatrix` consisting of ones. @@ -194,8 +366,11 @@ object DenseMatrix { * @param numCols number of columns of the matrix * @return `DenseMatrix` with size `numRows` x `numCols` and values of ones */ - def ones(numRows: Int, numCols: Int): DenseMatrix = + def ones(numRows: Int, numCols: Int): DenseMatrix = { + require(numRows.toLong * numCols <= Int.MaxValue, + s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(1.0)) + } /** * Generate an Identity Matrix in `DenseMatrix` format. @@ -213,24 +388,28 @@ object DenseMatrix { } /** - * Generate a `DenseMatrix` consisting of i.i.d. uniform random numbers. + * Generate a `DenseMatrix` consisting of `i.i.d.` uniform random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param rng a random number generator * @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1) */ def rand(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { + require(numRows.toLong * numCols <= Int.MaxValue, + s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextDouble())) } /** - * Generate a `DenseMatrix` consisting of i.i.d. gaussian random numbers. + * Generate a `DenseMatrix` consisting of `i.i.d.` gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param rng a random number generator * @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1) */ def randn(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { + require(numRows.toLong * numCols <= Int.MaxValue, + s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextGaussian())) } @@ -267,53 +446,78 @@ object DenseMatrix { * * @param numRows number of rows * @param numCols number of columns - * @param colPtrs the index corresponding to the start of a new column - * @param rowIndices the row index of the entry. They must be in strictly increasing order for each - * column - * @param values non-zero matrix entries in column major + * @param colPtrs the index corresponding to the start of a new column (if not transposed) + * @param rowIndices the row index of the entry (if not transposed). They must be in strictly + * increasing order for each column + * @param values nonzero matrix entries in column major (if not transposed) + * @param isTransposed whether the matrix is transposed. If true, the matrix can be considered + * Compressed Sparse Row (CSR) format, where `colPtrs` behaves as rowPtrs, + * and `rowIndices` behave as colIndices, and `values` are stored in row major. */ +@SQLUserDefinedType(udt = classOf[MatrixUDT]) class SparseMatrix( val numRows: Int, val numCols: Int, val colPtrs: Array[Int], val rowIndices: Array[Int], - val values: Array[Double]) extends Matrix { + val values: Array[Double], + override val isTransposed: Boolean) extends Matrix { require(values.length == rowIndices.length, "The number of row indices and values don't match! " + s"values.length: ${values.length}, rowIndices.length: ${rowIndices.length}") - require(colPtrs.length == numCols + 1, "The length of the column indices should be the " + - s"number of columns + 1. Currently, colPointers.length: ${colPtrs.length}, " + - s"numCols: $numCols") + // The Or statement is for the case when the matrix is transposed + require(colPtrs.length == numCols + 1 || colPtrs.length == numRows + 1, "The length of the " + + "column indices should be the number of columns + 1. Currently, colPointers.length: " + + s"${colPtrs.length}, numCols: $numCols") require(values.length == colPtrs.last, "The last value of colPtrs must equal the number of " + s"elements. values.length: ${values.length}, colPtrs.last: ${colPtrs.last}") - override def toArray: Array[Double] = { - val arr = new Array[Double](numRows * numCols) - var j = 0 - while (j < numCols) { - var i = colPtrs(j) - val indEnd = colPtrs(j + 1) - val offset = j * numRows - while (i < indEnd) { - val rowIndex = rowIndices(i) - arr(offset + rowIndex) = values(i) - i += 1 - } - j += 1 - } - arr + /** + * Column-major sparse matrix. + * The entry values are stored in Compressed Sparse Column (CSC) format. + * For example, the following matrix + * {{{ + * 1.0 0.0 4.0 + * 0.0 3.0 5.0 + * 2.0 0.0 6.0 + * }}} + * is stored as `values: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]`, + * `rowIndices=[0, 2, 1, 0, 1, 2]`, `colPointers=[0, 2, 3, 6]`. + * + * @param numRows number of rows + * @param numCols number of columns + * @param colPtrs the index corresponding to the start of a new column + * @param rowIndices the row index of the entry. They must be in strictly increasing + * order for each column + * @param values non-zero matrix entries in column major + */ + def this( + numRows: Int, + numCols: Int, + colPtrs: Array[Int], + rowIndices: Array[Int], + values: Array[Double]) = this(numRows, numCols, colPtrs, rowIndices, values, false) + + private[mllib] def toBreeze: BM[Double] = { + if (!isTransposed) { + new BSM[Double](values, numRows, numCols, colPtrs, rowIndices) + } else { + val breezeMatrix = new BSM[Double](values, numCols, numRows, colPtrs, rowIndices) + breezeMatrix.t + } } - private[mllib] def toBreeze: BM[Double] = - new BSM[Double](values, numRows, numCols, colPtrs, rowIndices) - - private[mllib] def apply(i: Int, j: Int): Double = { + override def apply(i: Int, j: Int): Double = { val ind = index(i, j) if (ind < 0) 0.0 else values(ind) } private[mllib] def index(i: Int, j: Int): Int = { - Arrays.binarySearch(rowIndices, colPtrs(j), colPtrs(j + 1), i) + if (!isTransposed) { + Arrays.binarySearch(rowIndices, colPtrs(j), colPtrs(j + 1), i) + } else { + Arrays.binarySearch(rowIndices, colPtrs(i), colPtrs(i + 1), j) + } } private[mllib] def update(i: Int, j: Int, v: Double): Unit = { @@ -322,11 +526,13 @@ class SparseMatrix( throw new NoSuchElementException("The given row and column indices correspond to a zero " + "value. Only non-zero elements in Sparse Matrices can be updated.") } else { - values(index(i, j)) = v + values(ind) = v } } - override def copy = new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone()) + override def copy: SparseMatrix = { + new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone()) + } private[mllib] def map(f: Double => Double) = new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f)) @@ -341,8 +547,41 @@ class SparseMatrix( this } - /** Generate a `DenseMatrix` from the given `SparseMatrix`. */ - def toDense(): DenseMatrix = { + override def transpose: SparseMatrix = + new SparseMatrix(numCols, numRows, colPtrs, rowIndices, values, !isTransposed) + + private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { + if (!isTransposed) { + var j = 0 + while (j < numCols) { + var idx = colPtrs(j) + val idxEnd = colPtrs(j + 1) + while (idx < idxEnd) { + f(rowIndices(idx), j, values(idx)) + idx += 1 + } + j += 1 + } + } else { + var i = 0 + while (i < numRows) { + var idx = colPtrs(i) + val idxEnd = colPtrs(i + 1) + while (idx < idxEnd) { + val j = rowIndices(idx) + f(i, j, values(idx)) + idx += 1 + } + i += 1 + } + } + } + + /** + * Generate a `DenseMatrix` from the given `SparseMatrix`. The new matrix will have isTransposed + * set to false. + */ + def toDense: DenseMatrix = { new DenseMatrix(numRows, numCols, toArray) } } @@ -469,7 +708,7 @@ object SparseMatrix { } /** - * Generate a `SparseMatrix` consisting of i.i.d. uniform random numbers. The number of non-zero + * Generate a `SparseMatrix` consisting of `i.i.d`. uniform random numbers. The number of non-zero * elements equal the ceiling of `numRows` x `numCols` x `density` * * @param numRows number of rows of the matrix @@ -484,7 +723,7 @@ object SparseMatrix { } /** - * Generate a `SparseMatrix` consisting of i.i.d. gaussian random numbers. + * Generate a `SparseMatrix` consisting of `i.i.d`. gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param density the desired density for the matrix @@ -502,7 +741,7 @@ object SparseMatrix { * @return Square `SparseMatrix` with size `values.length` x `values.length` and non-zero * `values` on the diagonal */ - def diag(vector: Vector): SparseMatrix = { + def spdiag(vector: Vector): SparseMatrix = { val n = vector.size vector match { case sVec: SparseVector => @@ -557,10 +796,9 @@ object Matrices { private[mllib] def fromBreeze(breeze: BM[Double]): Matrix = { breeze match { case dm: BDM[Double] => - require(dm.majorStride == dm.rows, - "Do not support stride size different from the number of rows.") - new DenseMatrix(dm.rows, dm.cols, dm.data) + new DenseMatrix(dm.rows, dm.cols, dm.data, dm.isTranspose) case sm: BSM[Double] => + // There is no isTranspose flag for sparse matrices in Breeze new SparseMatrix(sm.rows, sm.cols, sm.colPtrs, sm.rowIndices, sm.data) case _ => throw new UnsupportedOperationException( @@ -569,7 +807,7 @@ object Matrices { } /** - * Generate a `DenseMatrix` consisting of zeros. + * Generate a `Matrix` consisting of zeros. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @return `Matrix` with size `numRows` x `numCols` and values of zeros @@ -599,7 +837,7 @@ object Matrices { def speye(n: Int): Matrix = SparseMatrix.speye(n) /** - * Generate a `DenseMatrix` consisting of i.i.d. uniform random numbers. + * Generate a `DenseMatrix` consisting of `i.i.d.` uniform random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param rng a random number generator @@ -609,7 +847,7 @@ object Matrices { DenseMatrix.rand(numRows, numCols, rng) /** - * Generate a `SparseMatrix` consisting of i.i.d. gaussian random numbers. + * Generate a `SparseMatrix` consisting of `i.i.d.` gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param density the desired density for the matrix @@ -620,7 +858,7 @@ object Matrices { SparseMatrix.sprand(numRows, numCols, density, rng) /** - * Generate a `DenseMatrix` consisting of i.i.d. gaussian random numbers. + * Generate a `DenseMatrix` consisting of `i.i.d.` gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param rng a random number generator @@ -630,7 +868,7 @@ object Matrices { DenseMatrix.randn(numRows, numCols, rng) /** - * Generate a `SparseMatrix` consisting of i.i.d. gaussian random numbers. + * Generate a `SparseMatrix` consisting of `i.i.d.` gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param density the desired density for the matrix @@ -641,8 +879,8 @@ object Matrices { SparseMatrix.sprandn(numRows, numCols, density, rng) /** - * Generate a diagonal matrix in `DenseMatrix` format from the supplied values. - * @param vector a `Vector` tat will form the values on the diagonal of the matrix + * Generate a diagonal matrix in `Matrix` format from the supplied values. + * @param vector a `Vector` that will form the values on the diagonal of the matrix * @return Square `Matrix` with size `values.length` x `values.length` and `values` * on the diagonal */ @@ -679,46 +917,28 @@ object Matrices { new DenseMatrix(numRows, numCols, matrices.flatMap(_.toArray)) } else { var startCol = 0 - val entries: Array[(Int, Int, Double)] = matrices.flatMap { - case spMat: SparseMatrix => - var j = 0 - val colPtrs = spMat.colPtrs - val rowIndices = spMat.rowIndices - val values = spMat.values - val data = new Array[(Int, Int, Double)](values.length) - val nCols = spMat.numCols - while (j < nCols) { - var idx = colPtrs(j) - while (idx < colPtrs(j + 1)) { - val i = rowIndices(idx) - val v = values(idx) - data(idx) = (i, j + startCol, v) - idx += 1 + val entries: Array[(Int, Int, Double)] = matrices.flatMap { mat => + val nCols = mat.numCols + mat match { + case spMat: SparseMatrix => + val data = new Array[(Int, Int, Double)](spMat.values.length) + var cnt = 0 + spMat.foreachActive { (i, j, v) => + data(cnt) = (i, j + startCol, v) + cnt += 1 } - j += 1 - } - startCol += nCols - data - case dnMat: DenseMatrix => - val data = new ArrayBuffer[(Int, Int, Double)]() - var j = 0 - val nCols = dnMat.numCols - val nRows = dnMat.numRows - val values = dnMat.values - while (j < nCols) { - var i = 0 - val indStart = j * nRows - while (i < nRows) { - val v = values(indStart + i) + startCol += nCols + data + case dnMat: DenseMatrix => + val data = new ArrayBuffer[(Int, Int, Double)]() + dnMat.foreachActive { (i, j, v) => if (v != 0.0) { data.append((i, j + startCol, v)) } - i += 1 } - j += 1 - } - startCol += nCols - data + startCol += nCols + data + } } SparseMatrix.fromCOO(numRows, numCols, entries) } @@ -744,14 +964,12 @@ object Matrices { require(numCols == mat.numCols, "The number of rows of the matrices in this sequence, " + "don't match!") mat match { - case sparse: SparseMatrix => - hasSparse = true - case dense: DenseMatrix => + case sparse: SparseMatrix => hasSparse = true + case dense: DenseMatrix => // empty on purpose case _ => throw new IllegalArgumentException("Unsupported matrix format. Expected " + s"SparseMatrix or DenseMatrix. Instead got: ${mat.getClass}") } numRows += mat.numRows - } if (!hasSparse) { val allValues = new Array[Double](numRows * numCols) @@ -759,61 +977,37 @@ object Matrices { matrices.foreach { mat => var j = 0 val nRows = mat.numRows - val values = mat.toArray - while (j < numCols) { - var i = 0 + mat.foreachActive { (i, j, v) => val indStart = j * numRows + startRow - val subMatStart = j * nRows - while (i < nRows) { - allValues(indStart + i) = values(subMatStart + i) - i += 1 - } - j += 1 + allValues(indStart + i) = v } startRow += nRows } new DenseMatrix(numRows, numCols, allValues) } else { var startRow = 0 - val entries: Array[(Int, Int, Double)] = matrices.flatMap { - case spMat: SparseMatrix => - var j = 0 - val colPtrs = spMat.colPtrs - val rowIndices = spMat.rowIndices - val values = spMat.values - val data = new Array[(Int, Int, Double)](values.length) - while (j < numCols) { - var idx = colPtrs(j) - while (idx < colPtrs(j + 1)) { - val i = rowIndices(idx) - val v = values(idx) - data(idx) = (i + startRow, j, v) - idx += 1 + val entries: Array[(Int, Int, Double)] = matrices.flatMap { mat => + val nRows = mat.numRows + mat match { + case spMat: SparseMatrix => + val data = new Array[(Int, Int, Double)](spMat.values.length) + var cnt = 0 + spMat.foreachActive { (i, j, v) => + data(cnt) = (i + startRow, j, v) + cnt += 1 } - j += 1 - } - startRow += spMat.numRows - data - case dnMat: DenseMatrix => - val data = new ArrayBuffer[(Int, Int, Double)]() - var j = 0 - val nCols = dnMat.numCols - val nRows = dnMat.numRows - val values = dnMat.values - while (j < nCols) { - var i = 0 - val indStart = j * nRows - while (i < nRows) { - val v = values(indStart + i) + startRow += nRows + data + case dnMat: DenseMatrix => + val data = new ArrayBuffer[(Int, Int, Double)]() + dnMat.foreachActive { (i, j, v) => if (v != 0.0) { data.append((i + startRow, j, v)) } - i += 1 } - j += 1 - } - startRow += nRows - data + startRow += nRows + data + } } SparseMatrix.fromCOO(numRows, numCols, entries) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 7ee0224ad466..4ef171f4f041 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -26,8 +26,10 @@ import scala.collection.JavaConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} import org.apache.spark.SparkException +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.util.NumericParser -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.types._ /** @@ -77,7 +79,7 @@ sealed trait Vector extends Serializable { result = 31 * result + (bits ^ (bits >>> 32)).toInt } } - return result + result } /** @@ -109,9 +111,14 @@ sealed trait Vector extends Serializable { } /** + * :: DeveloperApi :: + * * User-defined type for [[Vector]] which allows easy interaction with SQL - * via [[org.apache.spark.sql.SchemaRDD]]. + * via [[org.apache.spark.sql.DataFrame]]. + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. */ +@DeveloperApi private[spark] class VectorUDT extends UserDefinedType[Vector] { override def sqlType: StructType = { @@ -168,6 +175,19 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { override def pyUDT: String = "pyspark.mllib.linalg.VectorUDT" override def userClass: Class[Vector] = classOf[Vector] + + override def equals(o: Any): Boolean = { + o match { + case v: VectorUDT => true + case _ => false + } + } + + override def hashCode: Int = 7919 + + override def typeName: String = "vector" + + private[spark] override def asNullable: VectorUDT = this } /** @@ -207,7 +227,7 @@ object Vectors { * @param elements vector elements in (index, value) pairs. */ def sparse(size: Int, elements: Seq[(Int, Double)]): Vector = { - require(size > 0) + require(size > 0, "The size of the requested sparse vector must be greater than 0.") val (indices, values) = elements.sortBy(_._1).unzip var prev = -1 @@ -215,7 +235,8 @@ object Vectors { require(prev < i, s"Found duplicate indices: $i.") prev = i } - require(prev < size) + require(prev < size, s"You may not write an element to index $prev because the declared " + + s"size of your vector is $size") new SparseVector(size, indices.toArray, values.toArray) } @@ -233,7 +254,7 @@ object Vectors { } /** - * Creates a dense vector of all zeros. + * Creates a vector of all zeros. * * @param size vector size * @return a zero vector @@ -243,8 +264,7 @@ object Vectors { } /** - * Parses a string resulted from `Vector#toString` into - * an [[org.apache.spark.mllib.linalg.Vector]]. + * Parses a string resulted from [[Vector.toString]] into a [[Vector]]. */ def parse(s: String): Vector = { parseNumeric(NumericParser.parse(s)) @@ -290,7 +310,8 @@ object Vectors { * @return norm in L^p^ space. */ def norm(vector: Vector, p: Double): Double = { - require(p >= 1.0) + require(p >= 1.0, "To compute the p-norm of the vector, we require that you specify a p>=1. " + + s"You specified p=$p.") val values = vector match { case DenseVector(vs) => vs case SparseVector(n, ids, vs) => vs @@ -333,7 +354,7 @@ object Vectors { math.pow(sum, 1.0 / p) } } - + /** * Returns the squared distance between two Vectors. * @param v1 first Vector. @@ -341,8 +362,10 @@ object Vectors { * @return squared distance between two Vectors. */ def sqdist(v1: Vector, v2: Vector): Double = { + require(v1.size == v2.size, s"Vector dimensions do not match: Dim(v1)=${v1.size} and Dim(v2)" + + s"=${v2.size}.") var squaredDistance = 0.0 - (v1, v2) match { + (v1, v2) match { case (v1: SparseVector, v2: SparseVector) => val v1Values = v1.values val v1Indices = v1.indices @@ -350,12 +373,12 @@ object Vectors { val v2Indices = v2.indices val nnzv1 = v1Indices.size val nnzv2 = v2Indices.size - + var kv1 = 0 var kv2 = 0 while (kv1 < nnzv1 || kv2 < nnzv2) { var score = 0.0 - + if (kv2 >= nnzv2 || (kv1 < nnzv1 && v1Indices(kv1) < v2Indices(kv2))) { score = v1Values(kv1) kv1 += 1 @@ -370,18 +393,23 @@ object Vectors { squaredDistance += score * score } - case (v1: SparseVector, v2: DenseVector) if v1.indices.length / v1.size < 0.5 => + case (v1: SparseVector, v2: DenseVector) => squaredDistance = sqdist(v1, v2) - case (v1: DenseVector, v2: SparseVector) if v2.indices.length / v2.size < 0.5 => + case (v1: DenseVector, v2: SparseVector) => squaredDistance = sqdist(v2, v1) - // When a SparseVector is approximately dense, we treat it as a DenseVector - case (v1, v2) => - squaredDistance = v1.toArray.zip(v2.toArray).foldLeft(0.0){ (distance, elems) => - val score = elems._1 - elems._2 - distance + score * score + case (DenseVector(vv1), DenseVector(vv2)) => + var kv = 0 + val sz = vv1.size + while (kv < sz) { + val score = vv1(kv) - vv2(kv) + squaredDistance += score * score + kv += 1 } + case _ => + throw new IllegalArgumentException("Do not support vector type " + v1.getClass + + " and " + v2.getClass) } squaredDistance } @@ -397,7 +425,7 @@ object Vectors { val nnzv1 = indices.size val nnzv2 = v2.size var iv1 = if (nnzv1 > 0) indices(kv1) else -1 - + while (kv2 < nnzv2) { var score = 0.0 if (kv2 != iv1) { @@ -457,7 +485,7 @@ class DenseVector(val values: Array[Double]) extends Vector { private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values) - override def apply(i: Int) = values(i) + override def apply(i: Int): Double = values(i) override def copy: DenseVector = { new DenseVector(values.clone()) @@ -476,6 +504,7 @@ class DenseVector(val values: Array[Double]) extends Vector { } object DenseVector { + /** Extracts the value array from a dense vector. */ def unapply(dv: DenseVector): Option[Array[Double]] = Some(dv.values) } @@ -492,7 +521,9 @@ class SparseVector( val indices: Array[Int], val values: Array[Double]) extends Vector { - require(indices.length == values.length) + require(indices.length == values.length, "Sparse vectors require that the dimension of the" + + s" indices match the dimension of the values. You provided ${indices.size} indices and " + + s" ${values.size} values.") override def toString: String = "(%s,%s,%s)".format(size, indices.mkString("[", ",", "]"), values.mkString("[", ",", "]")) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala new file mode 100644 index 000000000000..3323ae7b1fba --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -0,0 +1,386 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg.distributed + +import scala.collection.mutable.ArrayBuffer + +import breeze.linalg.{DenseMatrix => BDM} + +import org.apache.spark.{Logging, Partitioner, SparkException} +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + +/** + * A grid partitioner, which uses a regular grid to partition coordinates. + * + * @param rows Number of rows. + * @param cols Number of columns. + * @param rowsPerPart Number of rows per partition, which may be less at the bottom edge. + * @param colsPerPart Number of columns per partition, which may be less at the right edge. + */ +private[mllib] class GridPartitioner( + val rows: Int, + val cols: Int, + val rowsPerPart: Int, + val colsPerPart: Int) extends Partitioner { + + require(rows > 0) + require(cols > 0) + require(rowsPerPart > 0) + require(colsPerPart > 0) + + private val rowPartitions = math.ceil(rows * 1.0 / rowsPerPart).toInt + private val colPartitions = math.ceil(cols * 1.0 / colsPerPart).toInt + + override val numPartitions: Int = rowPartitions * colPartitions + + /** + * Returns the index of the partition the input coordinate belongs to. + * + * @param key The coordinate (i, j) or a tuple (i, j, k), where k is the inner index used in + * multiplication. k is ignored in computing partitions. + * @return The index of the partition, which the coordinate belongs to. + */ + override def getPartition(key: Any): Int = { + key match { + case (i: Int, j: Int) => + getPartitionId(i, j) + case (i: Int, j: Int, _: Int) => + getPartitionId(i, j) + case _ => + throw new IllegalArgumentException(s"Unrecognized key: $key.") + } + } + + /** Partitions sub-matrices as blocks with neighboring sub-matrices. */ + private def getPartitionId(i: Int, j: Int): Int = { + require(0 <= i && i < rows, s"Row index $i out of range [0, $rows).") + require(0 <= j && j < cols, s"Column index $j out of range [0, $cols).") + i / rowsPerPart + j / colsPerPart * rowPartitions + } + + override def equals(obj: Any): Boolean = { + obj match { + case r: GridPartitioner => + (this.rows == r.rows) && (this.cols == r.cols) && + (this.rowsPerPart == r.rowsPerPart) && (this.colsPerPart == r.colsPerPart) + case _ => + false + } + } + + override def hashCode: Int = { + com.google.common.base.Objects.hashCode( + rows: java.lang.Integer, + cols: java.lang.Integer, + rowsPerPart: java.lang.Integer, + colsPerPart: java.lang.Integer) + } +} + +private[mllib] object GridPartitioner { + + /** Creates a new [[GridPartitioner]] instance. */ + def apply(rows: Int, cols: Int, rowsPerPart: Int, colsPerPart: Int): GridPartitioner = { + new GridPartitioner(rows, cols, rowsPerPart, colsPerPart) + } + + /** Creates a new [[GridPartitioner]] instance with the input suggested number of partitions. */ + def apply(rows: Int, cols: Int, suggestedNumPartitions: Int): GridPartitioner = { + require(suggestedNumPartitions > 0) + val scale = 1.0 / math.sqrt(suggestedNumPartitions) + val rowsPerPart = math.round(math.max(scale * rows, 1.0)).toInt + val colsPerPart = math.round(math.max(scale * cols, 1.0)).toInt + new GridPartitioner(rows, cols, rowsPerPart, colsPerPart) + } +} + +/** + * :: Experimental :: + * + * Represents a distributed matrix in blocks of local matrices. + * + * @param blocks The RDD of sub-matrix blocks ((blockRowIndex, blockColIndex), sub-matrix) that + * form this distributed matrix. If multiple blocks with the same index exist, the + * results for operations like add and multiply will be unpredictable. + * @param rowsPerBlock Number of rows that make up each block. The blocks forming the final + * rows are not required to have the given number of rows + * @param colsPerBlock Number of columns that make up each block. The blocks forming the final + * columns are not required to have the given number of columns + * @param nRows Number of rows of this matrix. If the supplied value is less than or equal to zero, + * the number of rows will be calculated when `numRows` is invoked. + * @param nCols Number of columns of this matrix. If the supplied value is less than or equal to + * zero, the number of columns will be calculated when `numCols` is invoked. + */ +@Experimental +class BlockMatrix( + val blocks: RDD[((Int, Int), Matrix)], + val rowsPerBlock: Int, + val colsPerBlock: Int, + private var nRows: Long, + private var nCols: Long) extends DistributedMatrix with Logging { + + private type MatrixBlock = ((Int, Int), Matrix) // ((blockRowIndex, blockColIndex), sub-matrix) + + /** + * Alternate constructor for BlockMatrix without the input of the number of rows and columns. + * + * @param blocks The RDD of sub-matrix blocks ((blockRowIndex, blockColIndex), sub-matrix) that + * form this distributed matrix. If multiple blocks with the same index exist, the + * results for operations like add and multiply will be unpredictable. + * @param rowsPerBlock Number of rows that make up each block. The blocks forming the final + * rows are not required to have the given number of rows + * @param colsPerBlock Number of columns that make up each block. The blocks forming the final + * columns are not required to have the given number of columns + */ + def this( + blocks: RDD[((Int, Int), Matrix)], + rowsPerBlock: Int, + colsPerBlock: Int) = { + this(blocks, rowsPerBlock, colsPerBlock, 0L, 0L) + } + + override def numRows(): Long = { + if (nRows <= 0L) estimateDim() + nRows + } + + override def numCols(): Long = { + if (nCols <= 0L) estimateDim() + nCols + } + + val numRowBlocks = math.ceil(numRows() * 1.0 / rowsPerBlock).toInt + val numColBlocks = math.ceil(numCols() * 1.0 / colsPerBlock).toInt + + private[mllib] def createPartitioner(): GridPartitioner = + GridPartitioner(numRowBlocks, numColBlocks, suggestedNumPartitions = blocks.partitions.size) + + private lazy val blockInfo = blocks.mapValues(block => (block.numRows, block.numCols)).cache() + + /** Estimates the dimensions of the matrix. */ + private def estimateDim(): Unit = { + val (rows, cols) = blockInfo.map { case ((blockRowIndex, blockColIndex), (m, n)) => + (blockRowIndex.toLong * rowsPerBlock + m, + blockColIndex.toLong * colsPerBlock + n) + }.reduce { (x0, x1) => + (math.max(x0._1, x1._1), math.max(x0._2, x1._2)) + } + if (nRows <= 0L) nRows = rows + assert(rows <= nRows, s"The number of rows $rows is more than claimed $nRows.") + if (nCols <= 0L) nCols = cols + assert(cols <= nCols, s"The number of columns $cols is more than claimed $nCols.") + } + + /** + * Validates the block matrix info against the matrix data (`blocks`) and throws an exception if + * any error is found. + */ + def validate(): Unit = { + logDebug("Validating BlockMatrix...") + // check if the matrix is larger than the claimed dimensions + estimateDim() + logDebug("BlockMatrix dimensions are okay...") + + // Check if there are multiple MatrixBlocks with the same index. + blockInfo.countByKey().foreach { case (key, cnt) => + if (cnt > 1) { + throw new SparkException(s"Found multiple MatrixBlocks with the indices $key. Please " + + "remove blocks with duplicate indices.") + } + } + logDebug("MatrixBlock indices are okay...") + // Check if each MatrixBlock (except edges) has the dimensions rowsPerBlock x colsPerBlock + // The first tuple is the index and the second tuple is the dimensions of the MatrixBlock + val dimensionMsg = s"dimensions different than rowsPerBlock: $rowsPerBlock, and " + + s"colsPerBlock: $colsPerBlock. Blocks on the right and bottom edges can have smaller " + + s"dimensions. You may use the repartition method to fix this issue." + blockInfo.foreach { case ((blockRowIndex, blockColIndex), (m, n)) => + if ((blockRowIndex < numRowBlocks - 1 && m != rowsPerBlock) || + (blockRowIndex == numRowBlocks - 1 && (m <= 0 || m > rowsPerBlock))) { + throw new SparkException(s"The MatrixBlock at ($blockRowIndex, $blockColIndex) has " + + dimensionMsg) + } + if ((blockColIndex < numColBlocks - 1 && n != colsPerBlock) || + (blockColIndex == numColBlocks - 1 && (n <= 0 || n > colsPerBlock))) { + throw new SparkException(s"The MatrixBlock at ($blockRowIndex, $blockColIndex) has " + + dimensionMsg) + } + } + logDebug("MatrixBlock dimensions are okay...") + logDebug("BlockMatrix is valid!") + } + + /** Caches the underlying RDD. */ + def cache(): this.type = { + blocks.cache() + this + } + + /** Persists the underlying RDD with the specified storage level. */ + def persist(storageLevel: StorageLevel): this.type = { + blocks.persist(storageLevel) + this + } + + /** Converts to CoordinateMatrix. */ + def toCoordinateMatrix(): CoordinateMatrix = { + val entryRDD = blocks.flatMap { case ((blockRowIndex, blockColIndex), mat) => + val rowStart = blockRowIndex.toLong * rowsPerBlock + val colStart = blockColIndex.toLong * colsPerBlock + val entryValues = new ArrayBuffer[MatrixEntry]() + mat.foreachActive { (i, j, v) => + if (v != 0.0) entryValues.append(new MatrixEntry(rowStart + i, colStart + j, v)) + } + entryValues + } + new CoordinateMatrix(entryRDD, numRows(), numCols()) + } + + /** Converts to IndexedRowMatrix. The number of columns must be within the integer range. */ + def toIndexedRowMatrix(): IndexedRowMatrix = { + require(numCols() < Int.MaxValue, "The number of columns must be within the integer range. " + + s"numCols: ${numCols()}") + // TODO: This implementation may be optimized + toCoordinateMatrix().toIndexedRowMatrix() + } + + /** Collect the distributed matrix on the driver as a `DenseMatrix`. */ + def toLocalMatrix(): Matrix = { + require(numRows() < Int.MaxValue, "The number of rows of this matrix should be less than " + + s"Int.MaxValue. Currently numRows: ${numRows()}") + require(numCols() < Int.MaxValue, "The number of columns of this matrix should be less than " + + s"Int.MaxValue. Currently numCols: ${numCols()}") + require(numRows() * numCols() < Int.MaxValue, "The length of the values array must be " + + s"less than Int.MaxValue. Currently numRows * numCols: ${numRows() * numCols()}") + val m = numRows().toInt + val n = numCols().toInt + val mem = m * n / 125000 + if (mem > 500) logWarning(s"Storing this matrix will require $mem MB of memory!") + val localBlocks = blocks.collect() + val values = new Array[Double](m * n) + localBlocks.foreach { case ((blockRowIndex, blockColIndex), submat) => + val rowOffset = blockRowIndex * rowsPerBlock + val colOffset = blockColIndex * colsPerBlock + submat.foreachActive { (i, j, v) => + val indexOffset = (j + colOffset) * m + rowOffset + i + values(indexOffset) = v + } + } + new DenseMatrix(m, n, values) + } + + /** Transpose this `BlockMatrix`. Returns a new `BlockMatrix` instance sharing the + * same underlying data. Is a lazy operation. */ + def transpose: BlockMatrix = { + val transposedBlocks = blocks.map { case ((blockRowIndex, blockColIndex), mat) => + ((blockColIndex, blockRowIndex), mat.transpose) + } + new BlockMatrix(transposedBlocks, colsPerBlock, rowsPerBlock, nCols, nRows) + } + + /** Collects data and assembles a local dense breeze matrix (for test only). */ + private[mllib] def toBreeze(): BDM[Double] = { + val localMat = toLocalMatrix() + new BDM[Double](localMat.numRows, localMat.numCols, localMat.toArray) + } + + /** Adds two block matrices together. The matrices must have the same size and matching + * `rowsPerBlock` and `colsPerBlock` values. If one of the blocks that are being added are + * instances of [[SparseMatrix]], the resulting sub matrix will also be a [[SparseMatrix]], even + * if it is being added to a [[DenseMatrix]]. If two dense matrices are added, the output will + * also be a [[DenseMatrix]]. + */ + def add(other: BlockMatrix): BlockMatrix = { + require(numRows() == other.numRows(), "Both matrices must have the same number of rows. " + + s"A.numRows: ${numRows()}, B.numRows: ${other.numRows()}") + require(numCols() == other.numCols(), "Both matrices must have the same number of columns. " + + s"A.numCols: ${numCols()}, B.numCols: ${other.numCols()}") + if (rowsPerBlock == other.rowsPerBlock && colsPerBlock == other.colsPerBlock) { + val addedBlocks = blocks.cogroup(other.blocks, createPartitioner()) + .map { case ((blockRowIndex, blockColIndex), (a, b)) => + if (a.size > 1 || b.size > 1) { + throw new SparkException("There are multiple MatrixBlocks with indices: " + + s"($blockRowIndex, $blockColIndex). Please remove them.") + } + if (a.isEmpty) { + new MatrixBlock((blockRowIndex, blockColIndex), b.head) + } else if (b.isEmpty) { + new MatrixBlock((blockRowIndex, blockColIndex), a.head) + } else { + val result = a.head.toBreeze + b.head.toBreeze + new MatrixBlock((blockRowIndex, blockColIndex), Matrices.fromBreeze(result)) + } + } + new BlockMatrix(addedBlocks, rowsPerBlock, colsPerBlock, numRows(), numCols()) + } else { + throw new SparkException("Cannot add matrices with different block dimensions") + } + } + + /** Left multiplies this [[BlockMatrix]] to `other`, another [[BlockMatrix]]. The `colsPerBlock` + * of this matrix must equal the `rowsPerBlock` of `other`. If `other` contains + * [[SparseMatrix]], they will have to be converted to a [[DenseMatrix]]. The output + * [[BlockMatrix]] will only consist of blocks of [[DenseMatrix]]. This may cause + * some performance issues until support for multiplying two sparse matrices is added. + */ + def multiply(other: BlockMatrix): BlockMatrix = { + require(numCols() == other.numRows(), "The number of columns of A and the number of rows " + + s"of B must be equal. A.numCols: ${numCols()}, B.numRows: ${other.numRows()}. If you " + + "think they should be equal, try setting the dimensions of A and B explicitly while " + + "initializing them.") + if (colsPerBlock == other.rowsPerBlock) { + val resultPartitioner = GridPartitioner(numRowBlocks, other.numColBlocks, + math.max(blocks.partitions.length, other.blocks.partitions.length)) + // Each block of A must be multiplied with the corresponding blocks in each column of B. + // TODO: Optimize to send block to a partition once, similar to ALS + val flatA = blocks.flatMap { case ((blockRowIndex, blockColIndex), block) => + Iterator.tabulate(other.numColBlocks)(j => ((blockRowIndex, j, blockColIndex), block)) + } + // Each block of B must be multiplied with the corresponding blocks in each row of A. + val flatB = other.blocks.flatMap { case ((blockRowIndex, blockColIndex), block) => + Iterator.tabulate(numRowBlocks)(i => ((i, blockColIndex, blockRowIndex), block)) + } + val newBlocks: RDD[MatrixBlock] = flatA.cogroup(flatB, resultPartitioner) + .flatMap { case ((blockRowIndex, blockColIndex, _), (a, b)) => + if (a.size > 1 || b.size > 1) { + throw new SparkException("There are multiple MatrixBlocks with indices: " + + s"($blockRowIndex, $blockColIndex). Please remove them.") + } + if (a.nonEmpty && b.nonEmpty) { + val C = b.head match { + case dense: DenseMatrix => a.head.multiply(dense) + case sparse: SparseMatrix => a.head.multiply(sparse.toDense) + case _ => throw new SparkException(s"Unrecognized matrix type ${b.head.getClass}.") + } + Iterator(((blockRowIndex, blockColIndex), C.toBreeze)) + } else { + Iterator() + } + }.reduceByKey(resultPartitioner, (a, b) => a + b) + .mapValues(Matrices.fromBreeze) + // TODO: Try to use aggregateByKey instead of reduceByKey to get rid of intermediate matrices + new BlockMatrix(newBlocks, rowsPerBlock, other.colsPerBlock, numRows(), other.numCols()) + } else { + throw new SparkException("colsPerBlock of A doesn't match rowsPerBlock of B. " + + s"A.colsPerBlock: $colsPerBlock, B.rowsPerBlock: ${other.rowsPerBlock}") + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala index b60559c853a5..078d1fac4444 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala @@ -21,8 +21,7 @@ import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{Matrix, SparseMatrix, Vectors} /** * :: Experimental :: @@ -98,6 +97,46 @@ class CoordinateMatrix( toIndexedRowMatrix().toRowMatrix() } + /** Converts to BlockMatrix. Creates blocks of [[SparseMatrix]] with size 1024 x 1024. */ + def toBlockMatrix(): BlockMatrix = { + toBlockMatrix(1024, 1024) + } + + /** + * Converts to BlockMatrix. Creates blocks of [[SparseMatrix]]. + * @param rowsPerBlock The number of rows of each block. The blocks at the bottom edge may have + * a smaller value. Must be an integer value greater than 0. + * @param colsPerBlock The number of columns of each block. The blocks at the right edge may have + * a smaller value. Must be an integer value greater than 0. + * @return a [[BlockMatrix]] + */ + def toBlockMatrix(rowsPerBlock: Int, colsPerBlock: Int): BlockMatrix = { + require(rowsPerBlock > 0, + s"rowsPerBlock needs to be greater than 0. rowsPerBlock: $rowsPerBlock") + require(colsPerBlock > 0, + s"colsPerBlock needs to be greater than 0. colsPerBlock: $colsPerBlock") + val m = numRows() + val n = numCols() + val numRowBlocks = math.ceil(m.toDouble / rowsPerBlock).toInt + val numColBlocks = math.ceil(n.toDouble / colsPerBlock).toInt + val partitioner = GridPartitioner(numRowBlocks, numColBlocks, entries.partitions.length) + + val blocks: RDD[((Int, Int), Matrix)] = entries.map { entry => + val blockRowIndex = (entry.i / rowsPerBlock).toInt + val blockColIndex = (entry.j / colsPerBlock).toInt + + val rowId = entry.i % rowsPerBlock + val colId = entry.j % colsPerBlock + + ((blockRowIndex, blockColIndex), (rowId.toInt, colId.toInt, entry.value)) + }.groupByKey(partitioner).map { case ((blockRowIndex, blockColIndex), entry) => + val effRows = math.min(m - blockRowIndex.toLong * rowsPerBlock, rowsPerBlock).toInt + val effCols = math.min(n - blockColIndex.toLong * colsPerBlock, colsPerBlock).toInt + ((blockRowIndex, blockColIndex), SparseMatrix.fromCOO(effRows, effCols, entry)) + } + new BlockMatrix(blocks, rowsPerBlock, colsPerBlock, m, n) + } + /** Determines the size by computing the max row/column index. */ private def computeSize() { // Reduce will throw an exception if `entries` is empty. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index c518271f0472..3be530fa0753 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -75,6 +75,24 @@ class IndexedRowMatrix( new RowMatrix(rows.map(_.vector), 0L, nCols) } + /** Converts to BlockMatrix. Creates blocks of [[SparseMatrix]] with size 1024 x 1024. */ + def toBlockMatrix(): BlockMatrix = { + toBlockMatrix(1024, 1024) + } + + /** + * Converts to BlockMatrix. Creates blocks of [[SparseMatrix]]. + * @param rowsPerBlock The number of rows of each block. The blocks at the bottom edge may have + * a smaller value. Must be an integer value greater than 0. + * @param colsPerBlock The number of columns of each block. The blocks at the right edge may have + * a smaller value. Must be an integer value greater than 0. + * @return a [[BlockMatrix]] + */ + def toBlockMatrix(rowsPerBlock: Int, colsPerBlock: Int): BlockMatrix = { + // TODO: This implementation may be optimized + toCoordinateMatrix().toBlockMatrix(rowsPerBlock, colsPerBlock) + } + /** * Converts this matrix to a * [[org.apache.spark.mllib.linalg.distributed.CoordinateMatrix]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 02075edbabf8..9a89a6f3a515 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -30,7 +30,6 @@ import org.apache.spark.Logging import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg._ -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom @@ -152,10 +151,10 @@ class RowMatrix( * storing the right singular vectors, is computed via matrix multiplication as * U = A * (V * S^-1^), if requested by user. The actual method to use is determined * automatically based on the cost: - * - If n is small (n < 100) or k is large compared with n (k > n / 2), we compute the Gramian - * matrix first and then compute its top eigenvalues and eigenvectors locally on the driver. - * This requires a single pass with O(n^2^) storage on each executor and on the driver, and - * O(n^2^ k) time on the driver. + * - If n is small (n < 100) or k is large compared with n (k > n / 2), we compute + * the Gramian matrix first and then compute its top eigenvalues and eigenvectors locally + * on the driver. This requires a single pass with O(n^2^) storage on each executor and + * on the driver, and O(n^2^ k) time on the driver. * - Otherwise, we compute (A' * A) * v in a distributive way and send it to ARPACK's DSAUPD to * compute (A' * A)'s top eigenvalues and eigenvectors on the driver node. This requires O(k) * passes, O(n) storage on each executor, and O(n k) storage on the driver. @@ -220,8 +219,12 @@ class RowMatrix( val computeMode = mode match { case "auto" => + if(k > 5000) { + logWarning(s"computing svd with k=$k and n=$n, please check necessity") + } + // TODO: The conditions below are not fully tested. - if (n < 100 || k > n / 2) { + if (n < 100 || (k > n / 2 && n <= 15000)) { // If n is small or k is large compared with n, we better compute the Gramian matrix first // and then compute its eigenvalues locally, instead of making multiple passes. if (k < n / 3) { @@ -246,6 +249,8 @@ class RowMatrix( val G = computeGramianMatrix().toBreeze.asInstanceOf[BDM[Double]] EigenValueDecomposition.symmetricEigs(v => G * v, n, k, tol, maxIter) case SVDMode.LocalLAPACK => + // breeze (v0.10) svd latent constraint, 7 * n * n + 4 * n < Int.MaxValue + require(n < 17515, s"$n exceeds the breeze svd capability") val G = computeGramianMatrix().toBreeze.asInstanceOf[BDM[Double]] val brzSvd.SVD(uFull: BDM[Double], sigmaSquaresFull: BDV[Double], _) = brzSvd(G) (sigmaSquaresFull, uFull) @@ -526,7 +531,6 @@ class RowMatrix( val rand = new XORShiftRandom(indx) val scaled = new Array[Double](p.size) iter.flatMap { row => - val buf = new ListBuffer[((Int, Int), Double)]() row match { case SparseVector(size, indices, values) => val nnz = indices.size @@ -535,8 +539,9 @@ class RowMatrix( scaled(k) = values(k) / q(indices(k)) k += 1 } - k = 0 - while (k < nnz) { + + Iterator.tabulate (nnz) { k => + val buf = new ListBuffer[((Int, Int), Double)]() val i = indices(k) val iVal = scaled(k) if (iVal != 0 && rand.nextDouble() < p(i)) { @@ -550,8 +555,8 @@ class RowMatrix( l += 1 } } - k += 1 - } + buf + }.flatten case DenseVector(values) => val n = values.size var i = 0 @@ -559,8 +564,8 @@ class RowMatrix( scaled(i) = values(i) / q(i) i += 1 } - i = 0 - while (i < n) { + Iterator.tabulate (n) { i => + val buf = new ListBuffer[((Int, Int), Double)]() val iVal = scaled(i) if (iVal != 0 && rand.nextDouble() < p(i)) { var j = i + 1 @@ -572,10 +577,9 @@ class RowMatrix( j += 1 } } - i += 1 - } + buf + }.flatten } - buf } }.reduceByKey(_ + _).map { case ((i, j), sim) => MatrixEntry(i.toLong, j.toLong, sim) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index 1ca0f36c6ac3..8bfa0d2b6499 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.optimization import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal} import org.apache.spark.mllib.util.MLUtils @@ -55,24 +55,96 @@ abstract class Gradient extends Serializable { /** * :: DeveloperApi :: - * Compute gradient and loss for a logistic loss function, as used in binary classification. - * See also the documentation for the precise formulation. + * Compute gradient and loss for a multinomial logistic loss function, as used + * in multi-class classification (it is also used in binary logistic regression). + * + * In `The Elements of Statistical Learning: Data Mining, Inference, and Prediction, 2nd Edition` + * by Trevor Hastie, Robert Tibshirani, and Jerome Friedman, which can be downloaded from + * http://statweb.stanford.edu/~tibs/ElemStatLearn/ , Eq. (4.17) on page 119 gives the formula of + * multinomial logistic regression model. A simple calculation shows that + * + * {{{ + * P(y=0|x, w) = 1 / (1 + \sum_i^{K-1} \exp(x w_i)) + * P(y=1|x, w) = exp(x w_1) / (1 + \sum_i^{K-1} \exp(x w_i)) + * ... + * P(y=K-1|x, w) = exp(x w_{K-1}) / (1 + \sum_i^{K-1} \exp(x w_i)) + * }}} + * + * for K classes multiclass classification problem. + * + * The model weights w = (w_1, w_2, ..., w_{K-1})^T becomes a matrix which has dimension of + * (K-1) * (N+1) if the intercepts are added. If the intercepts are not added, the dimension + * will be (K-1) * N. + * + * As a result, the loss of objective function for a single instance of data can be written as + * {{{ + * l(w, x) = -log P(y|x, w) = -\alpha(y) log P(y=0|x, w) - (1-\alpha(y)) log P(y|x, w) + * = log(1 + \sum_i^{K-1}\exp(x w_i)) - (1-\alpha(y)) x w_{y-1} + * = log(1 + \sum_i^{K-1}\exp(margins_i)) - (1-\alpha(y)) margins_{y-1} + * }}} + * + * where \alpha(i) = 1 if i != 0, and + * \alpha(i) = 0 if i == 0, + * margins_i = x w_i. + * + * For optimization, we have to calculate the first derivative of the loss function, and + * a simple calculation shows that + * + * {{{ + * \frac{\partial l(w, x)}{\partial w_{ij}} + * = (\exp(x w_i) / (1 + \sum_k^{K-1} \exp(x w_k)) - (1-\alpha(y)\delta_{y, i+1})) * x_j + * = multiplier_i * x_j + * }}} + * + * where \delta_{i, j} = 1 if i == j, + * \delta_{i, j} = 0 if i != j, and + * multiplier = + * \exp(margins_i) / (1 + \sum_k^{K-1} \exp(margins_i)) - (1-\alpha(y)\delta_{y, i+1}) + * + * If any of margins is larger than 709.78, the numerical computation of multiplier and loss + * function will be suffered from arithmetic overflow. This issue occurs when there are outliers + * in data which are far away from hyperplane, and this will cause the failing of training once + * infinity / infinity is introduced. Note that this is only a concern when max(margins) > 0. + * + * Fortunately, when max(margins) = maxMargin > 0, the loss function and the multiplier can be + * easily rewritten into the following equivalent numerically stable formula. + * + * {{{ + * l(w, x) = log(1 + \sum_i^{K-1}\exp(margins_i)) - (1-\alpha(y)) margins_{y-1} + * = log(\exp(-maxMargin) + \sum_i^{K-1}\exp(margins_i - maxMargin)) + maxMargin + * - (1-\alpha(y)) margins_{y-1} + * = log(1 + sum) + maxMargin - (1-\alpha(y)) margins_{y-1} + * }}} + * + * where sum = \exp(-maxMargin) + \sum_i^{K-1}\exp(margins_i - maxMargin) - 1. + * + * Note that each term, (margins_i - maxMargin) in \exp is smaller than zero; as a result, + * overflow will not happen with this formula. + * + * For multiplier, similar trick can be applied as the following, + * + * {{{ + * multiplier = \exp(margins_i) / (1 + \sum_k^{K-1} \exp(margins_i)) - (1-\alpha(y)\delta_{y, i+1}) + * = \exp(margins_i - maxMargin) / (1 + sum) - (1-\alpha(y)\delta_{y, i+1}) + * }}} + * + * where each term in \exp is also smaller than zero, so overflow is not a concern. + * + * For the detailed mathematical derivation, see the reference at + * http://www.slideshare.net/dbtsai/2014-0620-mlor-36132297 + * + * @param numClasses the number of possible outcomes for k classes classification problem in + * Multinomial Logistic Regression. By default, it is binary logistic regression + * so numClasses will be set to 2. */ @DeveloperApi -class LogisticGradient extends Gradient { - override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { - val margin = -1.0 * dot(data, weights) - val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label - val gradient = data.copy - scal(gradientMultiplier, gradient) - val loss = - if (label > 0) { - // The following is equivalent to log(1 + exp(margin)) but more numerically stable. - MLUtils.log1pExp(margin) - } else { - MLUtils.log1pExp(margin) - margin - } +class LogisticGradient(numClasses: Int) extends Gradient { + def this() = this(2) + + override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { + val gradient = Vectors.zeros(weights.size) + val loss = compute(data, label, weights, gradient) (gradient, loss) } @@ -81,14 +153,104 @@ class LogisticGradient extends Gradient { label: Double, weights: Vector, cumGradient: Vector): Double = { - val margin = -1.0 * dot(data, weights) - val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label - axpy(gradientMultiplier, data, cumGradient) - if (label > 0) { - // The following is equivalent to log(1 + exp(margin)) but more numerically stable. - MLUtils.log1pExp(margin) - } else { - MLUtils.log1pExp(margin) - margin + val dataSize = data.size + + // (weights.size / dataSize + 1) is number of classes + require(weights.size % dataSize == 0 && numClasses == weights.size / dataSize + 1) + numClasses match { + case 2 => + /** + * For Binary Logistic Regression. + * + * Although the loss and gradient calculation for multinomial one is more generalized, + * and multinomial one can also be used in binary case, we still implement a specialized + * binary version for performance reason. + */ + val margin = -1.0 * dot(data, weights) + val multiplier = (1.0 / (1.0 + math.exp(margin))) - label + axpy(multiplier, data, cumGradient) + if (label > 0) { + // The following is equivalent to log(1 + exp(margin)) but more numerically stable. + MLUtils.log1pExp(margin) + } else { + MLUtils.log1pExp(margin) - margin + } + case _ => + /** + * For Multinomial Logistic Regression. + */ + val weightsArray = weights match { + case dv: DenseVector => dv.values + case _ => + throw new IllegalArgumentException( + s"weights only supports dense vector but got type ${weights.getClass}.") + } + val cumGradientArray = cumGradient match { + case dv: DenseVector => dv.values + case _ => + throw new IllegalArgumentException( + s"cumGradient only supports dense vector but got type ${cumGradient.getClass}.") + } + + // marginY is margins(label - 1) in the formula. + var marginY = 0.0 + var maxMargin = Double.NegativeInfinity + var maxMarginIndex = 0 + + val margins = Array.tabulate(numClasses - 1) { i => + var margin = 0.0 + data.foreachActive { (index, value) => + if (value != 0.0) margin += value * weightsArray((i * dataSize) + index) + } + if (i == label.toInt - 1) marginY = margin + if (margin > maxMargin) { + maxMargin = margin + maxMarginIndex = i + } + margin + } + + /** + * When maxMargin > 0, the original formula will cause overflow as we discuss + * in the previous comment. + * We address this by subtracting maxMargin from all the margins, so it's guaranteed + * that all of the new margins will be smaller than zero to prevent arithmetic overflow. + */ + val sum = { + var temp = 0.0 + if (maxMargin > 0) { + for (i <- 0 until numClasses - 1) { + margins(i) -= maxMargin + if (i == maxMarginIndex) { + temp += math.exp(-maxMargin) + } else { + temp += math.exp(margins(i)) + } + } + } else { + for (i <- 0 until numClasses - 1) { + temp += math.exp(margins(i)) + } + } + temp + } + + for (i <- 0 until numClasses - 1) { + val multiplier = math.exp(margins(i)) / (sum + 1.0) - { + if (label != 0.0 && label == i + 1) 1.0 else 0.0 + } + data.foreachActive { (index, value) => + if (value != 0.0) cumGradientArray(i * dataSize + index) += multiplier * value + } + } + + val loss = if (label > 0.0) math.log1p(sum) - marginY else math.log1p(sum) + + if (maxMargin > 0) { + loss + maxMargin + } else { + loss + } } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 0857877951c8..4b7d0589c973 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -25,7 +25,6 @@ import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Vectors, Vector} -import org.apache.spark.mllib.rdd.RDDFunctions._ /** * Class used to solve an optimization problem using Gradient Descent. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index d16d0daf0856..ef6eccd90711 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -26,7 +26,6 @@ import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.axpy -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.rdd.RDD /** @@ -61,6 +60,8 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) /** * Set the convergence tolerance of iterations for L-BFGS. Default 1E-4. * Smaller value will lead to higher accuracy with the cost of more iterations. + * This value must be nonnegative. Lower convergence values are less tolerant + * and therefore generally cause more iterations to be run. */ def setConvergenceTol(tolerance: Double): this.type = { this.convergenceTol = tolerance @@ -143,7 +144,9 @@ object LBFGS extends Logging { * one single data example) * @param updater - Updater function to actually perform a gradient step in a given direction. * @param numCorrections - The number of corrections used in the L-BFGS update. - * @param convergenceTol - The convergence tolerance of iterations for L-BFGS + * @param convergenceTol - The convergence tolerance of iterations for L-BFGS which is must be + * nonnegative. Lower values are less tolerant and therefore generally + * cause more iterations to be run. * @param maxNumIterations - Maximal number of iterations that L-BFGS can be run. * @param regParam - Regularization parameter * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala index fef062e02b6e..4766f7708295 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala @@ -17,30 +17,30 @@ package org.apache.spark.mllib.optimization -import org.jblas.{DoubleMatrix, SimpleBlas} +import java.{util => ju} -import org.apache.spark.annotation.DeveloperApi +import com.github.fommil.netlib.BLAS.{getInstance => blas} /** * Object used to solve nonnegative least squares problems using a modified * projected gradient method. */ -private[mllib] object NNLS { +private[spark] object NNLS { class Workspace(val n: Int) { - val scratch = new DoubleMatrix(n, 1) - val grad = new DoubleMatrix(n, 1) - val x = new DoubleMatrix(n, 1) - val dir = new DoubleMatrix(n, 1) - val lastDir = new DoubleMatrix(n, 1) - val res = new DoubleMatrix(n, 1) - - def wipe() { - scratch.fill(0.0) - grad.fill(0.0) - x.fill(0.0) - dir.fill(0.0) - lastDir.fill(0.0) - res.fill(0.0) + val scratch = new Array[Double](n) + val grad = new Array[Double](n) + val x = new Array[Double](n) + val dir = new Array[Double](n) + val lastDir = new Array[Double](n) + val res = new Array[Double](n) + + def wipe(): Unit = { + ju.Arrays.fill(scratch, 0.0) + ju.Arrays.fill(grad, 0.0) + ju.Arrays.fill(x, 0.0) + ju.Arrays.fill(dir, 0.0) + ju.Arrays.fill(lastDir, 0.0) + ju.Arrays.fill(res, 0.0) } } @@ -62,18 +62,18 @@ private[mllib] object NNLS { * direction, however, while this method only uses a conjugate gradient direction if the last * iteration did not cause a previously-inactive constraint to become active. */ - def solve(ata: DoubleMatrix, atb: DoubleMatrix, ws: Workspace): Array[Double] = { + def solve(ata: Array[Double], atb: Array[Double], ws: Workspace): Array[Double] = { ws.wipe() - val n = atb.rows + val n = atb.length val scratch = ws.scratch // find the optimal unconstrained step - def steplen(dir: DoubleMatrix, res: DoubleMatrix): Double = { - val top = SimpleBlas.dot(dir, res) - SimpleBlas.gemv(1.0, ata, dir, 0.0, scratch) + def steplen(dir: Array[Double], res: Array[Double]): Double = { + val top = blas.ddot(n, dir, 1, res, 1) + blas.dgemv("N", n, n, 1.0, ata, n, dir, 1, 0.0, scratch, 1) // Push the denominator upward very slightly to avoid infinities and silliness - top / (SimpleBlas.dot(scratch, dir) + 1e-20) + top / (blas.ddot(n, scratch, 1, dir, 1) + 1e-20) } // stopping condition @@ -98,52 +98,52 @@ private[mllib] object NNLS { var i = 0 while (iterno < iterMax) { // find the residual - SimpleBlas.gemv(1.0, ata, x, 0.0, res) - SimpleBlas.axpy(-1.0, atb, res) - SimpleBlas.copy(res, grad) + blas.dgemv("N", n, n, 1.0, ata, n, x, 1, 0.0, res, 1) + blas.daxpy(n, -1.0, atb, 1, res, 1) + blas.dcopy(n, res, 1, grad, 1) // project the gradient i = 0 while (i < n) { - if (grad.data(i) > 0.0 && x.data(i) == 0.0) { - grad.data(i) = 0.0 + if (grad(i) > 0.0 && x(i) == 0.0) { + grad(i) = 0.0 } i = i + 1 } - val ngrad = SimpleBlas.dot(grad, grad) + val ngrad = blas.ddot(n, grad, 1, grad, 1) - SimpleBlas.copy(grad, dir) + blas.dcopy(n, grad, 1, dir, 1) // use a CG direction under certain conditions var step = steplen(grad, res) var ndir = 0.0 - val nx = SimpleBlas.dot(x, x) + val nx = blas.ddot(n, x, 1, x, 1) if (iterno > lastWall + 1) { val alpha = ngrad / lastNorm - SimpleBlas.axpy(alpha, lastDir, dir) + blas.daxpy(n, alpha, lastDir, 1, dir, 1) val dstep = steplen(dir, res) - ndir = SimpleBlas.dot(dir, dir) + ndir = blas.ddot(n, dir, 1, dir, 1) if (stop(dstep, ndir, nx)) { // reject the CG step if it could lead to premature termination - SimpleBlas.copy(grad, dir) - ndir = SimpleBlas.dot(dir, dir) + blas.dcopy(n, grad, 1, dir, 1) + ndir = blas.ddot(n, dir, 1, dir, 1) } else { step = dstep } } else { - ndir = SimpleBlas.dot(dir, dir) + ndir = blas.ddot(n, dir, 1, dir, 1) } // terminate? if (stop(step, ndir, nx)) { - return x.data.clone + return x.clone } // don't run through the walls i = 0 while (i < n) { - if (step * dir.data(i) > x.data(i)) { - step = x.data(i) / dir.data(i) + if (step * dir(i) > x(i)) { + step = x(i) / dir(i) } i = i + 1 } @@ -151,19 +151,19 @@ private[mllib] object NNLS { // take the step i = 0 while (i < n) { - if (step * dir.data(i) > x.data(i) * (1 - 1e-14)) { - x.data(i) = 0 + if (step * dir(i) > x(i) * (1 - 1e-14)) { + x(i) = 0 lastWall = iterno } else { - x.data(i) -= step * dir.data(i) + x(i) -= step * dir(i) } i = i + 1 } iterno = iterno + 1 - SimpleBlas.copy(dir, lastDir) + blas.dcopy(n, dir, 1, lastDir, 1) lastNorm = ngrad } - x.data.clone + x.clone } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala index 405bae62ee8b..9349ecaa13f5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala @@ -56,7 +56,7 @@ class UniformGenerator extends RandomDataGenerator[Double] { random.nextDouble() } - override def setSeed(seed: Long) = random.setSeed(seed) + override def setSeed(seed: Long): Unit = random.setSeed(seed) override def copy(): UniformGenerator = new UniformGenerator() } @@ -75,7 +75,7 @@ class StandardNormalGenerator extends RandomDataGenerator[Double] { random.nextGaussian() } - override def setSeed(seed: Long) = random.setSeed(seed) + override def setSeed(seed: Long): Unit = random.setSeed(seed) override def copy(): StandardNormalGenerator = new StandardNormalGenerator() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index 955c593a085d..8341bb86afd7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -29,13 +29,13 @@ import org.apache.spark.util.Utils /** * :: Experimental :: - * Generator methods for creating RDDs comprised of i.i.d. samples from some distribution. + * Generator methods for creating RDDs comprised of `i.i.d.` samples from some distribution. */ @Experimental object RandomRDDs { /** - * Generates an RDD comprised of i.i.d. samples from the uniform distribution `U(0.0, 1.0)`. + * Generates an RDD comprised of `i.i.d.` samples from the uniform distribution `U(0.0, 1.0)`. * * To transform the distribution in the generated RDD from `U(0.0, 1.0)` to `U(a, b)`, use * `RandomRDDs.uniformRDD(sc, n, p, seed).map(v => a + (b - a) * v)`. @@ -44,7 +44,7 @@ object RandomRDDs { * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples ~ `U(0.0, 1.0)`. + * @return RDD[Double] comprised of `i.i.d.` samples ~ `U(0.0, 1.0)`. */ def uniformRDD( sc: SparkContext, @@ -81,7 +81,7 @@ object RandomRDDs { } /** - * Generates an RDD comprised of i.i.d. samples from the standard normal distribution. + * Generates an RDD comprised of `i.i.d.` samples from the standard normal distribution. * * To transform the distribution in the generated RDD from standard normal to some other normal * `N(mean, sigma^2^)`, use `RandomRDDs.normalRDD(sc, n, p, seed).map(v => mean + sigma * v)`. @@ -90,7 +90,7 @@ object RandomRDDs { * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0). + * @return RDD[Double] comprised of `i.i.d.` samples ~ N(0.0, 1.0). */ def normalRDD( sc: SparkContext, @@ -127,14 +127,15 @@ object RandomRDDs { } /** - * Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. + * Generates an RDD comprised of `i.i.d.` samples from the Poisson distribution with the input + * mean. * * @param sc SparkContext used to create the RDD. * @param mean Mean, or lambda, for the Poisson distribution. * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). + * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ def poissonRDD( sc: SparkContext, @@ -177,7 +178,7 @@ object RandomRDDs { } /** - * Generates an RDD comprised of i.i.d. samples from the exponential distribution with + * Generates an RDD comprised of `i.i.d.` samples from the exponential distribution with * the input mean. * * @param sc SparkContext used to create the RDD. @@ -185,7 +186,7 @@ object RandomRDDs { * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). + * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ def exponentialRDD( sc: SparkContext, @@ -228,7 +229,7 @@ object RandomRDDs { } /** - * Generates an RDD comprised of i.i.d. samples from the gamma distribution with the input + * Generates an RDD comprised of `i.i.d.` samples from the gamma distribution with the input * shape and scale. * * @param sc SparkContext used to create the RDD. @@ -237,7 +238,7 @@ object RandomRDDs { * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). + * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ def gammaRDD( sc: SparkContext, @@ -287,7 +288,7 @@ object RandomRDDs { } /** - * Generates an RDD comprised of i.i.d. samples from the log normal distribution with the input + * Generates an RDD comprised of `i.i.d.` samples from the log normal distribution with the input * mean and standard deviation * * @param sc SparkContext used to create the RDD. @@ -296,7 +297,7 @@ object RandomRDDs { * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). + * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ def logNormalRDD( sc: SparkContext, @@ -348,14 +349,14 @@ object RandomRDDs { /** * :: DeveloperApi :: - * Generates an RDD comprised of i.i.d. samples produced by the input RandomDataGenerator. + * Generates an RDD comprised of `i.i.d.` samples produced by the input RandomDataGenerator. * * @param sc SparkContext used to create the RDD. * @param generator RandomDataGenerator used to populate the RDD. * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of i.i.d. samples produced by generator. + * @return RDD[Double] comprised of `i.i.d.` samples produced by generator. */ @DeveloperApi def randomRDD[T: ClassTag]( @@ -370,7 +371,7 @@ object RandomRDDs { // TODO Generate RDD[Vector] from multivariate distributions. /** - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples drawn from the * uniform distribution on `U(0.0, 1.0)`. * * @param sc SparkContext used to create the RDD. @@ -424,7 +425,7 @@ object RandomRDDs { } /** - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples drawn from the * standard normal distribution. * * @param sc SparkContext used to create the RDD. @@ -432,7 +433,7 @@ object RandomRDDs { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Vector] with vectors containing i.i.d. samples ~ `N(0.0, 1.0)`. + * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ `N(0.0, 1.0)`. */ def normalVectorRDD( sc: SparkContext, @@ -478,7 +479,7 @@ object RandomRDDs { } /** - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from a + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples drawn from a * log normal distribution. * * @param sc SparkContext used to create the RDD. @@ -488,7 +489,7 @@ object RandomRDDs { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Vector] with vectors containing i.i.d. samples. + * @return RDD[Vector] with vectors containing `i.i.d.` samples. */ def logNormalVectorRDD( sc: SparkContext, @@ -544,7 +545,7 @@ object RandomRDDs { } /** - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples drawn from the * Poisson distribution with the input mean. * * @param sc SparkContext used to create the RDD. @@ -553,7 +554,7 @@ object RandomRDDs { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`) * @param seed Random seed (default: a random long integer). - * @return RDD[Vector] with vectors containing i.i.d. samples ~ Pois(mean). + * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ Pois(mean). */ def poissonVectorRDD( sc: SparkContext, @@ -603,7 +604,7 @@ object RandomRDDs { } /** - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples drawn from the * exponential distribution with the input mean. * * @param sc SparkContext used to create the RDD. @@ -612,7 +613,7 @@ object RandomRDDs { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`) * @param seed Random seed (default: a random long integer). - * @return RDD[Vector] with vectors containing i.i.d. samples ~ Exp(mean). + * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ Exp(mean). */ def exponentialVectorRDD( sc: SparkContext, @@ -665,7 +666,7 @@ object RandomRDDs { /** - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples drawn from the * gamma distribution with the input shape and scale. * * @param sc SparkContext used to create the RDD. @@ -675,7 +676,7 @@ object RandomRDDs { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`) * @param seed Random seed (default: a random long integer). - * @return RDD[Vector] with vectors containing i.i.d. samples ~ Exp(mean). + * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ Exp(mean). */ def gammaVectorRDD( sc: SparkContext, @@ -731,7 +732,7 @@ object RandomRDDs { /** * :: DeveloperApi :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples produced by the + * Generates an RDD[Vector] with vectors containing `i.i.d.` samples produced by the * input RandomDataGenerator. * * @param sc SparkContext used to create the RDD. @@ -740,7 +741,7 @@ object RandomRDDs { * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Vector] with vectors containing i.i.d. samples produced by generator. + * @return RDD[Vector] with vectors containing `i.i.d.` samples produced by generator. */ @DeveloperApi def randomVectorRDD(sc: SparkContext, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala new file mode 100644 index 000000000000..9213fd3f595c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.rdd + +import scala.language.implicitConversions +import scala.reflect.ClassTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.util.BoundedPriorityQueue + +/** + * Machine learning specific Pair RDD functions. + */ +@DeveloperApi +class MLPairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) extends Serializable { + /** + * Returns the top k (largest) elements for each key from this RDD as defined by the specified + * implicit Ordering[T]. + * If the number of elements for a certain key is less than k, all of them will be returned. + * + * @param num k, the number of top elements to return + * @param ord the implicit ordering for T + * @return an RDD that contains the top k values for each key + */ + def topByKey(num: Int)(implicit ord: Ordering[V]): RDD[(K, Array[V])] = { + self.aggregateByKey(new BoundedPriorityQueue[V](num)(ord))( + seqOp = (queue, item) => { + queue += item + queue + }, + combOp = (queue1, queue2) => { + queue1 ++= queue2 + queue1 + } + ).mapValues(_.toArray.sorted(ord.reverse)) + } +} + +@DeveloperApi +object MLPairRDDFunctions { + /** Implicit conversion from a pair RDD to MLPairRDDFunctions. */ + implicit def fromPairRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): MLPairRDDFunctions[K, V] = + new MLPairRDDFunctions[K, V](rdd) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index 57c0768084e4..78172843be56 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -21,10 +21,7 @@ import scala.language.implicitConversions import scala.reflect.ClassTag import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.HashPartitioner -import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils /** * Machine learning specific RDD functions. @@ -53,63 +50,25 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable { * Reduces the elements of this RDD in a multi-level tree pattern. * * @param depth suggested depth of the tree (default: 2) - * @see [[org.apache.spark.rdd.RDD#reduce]] + * @see [[org.apache.spark.rdd.RDD#treeReduce]] + * @deprecated Use [[org.apache.spark.rdd.RDD#treeReduce]] instead. */ - def treeReduce(f: (T, T) => T, depth: Int = 2): T = { - require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") - val cleanF = self.context.clean(f) - val reducePartition: Iterator[T] => Option[T] = iter => { - if (iter.hasNext) { - Some(iter.reduceLeft(cleanF)) - } else { - None - } - } - val partiallyReduced = self.mapPartitions(it => Iterator(reducePartition(it))) - val op: (Option[T], Option[T]) => Option[T] = (c, x) => { - if (c.isDefined && x.isDefined) { - Some(cleanF(c.get, x.get)) - } else if (c.isDefined) { - c - } else if (x.isDefined) { - x - } else { - None - } - } - RDDFunctions.fromRDD(partiallyReduced).treeAggregate(Option.empty[T])(op, op, depth) - .getOrElse(throw new UnsupportedOperationException("empty collection")) - } + @deprecated("Use RDD.treeReduce instead.", "1.3.0") + def treeReduce(f: (T, T) => T, depth: Int = 2): T = self.treeReduce(f, depth) /** * Aggregates the elements of this RDD in a multi-level tree pattern. * * @param depth suggested depth of the tree (default: 2) - * @see [[org.apache.spark.rdd.RDD#aggregate]] + * @see [[org.apache.spark.rdd.RDD#treeAggregate]] + * @deprecated Use [[org.apache.spark.rdd.RDD#treeAggregate]] instead. */ + @deprecated("Use RDD.treeAggregate instead.", "1.3.0") def treeAggregate[U: ClassTag](zeroValue: U)( seqOp: (U, T) => U, combOp: (U, U) => U, depth: Int = 2): U = { - require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") - if (self.partitions.size == 0) { - return Utils.clone(zeroValue, self.context.env.closureSerializer.newInstance()) - } - val cleanSeqOp = self.context.clean(seqOp) - val cleanCombOp = self.context.clean(combOp) - val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) - var partiallyAggregated = self.mapPartitions(it => Iterator(aggregatePartition(it))) - var numPartitions = partiallyAggregated.partitions.size - val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) - // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation. - while (numPartitions > scale + numPartitions / scale) { - numPartitions /= scale - val curNumPartitions = numPartitions - partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) => - iter.map((i % curNumPartitions, _)) - }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values - } - partiallyAggregated.reduce(cleanCombOp) + self.treeAggregate(zeroValue)(seqOp, combOp, depth) } } @@ -117,5 +76,5 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable { object RDDFunctions { /** Implicit conversion from an RDD to RDDFunctions. */ - implicit def fromRDD[T: ClassTag](rdd: RDD[T]) = new RDDFunctions[T](rdd) + implicit def fromRDD[T: ClassTag](rdd: RDD[T]): RDDFunctions[T] = new RDDFunctions[T](rdd) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 5f84677be238..dddefe1944e9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -17,52 +17,16 @@ package org.apache.spark.mllib.recommendation -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer -import scala.math.{abs, sqrt} -import scala.util.{Random, Sorting} -import scala.util.hashing.byteswap32 - -import org.jblas.{DoubleMatrix, SimpleBlas, Solve} - -import org.apache.spark.{HashPartitioner, Logging, Partitioner} -import org.apache.spark.SparkContext._ -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.JavaRDD -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.mllib.optimization.NNLS +import org.apache.spark.ml.recommendation.{ALS => NewALS} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.Utils /** - * Out-link information for a user or product block. This includes the original user/product IDs - * of the elements within this block, and the list of destination blocks that each user or - * product will need to send its feature vector to. - */ -private[recommendation] -case class OutLinkBlock(elementIds: Array[Int], shouldSend: Array[mutable.BitSet]) - - -/** - * In-link information for a user (or product) block. This includes the original user/product IDs - * of the elements within this block, as well as an array of indices and ratings that specify - * which user in the block will be rated by which products from each product block (or vice-versa). - * Specifically, if this InLinkBlock is for users, ratingsForBlock(b)(i) will contain two arrays, - * indices and ratings, for the i'th product that will be sent to us by product block b (call this - * P). These arrays represent the users that product P had ratings for (by their index in this - * block), as well as the corresponding rating for each one. We can thus use this information when - * we get product block b's message to update the corresponding users. - */ -private[recommendation] case class InLinkBlock( - elementIds: Array[Int], ratingsForBlock: Array[Array[(Array[Int], Array[Double])]]) - - -/** - * :: Experimental :: * A more compact class to represent a rating than Tuple3[Int, Int, Double]. */ -@Experimental case class Rating(user: Int, product: Int, rating: Double) /** @@ -118,6 +82,9 @@ class ALS private ( private var intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK private var finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK + /** checkpoint interval */ + private var checkpointInterval: Int = 10 + /** * Set the number of blocks for both user blocks and product blocks to parallelize the computation * into; pass -1 for an auto-configured number of blocks. Default: -1. @@ -169,10 +136,8 @@ class ALS private ( } /** - * :: Experimental :: * Sets the constant used in computing confidence in implicit ALS. Default: 1.0. */ - @Experimental def setAlpha(alpha: Double): this.type = { this.alpha = alpha this @@ -201,6 +166,8 @@ class ALS private ( */ @DeveloperApi def setIntermediateRDDStorageLevel(storageLevel: StorageLevel): this.type = { + require(storageLevel != StorageLevel.NONE, + "ALS is not designed to run without persisting intermediate RDDs.") this.intermediateRDDStorageLevel = storageLevel this } @@ -218,6 +185,19 @@ class ALS private ( this } + /** + * Set period (in iterations) between checkpoints (default = 10). Checkpointing helps with + * recovery (when nodes fail) and StackOverflow exceptions caused by long lineage. It also helps + * with eliminating temporary shuffle files on disk, which can be important when there are many + * ALS iterations. If the checkpoint directory is not set in [[org.apache.spark.SparkContext]], + * this setting is ignored. + */ + @DeveloperApi + def setCheckpointInterval(checkpointInterval: Int): this.type = { + this.checkpointInterval = checkpointInterval + this + } + /** * Run ALS with the configured parameters on an input RDD of (user, product, rating) triples. * Returns a MatrixFactorizationModel with feature vectors for each user and product. @@ -236,431 +216,40 @@ class ALS private ( this.numProductBlocks } - val userPartitioner = new ALSPartitioner(numUserBlocks) - val productPartitioner = new ALSPartitioner(numProductBlocks) - - val ratingsByUserBlock = ratings.map { rating => - (userPartitioner.getPartition(rating.user), rating) - } - val ratingsByProductBlock = ratings.map { rating => - (productPartitioner.getPartition(rating.product), - Rating(rating.product, rating.user, rating.rating)) - } - - val (userInLinks, userOutLinks) = - makeLinkRDDs(numUserBlocks, numProductBlocks, ratingsByUserBlock, productPartitioner) - val (productInLinks, productOutLinks) = - makeLinkRDDs(numProductBlocks, numUserBlocks, ratingsByProductBlock, userPartitioner) - userInLinks.setName("userInLinks") - userOutLinks.setName("userOutLinks") - productInLinks.setName("productInLinks") - productOutLinks.setName("productOutLinks") - - // Initialize user and product factors randomly, but use a deterministic seed for each - // partition so that fault recovery works - val seedGen = new Random(seed) - val seed1 = seedGen.nextInt() - val seed2 = seedGen.nextInt() - var users = userOutLinks.mapPartitionsWithIndex { (index, itr) => - val rand = new Random(byteswap32(seed1 ^ index)) - itr.map { case (x, y) => - (x, y.elementIds.map(_ => randomFactor(rank, rand))) - } - } - var products = productOutLinks.mapPartitionsWithIndex { (index, itr) => - val rand = new Random(byteswap32(seed2 ^ index)) - itr.map { case (x, y) => - (x, y.elementIds.map(_ => randomFactor(rank, rand))) - } + val (floatUserFactors, floatProdFactors) = NewALS.train[Int]( + ratings = ratings.map(r => NewALS.Rating(r.user, r.product, r.rating.toFloat)), + rank = rank, + numUserBlocks = numUserBlocks, + numItemBlocks = numProductBlocks, + maxIter = iterations, + regParam = lambda, + implicitPrefs = implicitPrefs, + alpha = alpha, + nonnegative = nonnegative, + intermediateRDDStorageLevel = intermediateRDDStorageLevel, + finalRDDStorageLevel = StorageLevel.NONE, + checkpointInterval = checkpointInterval, + seed = seed) + + val userFactors = floatUserFactors + .mapValues(_.map(_.toDouble)) + .setName("users") + .persist(finalRDDStorageLevel) + val prodFactors = floatProdFactors + .mapValues(_.map(_.toDouble)) + .setName("products") + .persist(finalRDDStorageLevel) + if (finalRDDStorageLevel != StorageLevel.NONE) { + userFactors.count() + prodFactors.count() } - - if (implicitPrefs) { - for (iter <- 1 to iterations) { - // perform ALS update - logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations)) - // Persist users because it will be called twice. - users.setName(s"users-$iter").persist() - val YtY = Some(sc.broadcast(computeYtY(users))) - val previousProducts = products - products = updateFeatures(numProductBlocks, users, userOutLinks, productInLinks, - rank, lambda, alpha, YtY) - previousProducts.unpersist() - logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations)) - if (sc.checkpointDir.isDefined && (iter % 3 == 0)) { - products.checkpoint() - } - products.setName(s"products-$iter").persist() - val XtX = Some(sc.broadcast(computeYtY(products))) - val previousUsers = users - users = updateFeatures(numUserBlocks, products, productOutLinks, userInLinks, - rank, lambda, alpha, XtX) - previousUsers.unpersist() - } - } else { - for (iter <- 1 to iterations) { - // perform ALS update - logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations)) - products = updateFeatures(numProductBlocks, users, userOutLinks, productInLinks, - rank, lambda, alpha, YtY = None) - if (sc.checkpointDir.isDefined && (iter % 3 == 0)) { - products.checkpoint() - } - products.setName(s"products-$iter") - logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations)) - users = updateFeatures(numUserBlocks, products, productOutLinks, userInLinks, - rank, lambda, alpha, YtY = None) - users.setName(s"users-$iter") - } - } - - // The last `products` will be used twice. One to generate the last `users` and the other to - // generate `productsOut`. So we cache it for better performance. - products.setName("products").persist() - - // Flatten and cache the two final RDDs to un-block them - val usersOut = unblockFactors(users, userOutLinks) - val productsOut = unblockFactors(products, productOutLinks) - - usersOut.setName("usersOut").persist(finalRDDStorageLevel) - productsOut.setName("productsOut").persist(finalRDDStorageLevel) - - // Materialize usersOut and productsOut. - usersOut.count() - productsOut.count() - - products.unpersist() - - // Clean up. - userInLinks.unpersist() - userOutLinks.unpersist() - productInLinks.unpersist() - productOutLinks.unpersist() - - new MatrixFactorizationModel(rank, usersOut, productsOut) + new MatrixFactorizationModel(rank, userFactors, prodFactors) } /** * Java-friendly version of [[ALS.run]]. */ def run(ratings: JavaRDD[Rating]): MatrixFactorizationModel = run(ratings.rdd) - - /** - * Computes the (`rank x rank`) matrix `YtY`, where `Y` is the (`nui x rank`) matrix of factors - * for each user (or product), in a distributed fashion. - * - * @param factors the (block-distributed) user or product factor vectors - * @return YtY - whose value is only used in the implicit preference model - */ - private def computeYtY(factors: RDD[(Int, Array[Array[Double]])]) = { - val n = rank * (rank + 1) / 2 - val LYtY = factors.values.aggregate(new DoubleMatrix(n))( seqOp = (L, Y) => { - Y.foreach(y => dspr(1.0, wrapDoubleArray(y), L)) - L - }, combOp = (L1, L2) => { - L1.addi(L2) - }) - val YtY = new DoubleMatrix(rank, rank) - fillFullMatrix(LYtY, YtY) - YtY - } - - /** - * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's DSPR. - * - * @param L the lower triangular part of the matrix packed in an array (row major) - */ - private def dspr(alpha: Double, x: DoubleMatrix, L: DoubleMatrix) = { - val n = x.length - var i = 0 - var j = 0 - var idx = 0 - var axi = 0.0 - val xd = x.data - val Ld = L.data - while (i < n) { - axi = alpha * xd(i) - j = 0 - while (j <= i) { - Ld(idx) += axi * xd(j) - j += 1 - idx += 1 - } - i += 1 - } - } - - /** - * Wrap a double array in a DoubleMatrix without creating garbage. - * This is a temporary fix for jblas 1.2.3; it should be safe to move back to the - * DoubleMatrix(double[]) constructor come jblas 1.2.4. - */ - private def wrapDoubleArray(v: Array[Double]): DoubleMatrix = { - new DoubleMatrix(v.length, 1, v: _*) - } - - /** - * Flatten out blocked user or product factors into an RDD of (id, factor vector) pairs - */ - private def unblockFactors( - blockedFactors: RDD[(Int, Array[Array[Double]])], - outLinks: RDD[(Int, OutLinkBlock)]): RDD[(Int, Array[Double])] = { - blockedFactors.join(outLinks).flatMap { case (b, (factors, outLinkBlock)) => - for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i)) - } - } - - /** - * Make the out-links table for a block of the users (or products) dataset given the list of - * (user, product, rating) values for the users in that block (or the opposite for products). - */ - private def makeOutLinkBlock(numProductBlocks: Int, ratings: Array[Rating], - productPartitioner: Partitioner): OutLinkBlock = { - val userIds = ratings.map(_.user).distinct.sorted - val numUsers = userIds.length - val userIdToPos = userIds.zipWithIndex.toMap - val shouldSend = Array.fill(numUsers)(new mutable.BitSet(numProductBlocks)) - for (r <- ratings) { - shouldSend(userIdToPos(r.user))(productPartitioner.getPartition(r.product)) = true - } - OutLinkBlock(userIds, shouldSend) - } - - /** - * Make the in-links table for a block of the users (or products) dataset given a list of - * (user, product, rating) values for the users in that block (or the opposite for products). - */ - private def makeInLinkBlock(numProductBlocks: Int, ratings: Array[Rating], - productPartitioner: Partitioner): InLinkBlock = { - val userIds = ratings.map(_.user).distinct.sorted - val userIdToPos = userIds.zipWithIndex.toMap - // Split out our ratings by product block - val blockRatings = Array.fill(numProductBlocks)(new ArrayBuffer[Rating]) - for (r <- ratings) { - blockRatings(productPartitioner.getPartition(r.product)) += r - } - val ratingsForBlock = new Array[Array[(Array[Int], Array[Double])]](numProductBlocks) - for (productBlock <- 0 until numProductBlocks) { - // Create an array of (product, Seq(Rating)) ratings - val groupedRatings = blockRatings(productBlock).groupBy(_.product).toArray - // Sort them by product ID - val ordering = new Ordering[(Int, ArrayBuffer[Rating])] { - def compare(a: (Int, ArrayBuffer[Rating]), b: (Int, ArrayBuffer[Rating])): Int = - a._1 - b._1 - } - Sorting.quickSort(groupedRatings)(ordering) - // Translate the user IDs to indices based on userIdToPos - ratingsForBlock(productBlock) = groupedRatings.map { case (p, rs) => - (rs.view.map(r => userIdToPos(r.user)).toArray, rs.view.map(_.rating).toArray) - } - } - InLinkBlock(userIds, ratingsForBlock) - } - - /** - * Make RDDs of InLinkBlocks and OutLinkBlocks given an RDD of (blockId, (u, p, r)) values for - * the users (or (blockId, (p, u, r)) for the products). We create these simultaneously to avoid - * having to shuffle the (blockId, (u, p, r)) RDD twice, or to cache it. - */ - private def makeLinkRDDs( - numUserBlocks: Int, - numProductBlocks: Int, - ratingsByUserBlock: RDD[(Int, Rating)], - productPartitioner: Partitioner): (RDD[(Int, InLinkBlock)], RDD[(Int, OutLinkBlock)]) = { - val grouped = ratingsByUserBlock.partitionBy(new HashPartitioner(numUserBlocks)) - val links = grouped.mapPartitionsWithIndex((blockId, elements) => { - val ratings = elements.map(_._2).toArray - val inLinkBlock = makeInLinkBlock(numProductBlocks, ratings, productPartitioner) - val outLinkBlock = makeOutLinkBlock(numProductBlocks, ratings, productPartitioner) - Iterator.single((blockId, (inLinkBlock, outLinkBlock))) - }, preservesPartitioning = true) - val inLinks = links.mapValues(_._1) - val outLinks = links.mapValues(_._2) - inLinks.persist(intermediateRDDStorageLevel) - outLinks.persist(intermediateRDDStorageLevel) - (inLinks, outLinks) - } - - /** - * Make a random factor vector with the given random. - */ - private def randomFactor(rank: Int, rand: Random): Array[Double] = { - // Choose a unit vector uniformly at random from the unit sphere, but from the - // "first quadrant" where all elements are nonnegative. This can be done by choosing - // elements distributed as Normal(0,1) and taking the absolute value, and then normalizing. - // This appears to create factorizations that have a slightly better reconstruction - // (<1%) compared picking elements uniformly at random in [0,1]. - val factor = Array.fill(rank)(abs(rand.nextGaussian())) - val norm = sqrt(factor.map(x => x * x).sum) - factor.map(x => x / norm) - } - - /** - * Compute the user feature vectors given the current products (or vice-versa). This first joins - * the products with their out-links to generate a set of messages to each destination block - * (specifically, the features for the products that user block cares about), then groups these - * by destination and joins them with the in-link info to figure out how to update each user. - * It returns an RDD of new feature vectors for each user block. - */ - private def updateFeatures( - numUserBlocks: Int, - products: RDD[(Int, Array[Array[Double]])], - productOutLinks: RDD[(Int, OutLinkBlock)], - userInLinks: RDD[(Int, InLinkBlock)], - rank: Int, - lambda: Double, - alpha: Double, - YtY: Option[Broadcast[DoubleMatrix]]): RDD[(Int, Array[Array[Double]])] = { - productOutLinks.join(products).flatMap { case (bid, (outLinkBlock, factors)) => - val toSend = Array.fill(numUserBlocks)(new ArrayBuffer[Array[Double]]) - for (p <- 0 until outLinkBlock.elementIds.length; userBlock <- 0 until numUserBlocks) { - if (outLinkBlock.shouldSend(p)(userBlock)) { - toSend(userBlock) += factors(p) - } - } - toSend.zipWithIndex.map{ case (buf, idx) => (idx, (bid, buf.toArray)) } - }.groupByKey(new HashPartitioner(numUserBlocks)) - .join(userInLinks) - .mapValues{ case (messages, inLinkBlock) => - updateBlock(messages, inLinkBlock, rank, lambda, alpha, YtY) - } - } - - /** - * Compute the new feature vectors for a block of the users matrix given the list of factors - * it received from each product and its InLinkBlock. - */ - private def updateBlock(messages: Iterable[(Int, Array[Array[Double]])], inLinkBlock: InLinkBlock, - rank: Int, lambda: Double, alpha: Double, YtY: Option[Broadcast[DoubleMatrix]]) - : Array[Array[Double]] = - { - // Sort the incoming block factor messages by block ID and make them an array - val blockFactors = messages.toSeq.sortBy(_._1).map(_._2).toArray // Array[Array[Double]] - val numProductBlocks = blockFactors.length - val numUsers = inLinkBlock.elementIds.length - - // We'll sum up the XtXes using vectors that represent only the lower-triangular part, since - // the matrices are symmetric - val triangleSize = rank * (rank + 1) / 2 - val userXtX = Array.fill(numUsers)(DoubleMatrix.zeros(triangleSize)) - val userXy = Array.fill(numUsers)(DoubleMatrix.zeros(rank)) - - // Some temp variables to avoid memory allocation - val tempXtX = DoubleMatrix.zeros(triangleSize) - val fullXtX = DoubleMatrix.zeros(rank, rank) - - // Count the number of ratings each user gives to provide user-specific regularization - val numRatings = Array.fill(numUsers)(0) - - // Compute the XtX and Xy values for each user by adding products it rated in each product - // block - for (productBlock <- 0 until numProductBlocks) { - var p = 0 - while (p < blockFactors(productBlock).length) { - val x = wrapDoubleArray(blockFactors(productBlock)(p)) - tempXtX.fill(0.0) - dspr(1.0, x, tempXtX) - val (us, rs) = inLinkBlock.ratingsForBlock(productBlock)(p) - if (implicitPrefs) { - var i = 0 - while (i < us.length) { - numRatings(us(i)) += 1 - // Extension to the original paper to handle rs(i) < 0. confidence is a function - // of |rs(i)| instead so that it is never negative: - val confidence = 1 + alpha * abs(rs(i)) - SimpleBlas.axpy(confidence - 1.0, tempXtX, userXtX(us(i))) - // For rs(i) < 0, the corresponding entry in P is 0 now, not 1 -- negative rs(i) - // means we try to reconstruct 0. We add terms only where P = 1, so, term below - // is now only added for rs(i) > 0: - if (rs(i) > 0) { - SimpleBlas.axpy(confidence, x, userXy(us(i))) - } - i += 1 - } - } else { - var i = 0 - while (i < us.length) { - numRatings(us(i)) += 1 - userXtX(us(i)).addi(tempXtX) - SimpleBlas.axpy(rs(i), x, userXy(us(i))) - i += 1 - } - } - p += 1 - } - } - - val ws = if (nonnegative) NNLS.createWorkspace(rank) else null - - // Solve the least-squares problem for each user and return the new feature vectors - Array.range(0, numUsers).map { index => - // Compute the full XtX matrix from the lower-triangular part we got above - fillFullMatrix(userXtX(index), fullXtX) - // Add regularization - val regParam = numRatings(index) * lambda - var i = 0 - while (i < rank) { - fullXtX.data(i * rank + i) += regParam - i += 1 - } - // Solve the resulting matrix, which is symmetric and positive-definite - if (implicitPrefs) { - solveLeastSquares(fullXtX.addi(YtY.get.value), userXy(index), ws) - } else { - solveLeastSquares(fullXtX, userXy(index), ws) - } - } - } - - /** - * Given A^T A and A^T b, find the x minimising ||Ax - b||_2, possibly subject - * to nonnegativity constraints if `nonnegative` is true. - */ - def solveLeastSquares(ata: DoubleMatrix, atb: DoubleMatrix, - ws: NNLS.Workspace): Array[Double] = { - if (!nonnegative) { - Solve.solvePositive(ata, atb).data - } else { - NNLS.solve(ata, atb, ws) - } - } - - /** - * Given a triangular matrix in the order of fillXtX above, compute the full symmetric square - * matrix that it represents, storing it into destMatrix. - */ - private def fillFullMatrix(triangularMatrix: DoubleMatrix, destMatrix: DoubleMatrix) { - val rank = destMatrix.rows - var i = 0 - var pos = 0 - while (i < rank) { - var j = 0 - while (j <= i) { - destMatrix.data(i*rank + j) = triangularMatrix.data(pos) - destMatrix.data(j*rank + i) = triangularMatrix.data(pos) - pos += 1 - j += 1 - } - i += 1 - } - } -} - -/** - * Partitioner for ALS. - */ -private[recommendation] class ALSPartitioner(override val numPartitions: Int) extends Partitioner { - override def getPartition(key: Any): Int = { - Utils.nonNegativeMod(byteswap32(key.asInstanceOf[Int]), numPartitions) - } - - override def equals(obj: Any): Boolean = { - obj match { - case p: ALSPartitioner => - this.numPartitions == p.numPartitions - case _ => - false - } - } } /** @@ -834,120 +423,4 @@ object ALS { : MatrixFactorizationModel = { trainImplicit(ratings, rank, iterations, 0.01, -1, 1.0) } - - /** - * :: DeveloperApi :: - * Statistics of a block in ALS computation. - * - * @param category type of this block, "user" or "product" - * @param index index of this block - * @param count number of users or products inside this block, the same as the number of - * least-squares problems to solve on this block in each iteration - * @param numRatings total number of ratings inside this block, the same as the number of outer - * products we need to make on this block in each iteration - * @param numInLinks total number of incoming links, the same as the number of vectors to retrieve - * before each iteration - * @param numOutLinks total number of outgoing links, the same as the number of vectors to send - * for the next iteration - */ - @DeveloperApi - case class BlockStats( - category: String, - index: Int, - count: Long, - numRatings: Long, - numInLinks: Long, - numOutLinks: Long) - - /** - * :: DeveloperApi :: - * Given an RDD of ratings, number of user blocks, and number of product blocks, computes the - * statistics of each block in ALS computation. This is useful for estimating cost and diagnosing - * load balance. - * - * @param ratings an RDD of ratings - * @param numUserBlocks number of user blocks - * @param numProductBlocks number of product blocks - * @return statistics of user blocks and product blocks - */ - @DeveloperApi - def analyzeBlocks( - ratings: RDD[Rating], - numUserBlocks: Int, - numProductBlocks: Int): Array[BlockStats] = { - - val userPartitioner = new ALSPartitioner(numUserBlocks) - val productPartitioner = new ALSPartitioner(numProductBlocks) - - val ratingsByUserBlock = ratings.map { rating => - (userPartitioner.getPartition(rating.user), rating) - } - val ratingsByProductBlock = ratings.map { rating => - (productPartitioner.getPartition(rating.product), - Rating(rating.product, rating.user, rating.rating)) - } - - val als = new ALS() - val (userIn, userOut) = - als.makeLinkRDDs(numUserBlocks, numProductBlocks, ratingsByUserBlock, userPartitioner) - val (prodIn, prodOut) = - als.makeLinkRDDs(numProductBlocks, numUserBlocks, ratingsByProductBlock, productPartitioner) - - def sendGrid(outLinks: RDD[(Int, OutLinkBlock)]): Map[(Int, Int), Long] = { - outLinks.map { x => - val grid = new mutable.HashMap[(Int, Int), Long]() - val uPartition = x._1 - x._2.shouldSend.foreach { ss => - ss.foreach { pPartition => - val pair = (uPartition, pPartition) - grid.put(pair, grid.getOrElse(pair, 0L) + 1L) - } - } - grid - }.reduce { (grid1, grid2) => - grid2.foreach { x => - grid1.put(x._1, grid1.getOrElse(x._1, 0L) + x._2) - } - grid1 - }.toMap - } - - val userSendGrid = sendGrid(userOut) - val prodSendGrid = sendGrid(prodOut) - - val userInbound = new Array[Long](numUserBlocks) - val prodInbound = new Array[Long](numProductBlocks) - val userOutbound = new Array[Long](numUserBlocks) - val prodOutbound = new Array[Long](numProductBlocks) - - for (u <- 0 until numUserBlocks; p <- 0 until numProductBlocks) { - userOutbound(u) += userSendGrid.getOrElse((u, p), 0L) - prodInbound(p) += userSendGrid.getOrElse((u, p), 0L) - userInbound(u) += prodSendGrid.getOrElse((p, u), 0L) - prodOutbound(p) += prodSendGrid.getOrElse((p, u), 0L) - } - - val userCounts = userOut.mapValues(x => x.elementIds.length).collectAsMap() - val prodCounts = prodOut.mapValues(x => x.elementIds.length).collectAsMap() - - val userRatings = countRatings(userIn) - val prodRatings = countRatings(prodIn) - - val userStats = Array.tabulate(numUserBlocks)( - u => BlockStats("user", u, userCounts(u), userRatings(u), userInbound(u), userOutbound(u))) - val productStatus = Array.tabulate(numProductBlocks)( - p => BlockStats("product", p, prodCounts(p), prodRatings(p), prodInbound(p), prodOutbound(p))) - - (userStats ++ productStatus).toArray - } - - private def countRatings(inLinks: RDD[(Int, InLinkBlock)]): Map[Int, Long] = { - inLinks.mapValues { ilb => - var numRatings = 0L - ilb.ratingsForBlock.foreach { ar => - ar.foreach { p => numRatings += p._1.length } - } - numRatings - }.collectAsMap().toMap - } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index ed2f8b41bcae..36cbf060d999 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -17,13 +17,20 @@ package org.apache.spark.mllib.recommendation +import java.io.IOException import java.lang.{Integer => JavaInteger} -import org.jblas.DoubleMatrix +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ +import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.storage.StorageLevel /** @@ -41,7 +48,8 @@ import org.apache.spark.storage.StorageLevel class MatrixFactorizationModel( val rank: Int, val userFeatures: RDD[(Int, Array[Double])], - val productFeatures: RDD[(Int, Array[Double])]) extends Serializable with Logging { + val productFeatures: RDD[(Int, Array[Double])]) + extends Saveable with Serializable with Logging { require(rank > 0) validateFeatures("User", userFeatures) @@ -62,9 +70,9 @@ class MatrixFactorizationModel( /** Predict the rating of one user for one product. */ def predict(user: Int, product: Int): Double = { - val userVector = new DoubleMatrix(userFeatures.lookup(user).head) - val productVector = new DoubleMatrix(productFeatures.lookup(product).head) - userVector.dot(productVector) + val userVector = userFeatures.lookup(user).head + val productVector = productFeatures.lookup(product).head + blas.ddot(userVector.length, userVector, 1, productVector, 1) } /** @@ -81,9 +89,7 @@ class MatrixFactorizationModel( } users.join(productFeatures).map { case (product, ((user, uFeatures), pFeatures)) => - val userVector = new DoubleMatrix(uFeatures) - val productVector = new DoubleMatrix(pFeatures) - Rating(user, product, userVector.dot(productVector)) + Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1)) } } @@ -125,14 +131,87 @@ class MatrixFactorizationModel( recommend(productFeatures.lookup(product).head, userFeatures, num) .map(t => Rating(t._1, product, t._2)) + protected override val formatVersion: String = "1.0" + + override def save(sc: SparkContext, path: String): Unit = { + MatrixFactorizationModel.SaveLoadV1_0.save(this, path) + } + private def recommend( recommendToFeatures: Array[Double], recommendableFeatures: RDD[(Int, Array[Double])], num: Int): Array[(Int, Double)] = { - val recommendToVector = new DoubleMatrix(recommendToFeatures) val scored = recommendableFeatures.map { case (id,features) => - (id, recommendToVector.dot(new DoubleMatrix(features))) + (id, blas.ddot(features.length, recommendToFeatures, 1, features, 1)) } scored.top(num)(Ordering.by(_._2)) } } + +object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { + + import org.apache.spark.mllib.util.Loader._ + + override def load(sc: SparkContext, path: String): MatrixFactorizationModel = { + val (loadedClassName, formatVersion, _) = loadMetadata(sc, path) + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, formatVersion) match { + case (className, "1.0") if className == classNameV1_0 => + SaveLoadV1_0.load(sc, path) + case _ => + throw new IOException("MatrixFactorizationModel.load did not recognize model with" + + s"(class: $loadedClassName, version: $formatVersion). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } + + private[recommendation] + object SaveLoadV1_0 { + + private val thisFormatVersion = "1.0" + + private[recommendation] + val thisClassName = "org.apache.spark.mllib.recommendation.MatrixFactorizationModel" + + /** + * Saves a [[MatrixFactorizationModel]], where user features are saved under `data/users` and + * product features are saved under `data/products`. + */ + def save(model: MatrixFactorizationModel, path: String): Unit = { + val sc = model.userFeatures.sparkContext + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) + model.userFeatures.toDF("id", "features").saveAsParquetFile(userPath(path)) + model.productFeatures.toDF("id", "features").saveAsParquetFile(productPath(path)) + } + + def load(sc: SparkContext, path: String): MatrixFactorizationModel = { + implicit val formats = DefaultFormats + val sqlContext = new SQLContext(sc) + val (className, formatVersion, metadata) = loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + val rank = (metadata \ "rank").extract[Int] + val userFeatures = sqlContext.parquetFile(userPath(path)) + .map { case Row(id: Int, features: Seq[_]) => + (id, features.asInstanceOf[Seq[Double]].toArray) + } + val productFeatures = sqlContext.parquetFile(productPath(path)) + .map { case Row(id: Int, features: Seq[_]) => + (id, features.asInstanceOf[Seq[Double]].toArray) + } + new MatrixFactorizationModel(rank, userFeatures, productFeatures) + } + + private def userPath(path: String): String = { + new Path(dataPath(path), "user").toUri.toString + } + + private def productPath(path: String): String = { + new Path(dataPath(path), "product").toUri.toString + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 0287f04e2c77..9fd60ff7a0c7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -76,7 +76,12 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double predictPoint(testData, weights, intercept) } - override def toString() = "(weights=%s, intercept=%s)".format(weights, intercept) + /** + * Print a summary of the model. + */ + override def toString: String = { + s"${this.getClass.getName}: intercept = ${intercept}, numFeatures = ${weights.size}" + } } /** @@ -98,6 +103,23 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] protected var validateData: Boolean = true + /** + * In `GeneralizedLinearModel`, only single linear predictor is allowed for both weights + * and intercept. However, for multinomial logistic regression, with K possible outcomes, + * we are training K-1 independent binary logistic regression models which requires K-1 sets + * of linear predictor. + * + * As a result, the workaround here is if more than two sets of linear predictors are needed, + * we construct bigger `weights` vector which can hold both weights and intercepts. + * If the intercepts are added, the dimension of `weights` will be + * (numOfLinearPredictor) * (numFeatures + 1) . If the intercepts are not added, + * the dimension of `weights` will be (numOfLinearPredictor) * numFeatures. + * + * Thus, the intercepts will be encapsulated into weights, and we leave the value of intercept + * in GeneralizedLinearModel as zero. + */ + protected var numOfLinearPredictor: Int = 1 + /** * Whether to perform feature scaling before model training to reduce the condition numbers * which can significantly help the optimizer converging faster. The scaling correction will be @@ -106,6 +128,16 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] */ private var useFeatureScaling = false + /** + * The dimension of training features. + */ + def getNumFeatures: Int = this.numFeatures + + /** + * The dimension of training features. + */ + protected var numFeatures: Int = -1 + /** * Set if the algorithm should use feature scaling to improve the convergence during optimization. */ @@ -119,6 +151,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] */ protected def createModel(weights: Vector, intercept: Double): M + /** + * Get if the algorithm uses addIntercept + */ + def isAddIntercept: Boolean = this.addIntercept + /** * Set if the algorithm should add an intercept. Default false. * We set the default to false because adding the intercept will cause memory allocation. @@ -141,8 +178,30 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * RDD of LabeledPoint entries. */ def run(input: RDD[LabeledPoint]): M = { - val numFeatures: Int = input.first().features.size - val initialWeights = Vectors.dense(new Array[Double](numFeatures)) + if (numFeatures < 0) { + numFeatures = input.map(_.features.size).first() + } + + /** + * When `numOfLinearPredictor > 1`, the intercepts are encapsulated into weights, + * so the `weights` will include the intercepts. When `numOfLinearPredictor == 1`, + * the intercept will be stored as separated value in `GeneralizedLinearModel`. + * This will result in different behaviors since when `numOfLinearPredictor == 1`, + * users have no way to set the initial intercept, while in the other case, users + * can set the intercepts as part of weights. + * + * TODO: See if we can deprecate `intercept` in `GeneralizedLinearModel`, and always + * have the intercept as part of weights to have consistent design. + */ + val initialWeights = { + if (numOfLinearPredictor == 1) { + Vectors.dense(new Array[Double](numFeatures)) + } else if (addIntercept) { + Vectors.dense(new Array[Double]((numFeatures + 1) * numOfLinearPredictor)) + } else { + Vectors.dense(new Array[Double](numFeatures * numOfLinearPredictor)) + } + } run(input, initialWeights) } @@ -152,6 +211,10 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] */ def run(input: RDD[LabeledPoint], initialWeights: Vector): M = { + if (numFeatures < 0) { + numFeatures = input.map(_.features.size).first() + } + if (input.getStorageLevel == StorageLevel.NONE) { logWarning("The input data is not directly cached, which may hurt performance if its" + " parent RDDs are also uncached.") @@ -162,7 +225,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] throw new SparkException("Input validation failed.") } - /** + /* * Scaling columns to unit variance as a heuristic to reduce the condition number: * * During the optimization process, the convergence (rate) depends on the condition number of @@ -182,42 +245,53 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * Currently, it's only enabled in LogisticRegressionWithLBFGS */ val scaler = if (useFeatureScaling) { - (new StandardScaler).fit(input.map(x => x.features)) + new StandardScaler(withStd = true, withMean = false).fit(input.map(_.features)) } else { null } // Prepend an extra variable consisting of all 1.0's for the intercept. - val data = if (addIntercept) { - if(useFeatureScaling) { - input.map(labeledPoint => - (labeledPoint.label, appendBias(scaler.transform(labeledPoint.features)))) - } else { - input.map(labeledPoint => (labeledPoint.label, appendBias(labeledPoint.features))) - } - } else { - if (useFeatureScaling) { - input.map(labeledPoint => (labeledPoint.label, scaler.transform(labeledPoint.features))) + // TODO: Apply feature scaling to the weight vector instead of input data. + val data = + if (addIntercept) { + if (useFeatureScaling) { + input.map(lp => (lp.label, appendBias(scaler.transform(lp.features)))).cache() + } else { + input.map(lp => (lp.label, appendBias(lp.features))).cache() + } } else { - input.map(labeledPoint => (labeledPoint.label, labeledPoint.features)) + if (useFeatureScaling) { + input.map(lp => (lp.label, scaler.transform(lp.features))).cache() + } else { + input.map(lp => (lp.label, lp.features)) + } } - } - val initialWeightsWithIntercept = if (addIntercept) { + /** + * TODO: For better convergence, in logistic regression, the intercepts should be computed + * from the prior probability distribution of the outcomes; for linear regression, + * the intercept should be set as the average of response. + */ + val initialWeightsWithIntercept = if (addIntercept && numOfLinearPredictor == 1) { appendBias(initialWeights) } else { + /** If `numOfLinearPredictor > 1`, initialWeights already contains intercepts. */ initialWeights } val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept) - val intercept = if (addIntercept) weightsWithIntercept(weightsWithIntercept.size - 1) else 0.0 - var weights = - if (addIntercept) { - Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1)) - } else { - weightsWithIntercept - } + val intercept = if (addIntercept && numOfLinearPredictor == 1) { + weightsWithIntercept(weightsWithIntercept.size - 1) + } else { + 0.0 + } + + var weights = if (addIntercept && numOfLinearPredictor == 1) { + Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1)) + } else { + weightsWithIntercept + } /** * The weights and intercept are trained in the scaled space; we're converting them back to @@ -228,7 +302,29 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * is the coefficient in the original space, and v_i is the variance of the column i. */ if (useFeatureScaling) { - weights = scaler.transform(weights) + if (numOfLinearPredictor == 1) { + weights = scaler.transform(weights) + } else { + /** + * For `numOfLinearPredictor > 1`, we have to transform the weights back to the original + * scale for each set of linear predictor. Note that the intercepts have to be explicitly + * excluded when `addIntercept == true` since the intercepts are part of weights now. + */ + var i = 0 + val n = weights.size / numOfLinearPredictor + val weightsArray = weights.toArray + while (i < numOfLinearPredictor) { + val start = i * n + val end = (i + 1) * n - { if (addIntercept) 1 else 0 } + + val partialWeightsArray = scaler.transform( + Vectors.dense(weightsArray.slice(start, end))).toArray + + System.arraycopy(partialWeightsArray, 0, weightsArray, start, partialWeightsArray.size) + i += 1 + } + weights = Vectors.dense(weightsArray) + } } // Warn at the end of the run as well, for increased visibility. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala new file mode 100644 index 000000000000..1d7617046b6c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -0,0 +1,389 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import java.io.Serializable +import java.lang.{Double => JDouble} +import java.util.Arrays.binarySearch + +import scala.collection.mutable.ArrayBuffer + +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD} +import org.apache.spark.mllib.util.{Loader, Saveable} +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext +import org.apache.spark.sql.{DataFrame, SQLContext} + +/** + * :: Experimental :: + * + * Regression model for isotonic regression. + * + * @param boundaries Array of boundaries for which predictions are known. + * Boundaries must be sorted in increasing order. + * @param predictions Array of predictions associated to the boundaries at the same index. + * Results of isotonic regression and therefore monotone. + * @param isotonic indicates whether this is isotonic or antitonic. + */ +@Experimental +class IsotonicRegressionModel ( + val boundaries: Array[Double], + val predictions: Array[Double], + val isotonic: Boolean) extends Serializable with Saveable { + + private val predictionOrd = if (isotonic) Ordering[Double] else Ordering[Double].reverse + + require(boundaries.length == predictions.length) + assertOrdered(boundaries) + assertOrdered(predictions)(predictionOrd) + + /** Asserts the input array is monotone with the given ordering. */ + private def assertOrdered(xs: Array[Double])(implicit ord: Ordering[Double]): Unit = { + var i = 1 + while (i < xs.length) { + require(ord.compare(xs(i - 1), xs(i)) <= 0, + s"Elements (${xs(i - 1)}, ${xs(i)}) are not ordered.") + i += 1 + } + } + + /** + * Predict labels for provided features. + * Using a piecewise linear function. + * + * @param testData Features to be labeled. + * @return Predicted labels. + */ + def predict(testData: RDD[Double]): RDD[Double] = { + testData.map(predict) + } + + /** + * Predict labels for provided features. + * Using a piecewise linear function. + * + * @param testData Features to be labeled. + * @return Predicted labels. + */ + def predict(testData: JavaDoubleRDD): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(predict(testData.rdd.retag.asInstanceOf[RDD[Double]])) + } + + /** + * Predict a single label. + * Using a piecewise linear function. + * + * @param testData Feature to be labeled. + * @return Predicted label. + * 1) If testData exactly matches a boundary then associated prediction is returned. + * In case there are multiple predictions with the same boundary then one of them + * is returned. Which one is undefined (same as java.util.Arrays.binarySearch). + * 2) If testData is lower or higher than all boundaries then first or last prediction + * is returned respectively. In case there are multiple predictions with the same + * boundary then the lowest or highest is returned respectively. + * 3) If testData falls between two values in boundary array then prediction is treated + * as piecewise linear function and interpolated value is returned. In case there are + * multiple values with the same boundary then the same rules as in 2) are used. + */ + def predict(testData: Double): Double = { + + def linearInterpolation(x1: Double, y1: Double, x2: Double, y2: Double, x: Double): Double = { + y1 + (y2 - y1) * (x - x1) / (x2 - x1) + } + + val foundIndex = binarySearch(boundaries, testData) + val insertIndex = -foundIndex - 1 + + // Find if the index was lower than all values, + // higher than all values, in between two values or exact match. + if (insertIndex == 0) { + predictions.head + } else if (insertIndex == boundaries.length){ + predictions.last + } else if (foundIndex < 0) { + linearInterpolation( + boundaries(insertIndex - 1), + predictions(insertIndex - 1), + boundaries(insertIndex), + predictions(insertIndex), + testData) + } else { + predictions(foundIndex) + } + } + + override def save(sc: SparkContext, path: String): Unit = { + IsotonicRegressionModel.SaveLoadV1_0.save(sc, path, boundaries, predictions, isotonic) + } + + override protected def formatVersion: String = "1.0" +} + +object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { + + import org.apache.spark.mllib.util.Loader._ + + private object SaveLoadV1_0 { + + def thisFormatVersion: String = "1.0" + + /** Hard-code class name string in case it changes in the future */ + def thisClassName: String = "org.apache.spark.mllib.regression.IsotonicRegressionModel" + + /** Model data for model import/export */ + case class Data(boundary: Double, prediction: Double) + + def save( + sc: SparkContext, + path: String, + boundaries: Array[Double], + predictions: Array[Double], + isotonic: Boolean): Unit = { + val sqlContext = new SQLContext(sc) + + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("isotonic" -> isotonic))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) + + sqlContext.createDataFrame( + boundaries.toSeq.zip(predictions).map { case (b, p) => Data(b, p) } + ).saveAsParquetFile(dataPath(path)) + } + + def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = { + val sqlContext = new SQLContext(sc) + val dataRDD = sqlContext.parquetFile(dataPath(path)) + + checkSchema[Data](dataRDD.schema) + val dataArray = dataRDD.select("boundary", "prediction").collect() + val (boundaries, predictions) = dataArray.map { x => + (x.getDouble(0), x.getDouble(1)) + }.toList.sortBy(_._1).unzip + (boundaries.toArray, predictions.toArray) + } + } + + override def load(sc: SparkContext, path: String): IsotonicRegressionModel = { + implicit val formats = DefaultFormats + val (loadedClassName, version, metadata) = loadMetadata(sc, path) + val isotonic = (metadata \ "isotonic").extract[Boolean] + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val (boundaries, predictions) = SaveLoadV1_0.load(sc, path) + new IsotonicRegressionModel(boundaries, predictions, isotonic) + case _ => throw new Exception( + s"IsotonicRegressionModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)" + ) + } + } +} + +/** + * :: Experimental :: + * + * Isotonic regression. + * Currently implemented using parallelized pool adjacent violators algorithm. + * Only univariate (single feature) algorithm supported. + * + * Sequential PAV implementation based on: + * Tibshirani, Ryan J., Holger Hoefling, and Robert Tibshirani. + * "Nearly-isotonic regression." Technometrics 53.1 (2011): 54-61. + * Available from [[http://www.stat.cmu.edu/~ryantibs/papers/neariso.pdf]] + * + * Sequential PAV parallelization based on: + * Kearsley, Anthony J., Richard A. Tapia, and Michael W. Trosset. + * "An approach to parallelizing isotonic regression." + * Applied Mathematics and Parallel Computing. Physica-Verlag HD, 1996. 141-147. + * Available from [[http://softlib.rice.edu/pub/CRPC-TRs/reports/CRPC-TR96640.pdf]] + * + * @see [[http://en.wikipedia.org/wiki/Isotonic_regression Isotonic regression (Wikipedia)]] + */ +@Experimental +class IsotonicRegression private (private var isotonic: Boolean) extends Serializable { + + /** + * Constructs IsotonicRegression instance with default parameter isotonic = true. + * + * @return New instance of IsotonicRegression. + */ + def this() = this(true) + + /** + * Sets the isotonic parameter. + * + * @param isotonic Isotonic (increasing) or antitonic (decreasing) sequence. + * @return This instance of IsotonicRegression. + */ + def setIsotonic(isotonic: Boolean): this.type = { + this.isotonic = isotonic + this + } + + /** + * Run IsotonicRegression algorithm to obtain isotonic regression model. + * + * @param input RDD of tuples (label, feature, weight) where label is dependent variable + * for which we calculate isotonic regression, feature is independent variable + * and weight represents number of measures with default 1. + * If multiple labels share the same feature value then they are ordered before + * the algorithm is executed. + * @return Isotonic regression model. + */ + def run(input: RDD[(Double, Double, Double)]): IsotonicRegressionModel = { + val preprocessedInput = if (isotonic) { + input + } else { + input.map(x => (-x._1, x._2, x._3)) + } + + val pooled = parallelPoolAdjacentViolators(preprocessedInput) + + val predictions = if (isotonic) pooled.map(_._1) else pooled.map(-_._1) + val boundaries = pooled.map(_._2) + + new IsotonicRegressionModel(boundaries, predictions, isotonic) + } + + /** + * Run pool adjacent violators algorithm to obtain isotonic regression model. + * + * @param input JavaRDD of tuples (label, feature, weight) where label is dependent variable + * for which we calculate isotonic regression, feature is independent variable + * and weight represents number of measures with default 1. + * If multiple labels share the same feature value then they are ordered before + * the algorithm is executed. + * @return Isotonic regression model. + */ + def run(input: JavaRDD[(JDouble, JDouble, JDouble)]): IsotonicRegressionModel = { + run(input.rdd.retag.asInstanceOf[RDD[(Double, Double, Double)]]) + } + + /** + * Performs a pool adjacent violators algorithm (PAV). + * Uses approach with single processing of data where violators + * in previously processed data created by pooling are fixed immediately. + * Uses optimization of discovering monotonicity violating sequences (blocks). + * + * @param input Input data of tuples (label, feature, weight). + * @return Result tuples (label, feature, weight) where labels were updated + * to form a monotone sequence as per isotonic regression definition. + */ + private def poolAdjacentViolators( + input: Array[(Double, Double, Double)]): Array[(Double, Double, Double)] = { + + if (input.isEmpty) { + return Array.empty + } + + // Pools sub array within given bounds assigning weighted average value to all elements. + def pool(input: Array[(Double, Double, Double)], start: Int, end: Int): Unit = { + val poolSubArray = input.slice(start, end + 1) + + val weightedSum = poolSubArray.map(lp => lp._1 * lp._3).sum + val weight = poolSubArray.map(_._3).sum + + var i = start + while (i <= end) { + input(i) = (weightedSum / weight, input(i)._2, input(i)._3) + i = i + 1 + } + } + + var i = 0 + while (i < input.length) { + var j = i + + // Find monotonicity violating sequence, if any. + while (j < input.length - 1 && input(j)._1 > input(j + 1)._1) { + j = j + 1 + } + + // If monotonicity was not violated, move to next data point. + if (i == j) { + i = i + 1 + } else { + // Otherwise pool the violating sequence + // and check if pooling caused monotonicity violation in previously processed points. + while (i >= 0 && input(i)._1 > input(i + 1)._1) { + pool(input, i, j) + i = i - 1 + } + + i = j + } + } + + // For points having the same prediction, we only keep two boundary points. + val compressed = ArrayBuffer.empty[(Double, Double, Double)] + + var (curLabel, curFeature, curWeight) = input.head + var rightBound = curFeature + def merge(): Unit = { + compressed += ((curLabel, curFeature, curWeight)) + if (rightBound > curFeature) { + compressed += ((curLabel, rightBound, 0.0)) + } + } + i = 1 + while (i < input.length) { + val (label, feature, weight) = input(i) + if (label == curLabel) { + curWeight += weight + rightBound = feature + } else { + merge() + curLabel = label + curFeature = feature + curWeight = weight + rightBound = curFeature + } + i += 1 + } + merge() + + compressed.toArray + } + + /** + * Performs parallel pool adjacent violators algorithm. + * Performs Pool adjacent violators algorithm on each partition and then again on the result. + * + * @param input Input data of tuples (label, feature, weight). + * @return Result tuples (label, feature, weight) where labels were updated + * to form a monotone sequence as per isotonic regression definition. + */ + private def parallelPoolAdjacentViolators( + input: RDD[(Double, Double, Double)]): Array[(Double, Double, Double)] = { + val parallelStepResult = input + .sortBy(x => (x._2, x._1)) + .glom() + .flatMap(poolAdjacentViolators) + .collect() + .sortBy(x => (x._2, x._1)) // Sort again because collect() doesn't promise ordering. + poolAdjacentViolators(parallelStepResult) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index 8ecd5c6ad93c..e8b03816573c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -17,9 +17,11 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.Experimental +import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.regression.impl.GLMRegressionModel +import org.apache.spark.mllib.util.{Saveable, Loader} import org.apache.spark.rdd.RDD /** @@ -32,7 +34,7 @@ class LassoModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable { + with RegressionModel with Serializable with Saveable { override protected def predictPoint( dataMatrix: Vector, @@ -40,12 +42,37 @@ class LassoModel ( intercept: Double): Double = { weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } + + override def save(sc: SparkContext, path: String): Unit = { + GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) + } + + override protected def formatVersion: String = "1.0" +} + +object LassoModel extends Loader[LassoModel] { + + override def load(sc: SparkContext, path: String): LassoModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future + val classNameV1_0 = "org.apache.spark.mllib.regression.LassoModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val numFeatures = RegressionModel.getNumFeatures(metadata) + val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures) + new LassoModel(data.weights, data.intercept) + case _ => throw new Exception( + s"LassoModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } /** * Train a regression model with L1-regularization using Stochastic Gradient Descent. * This solves the l1-regularized least squares regression formulation - * f(weights) = 1/2n ||A weights-y||^2 + regParam ||weights||_1 + * f(weights) = 1/2n ||A weights-y||^2^ + regParam ||weights||_1 * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with * its corresponding right hand side label y. * See also the documentation for the precise formulation. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 81b6598377ff..6fa7ad52a5b3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -17,9 +17,12 @@ package org.apache.spark.mllib.regression -import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.regression.impl.GLMRegressionModel +import org.apache.spark.mllib.util.{Saveable, Loader} +import org.apache.spark.rdd.RDD /** * Regression model trained using LinearRegression. @@ -30,7 +33,8 @@ import org.apache.spark.mllib.optimization._ class LinearRegressionModel ( override val weights: Vector, override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { + extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable + with Saveable { override protected def predictPoint( dataMatrix: Vector, @@ -38,12 +42,37 @@ class LinearRegressionModel ( intercept: Double): Double = { weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } + + override def save(sc: SparkContext, path: String): Unit = { + GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) + } + + override protected def formatVersion: String = "1.0" +} + +object LinearRegressionModel extends Loader[LinearRegressionModel] { + + override def load(sc: SparkContext, path: String): LinearRegressionModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future + val classNameV1_0 = "org.apache.spark.mllib.regression.LinearRegressionModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val numFeatures = RegressionModel.getNumFeatures(metadata) + val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures) + new LinearRegressionModel(data.weights, data.intercept) + case _ => throw new Exception( + s"LinearRegressionModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } /** * Train a linear regression model with no regularization using Stochastic Gradient Descent. * This solves the least squares regression formulation - * f(weights) = 1/n ||A weights-y||^2 + * f(weights) = 1/n ||A weights-y||^2^ * (which is the mean squared error). * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with * its corresponding right hand side label y. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala index 64b02f7a6e7a..214ac4d0ed7d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala @@ -17,10 +17,12 @@ package org.apache.spark.mllib.regression +import org.json4s.{DefaultFormats, JValue} + import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD @Experimental trait RegressionModel extends Serializable { @@ -48,3 +50,15 @@ trait RegressionModel extends Serializable { def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } + +private[mllib] object RegressionModel { + + /** + * Helper method for loading GLM regression model metadata. + * @return numFeatures + */ + def getNumFeatures(metadata: JValue): Int = { + implicit val formats = DefaultFormats + (metadata \ "numFeatures").extract[Int] + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 076ba35051c9..309f9af46645 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -17,10 +17,13 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.Experimental -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.optimization._ +import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.regression.impl.GLMRegressionModel +import org.apache.spark.mllib.util.{Loader, Saveable} +import org.apache.spark.rdd.RDD + /** * Regression model trained using RidgeRegression. @@ -32,7 +35,7 @@ class RidgeRegressionModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable { + with RegressionModel with Serializable with Saveable { override protected def predictPoint( dataMatrix: Vector, @@ -40,12 +43,37 @@ class RidgeRegressionModel ( intercept: Double): Double = { weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } + + override def save(sc: SparkContext, path: String): Unit = { + GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) + } + + override protected def formatVersion: String = "1.0" +} + +object RidgeRegressionModel extends Loader[RidgeRegressionModel] { + + override def load(sc: SparkContext, path: String): RidgeRegressionModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future + val classNameV1_0 = "org.apache.spark.mllib.regression.RidgeRegressionModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val numFeatures = RegressionModel.getNumFeatures(metadata) + val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures) + new RidgeRegressionModel(data.weights, data.intercept) + case _ => throw new Exception( + s"RidgeRegressionModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } /** * Train a regression model with L2-regularization using Stochastic Gradient Descent. * This solves the l1-regularized least squares regression formulation - * f(weights) = 1/2n ||A weights-y||^2 + regParam/2 ||weights||^2 + * f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^ * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with * its corresponding right hand side label y. * See also the documentation for the precise formulation. @@ -143,7 +171,7 @@ object RidgeRegressionWithSGD { numIterations: Int, stepSize: Double, regParam: Double): RidgeRegressionModel = { - train(input, numIterations, stepSize, regParam, 0.01) + train(input, numIterations, stepSize, regParam, 1.0) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index b549b7c475fc..cea8f3f47307 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -21,7 +21,9 @@ import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.streaming.api.java.{JavaDStream, JavaPairDStream} import org.apache.spark.streaming.dstream.DStream /** @@ -39,14 +41,14 @@ import org.apache.spark.streaming.dstream.DStream * * For example usage, see `StreamingLinearRegressionWithSGD`. * - * NOTE(Freeman): In some use cases, the order in which trainOn and predictOn + * NOTE: In some use cases, the order in which trainOn and predictOn * are called in an application will affect the results. When called on * the same DStream, if trainOn is called before predictOn, when new data * arrive the model will update and the prediction will be based on the new * model. Whereas if predictOn is called first, the prediction will use the model * from the previous update. * - * NOTE(Freeman): It is ok to call predictOn repeatedly on multiple streams; this + * NOTE: It is ok to call predictOn repeatedly on multiple streams; this * will generate predictions for each one all using the current model. * It is also ok to call trainOn on different streams; this will update * the model using each of the different sources, in sequence. @@ -58,14 +60,14 @@ abstract class StreamingLinearAlgorithm[ A <: GeneralizedLinearAlgorithm[M]] extends Logging { /** The model to be updated and used for prediction. */ - protected var model: M + protected var model: Option[M] /** The algorithm to use for updating. */ protected val algorithm: A /** Return the latest model. */ def latestModel(): M = { - model + model.get } /** @@ -76,22 +78,32 @@ abstract class StreamingLinearAlgorithm[ * * @param data DStream containing labeled data */ - def trainOn(data: DStream[LabeledPoint]) { - if (Option(model.weights) == None) { - logError("Initial weights must be set before starting training") - throw new IllegalArgumentException + def trainOn(data: DStream[LabeledPoint]): Unit = { + if (model.isEmpty) { + throw new IllegalArgumentException("Model must be initialized before starting training.") } data.foreachRDD { (rdd, time) => - model = algorithm.run(rdd, model.weights) - logInfo("Model updated at time %s".format(time.toString)) - val display = model.weights.size match { - case x if x > 100 => model.weights.toArray.take(100).mkString("[", ",", "...") - case _ => model.weights.toArray.mkString("[", ",", "]") + val initialWeights = + model match { + case Some(m) => + m.weights + case None => + val numFeatures = rdd.first().features.size + Vectors.dense(numFeatures) } - logInfo("Current model: weights, %s".format (display)) + model = Some(algorithm.run(rdd, initialWeights)) + logInfo("Model updated at time %s".format(time.toString)) + val display = model.get.weights.size match { + case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...") + case _ => model.get.weights.toArray.mkString("[", ",", "]") + } + logInfo("Current model: weights, %s".format (display)) } } + /** Java-friendly version of `trainOn`. */ + def trainOn(data: JavaDStream[LabeledPoint]): Unit = trainOn(data.dstream) + /** * Use the model to make predictions on batches of data from a DStream * @@ -99,12 +111,15 @@ abstract class StreamingLinearAlgorithm[ * @return DStream containing predictions */ def predictOn(data: DStream[Vector]): DStream[Double] = { - if (Option(model.weights) == None) { - val msg = "Initial weights must be set before starting prediction" - logError(msg) - throw new IllegalArgumentException(msg) + if (model.isEmpty) { + throw new IllegalArgumentException("Model must be initialized before starting prediction.") } - data.map(model.predict) + data.map{x => model.get.predict(x)} + } + + /** Java-friendly version of `predictOn`. */ + def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Double] = { + JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Double]]) } /** @@ -114,11 +129,17 @@ abstract class StreamingLinearAlgorithm[ * @return DStream containing the input keys and the predictions as values */ def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Double)] = { - if (Option(model.weights) == None) { - val msg = "Initial weights must be set before starting prediction" - logError(msg) - throw new IllegalArgumentException(msg) + if (model.isEmpty) { + throw new IllegalArgumentException("Model must be initialized before starting prediction") } - data.mapValues(model.predict) + data.mapValues{x => model.get.predict(x)} + } + + + /** Java-friendly version of `predictOnValues`. */ + def predictOnValues[K](data: JavaPairDStream[K, Vector]): JavaPairDStream[K, java.lang.Double] = { + implicit val tag = fakeClassTag[K] + JavaPairDStream.fromPairDStream( + predictOnValues(data.dstream).asInstanceOf[DStream[(K, java.lang.Double)]]) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index 1d11fde24712..a49153bf73c0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -21,6 +21,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.Vector /** + * :: Experimental :: * Train or predict a linear regression model on streaming data. Training uses * Stochastic Gradient Descent to update the model based on each new batch of * incoming data from a DStream (see `LinearRegressionWithSGD` for model equation) @@ -41,13 +42,12 @@ import org.apache.spark.mllib.linalg.Vector * */ @Experimental -class StreamingLinearRegressionWithSGD ( +class StreamingLinearRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, - private var miniBatchFraction: Double, - private var initialWeights: Vector) - extends StreamingLinearAlgorithm[ - LinearRegressionModel, LinearRegressionWithSGD] with Serializable { + private var miniBatchFraction: Double) + extends StreamingLinearAlgorithm[LinearRegressionModel, LinearRegressionWithSGD] + with Serializable { /** * Construct a StreamingLinearRegression object with default parameters: @@ -55,11 +55,11 @@ class StreamingLinearRegressionWithSGD ( * Initial weights must be set before using trainOn or predictOn * (see `StreamingLinearAlgorithm`) */ - def this() = this(0.1, 50, 1.0, null) + def this() = this(0.1, 50, 1.0) val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction) - var model = algorithm.createModel(initialWeights, 0.0) + protected var model: Option[LinearRegressionModel] = None /** Set the step size for gradient descent. Default: 0.1. */ def setStepSize(stepSize: Double): this.type = { @@ -81,7 +81,7 @@ class StreamingLinearRegressionWithSGD ( /** Set the initial weights. Default: [0.0, 0.0]. */ def setInitialWeights(initialWeights: Vector): this.type = { - this.model = algorithm.createModel(initialWeights, 0.0) + this.model = Some(algorithm.createModel(initialWeights, 0.0)) this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala new file mode 100644 index 000000000000..b55944f74f62 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression.impl + +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.util.Loader +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +/** + * Helper methods for import/export of GLM regression models. + */ +private[regression] object GLMRegressionModel { + + object SaveLoadV1_0 { + + def thisFormatVersion: String = "1.0" + + /** Model data for model import/export */ + case class Data(weights: Vector, intercept: Double) + + /** + * Helper method for saving GLM regression model metadata and data. + * @param modelClass String name for model class, to be saved with metadata + */ + def save( + sc: SparkContext, + path: String, + modelClass: String, + weights: Vector, + intercept: Double): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Create JSON metadata. + val metadata = compact(render( + ("class" -> modelClass) ~ ("version" -> thisFormatVersion) ~ + ("numFeatures" -> weights.size))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val data = Data(weights, intercept) + val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() + // TODO: repartition with 1 partition after SPARK-5532 gets fixed + dataRDD.saveAsParquetFile(Loader.dataPath(path)) + } + + /** + * Helper method for loading GLM regression model data. + * @param modelClass String name for model class (used for error messages) + * @param numFeatures Number of features, to be checked against loaded data. + * The length of the weights vector should equal numFeatures. + */ + def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = { + val datapath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + val dataRDD = sqlContext.parquetFile(datapath) + val dataArray = dataRDD.select("weights", "intercept").take(1) + assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") + val data = dataArray(0) + assert(data.size == 2, s"Unable to load $modelClass data from: $datapath") + data match { + case Row(weights: Vector, intercept: Double) => + assert(weights.size == numFeatures, s"Expected $numFeatures features, but" + + s" found ${weights.size} features when loading $modelClass weights from $datapath") + Data(weights, intercept) + } + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala new file mode 100644 index 000000000000..0deef11b4511 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import org.apache.spark.rdd.RDD + +private[stat] object KernelDensity { + /** + * Given a set of samples from a distribution, estimates its density at the set of given points. + * Uses a Gaussian kernel with the given standard deviation. + */ + def estimate(samples: RDD[Double], standardDeviation: Double, + evaluationPoints: Array[Double]): Array[Double] = { + if (standardDeviation <= 0.0) { + throw new IllegalArgumentException("Standard deviation must be positive") + } + + // This gets used in each Gaussian PDF computation, so compute it up front + val logStandardDeviationPlusHalfLog2Pi = + Math.log(standardDeviation) + 0.5 * Math.log(2 * Math.PI) + + val (points, count) = samples.aggregate((new Array[Double](evaluationPoints.length), 0))( + (x, y) => { + var i = 0 + while (i < evaluationPoints.length) { + x._1(i) += normPdf(y, standardDeviation, logStandardDeviationPlusHalfLog2Pi, + evaluationPoints(i)) + i += 1 + } + (x._1, i) + }, + (x, y) => { + var i = 0 + while (i < evaluationPoints.length) { + x._1(i) += y._1(i) + i += 1 + } + (x._1, x._2 + y._2) + }) + + var i = 0 + while (i < points.length) { + points(i) /= count + i += 1 + } + points + } + + private def normPdf(mean: Double, standardDeviation: Double, + logStandardDeviationPlusHalfLog2Pi: Double, x: Double): Double = { + val x0 = x - mean + val x1 = x0 / standardDeviation + val logDensity = -0.5 * x1 * x1 - logStandardDeviationPlusHalfLog2Pi + Math.exp(logDensity) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index 3cf4e807b4cf..32561620ac91 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -26,36 +26,32 @@ import org.apache.spark.mllib.stat.test.{ChiSqTest, ChiSqTestResult} import org.apache.spark.rdd.RDD /** + * :: Experimental :: * API for statistical functions in MLlib. */ @Experimental object Statistics { /** - * :: Experimental :: * Computes column-wise summary statistics for the input RDD[Vector]. * * @param X an RDD[Vector] for which column-wise summary statistics are to be computed. * @return [[MultivariateStatisticalSummary]] object containing column-wise summary statistics. */ - @Experimental def colStats(X: RDD[Vector]): MultivariateStatisticalSummary = { new RowMatrix(X).computeColumnSummaryStatistics() } /** - * :: Experimental :: * Compute the Pearson correlation matrix for the input RDD of Vectors. * Columns with 0 covariance produce NaN entries in the correlation matrix. * * @param X an RDD[Vector] for which the correlation matrix is to be computed. * @return Pearson correlation matrix comparing columns in X. */ - @Experimental def corr(X: RDD[Vector]): Matrix = Correlations.corrMatrix(X) /** - * :: Experimental :: * Compute the correlation matrix for the input RDD of Vectors using the specified method. * Methods currently supported: `pearson` (default), `spearman`. * @@ -69,11 +65,9 @@ object Statistics { * Supported: `pearson` (default), `spearman` * @return Correlation matrix comparing columns in X. */ - @Experimental def corr(X: RDD[Vector], method: String): Matrix = Correlations.corrMatrix(X, method) /** - * :: Experimental :: * Compute the Pearson correlation for the input RDDs. * Returns NaN if either vector has 0 variance. * @@ -84,11 +78,9 @@ object Statistics { * @param y RDD[Double] of the same cardinality as x. * @return A Double containing the Pearson correlation between the two input RDD[Double]s */ - @Experimental def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) /** - * :: Experimental :: * Compute the correlation for the input RDDs using the specified method. * Methods currently supported: `pearson` (default), `spearman`. * @@ -99,14 +91,12 @@ object Statistics { * @param y RDD[Double] of the same cardinality as x. * @param method String specifying the method to use for computing correlation. * Supported: `pearson` (default), `spearman` - *@return A Double containing the correlation between the two input RDD[Double]s using the + * @return A Double containing the correlation between the two input RDD[Double]s using the * specified method. */ - @Experimental def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) /** - * :: Experimental :: * Conduct Pearson's chi-squared goodness of fit test of the observed data against the * expected distribution. * @@ -120,13 +110,11 @@ object Statistics { * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. */ - @Experimental def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = { ChiSqTest.chiSquared(observed, expected) } /** - * :: Experimental :: * Conduct Pearson's chi-squared goodness of fit test of the observed data against the uniform * distribution, with each category having an expected frequency of `1 / observed.size`. * @@ -136,11 +124,9 @@ object Statistics { * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. */ - @Experimental def chiSqTest(observed: Vector): ChiSqTestResult = ChiSqTest.chiSquared(observed) /** - * :: Experimental :: * Conduct Pearson's independence test on the input contingency matrix, which cannot contain * negative entries or columns or rows that sum up to 0. * @@ -148,11 +134,9 @@ object Statistics { * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. */ - @Experimental def chiSqTest(observed: Matrix): ChiSqTestResult = ChiSqTest.chiSquaredMatrix(observed) /** - * :: Experimental :: * Conduct Pearson's independence test for every feature against the label across the input RDD. * For each feature, the (feature, label) pairs are converted into a contingency matrix for which * the chi-squared statistic is computed. All label and feature values must be categorical. @@ -162,8 +146,21 @@ object Statistics { * @return an array containing the ChiSquaredTestResult for every feature against the label. * The order of the elements in the returned array reflects the order of input features. */ - @Experimental def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = { ChiSqTest.chiSquaredFeatures(data) } + + /** + * Given an empirical distribution defined by the input RDD of samples, estimate its density at + * each of the given evaluation points using a Gaussian kernel. + * + * @param samples The samples RDD used to define the empirical distribution. + * @param standardDeviation The standard deviation of the kernel Gaussians. + * @param evaluationPoints The points at which to estimate densities. + * @return An array the same size as evaluationPoints with the density at each point. + */ + def kernelDensity(samples: RDD[Double], standardDeviation: Double, + evaluationPoints: Iterable[Double]): Array[Double] = { + KernelDensity.estimate(samples, standardDeviation, evaluationPoints.toArray) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index fd186b5ee6f7..cd6add9d60b0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.stat.distribution -import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym} +import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym, Vector => BV} import org.apache.spark.annotation.DeveloperApi; import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix} @@ -62,21 +62,21 @@ class MultivariateGaussian ( /** Returns density of this multivariate Gaussian at given point, x */ def pdf(x: Vector): Double = { - pdf(x.toBreeze.toDenseVector) + pdf(x.toBreeze) } /** Returns the log-density of this multivariate Gaussian at given point, x */ def logpdf(x: Vector): Double = { - logpdf(x.toBreeze.toDenseVector) + logpdf(x.toBreeze) } /** Returns density of this multivariate Gaussian at given point, x */ - private[mllib] def pdf(x: DBV[Double]): Double = { + private[mllib] def pdf(x: BV[Double]): Double = { math.exp(logpdf(x)) } /** Returns the log-density of this multivariate Gaussian at given point, x */ - private[mllib] def logpdf(x: DBV[Double]): Double = { + private[mllib] def logpdf(x: BV[Double]): Double = { val delta = x - breezeMu val v = rootSigmaInv * delta u + v.t * v * -0.5 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b3e8ed9af8c5..dfe3a0b6913e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -19,12 +19,11 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - +import scala.collection.mutable.ArrayBuilder +import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD -import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo import org.apache.spark.mllib.tree.configuration.Strategy @@ -32,13 +31,10 @@ import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.impl._ -import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity} import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom -import org.apache.spark.SparkContext._ - /** * :: Experimental :: @@ -331,14 +327,14 @@ object DecisionTree extends Serializable with Logging { * @param agg Array storing aggregate calculation, with a set of sufficient statistics for * each (feature, bin). * @param treePoint Data point being aggregated. - * @param bins possible bins for all features, indexed (numFeatures)(numBins) + * @param splits possible splits indexed (numFeatures)(numSplits) * @param unorderedFeatures Set of indices of unordered features. * @param instanceWeight Weight (importance) of instance in dataset. */ private def mixedBinSeqOp( agg: DTStatsAggregator, treePoint: TreePoint, - bins: Array[Array[Bin]], + splits: Array[Array[Split]], unorderedFeatures: Set[Int], instanceWeight: Double, featuresForNode: Option[Array[Int]]): Unit = { @@ -366,7 +362,7 @@ object DecisionTree extends Serializable with Logging { val numSplits = agg.metadata.numSplits(featureIndex) var splitIndex = 0 while (splitIndex < numSplits) { - if (bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) { + if (splits(featureIndex)(splitIndex).categories.contains(featureValue)) { agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) } else { @@ -510,8 +506,8 @@ object DecisionTree extends Serializable with Logging { if (metadata.unorderedFeatures.isEmpty) { orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) } else { - mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures, - instanceWeight, featuresForNode) + mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits, + metadata.unorderedFeatures, instanceWeight, featuresForNode) } } } @@ -1028,35 +1024,15 @@ object DecisionTree extends Serializable with Logging { // Categorical feature val featureArity = metadata.featureArity(featureIndex) if (metadata.isUnordered(featureIndex)) { - // TODO: The second half of the bins are unused. Actually, we could just use - // splits and not build bins for unordered features. That should be part of - // a later PR since it will require changing other code (using splits instead - // of bins in a few places). // Unordered features - // 2^(maxFeatureValue - 1) - 1 combinations + // 2^(maxFeatureValue - 1) - 1 combinations splits(featureIndex) = new Array[Split](numSplits) - bins(featureIndex) = new Array[Bin](numBins) var splitIndex = 0 while (splitIndex < numSplits) { val categories: List[Double] = extractMultiClassCategories(splitIndex + 1, featureArity) splits(featureIndex)(splitIndex) = new Split(featureIndex, Double.MinValue, Categorical, categories) - bins(featureIndex)(splitIndex) = { - if (splitIndex == 0) { - new Bin( - new DummyCategoricalSplit(featureIndex, Categorical), - splits(featureIndex)(0), - Categorical, - Double.MinValue) - } else { - new Bin( - splits(featureIndex)(splitIndex - 1), - splits(featureIndex)(splitIndex), - Categorical, - Double.MinValue) - } - } splitIndex += 1 } } else { @@ -1064,8 +1040,11 @@ object DecisionTree extends Serializable with Logging { // Bins correspond to feature values, so we do not need to compute splits or bins // beforehand. Splits are constructed as needed during training. splits(featureIndex) = new Array[Split](0) - bins(featureIndex) = new Array[Bin](0) } + // For ordered features, bins correspond to feature values. + // For unordered categorical features, there is no need to construct the bins. + // since there is a one-to-one correspondence between the splits and the bins. + bins(featureIndex) = new Array[Bin](0) } featureIndex += 1 } @@ -1140,7 +1119,7 @@ object DecisionTree extends Serializable with Logging { logDebug("stride = " + stride) // iterate `valueCount` to find splits - val splits = new ArrayBuffer[Double] + val splitsBuilder = ArrayBuilder.make[Double] var index = 1 // currentCount: sum of counts of values that have been visited var currentCount = valueCounts(0)._2 @@ -1158,17 +1137,20 @@ object DecisionTree extends Serializable with Logging { // makes the gap between currentCount and targetCount smaller, // previous value is a split threshold. if (previousGap < currentGap) { - splits.append(valueCounts(index - 1)._1) + splitsBuilder += valueCounts(index - 1)._1 targetCount += stride } index += 1 } - splits.toArray + splitsBuilder.result() } } - assert(splits.length > 0) + // TODO: Do not fail; just ignore the useless feature. + assert(splits.length > 0, + s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." + + " Please remove this feature and then try again.") // set number of splits accordingly metadata.setNumSplits(featureIndex, splits.length) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 61f6b1313f82..0e31c7ed58df 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -60,11 +60,12 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { val algo = boostingStrategy.treeStrategy.algo algo match { - case Regression => GradientBoostedTrees.boost(input, boostingStrategy) + case Regression => GradientBoostedTrees.boost(input, input, boostingStrategy, validate=false) case Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - GradientBoostedTrees.boost(remappedInput, boostingStrategy) + GradientBoostedTrees.boost(remappedInput, + remappedInput, boostingStrategy, validate=false) case _ => throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") } @@ -76,8 +77,46 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { run(input.rdd) } -} + /** + * Method to validate a gradient boosting model + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param validationInput Validation dataset. + * This dataset should be different from the training dataset, + * but it should follow the same distribution. + * E.g., these two datasets could be created from an original dataset + * by using [[org.apache.spark.rdd.RDD.randomSplit()]] + * @return a gradient boosted trees model that can be used for prediction + */ + def runWithValidation( + input: RDD[LabeledPoint], + validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = { + val algo = boostingStrategy.treeStrategy.algo + algo match { + case Regression => GradientBoostedTrees.boost( + input, validationInput, boostingStrategy, validate=true) + case Classification => + // Map labels to -1, +1 so binary classification can be treated as regression. + val remappedInput = input.map( + x => new LabeledPoint((x.label * 2) - 1, x.features)) + val remappedValidationInput = validationInput.map( + x => new LabeledPoint((x.label * 2) - 1, x.features)) + GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy, + validate=true) + case _ => + throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") + } + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#runWithValidation]]. + */ + def runWithValidation( + input: JavaRDD[LabeledPoint], + validationInput: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { + runWithValidation(input.rdd, validationInput.rdd) + } +} object GradientBoostedTrees extends Logging { @@ -108,13 +147,16 @@ object GradientBoostedTrees extends Logging { /** * Internal method for performing regression using trees as base learners. * @param input training dataset + * @param validationInput validation dataset, ignored if validate is set to false. * @param boostingStrategy boosting parameters + * @param validate whether or not to use the validation dataset. * @return a gradient boosted trees model that can be used for prediction */ private def boost( input: RDD[LabeledPoint], - boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { - + validationInput: RDD[LabeledPoint], + boostingStrategy: BoostingStrategy, + validate: Boolean): GradientBoostedTreesModel = { val timer = new TimeTracker() timer.start("total") timer.start("init") @@ -129,6 +171,7 @@ object GradientBoostedTrees extends Logging { val learningRate = boostingStrategy.learningRate // Prepare strategy for individual trees, which use regression with variance impurity. val treeStrategy = boostingStrategy.treeStrategy.copy + val validationTol = boostingStrategy.validationTol treeStrategy.algo = Regression treeStrategy.impurity = Variance treeStrategy.assertValid() @@ -148,16 +191,26 @@ object GradientBoostedTrees extends Logging { // Initialize tree timer.start("building tree 0") val firstTreeModel = new DecisionTree(treeStrategy).run(data) + val firstTreeWeight = 1.0 baseLearners(0) = firstTreeModel - baseLearnerWeights(0) = 1.0 - val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0)) - logDebug("error of gbt = " + loss.computeError(startingModel, input)) + baseLearnerWeights(0) = firstTreeWeight + + var predError: RDD[(Double, Double)] = GradientBoostedTreesModel. + computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss) + logDebug("error of gbt = " + predError.values.mean()) + // Note: A model of type regression is used since we require raw prediction timer.stop("building tree 0") - // psuedo-residual for second iteration - data = input.map(point => LabeledPoint(loss.gradient(startingModel, point), - point.features)) + var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel. + computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss) + var bestValidateError = if (validate) validatePredError.values.mean() else 0.0 + var bestM = 1 + + // pseudo-residual for second iteration + data = predError.zip(input).map { case ((pred, _), point) => + LabeledPoint(-loss.gradient(pred, point.label), point.features) + } var m = 1 while (m < numIterations) { @@ -175,11 +228,36 @@ object GradientBoostedTrees extends Logging { baseLearnerWeights(m) = learningRate // Note: A model of type regression is used since we require raw prediction val partialModel = new GradientBoostedTreesModel( - Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1)) - logDebug("error of gbt = " + loss.computeError(partialModel, input)) + Regression, baseLearners.slice(0, m + 1), + baseLearnerWeights.slice(0, m + 1)) + + predError = GradientBoostedTreesModel.updatePredictionError( + input, predError, baseLearnerWeights(m), baseLearners(m), loss) + logDebug("error of gbt = " + predError.values.mean()) + + if (validate) { + // Stop training early if + // 1. Reduction in error is less than the validationTol or + // 2. If the error increases, that is if the model is overfit. + // We want the model returned corresponding to the best validation error. + + validatePredError = GradientBoostedTreesModel.updatePredictionError( + validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss) + val currentValidateError = validatePredError.values.mean() + if (bestValidateError - currentValidateError < validationTol) { + return new GradientBoostedTreesModel( + boostingStrategy.treeStrategy.algo, + baseLearners.slice(0, bestM), + baseLearnerWeights.slice(0, bestM)) + } else if (currentValidateError < bestValidateError) { + bestValidateError = currentValidateError + bestM = m + 1 + } + } // Update data with pseudo-residuals - data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point), - point.features)) + data = predError.zip(input).map { case ((pred, _), point) => + LabeledPoint(-loss.gradient(pred, point.label), point.features) + } m += 1 } @@ -187,8 +265,15 @@ object GradientBoostedTrees extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") - - new GradientBoostedTreesModel( - boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights) + if (validate) { + new GradientBoostedTreesModel( + boostingStrategy.treeStrategy.algo, + baseLearners.slice(0, bestM), + baseLearnerWeights.slice(0, bestM)) + } else { + new GradientBoostedTreesModel( + boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights) + } } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index e9304b5e5c65..055e60c7d9c9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.tree +import java.io.IOException + import scala.collection.mutable import scala.collection.JavaConverters._ @@ -140,6 +142,7 @@ private class RandomForest ( logDebug("maxBins = " + metadata.maxBins) logDebug("featureSubsetStrategy = " + featureSubsetStrategy) logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode) + logDebug("subsamplingRate = " + strategy.subsamplingRate) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. @@ -155,19 +158,12 @@ private class RandomForest ( // Cache input RDD for speedup during multiple passes. val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) - val (subsample, withReplacement) = { - // TODO: Have a stricter check for RF in the strategy - val isRandomForest = numTrees > 1 - if (isRandomForest) { - (1.0, true) - } else { - (strategy.subsamplingRate, false) - } - } + val withReplacement = if (numTrees > 1) true else false val baggedInput - = BaggedPoint.convertToBaggedRDD(treeInput, subsample, numTrees, withReplacement, seed) - .persist(StorageLevel.MEMORY_AND_DISK) + = BaggedPoint.convertToBaggedRDD(treeInput, + strategy.subsamplingRate, numTrees, + withReplacement, seed).persist(StorageLevel.MEMORY_AND_DISK) // depth of the decision tree val maxDepth = strategy.maxDepth @@ -208,7 +204,6 @@ private class RandomForest ( Some(NodeIdCache.init( data = baggedInput, numTrees = numTrees, - checkpointDir = strategy.checkpointDir, checkpointInterval = strategy.checkpointInterval, initVal = 1)) } else { @@ -250,7 +245,12 @@ private class RandomForest ( // Delete any remaining checkpoints used for node Id cache. if (nodeIdCache.nonEmpty) { - nodeIdCache.get.deleteAllCheckpoints() + try { + nodeIdCache.get.deleteAllCheckpoints() + } catch { + case e:IOException => + logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}") + } } val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala index 0ef9c6181a0a..b6099259971b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala @@ -29,8 +29,8 @@ object Algo extends Enumeration { val Classification, Regression = Value private[mllib] def fromString(name: String): Algo = name match { - case "classification" => Classification - case "regression" => Regression + case "classification" | "Classification" => Classification + case "regression" | "Regression" => Regression case _ => throw new IllegalArgumentException(s"Did not recognize Algo name: $name") } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index ed8e6a796f8c..2d6b01524ff3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -34,6 +34,9 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} * weak hypotheses used in the final model. * @param learningRate Learning rate for shrinking the contribution of each estimator. The * learning rate should be between in the interval (0, 1] + * @param validationTol Useful when runWithValidation is used. If the error rate on the + * validation input between two iterations is less than the validationTol + * then stop. Ignored when [[run]] is used. */ @Experimental case class BoostingStrategy( @@ -42,7 +45,8 @@ case class BoostingStrategy( @BeanProperty var loss: Loss, // Optional boosting parameters @BeanProperty var numIterations: Int = 100, - @BeanProperty var learningRate: Double = 0.1) extends Serializable { + @BeanProperty var learningRate: Double = 0.1, + @BeanProperty var validationTol: Double = 1e-5) extends Serializable { /** * Check validity of parameters. @@ -85,14 +89,14 @@ object BoostingStrategy { * @return Configuration for boosting algorithm */ def defaultParams(algo: Algo): BoostingStrategy = { - val treeStragtegy = Strategy.defaultStategy(algo) - treeStragtegy.maxDepth = 3 + val treeStrategy = Strategy.defaultStategy(algo) + treeStrategy.maxDepth = 3 algo match { case Algo.Classification => - treeStragtegy.numClasses = 2 - new BoostingStrategy(treeStragtegy, LogLoss) + treeStrategy.numClasses = 2 + new BoostingStrategy(treeStrategy, LogLoss) case Algo.Regression => - new BoostingStrategy(treeStragtegy, SquaredError) + new BoostingStrategy(treeStrategy, SquaredError) case _ => throw new IllegalArgumentException(s"$algo is not supported by boosting.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 972959885f39..ada227c200a7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -62,11 +62,10 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * @param subsamplingRate Fraction of the training data used for learning decision tree. * @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will * maintain a separate RDD of node Id cache for each row. - * @param checkpointDir If the node Id cache is used, it will help to checkpoint - * the node Id cache periodically. This is the checkpoint directory - * to be used for the node Id cache. * @param checkpointInterval How often to checkpoint when the node Id cache gets updated. - * E.g. 10 means that the cache will get checkpointed every 10 updates. + * E.g. 10 means that the cache will get checkpointed every 10 updates. If + * the checkpoint directory is not set in + * [[org.apache.spark.SparkContext]], this setting is ignored. */ @Experimental class Strategy ( @@ -82,13 +81,15 @@ class Strategy ( @BeanProperty var maxMemoryInMB: Int = 256, @BeanProperty var subsamplingRate: Double = 1, @BeanProperty var useNodeIdCache: Boolean = false, - @BeanProperty var checkpointDir: Option[String] = None, @BeanProperty var checkpointInterval: Int = 10) extends Serializable { - def isMulticlassClassification = + def isMulticlassClassification: Boolean = { algo == Classification && numClasses > 2 - def isMulticlassWithCategoricalFeatures - = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) + } + + def isMulticlassWithCategoricalFeatures: Boolean = { + isMulticlassClassification && (categoricalFeaturesInfo.size > 0) + } /** * Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]] @@ -156,13 +157,16 @@ class Strategy ( s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") require(maxMemoryInMB <= 10240, s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") + require(subsamplingRate > 0 && subsamplingRate <= 1, + s"DecisionTree Strategy requires subsamplingRate <=1 and >0, but was given " + + s"$subsamplingRate") } /** Returns a shallow copy of this instance. */ def copy: Strategy = { new Strategy(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, - maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointDir, checkpointInterval) + maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 951733fada6b..f1a6ed230186 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -183,7 +183,7 @@ private[tree] object DecisionTreeMetadata extends Logging { } /** - * Version of [[buildMetadata()]] for DecisionTree. + * Version of [[DecisionTreeMetadata#buildMetadata]] for DecisionTree. */ def buildMetadata( input: RDD[LabeledPoint], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala index 83011b48b7d9..bdd0f576b048 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala @@ -71,15 +71,12 @@ private[tree] case class NodeIndexUpdater( * The nodeIdsForInstances RDD needs to be updated at each iteration. * @param nodeIdsForInstances The initial values in the cache * (should be an Array of all 1's (meaning the root nodes)). - * @param checkpointDir The checkpoint directory where - * the checkpointed files will be stored. * @param checkpointInterval The checkpointing interval * (how often should the cache be checkpointed.). */ @DeveloperApi private[tree] class NodeIdCache( var nodeIdsForInstances: RDD[Array[Int]], - val checkpointDir: Option[String], val checkpointInterval: Int) { // Keep a reference to a previous node Ids for instances. @@ -91,12 +88,6 @@ private[tree] class NodeIdCache( private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]() private var rddUpdateCount = 0 - // If a checkpoint directory is given, and there's no prior checkpoint directory, - // then set the checkpoint directory with the given one. - if (checkpointDir.nonEmpty && nodeIdsForInstances.sparkContext.getCheckpointDir.isEmpty) { - nodeIdsForInstances.sparkContext.setCheckpointDir(checkpointDir.get) - } - /** * Update the node index values in the cache. * This updates the RDD and its lineage. @@ -184,7 +175,6 @@ private[tree] object NodeIdCache { * Initialize the node Id cache with initial node Id values. * @param data The RDD of training rows. * @param numTrees The number of trees that we want to create cache for. - * @param checkpointDir The checkpoint directory where the checkpointed files will be stored. * @param checkpointInterval The checkpointing interval * (how often should the cache be checkpointed.). * @param initVal The initial values in the cache. @@ -193,12 +183,10 @@ private[tree] object NodeIdCache { def init( data: RDD[BaggedPoint[TreePoint]], numTrees: Int, - checkpointDir: Option[String], checkpointInterval: Int, initVal: Int = 1): NodeIdCache = { new NodeIdCache( data.map(_ => Array.fill[Int](numTrees)(initVal)), - checkpointDir, checkpointInterval) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala index 35e361ae309c..50b292e71b06 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala @@ -55,17 +55,15 @@ private[tree] object TreePoint { input: RDD[LabeledPoint], bins: Array[Array[Bin]], metadata: DecisionTreeMetadata): RDD[TreePoint] = { - // Construct arrays for featureArity and isUnordered for efficiency in the inner loop. + // Construct arrays for featureArity for efficiency in the inner loop. val featureArity: Array[Int] = new Array[Int](metadata.numFeatures) - val isUnordered: Array[Boolean] = new Array[Boolean](metadata.numFeatures) var featureIndex = 0 while (featureIndex < metadata.numFeatures) { featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0) - isUnordered(featureIndex) = metadata.isUnordered(featureIndex) featureIndex += 1 } input.map { x => - TreePoint.labeledPointToTreePoint(x, bins, featureArity, isUnordered) + TreePoint.labeledPointToTreePoint(x, bins, featureArity) } } @@ -74,19 +72,17 @@ private[tree] object TreePoint { * @param bins Bins for features, of size (numFeatures, numBins). * @param featureArity Array indexed by feature, with value 0 for continuous and numCategories * for categorical features. - * @param isUnordered Array index by feature, with value true for unordered categorical features. */ private def labeledPointToTreePoint( labeledPoint: LabeledPoint, bins: Array[Array[Bin]], - featureArity: Array[Int], - isUnordered: Array[Boolean]): TreePoint = { + featureArity: Array[Int]): TreePoint = { val numFeatures = labeledPoint.features.size val arr = new Array[Int](numFeatures) var featureIndex = 0 while (featureIndex < numFeatures) { arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex), - isUnordered(featureIndex), bins) + bins) featureIndex += 1 } new TreePoint(labeledPoint.label, arr) @@ -96,14 +92,12 @@ private[tree] object TreePoint { * Find bin for one (labeledPoint, feature). * * @param featureArity 0 for continuous features; number of categories for categorical features. - * @param isUnorderedFeature (only applies if feature is categorical) * @param bins Bins for features, of size (numFeatures, numBins). */ private def findBin( featureIndex: Int, labeledPoint: LabeledPoint, featureArity: Int, - isUnorderedFeature: Boolean, bins: Array[Array[Bin]]): Int = { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 0e02345aa377..5ac10f3fd32d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -71,7 +71,7 @@ object Entropy extends Impurity { * Get this impurity instance. * This is useful for passing impurity parameters to a Strategy in Java. */ - def instance = this + def instance: this.type = this } @@ -94,6 +94,10 @@ private[tree] class EntropyAggregator(numClasses: Int) throw new IllegalArgumentException(s"EntropyAggregator given label $label" + s" but requires label < numClasses (= $statsSize).") } + if (label < 0) { + throw new IllegalArgumentException(s"EntropyAggregator given label $label" + + s"but requires label is non-negative.") + } allStats(offset + label.toInt) += instanceWeight } @@ -147,6 +151,7 @@ private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalc val lbl = label.toInt require(lbl < stats.length, s"EntropyCalculator.prob given invalid label: $lbl (should be < ${stats.length}") + require(lbl >= 0, "Entropy does not support negative labels") val cnt = count if (cnt == 0) { 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 7c83cd48e16a..19d318203c34 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -67,7 +67,7 @@ object Gini extends Impurity { * Get this impurity instance. * This is useful for passing impurity parameters to a Strategy in Java. */ - def instance = this + def instance: this.type = this } @@ -90,6 +90,10 @@ private[tree] class GiniAggregator(numClasses: Int) throw new IllegalArgumentException(s"GiniAggregator given label $label" + s" but requires label < numClasses (= $statsSize).") } + if (label < 0) { + throw new IllegalArgumentException(s"GiniAggregator given label $label" + + s"but requires label is non-negative.") + } allStats(offset + label.toInt) += instanceWeight } @@ -143,6 +147,7 @@ private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalcula val lbl = label.toInt require(lbl < stats.length, s"GiniCalculator.prob given invalid label: $lbl (should be < ${stats.length}") + require(lbl >= 0, "GiniImpurity does not support negative labels") val cnt = count if (cnt == 0) { 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index df9eafa5da16..7104a7fa4dd4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -58,7 +58,7 @@ object Variance extends Impurity { * Get this impurity instance. * This is useful for passing impurity parameters to a Strategy in Java. */ - def instance = this + def instance: this.type = this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala index d1bde15e6b15..2bdef73c4a8f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel -import org.apache.spark.rdd.RDD + /** * :: DeveloperApi :: @@ -37,28 +37,16 @@ object AbsoluteError extends Loss { * Method to calculate the gradients for the gradient boosting calculation for least * absolute error calculation. * The gradient with respect to F(x) is: sign(F(x) - y) - * @param model Ensemble model - * @param point Instance of the training dataset + * @param prediction Predicted label. + * @param label True label. * @return Loss gradient */ - override def gradient( - model: TreeEnsembleModel, - point: LabeledPoint): Double = { - if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0 + override def gradient(prediction: Double, label: Double): Double = { + if (label - prediction < 0) 1.0 else -1.0 } - /** - * Method to calculate loss of the base learner for the gradient boosting calculation. - * Note: This method is not used by the gradient boosting algorithm but is useful for debugging - * purposes. - * @param model Ensemble model - * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @return Mean absolute error of model on data - */ - override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { - data.map { y => - val err = model.predict(y.features) - y.label - math.abs(err) - }.mean() + override private[mllib] def computeError(prediction: Double, label: Double): Double = { + val err = label - prediction + math.abs(err) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index 55213e695638..778c24526de7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD + /** * :: DeveloperApi :: @@ -39,31 +39,17 @@ object LogLoss extends Loss { * Method to calculate the loss gradients for the gradient boosting calculation for binary * classification * The gradient with respect to F(x) is: - 4 y / (1 + exp(2 y F(x))) - * @param model Ensemble model - * @param point Instance of the training dataset + * @param prediction Predicted label. + * @param label True label. * @return Loss gradient */ - override def gradient( - model: TreeEnsembleModel, - point: LabeledPoint): Double = { - val prediction = model.predict(point.features) - - 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction)) + override def gradient(prediction: Double, label: Double): Double = { + - 4.0 * label / (1.0 + math.exp(2.0 * label * prediction)) } - /** - * Method to calculate loss of the base learner for the gradient boosting calculation. - * Note: This method is not used by the gradient boosting algorithm but is useful for debugging - * purposes. - * @param model Ensemble model - * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @return Mean log loss of model on data - */ - override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { - data.map { case point => - val prediction = model.predict(point.features) - val margin = 2.0 * point.label * prediction - // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable. - 2.0 * MLUtils.log1pExp(-margin) - }.mean() + override private[mllib] def computeError(prediction: Double, label: Double): Double = { + val margin = 2.0 * label * prediction + // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable. + 2.0 * MLUtils.log1pExp(-margin) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index 4bca9039ebe1..64ffccbce073 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -22,6 +22,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD + /** * :: DeveloperApi :: * Trait for adding "pluggable" loss functions for the gradient boosting algorithm. @@ -31,13 +32,11 @@ trait Loss extends Serializable { /** * Method to calculate the gradients for the gradient boosting calculation. - * @param model Model of the weak learner. - * @param point Instance of the training dataset. + * @param prediction Predicted feature + * @param label true label. * @return Loss gradient. */ - def gradient( - model: TreeEnsembleModel, - point: LabeledPoint): Double + def gradient(prediction: Double, label: Double): Double /** * Method to calculate error of the base learner for the gradient boosting calculation. @@ -45,8 +44,19 @@ trait Loss extends Serializable { * purposes. * @param model Model of the weak learner. * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @return + * @return Measure of model error on data */ - def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double + def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { + data.map(point => computeError(model.predict(point.features), point.label)).mean() + } + /** + * Method to calculate loss when the predictions are already known. + * Note: This method is used in the method evaluateEachIteration to avoid recomputing the + * predicted values from previously fit trees. + * @param prediction Predicted label. + * @param label True label. + * @return Measure of model error on datapoint. + */ + private[mllib] def computeError(prediction: Double, label: Double): Double } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala index 50ecaa2f86f3..a5582d3ef332 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel -import org.apache.spark.rdd.RDD + /** * :: DeveloperApi :: @@ -37,28 +37,16 @@ object SquaredError extends Loss { * Method to calculate the gradients for the gradient boosting calculation for least * squares error calculation. * The gradient with respect to F(x) is: - 2 (y - F(x)) - * @param model Ensemble model - * @param point Instance of the training dataset + * @param prediction Predicted label. + * @param label True label. * @return Loss gradient */ - override def gradient( - model: TreeEnsembleModel, - point: LabeledPoint): Double = { - 2.0 * (model.predict(point.features) - point.label) + override def gradient(prediction: Double, label: Double): Double = { + 2.0 * (prediction - label) } - /** - * Method to calculate loss of the base learner for the gradient boosting calculation. - * Note: This method is not used by the gradient boosting algorithm but is useful for debugging - * purposes. - * @param model Ensemble model - * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @return Mean squared error of model on data - */ - override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { - data.map { y => - val err = model.predict(y.features) - y.label - err * err - }.mean() + override private[mllib] def computeError(prediction: Double, label: Double): Double = { + val err = prediction - label + err * err } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index a5760963068c..331af428533d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -17,11 +17,22 @@ package org.apache.spark.mllib.tree.model +import scala.collection.mutable + +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.configuration.{Algo, FeatureType} import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.util.Utils /** * :: Experimental :: @@ -31,7 +42,7 @@ import org.apache.spark.rdd.RDD * @param algo algorithm type -- classification or regression */ @Experimental -class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable { +class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable with Saveable { /** * Predict values for a single data point using the model trained. @@ -53,7 +64,6 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable features.map(x => predict(x)) } - /** * Predict values for the given data set using the model trained. * @@ -99,4 +109,207 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable header + topNode.subtreeToString(2) } + override def save(sc: SparkContext, path: String): Unit = { + DecisionTreeModel.SaveLoadV1_0.save(sc, path, this) + } + + override protected def formatVersion: String = DecisionTreeModel.formatVersion +} + +object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { + + private[spark] def formatVersion: String = "1.0" + + private[tree] object SaveLoadV1_0 { + + def thisFormatVersion: String = "1.0" + + // Hard-code class name string in case it changes in the future + def thisClassName: String = "org.apache.spark.mllib.tree.DecisionTreeModel" + + case class PredictData(predict: Double, prob: Double) { + def toPredict: Predict = new Predict(predict, prob) + } + + object PredictData { + def apply(p: Predict): PredictData = PredictData(p.predict, p.prob) + + def apply(r: Row): PredictData = PredictData(r.getDouble(0), r.getDouble(1)) + } + + case class SplitData( + feature: Int, + threshold: Double, + featureType: Int, + categories: Seq[Double]) { // TODO: Change to List once SPARK-3365 is fixed + def toSplit: Split = { + new Split(feature, threshold, FeatureType(featureType), categories.toList) + } + } + + object SplitData { + def apply(s: Split): SplitData = { + SplitData(s.feature, s.threshold, s.featureType.id, s.categories) + } + + def apply(r: Row): SplitData = { + SplitData(r.getInt(0), r.getDouble(1), r.getInt(2), r.getAs[Seq[Double]](3)) + } + } + + /** Model data for model import/export */ + case class NodeData( + treeId: Int, + nodeId: Int, + predict: PredictData, + impurity: Double, + isLeaf: Boolean, + split: Option[SplitData], + leftNodeId: Option[Int], + rightNodeId: Option[Int], + infoGain: Option[Double]) + + object NodeData { + def apply(treeId: Int, n: Node): NodeData = { + NodeData(treeId, n.id, PredictData(n.predict), n.impurity, n.isLeaf, + n.split.map(SplitData.apply), n.leftNode.map(_.id), n.rightNode.map(_.id), + n.stats.map(_.gain)) + } + + def apply(r: Row): NodeData = { + val split = if (r.isNullAt(5)) None else Some(SplitData(r.getStruct(5))) + val leftNodeId = if (r.isNullAt(6)) None else Some(r.getInt(6)) + val rightNodeId = if (r.isNullAt(7)) None else Some(r.getInt(7)) + val infoGain = if (r.isNullAt(8)) None else Some(r.getDouble(8)) + NodeData(r.getInt(0), r.getInt(1), PredictData(r.getStruct(2)), r.getDouble(3), + r.getBoolean(4), split, leftNodeId, rightNodeId, infoGain) + } + } + + def save(sc: SparkContext, path: String, model: DecisionTreeModel): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // SPARK-6120: We do a hacky check here so users understand why save() is failing + // when they run the ML guide example. + // TODO: Fix this issue for real. + val memThreshold = 768 + if (sc.isLocal) { + val driverMemory = sc.getConf.getOption("spark.driver.memory") + .orElse(Option(System.getenv("SPARK_DRIVER_MEMORY"))) + .map(Utils.memoryStringToMb) + .getOrElse(512) + if (driverMemory <= memThreshold) { + logWarning(s"$thisClassName.save() was called, but it may fail because of too little" + + s" driver memory (${driverMemory}m)." + + s" If failure occurs, try setting driver-memory ${memThreshold}m (or larger).") + } + } else { + if (sc.executorMemory <= memThreshold) { + logWarning(s"$thisClassName.save() was called, but it may fail because of too little" + + s" executor memory (${sc.executorMemory}m)." + + s" If failure occurs try setting executor-memory ${memThreshold}m (or larger).") + } + } + + // Create JSON metadata. + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("algo" -> model.algo.toString) ~ ("numNodes" -> model.numNodes))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val nodes = model.topNode.subtreeIterator.toSeq + val dataRDD: DataFrame = sc.parallelize(nodes) + .map(NodeData.apply(0, _)) + .toDF() + dataRDD.saveAsParquetFile(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = { + val datapath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + // Load Parquet data. + val dataRDD = sqlContext.parquetFile(datapath) + // Check schema explicitly since erasure makes it hard to use match-case for checking. + Loader.checkSchema[NodeData](dataRDD.schema) + val nodes = dataRDD.map(NodeData.apply) + // Build node data into a tree. + val trees = constructTrees(nodes) + assert(trees.size == 1, + "Decision tree should contain exactly one tree but got ${trees.size} trees.") + val model = new DecisionTreeModel(trees(0), Algo.fromString(algo)) + assert(model.numNodes == numNodes, s"Unable to load DecisionTreeModel data from: $datapath." + + s" Expected $numNodes nodes but found ${model.numNodes}") + model + } + + def constructTrees(nodes: RDD[NodeData]): Array[Node] = { + val trees = nodes + .groupBy(_.treeId) + .mapValues(_.toArray) + .collect() + .map { case (treeId, data) => + (treeId, constructTree(data)) + }.sortBy(_._1) + val numTrees = trees.size + val treeIndices = trees.map(_._1).toSeq + assert(treeIndices == (0 until numTrees), + s"Tree indices must start from 0 and increment by 1, but we found $treeIndices.") + trees.map(_._2) + } + + /** + * Given a list of nodes from a tree, construct the tree. + * @param data array of all node data in a tree. + */ + def constructTree(data: Array[NodeData]): Node = { + val dataMap: Map[Int, NodeData] = data.map(n => n.nodeId -> n).toMap + assert(dataMap.contains(1), + s"DecisionTree missing root node (id = 1).") + constructNode(1, dataMap, mutable.Map.empty) + } + + /** + * Builds a node from the node data map and adds new nodes to the input nodes map. + */ + private def constructNode( + id: Int, + dataMap: Map[Int, NodeData], + nodes: mutable.Map[Int, Node]): Node = { + if (nodes.contains(id)) { + return nodes(id) + } + val data = dataMap(id) + val node = + if (data.isLeaf) { + Node(data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf) + } else { + val leftNode = constructNode(data.leftNodeId.get, dataMap, nodes) + val rightNode = constructNode(data.rightNodeId.get, dataMap, nodes) + val stats = new InformationGainStats(data.infoGain.get, data.impurity, leftNode.impurity, + rightNode.impurity, leftNode.predict, rightNode.predict) + new Node(data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf, + data.split.map(_.toSplit), Some(leftNode), Some(rightNode), Some(stats)) + } + nodes += node.id -> node + node + } + } + + override def load(sc: SparkContext, path: String): DecisionTreeModel = { + implicit val formats = DefaultFormats + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + val algo = (metadata \ "algo").extract[String] + val numNodes = (metadata \ "numNodes").extract[Int] + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + SaveLoadV1_0.load(sc, path, algo, numNodes) + case _ => throw new Exception( + s"DecisionTreeModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index 9a50ecb550c3..f209fdafd365 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -38,21 +38,32 @@ class InformationGainStats( val leftPredict: Predict, val rightPredict: Predict) extends Serializable { - override def toString = { + override def toString: String = { "gain = %f, impurity = %f, left impurity = %f, right impurity = %f" .format(gain, impurity, leftImpurity, rightImpurity) } - override def equals(o: Any) = - o match { - case other: InformationGainStats => { - gain == other.gain && - impurity == other.impurity && - leftImpurity == other.leftImpurity && - rightImpurity == other.rightImpurity - } - case _ => false - } + override def equals(o: Any): Boolean = o match { + case other: InformationGainStats => + gain == other.gain && + impurity == other.impurity && + leftImpurity == other.leftImpurity && + rightImpurity == other.rightImpurity && + leftPredict == other.leftPredict && + rightPredict == other.rightPredict + + case _ => false + } + + override def hashCode: Int = { + com.google.common.base.Objects.hashCode( + gain: java.lang.Double, + impurity: java.lang.Double, + leftImpurity: java.lang.Double, + rightImpurity: java.lang.Double, + leftPredict, + rightPredict) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 2179da8dbe03..86390a20cb5c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -50,8 +50,10 @@ class Node ( var rightNode: Option[Node], var stats: Option[InformationGainStats]) extends Serializable with Logging { - override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + - "impurity = " + impurity + "split = " + split + ", stats = " + stats + override def toString: String = { + "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + + "impurity = " + impurity + ", split = " + split + ", stats = " + stats + } /** * build the left node and right nodes if not leaf @@ -166,9 +168,14 @@ class Node ( } } + /** Returns an iterator that traverses (DFS, left to right) the subtree of this node. */ + private[tree] def subtreeIterator: Iterator[Node] = { + Iterator.single(this) ++ leftNode.map(_.subtreeIterator).getOrElse(Iterator.empty) ++ + rightNode.map(_.subtreeIterator).getOrElse(Iterator.empty) + } } -private[tree] object Node { +private[spark] object Node { /** * Return a node with the given node id (but nothing else set). diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala index 004838ee5ba0..25990af7c6cf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -29,7 +29,18 @@ class Predict( val predict: Double, val prob: Double = 0.0) extends Serializable { - override def toString = { + override def toString: String = { "predict = %f, prob = %f".format(predict, prob) } + + override def equals(other: Any): Boolean = { + other match { + case p: Predict => predict == p.predict && prob == p.prob + case _ => false + } + } + + override def hashCode: Int = { + com.google.common.base.Objects.hashCode(predict: java.lang.Double, prob: java.lang.Double) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index b7a85f58544a..fb35e70a8d07 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -38,9 +38,10 @@ case class Split( featureType: FeatureType, categories: List[Double]) { - override def toString = + override def toString: String = { "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType + ", categories = " + categories + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 22997110de8d..8341219bfa71 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -20,13 +20,24 @@ package org.apache.spark.mllib.tree.model import scala.collection.mutable import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Algo import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._ +import org.apache.spark.mllib.tree.loss.Loss +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext +import org.apache.spark.util.Utils + /** * :: Experimental :: @@ -37,10 +48,45 @@ import org.apache.spark.rdd.RDD */ @Experimental class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel]) - extends TreeEnsembleModel(algo, trees, Array.fill(trees.size)(1.0), - combiningStrategy = if (algo == Classification) Vote else Average) { + extends TreeEnsembleModel(algo, trees, Array.fill(trees.length)(1.0), + combiningStrategy = if (algo == Classification) Vote else Average) + with Saveable { require(trees.forall(_.algo == algo)) + + override def save(sc: SparkContext, path: String): Unit = { + TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this, + RandomForestModel.SaveLoadV1_0.thisClassName) + } + + override protected def formatVersion: String = RandomForestModel.formatVersion +} + +object RandomForestModel extends Loader[RandomForestModel] { + + private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion + + override def load(sc: SparkContext, path: String): RandomForestModel = { + val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path) + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(jsonMetadata) + assert(metadata.treeWeights.forall(_ == 1.0)) + val trees = + TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo) + new RandomForestModel(Algo.fromString(metadata.algo), trees) + case _ => throw new Exception(s"RandomForestModel.load did not recognize model" + + s" with (className, format version): ($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } + + private object SaveLoadV1_0 { + // Hard-code class name string in case it changes in the future + def thisClassName: String = "org.apache.spark.mllib.tree.model.RandomForestModel" + } + } /** @@ -56,9 +102,138 @@ class GradientBoostedTreesModel( override val algo: Algo, override val trees: Array[DecisionTreeModel], override val treeWeights: Array[Double]) - extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) { + extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) + with Saveable { + + require(trees.length == treeWeights.length) + + override def save(sc: SparkContext, path: String): Unit = { + TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this, + GradientBoostedTreesModel.SaveLoadV1_0.thisClassName) + } + + /** + * Method to compute error or loss for every iteration of gradient boosting. + * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param loss evaluation metric. + * @return an array with index i having the losses or errors for the ensemble + * containing the first i+1 trees + */ + def evaluateEachIteration( + data: RDD[LabeledPoint], + loss: Loss): Array[Double] = { + + val sc = data.sparkContext + val remappedData = algo match { + case Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + case _ => data + } + + val numIterations = trees.length + val evaluationArray = Array.fill(numIterations)(0.0) + val localTreeWeights = treeWeights + + var predictionAndError = GradientBoostedTreesModel.computeInitialPredictionAndError( + remappedData, localTreeWeights(0), trees(0), loss) + + evaluationArray(0) = predictionAndError.values.mean() + + val broadcastTrees = sc.broadcast(trees) + (1 until numIterations).foreach { nTree => + predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter => + val currentTree = broadcastTrees.value(nTree) + val currentTreeWeight = localTreeWeights(nTree) + iter.map { case (point, (pred, error)) => + val newPred = pred + currentTree.predict(point.features) * currentTreeWeight + val newError = loss.computeError(newPred, point.label) + (newPred, newError) + } + } + evaluationArray(nTree) = predictionAndError.values.mean() + } + + broadcastTrees.unpersist() + evaluationArray + } + + override protected def formatVersion: String = GradientBoostedTreesModel.formatVersion +} + +object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { + + /** + * Compute the initial predictions and errors for a dataset for the first + * iteration of gradient boosting. + * @param data: training data. + * @param initTreeWeight: learning rate assigned to the first tree. + * @param initTree: first DecisionTreeModel. + * @param loss: evaluation metric. + * @return a RDD with each element being a zip of the prediction and error + * corresponding to every sample. + */ + def computeInitialPredictionAndError( + data: RDD[LabeledPoint], + initTreeWeight: Double, + initTree: DecisionTreeModel, + loss: Loss): RDD[(Double, Double)] = { + data.map { lp => + val pred = initTreeWeight * initTree.predict(lp.features) + val error = loss.computeError(pred, lp.label) + (pred, error) + } + } + + /** + * Update a zipped predictionError RDD + * (as obtained with computeInitialPredictionAndError) + * @param data: training data. + * @param predictionAndError: predictionError RDD + * @param treeWeight: Learning rate. + * @param tree: Tree using which the prediction and error should be updated. + * @param loss: evaluation metric. + * @return a RDD with each element being a zip of the prediction and error + * corresponding to each sample. + */ + def updatePredictionError( + data: RDD[LabeledPoint], + predictionAndError: RDD[(Double, Double)], + treeWeight: Double, + tree: DecisionTreeModel, + loss: Loss): RDD[(Double, Double)] = { + + val newPredError = data.zip(predictionAndError).mapPartitions { iter => + iter.map { case (lp, (pred, error)) => + val newPred = pred + tree.predict(lp.features) * treeWeight + val newError = loss.computeError(newPred, lp.label) + (newPred, newError) + } + } + newPredError + } + + private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion + + override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = { + val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path) + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(jsonMetadata) + assert(metadata.combiningStrategy == Sum.toString) + val trees = + TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo) + new GradientBoostedTreesModel(Algo.fromString(metadata.algo), trees, metadata.treeWeights) + case _ => throw new Exception(s"GradientBoostedTreesModel.load did not recognize model" + + s" with (className, format version): ($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } + + private object SaveLoadV1_0 { + // Hard-code class name string in case it changes in the future + def thisClassName: String = "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel" + } - require(trees.size == treeWeights.size) } /** @@ -167,12 +342,105 @@ private[tree] sealed class TreeEnsembleModel( } /** - * Get number of trees in forest. + * Get number of trees in ensemble. */ - def numTrees: Int = trees.size + def numTrees: Int = trees.length /** - * Get total number of nodes, summed over all trees in the forest. + * Get total number of nodes, summed over all trees in the ensemble. */ def totalNumNodes: Int = trees.map(_.numNodes).sum } + +private[tree] object TreeEnsembleModel extends Logging { + + object SaveLoadV1_0 { + + import org.apache.spark.mllib.tree.model.DecisionTreeModel.SaveLoadV1_0.{NodeData, constructTrees} + + def thisFormatVersion: String = "1.0" + + case class Metadata( + algo: String, + treeAlgo: String, + combiningStrategy: String, + treeWeights: Array[Double]) + + /** + * Model data for model import/export. + * We have to duplicate NodeData here since Spark SQL does not yet support extracting subfields + * of nested fields; once that is possible, we can use something like: + * case class EnsembleNodeData(treeId: Int, node: NodeData), + * where NodeData is from DecisionTreeModel. + */ + case class EnsembleNodeData(treeId: Int, node: NodeData) + + def save(sc: SparkContext, path: String, model: TreeEnsembleModel, className: String): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // SPARK-6120: We do a hacky check here so users understand why save() is failing + // when they run the ML guide example. + // TODO: Fix this issue for real. + val memThreshold = 768 + if (sc.isLocal) { + val driverMemory = sc.getConf.getOption("spark.driver.memory") + .orElse(Option(System.getenv("SPARK_DRIVER_MEMORY"))) + .map(Utils.memoryStringToMb) + .getOrElse(512) + if (driverMemory <= memThreshold) { + logWarning(s"$className.save() was called, but it may fail because of too little" + + s" driver memory (${driverMemory}m)." + + s" If failure occurs, try setting driver-memory ${memThreshold}m (or larger).") + } + } else { + if (sc.executorMemory <= memThreshold) { + logWarning(s"$className.save() was called, but it may fail because of too little" + + s" executor memory (${sc.executorMemory}m)." + + s" If failure occurs try setting executor-memory ${memThreshold}m (or larger).") + } + } + + // Create JSON metadata. + implicit val format = DefaultFormats + val ensembleMetadata = Metadata(model.algo.toString, model.trees(0).algo.toString, + model.combiningStrategy.toString, model.treeWeights) + val metadata = compact(render( + ("class" -> className) ~ ("version" -> thisFormatVersion) ~ + ("metadata" -> Extraction.decompose(ensembleMetadata)))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val dataRDD = sc.parallelize(model.trees.zipWithIndex).flatMap { case (tree, treeId) => + tree.topNode.subtreeIterator.toSeq.map(node => NodeData(treeId, node)) + }.toDF() + dataRDD.saveAsParquetFile(Loader.dataPath(path)) + } + + /** + * Read metadata from the loaded JSON metadata. + */ + def readMetadata(metadata: JValue): Metadata = { + implicit val formats = DefaultFormats + (metadata \ "metadata").extract[Metadata] + } + + /** + * Load trees for an ensemble, and return them in order. + * @param path path to load the model from + * @param treeAlgo Algorithm for individual trees (which may differ from the ensemble's + * algorithm). + */ + def loadTrees( + sc: SparkContext, + path: String, + treeAlgo: String): Array[DecisionTreeModel] = { + val datapath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + val nodes = sqlContext.parquetFile(datapath).map(NodeData.apply) + val trees = constructTrees(nodes) + trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo))) + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala index 45f95482a1de..be335a1aca58 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala @@ -34,11 +34,27 @@ object DataValidators extends Logging { * * @return True if labels are all zero or one, false otherwise. */ - val binaryLabelValidator: RDD[LabeledPoint] => Boolean = { data => + val binaryLabelValidator: RDD[LabeledPoint] => Boolean = { data => val numInvalid = data.filter(x => x.label != 1.0 && x.label != 0.0).count() if (numInvalid != 0) { logError("Classification labels should be 0 or 1. Found " + numInvalid + " invalid labels") } numInvalid == 0 } + + /** + * Function to check if labels used for k class multi-label classification are + * in the range of {0, 1, ..., k - 1}. + * + * @return True if labels are all in the range of {0, 1, ..., k-1}, false otherwise. + */ + def multiLabelValidator(k: Int): RDD[LabeledPoint] => Boolean = { data => + val numInvalid = data.filter(x => + x.label - x.label.toInt != 0.0 || x.label < 0 || x.label > k - 1).count() + if (numInvalid != 0) { + logError("Classification labels should be in {0 to " + (k - 1) + "}. " + + "Found " + numInvalid + " invalid labels") + } + numInvalid == 0 + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index 69299c219878..c9d33787b0bb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.util import scala.collection.JavaConversions._ import scala.util.Random -import org.jblas.DoubleMatrix +import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext @@ -62,7 +62,7 @@ object LinearDataGenerator { * @param nPoints Number of points in sample. * @param seed Random seed * @param eps Epsilon scaling factor. - * @return + * @return Seq of input. */ def generateLinearInput( intercept: Double, @@ -72,11 +72,10 @@ object LinearDataGenerator { eps: Double = 0.1): Seq[LabeledPoint] = { val rnd = new Random(seed) - val weightsMat = new DoubleMatrix(1, weights.length, weights:_*) val x = Array.fill[Array[Double]](nPoints)( Array.fill[Double](weights.length)(2 * rnd.nextDouble - 1.0)) val y = x.map { xi => - new DoubleMatrix(1, xi.length, xi: _*).dot(weightsMat) + intercept + eps * rnd.nextGaussian() + blas.ddot(weights.length, xi, 1, weights, 1) + intercept + eps * rnd.nextGaussian() } y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2))) } @@ -100,9 +99,9 @@ object LinearDataGenerator { eps: Double, nparts: Int = 2, intercept: Double = 0.0) : RDD[LabeledPoint] = { - org.jblas.util.Random.seed(42) + val random = new Random(42) // Random values distributed uniformly in [-0.5, 0.5] - val w = DoubleMatrix.rand(nfeatures, 1).subi(0.5) + val w = Array.fill(nfeatures)(random.nextDouble() - 0.5) val data: RDD[LabeledPoint] = sc.parallelize(0 until nparts, nparts).flatMap { p => val seed = 42 + p diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala index b76fbe89c368..0c5b4f9d04a7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala @@ -17,13 +17,14 @@ package org.apache.spark.mllib.util +import java.{util => ju} + import scala.language.postfixOps import scala.util.Random -import org.jblas.DoubleMatrix - -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix} import org.apache.spark.rdd.RDD /** @@ -72,24 +73,25 @@ object MFDataGenerator { val sc = new SparkContext(sparkMaster, "MFDataGenerator") - val A = DoubleMatrix.randn(m, rank) - val B = DoubleMatrix.randn(rank, n) - val z = 1 / scala.math.sqrt(scala.math.sqrt(rank)) - A.mmuli(z) - B.mmuli(z) - val fullData = A.mmul(B) + val random = new ju.Random(42L) + + val A = DenseMatrix.randn(m, rank, random) + val B = DenseMatrix.randn(rank, n, random) + val z = 1 / math.sqrt(rank) + val fullData = DenseMatrix.zeros(m, n) + BLAS.gemm(z, A, B, 1.0, fullData) val df = rank * (m + n - rank) val sampSize = scala.math.min(scala.math.round(trainSampFact * df), scala.math.round(.99 * m * n)).toInt val rand = new Random() val mn = m * n - val shuffled = rand.shuffle(1 to mn toList) + val shuffled = rand.shuffle((0 until mn).toList) val omega = shuffled.slice(0, sampSize) val ordered = omega.sortWith(_ < _).toArray val trainData: RDD[(Int, Int, Double)] = sc.parallelize(ordered) - .map(x => (fullData.indexRows(x - 1), fullData.indexColumns(x - 1), fullData.get(x - 1))) + .map(x => (x % m, x / m, fullData.values(x))) // optionally add gaussian noise if (noise) { @@ -105,7 +107,7 @@ object MFDataGenerator { val testOmega = shuffled.slice(sampSize, sampSize + testSampSize) val testOrdered = testOmega.sortWith(_ < _).toArray val testData: RDD[(Int, Int, Double)] = sc.parallelize(testOrdered) - .map(x => (fullData.indexRows(x - 1), fullData.indexColumns(x - 1), fullData.get(x - 1))) + .map(x => (x % m, x / m, fullData.values(x))) testData.map(x => x._1 + "," + x._2 + "," + x._3).saveAsTextFile(outputPath) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala index f7cba6c6cb62..308f7f3578e2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.util import java.util.StringTokenizer -import scala.collection.mutable.{ArrayBuffer, ListBuffer} +import scala.collection.mutable.{ArrayBuilder, ListBuffer} import org.apache.spark.SparkException @@ -51,7 +51,7 @@ private[mllib] object NumericParser { } private def parseArray(tokenizer: StringTokenizer): Array[Double] = { - val values = ArrayBuffer.empty[Double] + val values = ArrayBuilder.make[Double] var parsing = true var allowComma = false var token: String = null @@ -67,14 +67,14 @@ private[mllib] object NumericParser { } } else { // expecting a number - values.append(parseDouble(token)) + values += parseDouble(token) allowComma = true } } if (parsing) { throw new SparkException(s"An array must end with ']'.") } - values.toArray + values.result() } private def parseTuple(tokenizer: StringTokenizer): Seq[_] = { @@ -114,7 +114,7 @@ private[mllib] object NumericParser { try { java.lang.Double.parseDouble(s) } catch { - case e: Throwable => + case e: NumberFormatException => throw new SparkException(s"Cannot parse a double from: $s", e) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala index 7db97e6bac68..a8e30cc9d730 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.util import scala.util.Random -import org.jblas.DoubleMatrix +import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext @@ -51,8 +51,7 @@ object SVMDataGenerator { val sc = new SparkContext(sparkMaster, "SVMGenerator") val globalRnd = new Random(94720) - val trueWeights = new DoubleMatrix(1, nfeatures + 1, - Array.fill[Double](nfeatures + 1)(globalRnd.nextGaussian()):_*) + val trueWeights = Array.fill[Double](nfeatures + 1)(globalRnd.nextGaussian()) val data: RDD[LabeledPoint] = sc.parallelize(0 until nexamples, parts).map { idx => val rnd = new Random(42 + idx) @@ -60,7 +59,7 @@ object SVMDataGenerator { val x = Array.fill[Double](nfeatures) { rnd.nextDouble() * 2.0 - 1.0 } - val yD = new DoubleMatrix(1, x.length, x: _*).dot(trueWeights) + rnd.nextGaussian() * 0.1 + val yD = blas.ddot(trueWeights.length, x, 1, trueWeights, 1) + rnd.nextGaussian() * 0.1 val y = if (yD < 0) 0.0 else 1.0 LabeledPoint(y, Vectors.dense(x)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala new file mode 100644 index 000000000000..30d642c754b7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import scala.reflect.runtime.universe.TypeTag + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.types.{DataType, StructField, StructType} + +/** + * :: DeveloperApi :: + * + * Trait for models and transformers which may be saved as files. + * This should be inherited by the class which implements model instances. + */ +@DeveloperApi +trait Saveable { + + /** + * Save this model to the given path. + * + * This saves: + * - human-readable (JSON) model metadata to path/metadata/ + * - Parquet formatted data to path/data/ + * + * The model may be loaded using [[Loader.load]]. + * + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * If the directory already exists, this method throws an exception. + */ + def save(sc: SparkContext, path: String): Unit + + /** Current version of model save/load format. */ + protected def formatVersion: String + +} + +/** + * :: DeveloperApi :: + * + * Trait for classes which can load models and transformers from files. + * This should be inherited by an object paired with the model class. + */ +@DeveloperApi +trait Loader[M <: Saveable] { + + /** + * Load a model from the given path. + * + * The model should have been saved by [[Saveable.save]]. + * + * @param sc Spark context used for loading model files. + * @param path Path specifying the directory to which the model was saved. + * @return Model instance + */ + def load(sc: SparkContext, path: String): M + +} + +/** + * Helper methods for loading models from files. + */ +private[mllib] object Loader { + + /** Returns URI for path/data using the Hadoop filesystem */ + def dataPath(path: String): String = new Path(path, "data").toUri.toString + + /** Returns URI for path/metadata using the Hadoop filesystem */ + def metadataPath(path: String): String = new Path(path, "metadata").toUri.toString + + /** + * Check the schema of loaded model data. + * + * This checks every field in the expected schema to make sure that a field with the same + * name and DataType appears in the loaded schema. Note that this does NOT check metadata + * or containsNull. + * + * @param loadedSchema Schema for model data loaded from file. + * @tparam Data Expected data type from which an expected schema can be derived. + */ + def checkSchema[Data: TypeTag](loadedSchema: StructType): Unit = { + // Check schema explicitly since erasure makes it hard to use match-case for checking. + val expectedFields: Array[StructField] = + ScalaReflection.schemaFor[Data].dataType.asInstanceOf[StructType].fields + val loadedFields: Map[String, DataType] = + loadedSchema.map(field => field.name -> field.dataType).toMap + expectedFields.foreach { field => + assert(loadedFields.contains(field.name), s"Unable to parse model data." + + s" Expected field with name ${field.name} was missing in loaded schema:" + + s" ${loadedFields.mkString(", ")}") + assert(loadedFields(field.name).sameType(field.dataType), + s"Unable to parse model data. Expected field $field but found field" + + s" with different type: ${loadedFields(field.name)}") + } + } + + /** + * Load metadata from the given path. + * @return (class name, version, metadata) + */ + def loadMetadata(sc: SparkContext, path: String): (String, String, JValue) = { + implicit val formats = DefaultFormats + val metadata = parse(sc.textFile(metadataPath(path)).first()) + val clazz = (metadata \ "class").extract[String] + val version = (metadata \ "version").extract[String] + (clazz, version, metadata) + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java index 47f1f46c6c26..0a8c9e595467 100644 --- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java @@ -26,7 +26,7 @@ import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; @@ -37,7 +37,7 @@ public class JavaPipelineSuite { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient SchemaRDD dataset; + private transient DataFrame dataset; @Before public void setUp() { @@ -45,7 +45,7 @@ public void setUp() { jsql = new SQLContext(jsc); JavaRDD points = jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2); - dataset = jsql.applySchema(points, LabeledPoint.class); + dataset = jsql.createDataFrame(points, LabeledPoint.class); } @After @@ -65,7 +65,7 @@ public void pipeline() { .setStages(new PipelineStage[] {scaler, lr}); PipelineModel model = pipeline.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); predictions.collectAsList(); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeGroupSuite.java b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeGroupSuite.java new file mode 100644 index 000000000000..38eb58673ad5 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeGroupSuite.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute; + +import org.junit.Assert; +import org.junit.Test; + +public class JavaAttributeGroupSuite { + + @Test + public void testAttributeGroup() { + Attribute[] attrs = new Attribute[]{ + NumericAttribute.defaultAttr(), + NominalAttribute.defaultAttr(), + BinaryAttribute.defaultAttr().withIndex(0), + NumericAttribute.defaultAttr().withName("age").withSparsity(0.8), + NominalAttribute.defaultAttr().withName("size").withValues("small", "medium", "large"), + BinaryAttribute.defaultAttr().withName("clicked").withValues("no", "yes"), + NumericAttribute.defaultAttr(), + NumericAttribute.defaultAttr() + }; + AttributeGroup group = new AttributeGroup("user", attrs); + Assert.assertEquals(8, group.size()); + Assert.assertEquals("user", group.name()); + Assert.assertEquals(NumericAttribute.defaultAttr().withIndex(0), group.getAttr(0)); + Assert.assertEquals(3, group.indexOf("age")); + Assert.assertFalse(group.hasAttr("abc")); + Assert.assertEquals(group, AttributeGroup.fromStructField(group.toStructField())); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java new file mode 100644 index 000000000000..b74bbed23143 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute; + +import org.junit.Test; +import org.junit.Assert; + +public class JavaAttributeSuite { + + @Test + public void testAttributeType() { + AttributeType numericType = AttributeType.Numeric(); + AttributeType nominalType = AttributeType.Nominal(); + AttributeType binaryType = AttributeType.Binary(); + Assert.assertEquals(numericType, NumericAttribute.defaultAttr().attrType()); + Assert.assertEquals(nominalType, NominalAttribute.defaultAttr().attrType()); + Assert.assertEquals(binaryType, BinaryAttribute.defaultAttr().attrType()); + } + + @Test + public void testNumericAttribute() { + NumericAttribute attr = NumericAttribute.defaultAttr() + .withName("age").withIndex(0).withMin(0.0).withMax(1.0).withStd(0.5).withSparsity(0.4); + Assert.assertEquals(attr.withoutIndex(), Attribute.fromStructField(attr.toStructField())); + } + + @Test + public void testNominalAttribute() { + NominalAttribute attr = NominalAttribute.defaultAttr() + .withName("size").withIndex(1).withValues("small", "medium", "large"); + Assert.assertEquals(attr.withoutIndex(), Attribute.fromStructField(attr.toStructField())); + } + + @Test + public void testBinaryAttribute() { + BinaryAttribute attr = BinaryAttribute.defaultAttr() + .withName("clicked").withIndex(2).withValues("no", "yes"); + Assert.assertEquals(attr.withoutIndex(), Attribute.fromStructField(attr.toStructField())); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java new file mode 100644 index 000000000000..43b8787f9dd7 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.File; +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.util.Utils; + + +public class JavaDecisionTreeClassifierSuite implements Serializable { + + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaDecisionTreeClassifierSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runDT() { + int nPoints = 20; + double A = 2.0; + double B = -1.5; + + JavaRDD data = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + Map categoricalFeatures = new HashMap(); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + + // This tests setters. Training with various options is tested in Scala. + DecisionTreeClassifier dt = new DecisionTreeClassifier() + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (int i = 0; i < DecisionTreeClassifier.supportedImpurities().length; ++i) { + dt.setImpurity(DecisionTreeClassifier.supportedImpurities()[i]); + } + DecisionTreeClassificationModel model = dt.fit(dataFrame); + + model.transform(dataFrame); + model.numNodes(); + model.depth(); + model.toDebugString(); + + /* + // TODO: Add test once save/load are implemented. + File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); + String path = tempDir.toURI().toString(); + try { + model3.save(sc.sc(), path); + DecisionTreeClassificationModel sameModel = + DecisionTreeClassificationModel.load(sc.sc(), path); + TreeTests.checkEqual(model3, sameModel); + } finally { + Utils.deleteRecursively(tempDir); + } + */ + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 2eba83335bb5..3f8e59de0f05 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -18,30 +18,40 @@ package org.apache.spark.ml.classification; import java.io.Serializable; +import java.lang.Math; import java.util.List; import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; +import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; +import org.apache.spark.sql.Row; + public class JavaLogisticRegressionSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient SchemaRDD dataset; + private transient DataFrame dataset; + + private transient JavaRDD datasetRDD; + private double eps = 1e-5; @Before public void setUp() { jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); jsql = new SQLContext(jsc); List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); - dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class); + datasetRDD = jsc.parallelize(points, 2); + dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); + dataset.registerTempTable("dataset"); } @After @@ -51,29 +61,88 @@ public void tearDown() { } @Test - public void logisticRegression() { + public void logisticRegressionDefaultParams() { LogisticRegression lr = new LogisticRegression(); + assert(lr.getLabelCol().equals("label")); LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); predictions.collectAsList(); + // Check defaults + assert(model.getThreshold() == 0.5); + assert(model.getFeaturesCol().equals("features")); + assert(model.getPredictionCol().equals("prediction")); + assert(model.getProbabilityCol().equals("probability")); } @Test public void logisticRegressionWithSetters() { + // Set params, train, and check as many params as we can. LogisticRegression lr = new LogisticRegression() .setMaxIter(10) - .setRegParam(1.0); + .setRegParam(1.0) + .setThreshold(0.6) + .setProbabilityCol("myProbability"); LogisticRegressionModel model = lr.fit(dataset); - model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold - .registerTempTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); - predictions.collectAsList(); + assert(model.fittingParamMap().apply(lr.maxIter()).equals(10)); + assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0)); + assert(model.fittingParamMap().apply(lr.threshold()).equals(0.6)); + assert(model.getThreshold() == 0.6); + + // Modify model params, and check that the params worked. + model.setThreshold(1.0); + model.transform(dataset).registerTempTable("predAllZero"); + DataFrame predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero"); + for (Row r: predAllZero.collectAsList()) { + assert(r.getDouble(0) == 0.0); + } + // Call transform with params, and check that the params worked. + model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb")) + .registerTempTable("predNotAllZero"); + DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); + boolean foundNonZero = false; + for (Row r: predNotAllZero.collectAsList()) { + if (r.getDouble(0) != 0.0) foundNonZero = true; + } + assert(foundNonZero); + + // Call fit() with new params, and check as many params as we can. + LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), + lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); + assert(model2.fittingParamMap().apply(lr.maxIter()).equals(5)); + assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1)); + assert(model2.fittingParamMap().apply(lr.threshold()).equals(0.4)); + assert(model2.getThreshold() == 0.4); + assert(model2.getProbabilityCol().equals("theProb")); } + @SuppressWarnings("unchecked") @Test - public void logisticRegressionFitWithVarargs() { + public void logisticRegressionPredictorClassifierMethods() { LogisticRegression lr = new LogisticRegression(); - lr.fit(dataset, lr.maxIter().w(10), lr.regParam().w(1.0)); + LogisticRegressionModel model = lr.fit(dataset); + assert(model.numClasses() == 2); + + model.transform(dataset).registerTempTable("transformed"); + DataFrame trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed"); + for (Row row: trans1.collect()) { + Vector raw = (Vector)row.get(0); + Vector prob = (Vector)row.get(1); + assert(raw.size() == 2); + assert(prob.size() == 2); + double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1))); + assert(Math.abs(prob.apply(1) - probFromRaw1) < eps); + assert(Math.abs(prob.apply(0) - (1.0 - probFromRaw1)) < eps); + } + + DataFrame trans2 = jsql.sql("SELECT prediction, probability FROM transformed"); + for (Row row: trans2.collect()) { + double pred = row.getDouble(0); + Vector prob = (Vector)row.get(1); + double probOfPred = prob.apply((int)pred); + for (int i = 0; i < prob.size(); ++i) { + assert(probOfPred >= prob.apply(i)); + } + } } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java new file mode 100644 index 000000000000..640d2ec55e4e --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; +import java.util.List; + +import scala.Tuple2; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.SparkConf; +import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import static org.apache.spark.streaming.JavaTestUtils.*; + +public class JavaStreamingLogisticRegressionSuite implements Serializable { + + protected transient JavaStreamingContext ssc; + + @Before + public void setUp() { + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); + ssc.checkpoint("checkpoint"); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + } + + @Test + @SuppressWarnings("unchecked") + public void javaAPI() { + List trainingBatch = Lists.newArrayList( + new LabeledPoint(1.0, Vectors.dense(1.0)), + new LabeledPoint(0.0, Vectors.dense(0.0))); + JavaDStream training = + attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2); + List> testBatch = Lists.newArrayList( + new Tuple2(10, Vectors.dense(1.0)), + new Tuple2(11, Vectors.dense(0.0))); + JavaPairDStream test = JavaPairDStream.fromJavaDStream( + attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2)); + StreamingLogisticRegressionWithSGD slr = new StreamingLogisticRegressionWithSGD() + .setNumIterations(2) + .setInitialWeights(Vectors.dense(0.0)); + slr.trainOn(training); + JavaPairDStream prediction = slr.predictOnValues(test); + attachTestOutputStream(prediction.count()); + runStreams(ssc, 2, 2); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java new file mode 100644 index 000000000000..3806f650025b --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; + +public class JavaTokenizerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaTokenizerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void regexTokenizer() { + RegexTokenizer myRegExTokenizer = new RegexTokenizer() + .setInputCol("rawText") + .setOutputCol("tokens") + .setPattern("\\s") + .setGaps(true) + .setMinTokenLength(3); + + JavaRDD rdd = jsc.parallelize(Lists.newArrayList( + new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}), + new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"}) + )); + DataFrame dataset = jsql.createDataFrame(rdd, TokenizerTestData.class); + + Row[] pairs = myRegExTokenizer.transform(dataset) + .select("tokens", "wantedTokens") + .collect(); + + for (Row r : pairs) { + Assert.assertEquals(r.get(0), r.get(1)); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java new file mode 100644 index 000000000000..161100134c92 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + + +public class JavaVectorIndexerSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaVectorIndexerSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void vectorIndexerAPI() { + // The tests are to check Java compatibility. + List points = Lists.newArrayList( + new FeatureData(Vectors.dense(0.0, -2.0)), + new FeatureData(Vectors.dense(1.0, 3.0)), + new FeatureData(Vectors.dense(1.0, 4.0)) + ); + SQLContext sqlContext = new SQLContext(sc); + DataFrame data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class); + VectorIndexer indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(2); + VectorIndexerModel model = indexer.fit(data); + Assert.assertEquals(model.numFeatures(), 2); + Assert.assertEquals(model.categoryMaps().size(), 1); + DataFrame indexedData = model.transform(data); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java new file mode 100644 index 000000000000..a3a339004f31 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression; + +import java.io.File; +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.util.Utils; + + +public class JavaDecisionTreeRegressorSuite implements Serializable { + + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaDecisionTreeRegressorSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runDT() { + int nPoints = 20; + double A = 2.0; + double B = -1.5; + + JavaRDD data = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + Map categoricalFeatures = new HashMap(); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + + // This tests setters. Training with various options is tested in Scala. + DecisionTreeRegressor dt = new DecisionTreeRegressor() + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (int i = 0; i < DecisionTreeRegressor.supportedImpurities().length; ++i) { + dt.setImpurity(DecisionTreeRegressor.supportedImpurities()[i]); + } + DecisionTreeRegressionModel model = dt.fit(dataFrame); + + model.transform(dataFrame); + model.numNodes(); + model.depth(); + model.toDebugString(); + + /* + // TODO: Add test once save/load are implemented. + File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); + String path = tempDir.toURI().toString(); + try { + model2.save(sc.sc(), path); + DecisionTreeRegressionModel sameModel = DecisionTreeRegressionModel.load(sc.sc(), path); + TreeTests.checkEqual(model2, sameModel); + } finally { + Utils.deleteRecursively(tempDir); + } + */ + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java new file mode 100644 index 000000000000..0cc36c8d56d7 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite + .generateLogisticInputAsList; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + + +public class JavaLinearRegressionSuite implements Serializable { + + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + private transient DataFrame dataset; + private transient JavaRDD datasetRDD; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaLinearRegressionSuite"); + jsql = new SQLContext(jsc); + List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); + datasetRDD = jsc.parallelize(points, 2); + dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); + dataset.registerTempTable("dataset"); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void linearRegressionDefaultParams() { + LinearRegression lr = new LinearRegression(); + assert(lr.getLabelCol().equals("label")); + LinearRegressionModel model = lr.fit(dataset); + model.transform(dataset).registerTempTable("prediction"); + DataFrame predictions = jsql.sql("SELECT label, prediction FROM prediction"); + predictions.collect(); + // Check defaults + assert(model.getFeaturesCol().equals("features")); + assert(model.getPredictionCol().equals("prediction")); + } + + @Test + public void linearRegressionWithSetters() { + // Set params, train, and check as many params as we can. + LinearRegression lr = new LinearRegression() + .setMaxIter(10) + .setRegParam(1.0); + LinearRegressionModel model = lr.fit(dataset); + assert(model.fittingParamMap().apply(lr.maxIter()).equals(10)); + assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0)); + + // Call fit() with new params, and check as many params as we can. + LinearRegressionModel model2 = + lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred")); + assert(model2.fittingParamMap().apply(lr.maxIter()).equals(5)); + assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1)); + assert(model2.getPredictionCol().equals("thePred")); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java index a9f1c4a2c3ca..0bb6b489f275 100644 --- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -30,7 +30,7 @@ import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; import org.apache.spark.ml.param.ParamMap; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; @@ -38,14 +38,14 @@ public class JavaCrossValidatorSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient SchemaRDD dataset; + private transient DataFrame dataset; @Before public void setUp() { jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite"); jsql = new SQLContext(jsc); List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); - dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class); + dataset = jsql.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class); } @After diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java index 1c90522a0714..71fb7f13c39c 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java @@ -17,20 +17,22 @@ package org.apache.spark.mllib.classification; +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import java.io.Serializable; -import java.util.Arrays; -import java.util.List; public class JavaNaiveBayesSuite implements Serializable { private transient JavaSparkContext sc; @@ -102,4 +104,11 @@ public Vector call(LabeledPoint v) throws Exception { // Should be able to get the first prediction. predictions.first(); } + + @Test + public void testModelTypeSetters() { + NaiveBayes nb = new NaiveBayes() + .setModelType("Bernoulli") + .setModelType("Multinomial"); + } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java new file mode 100644 index 000000000000..dc10aa67c7c1 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering; + +import java.io.Serializable; +import java.util.ArrayList; + +import org.apache.spark.api.java.JavaRDD; +import scala.Tuple2; + +import org.junit.After; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertArrayEquals; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.Vector; + + +public class JavaLDASuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaLDA"); + ArrayList> tinyCorpus = new ArrayList>(); + for (int i = 0; i < LDASuite$.MODULE$.tinyCorpus().length; i++) { + tinyCorpus.add(new Tuple2((Long)LDASuite$.MODULE$.tinyCorpus()[i]._1(), + LDASuite$.MODULE$.tinyCorpus()[i]._2())); + } + JavaRDD> tmpCorpus = sc.parallelize(tinyCorpus, 2); + corpus = JavaPairRDD.fromJavaRDD(tmpCorpus); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void localLDAModel() { + LocalLDAModel model = new LocalLDAModel(LDASuite$.MODULE$.tinyTopics()); + + // Check: basic parameters + assertEquals(model.k(), tinyK); + assertEquals(model.vocabSize(), tinyVocabSize); + assertEquals(model.topicsMatrix(), tinyTopics); + + // Check: describeTopics() with all terms + Tuple2[] fullTopicSummary = model.describeTopics(); + assertEquals(fullTopicSummary.length, tinyK); + for (int i = 0; i < fullTopicSummary.length; i++) { + assertArrayEquals(fullTopicSummary[i]._1(), tinyTopicDescription[i]._1()); + assertArrayEquals(fullTopicSummary[i]._2(), tinyTopicDescription[i]._2(), 1e-5); + } + } + + @Test + public void distributedLDAModel() { + int k = 3; + double topicSmoothing = 1.2; + double termSmoothing = 1.2; + + // Train a model + LDA lda = new LDA(); + lda.setK(k) + .setDocConcentration(topicSmoothing) + .setTopicConcentration(termSmoothing) + .setMaxIterations(5) + .setSeed(12345); + + DistributedLDAModel model = lda.run(corpus); + + // Check: basic parameters + LocalLDAModel localModel = model.toLocal(); + assertEquals(model.k(), k); + assertEquals(localModel.k(), k); + assertEquals(model.vocabSize(), tinyVocabSize); + assertEquals(localModel.vocabSize(), tinyVocabSize); + assertEquals(model.topicsMatrix(), localModel.topicsMatrix()); + + // Check: topic summaries + Tuple2[] roundedTopicSummary = model.describeTopics(); + assertEquals(roundedTopicSummary.length, k); + Tuple2[] roundedLocalTopicSummary = localModel.describeTopics(); + assertEquals(roundedLocalTopicSummary.length, k); + + // Check: log probabilities + assert(model.logLikelihood() < 0.0); + assert(model.logPrior() < 0.0); + } + + private static int tinyK = LDASuite$.MODULE$.tinyK(); + private static int tinyVocabSize = LDASuite$.MODULE$.tinyVocabSize(); + private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics(); + private static Tuple2[] tinyTopicDescription = + LDASuite$.MODULE$.tinyTopicDescription(); + JavaPairRDD corpus; + +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java new file mode 100644 index 000000000000..bd0edf2b9ea6 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.fpm; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import com.google.common.collect.Lists; +import static org.junit.Assert.*; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset; + +public class JavaFPGrowthSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaFPGrowth"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runFPGrowth() { + + @SuppressWarnings("unchecked") + JavaRDD> rdd = sc.parallelize(Lists.newArrayList( + Lists.newArrayList("r z h k p".split(" ")), + Lists.newArrayList("z y x w v u t s".split(" ")), + Lists.newArrayList("s x o n r".split(" ")), + Lists.newArrayList("x z y m t s q e".split(" ")), + Lists.newArrayList("z".split(" ")), + Lists.newArrayList("x z y r q t p".split(" "))), 2); + + FPGrowthModel model = new FPGrowth() + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd); + + List> freqItemsets = model.freqItemsets().toJavaRDD().collect(); + assertEquals(18, freqItemsets.size()); + + for (FreqItemset itemset: freqItemsets) { + // Test return types. + List items = itemset.javaItems(); + long freq = itemset.freq(); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java index 704d484d0b58..3349c5022423 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java @@ -71,8 +71,8 @@ public void diagonalMatrixConstruction() { Matrix sm = Matrices.diag(sv); DenseMatrix d = DenseMatrix.diag(v); DenseMatrix sd = DenseMatrix.diag(sv); - SparseMatrix s = SparseMatrix.diag(v); - SparseMatrix ss = SparseMatrix.diag(sv); + SparseMatrix s = SparseMatrix.spdiag(v); + SparseMatrix ss = SparseMatrix.spdiag(sv); assertArrayEquals(m.toArray(), sm.toArray(), 0.0); assertArrayEquals(d.toArray(), sm.toArray(), 0.0); diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java new file mode 100644 index 000000000000..d38fc91ace3c --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression; + +import java.io.Serializable; +import java.util.List; + +import scala.Tuple3; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; + +public class JavaIsotonicRegressionSuite implements Serializable { + private transient JavaSparkContext sc; + + private List> generateIsotonicInput(double[] labels) { + List> input = Lists.newArrayList(); + + for (int i = 1; i <= labels.length; i++) { + input.add(new Tuple3(labels[i-1], (double) i, 1d)); + } + + return input; + } + + private IsotonicRegressionModel runIsotonicRegression(double[] labels) { + JavaRDD> trainRDD = + sc.parallelize(generateIsotonicInput(labels), 2).cache(); + + return new IsotonicRegression().run(trainRDD); + } + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaLinearRegressionSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void testIsotonicRegressionJavaRDD() { + IsotonicRegressionModel model = + runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12}); + + Assert.assertArrayEquals( + new double[] {1, 2, 7d/3, 7d/3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1e-14); + } + + @Test + public void testIsotonicRegressionPredictionsJavaRDD() { + IsotonicRegressionModel model = + runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12}); + + JavaDoubleRDD testRDD = sc.parallelizeDoubles(Lists.newArrayList(0.0, 1.0, 9.5, 12.0, 13.0)); + List predictions = model.predict(testRDD).collect(); + + Assert.assertTrue(predictions.get(0) == 1d); + Assert.assertTrue(predictions.get(1) == 1d); + Assert.assertTrue(predictions.get(2) == 10d); + Assert.assertTrue(predictions.get(3) == 12d); + Assert.assertTrue(predictions.get(4) == 12d); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java new file mode 100644 index 000000000000..899c4ea60786 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression; + +import java.io.Serializable; +import java.util.List; + +import scala.Tuple2; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.SparkConf; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import static org.apache.spark.streaming.JavaTestUtils.*; + +public class JavaStreamingLinearRegressionSuite implements Serializable { + + protected transient JavaStreamingContext ssc; + + @Before + public void setUp() { + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); + ssc.checkpoint("checkpoint"); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + } + + @Test + @SuppressWarnings("unchecked") + public void javaAPI() { + List trainingBatch = Lists.newArrayList( + new LabeledPoint(1.0, Vectors.dense(1.0)), + new LabeledPoint(0.0, Vectors.dense(0.0))); + JavaDStream training = + attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2); + List> testBatch = Lists.newArrayList( + new Tuple2(10, Vectors.dense(1.0)), + new Tuple2(11, Vectors.dense(0.0))); + JavaPairDStream test = JavaPairDStream.fromJavaDStream( + attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2)); + StreamingLinearRegressionWithSGD slr = new StreamingLinearRegressionWithSGD() + .setNumIterations(2) + .setInitialWeights(Vectors.dense(0.0)); + slr.trainOn(training); + JavaPairDStream prediction = slr.predictOnValues(test); + attachTestOutputStream(prediction.count()); + runStreams(ssc, 2, 2); + } +} diff --git a/mllib/src/test/resources/log4j.properties b/mllib/src/test/resources/log4j.properties index 9697237bfa1a..75e3b53a093f 100644 --- a/mllib/src/test/resources/log4j.properties +++ b/mllib/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 4515084bc7ae..2f175fb11794 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -23,7 +23,7 @@ import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.ml.param.ParamMap -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame class PipelineSuite extends FunSuite { @@ -36,11 +36,11 @@ class PipelineSuite extends FunSuite { val estimator2 = mock[Estimator[MyModel]] val model2 = mock[MyModel] val transformer3 = mock[Transformer] - val dataset0 = mock[SchemaRDD] - val dataset1 = mock[SchemaRDD] - val dataset2 = mock[SchemaRDD] - val dataset3 = mock[SchemaRDD] - val dataset4 = mock[SchemaRDD] + val dataset0 = mock[DataFrame] + val dataset1 = mock[DataFrame] + val dataset2 = mock[DataFrame] + val dataset3 = mock[DataFrame] + val dataset4 = mock[DataFrame] when(estimator0.fit(meq(dataset0), any[ParamMap]())).thenReturn(model0) when(model0.transform(meq(dataset0), any[ParamMap]())).thenReturn(dataset1) @@ -74,7 +74,7 @@ class PipelineSuite extends FunSuite { val estimator = mock[Estimator[MyModel]] val pipeline = new Pipeline() .setStages(Array(estimator, estimator)) - val dataset = mock[SchemaRDD] + val dataset = mock[DataFrame] intercept[IllegalArgumentException] { pipeline.fit(dataset) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala new file mode 100644 index 000000000000..17ddd335deb6 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute + +import org.scalatest.FunSuite + +class AttributeGroupSuite extends FunSuite { + + test("attribute group") { + val attrs = Array( + NumericAttribute.defaultAttr, + NominalAttribute.defaultAttr, + BinaryAttribute.defaultAttr.withIndex(0), + NumericAttribute.defaultAttr.withName("age").withSparsity(0.8), + NominalAttribute.defaultAttr.withName("size").withValues("small", "medium", "large"), + BinaryAttribute.defaultAttr.withName("clicked").withValues("no", "yes"), + NumericAttribute.defaultAttr, + NumericAttribute.defaultAttr) + val group = new AttributeGroup("user", attrs) + assert(group.size === 8) + assert(group.name === "user") + assert(group(0) === NumericAttribute.defaultAttr.withIndex(0)) + assert(group(2) === BinaryAttribute.defaultAttr.withIndex(2)) + assert(group.indexOf("age") === 3) + assert(group.indexOf("size") === 4) + assert(group.indexOf("clicked") === 5) + assert(!group.hasAttr("abc")) + intercept[NoSuchElementException] { + group("abc") + } + assert(group === AttributeGroup.fromMetadata(group.toMetadataImpl, group.name)) + assert(group === AttributeGroup.fromStructField(group.toStructField())) + } + + test("attribute group without attributes") { + val group0 = new AttributeGroup("user", 10) + assert(group0.name === "user") + assert(group0.numAttributes === Some(10)) + assert(group0.size === 10) + assert(group0.attributes.isEmpty) + assert(group0 === AttributeGroup.fromMetadata(group0.toMetadataImpl, group0.name)) + assert(group0 === AttributeGroup.fromStructField(group0.toStructField())) + + val group1 = new AttributeGroup("item") + assert(group1.name === "item") + assert(group1.numAttributes.isEmpty) + assert(group1.attributes.isEmpty) + assert(group1.size === -1) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala new file mode 100644 index 000000000000..3e1a7196e37c --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute + +import org.scalatest.FunSuite + +import org.apache.spark.sql.types.{DoubleType, MetadataBuilder, Metadata} + +class AttributeSuite extends FunSuite { + + test("default numeric attribute") { + val attr: NumericAttribute = NumericAttribute.defaultAttr + val metadata = Metadata.fromJson("{}") + val metadataWithType = Metadata.fromJson("""{"type":"numeric"}""") + assert(attr.attrType === AttributeType.Numeric) + assert(attr.isNumeric) + assert(!attr.isNominal) + assert(attr.name.isEmpty) + assert(attr.index.isEmpty) + assert(attr.min.isEmpty) + assert(attr.max.isEmpty) + assert(attr.std.isEmpty) + assert(attr.sparsity.isEmpty) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = false) === metadata) + assert(attr.toMetadataImpl(withType = true) === metadataWithType) + assert(attr === Attribute.fromMetadata(metadata)) + assert(attr === Attribute.fromMetadata(metadataWithType)) + intercept[NoSuchElementException] { + attr.toStructField() + } + } + + test("customized numeric attribute") { + val name = "age" + val index = 0 + val metadata = Metadata.fromJson("""{"name":"age","idx":0}""") + val metadataWithType = Metadata.fromJson("""{"type":"numeric","name":"age","idx":0}""") + val attr: NumericAttribute = NumericAttribute.defaultAttr + .withName(name) + .withIndex(index) + assert(attr.attrType == AttributeType.Numeric) + assert(attr.isNumeric) + assert(!attr.isNominal) + assert(attr.name === Some(name)) + assert(attr.index === Some(index)) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = false) === metadata) + assert(attr.toMetadataImpl(withType = true) === metadataWithType) + assert(attr === Attribute.fromMetadata(metadata)) + assert(attr === Attribute.fromMetadata(metadataWithType)) + val field = attr.toStructField() + assert(field.dataType === DoubleType) + assert(!field.nullable) + assert(attr.withoutIndex === Attribute.fromStructField(field)) + val existingMetadata = new MetadataBuilder() + .putString("name", "test") + .build() + assert(attr.toStructField(existingMetadata).metadata.getString("name") === "test") + + val attr2 = + attr.withoutName.withoutIndex.withMin(0.0).withMax(1.0).withStd(0.5).withSparsity(0.3) + assert(attr2.name.isEmpty) + assert(attr2.index.isEmpty) + assert(attr2.min === Some(0.0)) + assert(attr2.max === Some(1.0)) + assert(attr2.std === Some(0.5)) + assert(attr2.sparsity === Some(0.3)) + assert(attr2 === Attribute.fromMetadata(attr2.toMetadataImpl())) + } + + test("bad numeric attributes") { + val attr = NumericAttribute.defaultAttr + intercept[IllegalArgumentException](attr.withName("")) + intercept[IllegalArgumentException](attr.withIndex(-1)) + intercept[IllegalArgumentException](attr.withStd(-0.1)) + intercept[IllegalArgumentException](attr.withSparsity(-0.5)) + intercept[IllegalArgumentException](attr.withSparsity(1.5)) + } + + test("default nominal attribute") { + val attr: NominalAttribute = NominalAttribute.defaultAttr + val metadata = Metadata.fromJson("""{"type":"nominal"}""") + val metadataWithoutType = Metadata.fromJson("{}") + assert(attr.attrType === AttributeType.Nominal) + assert(!attr.isNumeric) + assert(attr.isNominal) + assert(attr.name.isEmpty) + assert(attr.index.isEmpty) + assert(attr.values.isEmpty) + assert(attr.numValues.isEmpty) + assert(attr.isOrdinal.isEmpty) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = true) === metadata) + assert(attr.toMetadataImpl(withType = false) === metadataWithoutType) + assert(attr === Attribute.fromMetadata(metadata)) + assert(attr === NominalAttribute.fromMetadata(metadataWithoutType)) + intercept[NoSuchElementException] { + attr.toStructField() + } + } + + test("customized nominal attribute") { + val name = "size" + val index = 1 + val values = Array("small", "medium", "large") + val metadata = Metadata.fromJson( + """{"type":"nominal","name":"size","idx":1,"vals":["small","medium","large"]}""") + val metadataWithoutType = Metadata.fromJson( + """{"name":"size","idx":1,"vals":["small","medium","large"]}""") + val attr: NominalAttribute = NominalAttribute.defaultAttr + .withName(name) + .withIndex(index) + .withValues(values) + assert(attr.attrType === AttributeType.Nominal) + assert(!attr.isNumeric) + assert(attr.isNominal) + assert(attr.name === Some(name)) + assert(attr.index === Some(index)) + assert(attr.values === Some(values)) + assert(attr.indexOf("medium") === 1) + assert(attr.getValue(1) === "medium") + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = true) === metadata) + assert(attr.toMetadataImpl(withType = false) === metadataWithoutType) + assert(attr === Attribute.fromMetadata(metadata)) + assert(attr === NominalAttribute.fromMetadata(metadataWithoutType)) + assert(attr.withoutIndex === Attribute.fromStructField(attr.toStructField())) + + val attr2 = attr.withoutName.withoutIndex.withValues(attr.values.get :+ "x-large") + assert(attr2.name.isEmpty) + assert(attr2.index.isEmpty) + assert(attr2.values.get === Array("small", "medium", "large", "x-large")) + assert(attr2.indexOf("x-large") === 3) + assert(attr2 === Attribute.fromMetadata(attr2.toMetadataImpl())) + assert(attr2 === NominalAttribute.fromMetadata(attr2.toMetadataImpl(withType = false))) + } + + test("bad nominal attributes") { + val attr = NominalAttribute.defaultAttr + intercept[IllegalArgumentException](attr.withName("")) + intercept[IllegalArgumentException](attr.withIndex(-1)) + intercept[IllegalArgumentException](attr.withNumValues(-1)) + } + + test("default binary attribute") { + val attr = BinaryAttribute.defaultAttr + val metadata = Metadata.fromJson("""{"type":"binary"}""") + val metadataWithoutType = Metadata.fromJson("{}") + assert(attr.attrType === AttributeType.Binary) + assert(attr.isNumeric) + assert(attr.isNominal) + assert(attr.name.isEmpty) + assert(attr.index.isEmpty) + assert(attr.values.isEmpty) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = true) === metadata) + assert(attr.toMetadataImpl(withType = false) === metadataWithoutType) + assert(attr === Attribute.fromMetadata(metadata)) + assert(attr === BinaryAttribute.fromMetadata(metadataWithoutType)) + intercept[NoSuchElementException] { + attr.toStructField() + } + } + + test("customized binary attribute") { + val name = "clicked" + val index = 2 + val values = Array("no", "yes") + val metadata = Metadata.fromJson( + """{"type":"binary","name":"clicked","idx":2,"vals":["no","yes"]}""") + val metadataWithoutType = Metadata.fromJson( + """{"name":"clicked","idx":2,"vals":["no","yes"]}""") + val attr = BinaryAttribute.defaultAttr + .withName(name) + .withIndex(index) + .withValues(values(0), values(1)) + assert(attr.attrType === AttributeType.Binary) + assert(attr.isNumeric) + assert(attr.isNominal) + assert(attr.name === Some(name)) + assert(attr.index === Some(index)) + assert(attr.values.get === values) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = true) === metadata) + assert(attr.toMetadataImpl(withType = false) === metadataWithoutType) + assert(attr === Attribute.fromMetadata(metadata)) + assert(attr === BinaryAttribute.fromMetadata(metadataWithoutType)) + assert(attr.withoutIndex === Attribute.fromStructField(attr.toStructField())) + } + + test("bad binary attributes") { + val attr = BinaryAttribute.defaultAttr + intercept[IllegalArgumentException](attr.withName("")) + intercept[IllegalArgumentException](attr.withIndex(-1)) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala new file mode 100644 index 000000000000..af88595df524 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.scalatest.FunSuite + +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, + DecisionTreeSuite => OldDecisionTreeSuite} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext { + + import DecisionTreeClassifierSuite.compareAPIs + + private var categoricalDataPointsRDD: RDD[LabeledPoint] = _ + private var orderedLabeledPointsWithLabel0RDD: RDD[LabeledPoint] = _ + private var orderedLabeledPointsWithLabel1RDD: RDD[LabeledPoint] = _ + private var categoricalDataPointsForMulticlassRDD: RDD[LabeledPoint] = _ + private var continuousDataPointsForMulticlassRDD: RDD[LabeledPoint] = _ + private var categoricalDataPointsForMulticlassForOrderedFeaturesRDD: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + categoricalDataPointsRDD = + sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints()) + orderedLabeledPointsWithLabel0RDD = + sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()) + orderedLabeledPointsWithLabel1RDD = + sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()) + categoricalDataPointsForMulticlassRDD = + sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlass()) + continuousDataPointsForMulticlassRDD = + sc.parallelize(OldDecisionTreeSuite.generateContinuousDataPointsForMulticlass()) + categoricalDataPointsForMulticlassForOrderedFeaturesRDD = sc.parallelize( + OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// + + test("Binary classification stump with ordered categorical features") { + val dt = new DecisionTreeClassifier() + .setImpurity("gini") + .setMaxDepth(2) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3, 1-> 3) + val numClasses = 2 + compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses) + } + + test("Binary classification stump with fixed labels 0,1 for Entropy,Gini") { + val dt = new DecisionTreeClassifier() + .setMaxDepth(3) + .setMaxBins(100) + val numClasses = 2 + Array(orderedLabeledPointsWithLabel0RDD, orderedLabeledPointsWithLabel1RDD).foreach { rdd => + DecisionTreeClassifier.supportedImpurities.foreach { impurity => + dt.setImpurity(impurity) + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + } + } + + test("Multiclass classification stump with 3-ary (unordered) categorical features") { + val rdd = categoricalDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + val numClasses = 3 + val categoricalFeatures = Map(0 -> 3, 1 -> 3) + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("Binary classification stump with 1 continuous feature, to check off-by-1 error") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(3.0))) + val rdd = sc.parallelize(arr) + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + val numClasses = 2 + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + test("Binary classification stump with 2 continuous features") { + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0))))) + val rdd = sc.parallelize(arr) + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + val numClasses = 2 + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + test("Multiclass classification stump with unordered categorical features," + + " with just enough bins") { + val maxBins = 2 * (math.pow(2, 3 - 1).toInt - 1) // just enough bins to allow unordered features + val rdd = categoricalDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(maxBins) + val categoricalFeatures = Map(0 -> 3, 1 -> 3) + val numClasses = 3 + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("Multiclass classification stump with continuous features") { + val rdd = continuousDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val numClasses = 3 + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + test("Multiclass classification stump with continuous + unordered categorical features") { + val rdd = continuousDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3) + val numClasses = 3 + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("Multiclass classification stump with 10-ary (ordered) categorical features") { + val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 10, 1 -> 10) + val numClasses = 3 + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("Multiclass classification tree with 10-ary (ordered) categorical features," + + " with just enough bins") { + val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(10) + val categoricalFeatures = Map(0 -> 10, 1 -> 10) + val numClasses = 3 + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("split must satisfy min instances per node requirements") { + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) + val rdd = sc.parallelize(arr) + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val numClasses = 2 + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + test("do not choose split that does not satisfy min instance per node requirements") { + // if a split does not satisfy min instances per node requirements, + // this split is invalid, even though the information gain of split is large. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0))) + val rdd = sc.parallelize(arr) + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxBins(2) + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val categoricalFeatures = Map(0 -> 2, 1-> 2) + val numClasses = 2 + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("split must satisfy min info gain requirements") { + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) + val rdd = sc.parallelize(arr) + + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + .setMinInfoGain(1.0) + val numClasses = 2 + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + // TODO: Reinstate test once save/load are implemented + /* + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val oldModel = OldDecisionTreeSuite.createModel(OldAlgo.Classification) + val newModel = DecisionTreeClassificationModel.fromOld(oldModel) + + // Save model, load it back, and compare. + try { + newModel.save(sc, path) + val sameNewModel = DecisionTreeClassificationModel.load(sc, path) + TreeTests.checkEqual(newModel, sameNewModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + */ +} + +private[ml] object DecisionTreeClassifierSuite extends FunSuite { + + /** + * Train 2 decision trees on the given dataset, one using the old API and one using the new API. + * Convert the old tree to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + dt: DecisionTreeClassifier, + categoricalFeatures: Map[Int, Int], + numClasses: Int): Unit = { + val oldStrategy = dt.getOldStrategy(categoricalFeatures, numClasses) + val oldTree = OldDecisionTree.train(data, oldStrategy) + val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) + val newTree = dt.fit(newData) + // Use parent, fittingParamMap from newTree since these are not checked anyways. + val oldTreeAsNew = DecisionTreeClassificationModel.fromOld(oldTree, newTree.parent, + newTree.fittingParamMap, categoricalFeatures) + TreeTests.checkEqual(oldTreeAsNew, newTree) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index e8030fef55b1..35d8c2e16c6c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -20,50 +20,117 @@ package org.apache.spark.ml.classification import org.scalatest.FunSuite import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{SQLContext, SchemaRDD} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { @transient var sqlContext: SQLContext = _ - @transient var dataset: SchemaRDD = _ + @transient var dataset: DataFrame = _ + private val eps: Double = 1e-5 override def beforeAll(): Unit = { super.beforeAll() sqlContext = new SQLContext(sc) - dataset = sqlContext.createSchemaRDD( - sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) + dataset = sqlContext.createDataFrame( + sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 2)) } - test("logistic regression") { - val sqlContext = this.sqlContext - import sqlContext._ + test("logistic regression: default params") { val lr = new LogisticRegression + assert(lr.getLabelCol == "label") + assert(lr.getFeaturesCol == "features") + assert(lr.getPredictionCol == "prediction") + assert(lr.getRawPredictionCol == "rawPrediction") + assert(lr.getProbabilityCol == "probability") + assert(lr.getFitIntercept == true) val model = lr.fit(dataset) model.transform(dataset) - .select('label, 'prediction) + .select("label", "probability", "prediction", "rawPrediction") .collect() + assert(model.getThreshold === 0.5) + assert(model.getFeaturesCol == "features") + assert(model.getPredictionCol == "prediction") + assert(model.getRawPredictionCol == "rawPrediction") + assert(model.getProbabilityCol == "probability") + assert(model.intercept !== 0.0) + } + + test("logistic regression doesn't fit intercept when fitIntercept is off") { + val lr = new LogisticRegression + lr.setFitIntercept(false) + val model = lr.fit(dataset) + assert(model.intercept === 0.0) } test("logistic regression with setters") { - val sqlContext = this.sqlContext - import sqlContext._ + // Set params, train, and check as many params as we can. val lr = new LogisticRegression() .setMaxIter(10) .setRegParam(1.0) + .setThreshold(0.6) + .setProbabilityCol("myProbability") val model = lr.fit(dataset) - model.transform(dataset, model.threshold -> 0.8) // overwrite threshold - .select('label, 'score, 'prediction) + assert(model.fittingParamMap.get(lr.maxIter) === Some(10)) + assert(model.fittingParamMap.get(lr.regParam) === Some(1.0)) + assert(model.fittingParamMap.get(lr.threshold) === Some(0.6)) + assert(model.getThreshold === 0.6) + + // Modify model params, and check that the params worked. + model.setThreshold(1.0) + val predAllZero = model.transform(dataset) + .select("prediction", "myProbability") .collect() + .map { case Row(pred: Double, prob: Vector) => pred } + assert(predAllZero.forall(_ === 0), + s"With threshold=1.0, expected predictions to be all 0, but only" + + s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.") + // Call transform with params, and check that the params worked. + val predNotAllZero = + model.transform(dataset, model.threshold -> 0.0, model.probabilityCol -> "myProb") + .select("prediction", "myProb") + .collect() + .map { case Row(pred: Double, prob: Vector) => pred } + assert(predNotAllZero.exists(_ !== 0.0)) + + // Call fit() with new params, and check as many params as we can. + val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4, + lr.probabilityCol -> "theProb") + assert(model2.fittingParamMap.get(lr.maxIter).get === 5) + assert(model2.fittingParamMap.get(lr.regParam).get === 0.1) + assert(model2.fittingParamMap.get(lr.threshold).get === 0.4) + assert(model2.getThreshold === 0.4) + assert(model2.getProbabilityCol == "theProb") } - test("logistic regression fit and transform with varargs") { + test("logistic regression: Predictor, Classifier methods") { val sqlContext = this.sqlContext - import sqlContext._ val lr = new LogisticRegression - val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0) - model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability") - .select('label, 'probability, 'prediction) - .collect() + + val model = lr.fit(dataset) + assert(model.numClasses === 2) + + val threshold = model.getThreshold + val results = model.transform(dataset) + + // Compare rawPrediction with probability + results.select("rawPrediction", "probability").collect().map { + case Row(raw: Vector, prob: Vector) => + assert(raw.size === 2) + assert(prob.size === 2) + val probFromRaw1 = 1.0 / (1.0 + math.exp(-raw(1))) + assert(prob(1) ~== probFromRaw1 relTol eps) + assert(prob(0) ~== 1.0 - probFromRaw1 relTol eps) + } + + // Compare prediction with probability + results.select("prediction", "probability").collect().map { + case Row(pred: Double, prob: Vector) => + val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2 + assert(pred == predFromProb) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala new file mode 100644 index 000000000000..eaee3443c1f2 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{Row, SQLContext} + +class IDFSuite extends FunSuite with MLlibTestSparkContext { + + @transient var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = { + dataSet.map { + case data: DenseVector => + val res = data.toArray.zip(model.toArray).map { case (x, y) => x * y } + Vectors.dense(res) + case data: SparseVector => + val res = data.indices.zip(data.values).map { case (id, value) => + (id, value * model(id)) + } + Vectors.sparse(data.size, res) + } + } + + test("compute IDF with default parameter") { + val numOfFeatures = 4 + val data = Array( + Vectors.sparse(numOfFeatures, Array(1, 3), Array(1.0, 2.0)), + Vectors.dense(0.0, 1.0, 2.0, 3.0), + Vectors.sparse(numOfFeatures, Array(1), Array(1.0)) + ) + val numOfData = data.size + val idf = Vectors.dense(Array(0, 3, 1, 2).map { x => + math.log((numOfData + 1.0) / (x + 1.0)) + }) + val expected = scaleDataWithIDF(data, idf) + + val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") + + val idfModel = new IDF() + .setInputCol("features") + .setOutputCol("idfValue") + .fit(df) + + idfModel.transform(df).select("idfValue", "expected").collect().foreach { + case Row(x: Vector, y: Vector) => + assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") + } + } + + test("compute IDF with setter") { + val numOfFeatures = 4 + val data = Array( + Vectors.sparse(numOfFeatures, Array(1, 3), Array(1.0, 2.0)), + Vectors.dense(0.0, 1.0, 2.0, 3.0), + Vectors.sparse(numOfFeatures, Array(1), Array(1.0)) + ) + val numOfData = data.size + val idf = Vectors.dense(Array(0, 3, 1, 2).map { x => + if (x > 0) math.log((numOfData + 1.0) / (x + 1.0)) else 0 + }) + val expected = scaleDataWithIDF(data, idf) + + val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") + + val idfModel = new IDF() + .setInputCol("features") + .setOutputCol("idfValue") + .setMinDocFreq(1) + .fit(df) + + idfModel.transform(df).select("idfValue", "expected").collect().foreach { + case Row(x: Vector, y: Vector) => + assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala new file mode 100644 index 000000000000..9d09f24709e2 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + + +class NormalizerSuite extends FunSuite with MLlibTestSparkContext { + + @transient var data: Array[Vector] = _ + @transient var dataFrame: DataFrame = _ + @transient var normalizer: Normalizer = _ + @transient var l1Normalized: Array[Vector] = _ + @transient var l2Normalized: Array[Vector] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.6, -1.1, -3.0), + Vectors.sparse(3, Seq((1, 0.91), (2, 3.2))), + Vectors.sparse(3, Seq((0, 5.7), (1, 0.72), (2, 2.7))), + Vectors.sparse(3, Seq()) + ) + l1Normalized = Array( + Vectors.sparse(3, Seq((0, -0.465116279), (1, 0.53488372))), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.12765957, -0.23404255, -0.63829787), + Vectors.sparse(3, Seq((1, 0.22141119), (2, 0.7785888))), + Vectors.dense(0.625, 0.07894737, 0.29605263), + Vectors.sparse(3, Seq()) + ) + l2Normalized = Array( + Vectors.sparse(3, Seq((0, -0.65617871), (1, 0.75460552))), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.184549876, -0.3383414, -0.922749378), + Vectors.sparse(3, Seq((1, 0.27352993), (2, 0.96186349))), + Vectors.dense(0.897906166, 0.113419726, 0.42532397), + Vectors.sparse(3, Seq()) + ) + + val sqlContext = new SQLContext(sc) + dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData)) + normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normalized_features") + } + + def collectResult(result: DataFrame): Array[Vector] = { + result.select("normalized_features").collect().map { + case Row(features: Vector) => features + } + } + + def assertTypeOfVector(lhs: Array[Vector], rhs: Array[Vector]): Unit = { + assert((lhs, rhs).zipped.forall { + case (v1: DenseVector, v2: DenseVector) => true + case (v1: SparseVector, v2: SparseVector) => true + case _ => false + }, "The vector type should be preserved after normalization.") + } + + def assertValues(lhs: Array[Vector], rhs: Array[Vector]): Unit = { + assert((lhs, rhs).zipped.forall { (vector1, vector2) => + vector1 ~== vector2 absTol 1E-5 + }, "The vector value is not correct after normalization.") + } + + test("Normalization with default parameter") { + val result = collectResult(normalizer.transform(dataFrame)) + + assertTypeOfVector(data, result) + + assertValues(result, l2Normalized) + } + + test("Normalization with setter") { + normalizer.setP(1) + + val result = collectResult(normalizer.transform(dataFrame)) + + assertTypeOfVector(data, result) + + assertValues(result, l1Normalized) + } +} + +private object NormalizerSuite { + case class FeatureData(features: Vector) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala new file mode 100644 index 000000000000..c1d64fba0aa8 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{Row, SQLContext} +import org.scalatest.exceptions.TestFailedException + +class PolynomialExpansionSuite extends FunSuite with MLlibTestSparkContext { + + @transient var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + test("Polynomial expansion with default parameter") { + val data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(-2.0, 2.3), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.6, -1.1, -3.0), + Vectors.sparse(3, Seq()) + ) + + val twoDegreeExpansion: Array[Vector] = Array( + Vectors.sparse(9, Array(0, 1, 2, 3, 4), Array(-2.0, 4.0, 2.3, -4.6, 5.29)), + Vectors.dense(-2.0, 4.0, 2.3, -4.6, 5.29), + Vectors.dense(new Array[Double](9)), + Vectors.dense(0.6, 0.36, -1.1, -0.66, 1.21, -3.0, -1.8, 3.3, 9.0), + Vectors.sparse(9, Array.empty, Array.empty)) + + val df = sqlContext.createDataFrame(data.zip(twoDegreeExpansion)).toDF("features", "expected") + + val polynomialExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + + polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach { + case Row(expanded: DenseVector, expected: DenseVector) => + assert(expanded ~== expected absTol 1e-1) + case Row(expanded: SparseVector, expected: SparseVector) => + assert(expanded ~== expected absTol 1e-1) + case _ => + throw new TestFailedException("Unmatched data types after polynomial expansion", 0) + } + } + + test("Polynomial expansion with setter") { + val data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(-2.0, 2.3), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.6, -1.1, -3.0), + Vectors.sparse(3, Seq()) + ) + + val threeDegreeExpansion: Array[Vector] = Array( + Vectors.sparse(19, Array(0, 1, 2, 3, 4, 5, 6, 7, 8), + Array(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17)), + Vectors.dense(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17), + Vectors.dense(new Array[Double](19)), + Vectors.dense(0.6, 0.36, 0.216, -1.1, -0.66, -0.396, 1.21, 0.726, -1.331, -3.0, -1.8, + -1.08, 3.3, 1.98, -3.63, 9.0, 5.4, -9.9, -27.0), + Vectors.sparse(19, Array.empty, Array.empty)) + + val df = sqlContext.createDataFrame(data.zip(threeDegreeExpansion)).toDF("features", "expected") + + val polynomialExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3) + + polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach { + case Row(expanded: DenseVector, expected: DenseVector) => + assert(expanded ~== expected absTol 1e-1) + case Row(expanded: SparseVector, expected: SparseVector) => + assert(expanded ~== expected absTol 1e-1) + case _ => + throw new TestFailedException("Unmatched data types after polynomial expansion", 0) + } + } +} + diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala new file mode 100644 index 000000000000..00b5d094d82f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.scalatest.FunSuite + +import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.SQLContext + +class StringIndexerSuite extends FunSuite with MLlibTestSparkContext { + private var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + test("StringIndexer") { + val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + val transformed = indexer.transform(df) + val attr = Attribute.fromStructField(transformed.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("a", "c", "b")) + val output = transformed.select("id", "labelIndex").map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 0, b -> 2, c -> 1 + val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) + assert(output === expected) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala new file mode 100644 index 000000000000..d186ead8f542 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.beans.BeanInfo + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +@BeanInfo +case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) + +class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext { + import org.apache.spark.ml.feature.RegexTokenizerSuite._ + + @transient var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + test("RegexTokenizer") { + val tokenizer = new RegexTokenizer() + .setInputCol("rawText") + .setOutputCol("tokens") + + val dataset0 = sqlContext.createDataFrame(Seq( + TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization", ".")), + TokenizerTestData("Te,st. punct", Array("Te", ",", "st", ".", "punct")) + )) + testRegexTokenizer(tokenizer, dataset0) + + val dataset1 = sqlContext.createDataFrame(Seq( + TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization")), + TokenizerTestData("Te,st. punct", Array("punct")) + )) + + tokenizer.setMinTokenLength(3) + testRegexTokenizer(tokenizer, dataset1) + + tokenizer + .setPattern("\\s") + .setGaps(true) + .setMinTokenLength(0) + val dataset2 = sqlContext.createDataFrame(Seq( + TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization.")), + TokenizerTestData("Te,st. punct", Array("Te,st.", "", "punct")) + )) + testRegexTokenizer(tokenizer, dataset2) + } +} + +object RegexTokenizerSuite extends FunSuite { + + def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = { + t.transform(dataset) + .select("tokens", "wantedTokens") + .collect() + .foreach { + case Row(tokens, wantedTokens) => + assert(tokens === wantedTokens) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala new file mode 100644 index 000000000000..57d0278e0363 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.scalatest.FunSuite + +import org.apache.spark.SparkException +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{Row, SQLContext} + +class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext { + + @transient var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + test("assemble") { + import org.apache.spark.ml.feature.VectorAssembler.assemble + assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty)) + assert(assemble(0.0, 1.0) === Vectors.sparse(2, Array(1), Array(1.0))) + val dv = Vectors.dense(2.0, 0.0) + assert(assemble(0.0, dv, 1.0) === Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0))) + val sv = Vectors.sparse(2, Array(0, 1), Array(3.0, 4.0)) + assert(assemble(0.0, dv, 1.0, sv) === + Vectors.sparse(6, Array(1, 3, 4, 5), Array(2.0, 1.0, 3.0, 4.0))) + for (v <- Seq(1, "a", null)) { + intercept[SparkException](assemble(v)) + intercept[SparkException](assemble(1.0, v)) + } + } + + test("VectorAssembler") { + val df = sqlContext.createDataFrame(Seq( + (0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L) + )).toDF("id", "x", "y", "name", "z", "n") + val assembler = new VectorAssembler() + .setInputCols(Array("x", "y", "z", "n")) + .setOutputCol("features") + assembler.transform(df).select("features").collect().foreach { + case Row(v: Vector) => + assert(v === Vectors.sparse(6, Array(1, 2, 4, 5), Array(1.0, 2.0, 3.0, 10.0))) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala new file mode 100644 index 000000000000..1b261b264385 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.beans.{BeanInfo, BeanProperty} + +import org.scalatest.FunSuite + +import org.apache.spark.SparkException +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.util.TestingUtils +import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, SQLContext} + + +class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext { + + import VectorIndexerSuite.FeatureData + + @transient var sqlContext: SQLContext = _ + + // identical, of length 3 + @transient var densePoints1: DataFrame = _ + @transient var sparsePoints1: DataFrame = _ + @transient var point1maxes: Array[Double] = _ + + // identical, of length 2 + @transient var densePoints2: DataFrame = _ + @transient var sparsePoints2: DataFrame = _ + + // different lengths + @transient var badPoints: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + val densePoints1Seq = Seq( + Vectors.dense(1.0, 2.0, 0.0), + Vectors.dense(0.0, 1.0, 2.0), + Vectors.dense(0.0, 0.0, -1.0), + Vectors.dense(1.0, 3.0, 2.0)) + val sparsePoints1Seq = Seq( + Vectors.sparse(3, Array(0, 1), Array(1.0, 2.0)), + Vectors.sparse(3, Array(1, 2), Array(1.0, 2.0)), + Vectors.sparse(3, Array(2), Array(-1.0)), + Vectors.sparse(3, Array(0, 1, 2), Array(1.0, 3.0, 2.0))) + point1maxes = Array(1.0, 3.0, 2.0) + + val densePoints2Seq = Seq( + Vectors.dense(1.0, 1.0, 0.0, 1.0), + Vectors.dense(0.0, 1.0, 1.0, 1.0), + Vectors.dense(-1.0, 1.0, 2.0, 0.0)) + val sparsePoints2Seq = Seq( + Vectors.sparse(4, Array(0, 1, 3), Array(1.0, 1.0, 1.0)), + Vectors.sparse(4, Array(1, 2, 3), Array(1.0, 1.0, 1.0)), + Vectors.sparse(4, Array(0, 1, 2), Array(-1.0, 1.0, 2.0))) + + val badPointsSeq = Seq( + Vectors.sparse(2, Array(0, 1), Array(1.0, 1.0)), + Vectors.sparse(3, Array(2), Array(-1.0))) + + // Sanity checks for assumptions made in tests + assert(densePoints1Seq.head.size == sparsePoints1Seq.head.size) + assert(densePoints2Seq.head.size == sparsePoints2Seq.head.size) + assert(densePoints1Seq.head.size != densePoints2Seq.head.size) + def checkPair(dvSeq: Seq[Vector], svSeq: Seq[Vector]): Unit = { + assert(dvSeq.zip(svSeq).forall { case (dv, sv) => dv.toArray === sv.toArray }, + "typo in unit test") + } + checkPair(densePoints1Seq, sparsePoints1Seq) + checkPair(densePoints2Seq, sparsePoints2Seq) + + sqlContext = new SQLContext(sc) + densePoints1 = sqlContext.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData)) + sparsePoints1 = sqlContext.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData)) + densePoints2 = sqlContext.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData)) + sparsePoints2 = sqlContext.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData)) + badPoints = sqlContext.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData)) + } + + private def getIndexer: VectorIndexer = + new VectorIndexer().setInputCol("features").setOutputCol("indexed") + + test("Cannot fit an empty DataFrame") { + val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData)) + val vectorIndexer = getIndexer + intercept[IllegalArgumentException] { + vectorIndexer.fit(rdd) + } + } + + test("Throws error when given RDDs with different size vectors") { + val vectorIndexer = getIndexer + val model = vectorIndexer.fit(densePoints1) // vectors of length 3 + model.transform(densePoints1) // should work + model.transform(sparsePoints1) // should work + intercept[IllegalArgumentException] { + model.transform(densePoints2) + println("Did not throw error when fit, transform were called on vectors of different lengths") + } + intercept[SparkException] { + vectorIndexer.fit(badPoints) + println("Did not throw error when fitting vectors of different lengths in same RDD.") + } + } + + test("Same result with dense and sparse vectors") { + def testDenseSparse(densePoints: DataFrame, sparsePoints: DataFrame): Unit = { + val denseVectorIndexer = getIndexer.setMaxCategories(2) + val sparseVectorIndexer = getIndexer.setMaxCategories(2) + val denseModel = denseVectorIndexer.fit(densePoints) + val sparseModel = sparseVectorIndexer.fit(sparsePoints) + val denseMap = denseModel.categoryMaps + val sparseMap = sparseModel.categoryMaps + assert(denseMap.keys.toSet == sparseMap.keys.toSet, + "Categorical features chosen from dense vs. sparse vectors did not match.") + assert(denseMap == sparseMap, + "Categorical feature value indexes chosen from dense vs. sparse vectors did not match.") + } + testDenseSparse(densePoints1, sparsePoints1) + testDenseSparse(densePoints2, sparsePoints2) + } + + test("Builds valid categorical feature value index, transform correctly, check metadata") { + def checkCategoryMaps( + data: DataFrame, + maxCategories: Int, + categoricalFeatures: Set[Int]): Unit = { + val collectedData = data.collect().map(_.getAs[Vector](0)) + val errMsg = s"checkCategoryMaps failed for input with maxCategories=$maxCategories," + + s" categoricalFeatures=${categoricalFeatures.mkString(", ")}" + try { + val vectorIndexer = getIndexer.setMaxCategories(maxCategories) + val model = vectorIndexer.fit(data) + val categoryMaps = model.categoryMaps + // Chose correct categorical features + assert(categoryMaps.keys.toSet === categoricalFeatures) + val transformed = model.transform(data).select("indexed") + val indexedRDD: RDD[Vector] = transformed.map(_.getAs[Vector](0)) + val featureAttrs = AttributeGroup.fromStructField(transformed.schema("indexed")) + assert(featureAttrs.name === "indexed") + assert(featureAttrs.attributes.get.length === model.numFeatures) + categoricalFeatures.foreach { feature: Int => + val origValueSet = collectedData.map(_(feature)).toSet + val targetValueIndexSet = Range(0, origValueSet.size).toSet + val catMap = categoryMaps(feature) + assert(catMap.keys.toSet === origValueSet) // Correct categories + assert(catMap.values.toSet === targetValueIndexSet) // Correct category indices + if (origValueSet.contains(0.0)) { + assert(catMap(0.0) === 0) // value 0 gets index 0 + } + // Check transformed data + assert(indexedRDD.map(_(feature)).collect().toSet === targetValueIndexSet) + // Check metadata + val featureAttr = featureAttrs(feature) + assert(featureAttr.index.get === feature) + featureAttr match { + case attr: BinaryAttribute => + assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) + case attr: NominalAttribute => + assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) + assert(attr.isOrdinal.get === false) + case _ => + throw new RuntimeException(errMsg + s". Categorical feature $feature failed" + + s" metadata check. Found feature attribute: $featureAttr.") + } + } + // Check numerical feature metadata. + Range(0, model.numFeatures).filter(feature => !categoricalFeatures.contains(feature)) + .foreach { feature: Int => + val featureAttr = featureAttrs(feature) + featureAttr match { + case attr: NumericAttribute => + assert(featureAttr.index.get === feature) + case _ => + throw new RuntimeException(errMsg + s". Numerical feature $feature failed" + + s" metadata check. Found feature attribute: $featureAttr.") + } + } + } catch { + case e: org.scalatest.exceptions.TestFailedException => + println(errMsg) + throw e + } + } + checkCategoryMaps(densePoints1, maxCategories = 2, categoricalFeatures = Set(0)) + checkCategoryMaps(densePoints1, maxCategories = 3, categoricalFeatures = Set(0, 2)) + checkCategoryMaps(densePoints2, maxCategories = 2, categoricalFeatures = Set(1, 3)) + } + + test("Maintain sparsity for sparse vectors") { + def checkSparsity(data: DataFrame, maxCategories: Int): Unit = { + val points = data.collect().map(_.getAs[Vector](0)) + val vectorIndexer = getIndexer.setMaxCategories(maxCategories) + val model = vectorIndexer.fit(data) + val indexedPoints = model.transform(data).select("indexed").map(_.getAs[Vector](0)).collect() + points.zip(indexedPoints).foreach { + case (orig: SparseVector, indexed: SparseVector) => + assert(orig.indices.length == indexed.indices.length) + case _ => throw new UnknownError("Unit test has a bug in it.") // should never happen + } + } + checkSparsity(sparsePoints1, maxCategories = 2) + checkSparsity(sparsePoints2, maxCategories = 2) + } + + test("Preserve metadata") { + // For continuous features, preserve name and stats. + val featureAttributes: Array[Attribute] = point1maxes.zipWithIndex.map { case (maxVal, i) => + NumericAttribute.defaultAttr.withName(i.toString).withMax(maxVal) + } + val attrGroup = new AttributeGroup("features", featureAttributes) + val densePoints1WithMeta = + densePoints1.select(densePoints1("features").as("features", attrGroup.toMetadata())) + val vectorIndexer = getIndexer.setMaxCategories(2) + val model = vectorIndexer.fit(densePoints1WithMeta) + // Check that ML metadata are preserved. + val indexedPoints = model.transform(densePoints1WithMeta) + val transAttributes: Array[Attribute] = + AttributeGroup.fromStructField(indexedPoints.schema("indexed")).attributes.get + featureAttributes.zip(transAttributes).foreach { case (orig, trans) => + assert(orig.name === trans.name) + (orig, trans) match { + case (orig: NumericAttribute, trans: NumericAttribute) => + assert(orig.max.nonEmpty && orig.max === trans.max) + case _ => + // do nothing + // TODO: Once input features marked as categorical are handled correctly, check that here. + } + } + // Check that non-ML metadata are preserved. + TestingUtils.testPreserveMetadata(densePoints1WithMeta, model, "features", "indexed") + } +} + +private[feature] object VectorIndexerSuite { + @BeanInfo + case class FeatureData(@BeanProperty features: Vector) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala new file mode 100644 index 000000000000..2e57d4ce37f1 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.impl + +import scala.collection.JavaConverters._ + +import org.scalatest.FunSuite + +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} +import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.tree.{DecisionTreeModel, InternalNode, LeafNode, Node} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{SQLContext, DataFrame} + + +private[ml] object TreeTests extends FunSuite { + + /** + * Convert the given data to a DataFrame, and set the features and label metadata. + * @param data Dataset. Categorical features and labels must already have 0-based indices. + * This must be non-empty. + * @param categoricalFeatures Map: categorical feature index -> number of distinct values + * @param numClasses Number of classes label can take. If 0, mark as continuous. + * @return DataFrame with metadata + */ + def setMetadata( + data: RDD[LabeledPoint], + categoricalFeatures: Map[Int, Int], + numClasses: Int): DataFrame = { + val sqlContext = new SQLContext(data.sparkContext) + import sqlContext.implicits._ + val df = data.toDF() + val numFeatures = data.first().features.size + val featuresAttributes = Range(0, numFeatures).map { feature => + if (categoricalFeatures.contains(feature)) { + NominalAttribute.defaultAttr.withIndex(feature).withNumValues(categoricalFeatures(feature)) + } else { + NumericAttribute.defaultAttr.withIndex(feature) + } + }.toArray + val featuresMetadata = new AttributeGroup("features", featuresAttributes).toMetadata() + val labelAttribute = if (numClasses == 0) { + NumericAttribute.defaultAttr.withName("label") + } else { + NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses) + } + val labelMetadata = labelAttribute.toMetadata() + df.select(df("features").as("features", featuresMetadata), + df("label").as("label", labelMetadata)) + } + + /** Java-friendly version of [[setMetadata()]] */ + def setMetadata( + data: JavaRDD[LabeledPoint], + categoricalFeatures: java.util.Map[java.lang.Integer, java.lang.Integer], + numClasses: Int): DataFrame = { + setMetadata(data.rdd, categoricalFeatures.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + numClasses) + } + + /** + * Check if the two trees are exactly the same. + * Note: I hesitate to override Node.equals since it could cause problems if users + * make mistakes such as creating loops of Nodes. + * If the trees are not equal, this prints the two trees and throws an exception. + */ + def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = { + try { + checkEqual(a.rootNode, b.rootNode) + } catch { + case ex: Exception => + throw new AssertionError("checkEqual failed since the two trees were not identical.\n" + + "TREE A:\n" + a.toDebugString + "\n" + + "TREE B:\n" + b.toDebugString + "\n", ex) + } + } + + /** + * Return true iff the two nodes and their descendants are exactly the same. + * Note: I hesitate to override Node.equals since it could cause problems if users + * make mistakes such as creating loops of Nodes. + */ + private def checkEqual(a: Node, b: Node): Unit = { + assert(a.prediction === b.prediction) + assert(a.impurity === b.impurity) + (a, b) match { + case (aye: InternalNode, bee: InternalNode) => + assert(aye.split === bee.split) + checkEqual(aye.leftChild, bee.leftChild) + checkEqual(aye.rightChild, bee.rightChild) + case (aye: LeafNode, bee: LeafNode) => // do nothing + case _ => + throw new AssertionError("Found mismatched nodes") + } + } + + // TODO: Reinstate after adding ensembles + /** + * Check if the two models are exactly the same. + * If the models are not equal, this throws an exception. + */ + /* + def checkEqual(a: TreeEnsembleModel, b: TreeEnsembleModel): Unit = { + try { + a.getTrees.zip(b.getTrees).foreach { case (treeA, treeB) => + TreeTests.checkEqual(treeA, treeB) + } + assert(a.getTreeWeights === b.getTreeWeights) + } catch { + case ex: Exception => throw new AssertionError( + "checkEqual failed since the two tree ensembles were not identical") + } + } + */ +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 1ce298761237..88ea679eeaad 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -21,19 +21,25 @@ import org.scalatest.FunSuite class ParamsSuite extends FunSuite { - val solver = new TestParams() - import solver.{inputCol, maxIter} - test("param") { + val solver = new TestParams() + import solver.{maxIter, inputCol} + assert(maxIter.name === "maxIter") assert(maxIter.doc === "max number of iterations") - assert(maxIter.defaultValue.get === 100) assert(maxIter.parent.eq(solver)) - assert(maxIter.toString === "maxIter: max number of iterations (default: 100)") - assert(inputCol.defaultValue === None) + assert(maxIter.toString === "maxIter: max number of iterations (default: 10)") + + solver.setMaxIter(5) + assert(maxIter.toString === "maxIter: max number of iterations (default: 10, current: 5)") + + assert(inputCol.toString === "inputCol: input column name (undefined)") } test("param pair") { + val solver = new TestParams() + import solver.maxIter + val pair0 = maxIter -> 5 val pair1 = maxIter.w(5) val pair2 = ParamPair(maxIter, 5) @@ -44,10 +50,12 @@ class ParamsSuite extends FunSuite { } test("param map") { + val solver = new TestParams() + import solver.{maxIter, inputCol} + val map0 = ParamMap.empty assert(!map0.contains(maxIter)) - assert(map0(maxIter) === maxIter.defaultValue.get) map0.put(maxIter, 10) assert(map0.contains(maxIter)) assert(map0(maxIter) === 10) @@ -78,23 +86,39 @@ class ParamsSuite extends FunSuite { } test("params") { + val solver = new TestParams() + import solver.{maxIter, inputCol} + val params = solver.params - assert(params.size === 2) + assert(params.length === 2) assert(params(0).eq(inputCol), "params must be ordered by name") assert(params(1).eq(maxIter)) + + assert(!solver.isSet(maxIter)) + assert(solver.isDefined(maxIter)) + assert(solver.getMaxIter === 10) + solver.setMaxIter(100) + assert(solver.isSet(maxIter)) + assert(solver.getMaxIter === 100) + assert(!solver.isSet(inputCol)) + assert(!solver.isDefined(inputCol)) + intercept[NoSuchElementException](solver.getInputCol) + assert(solver.explainParams() === Seq(inputCol, maxIter).mkString("\n")) + assert(solver.getParam("inputCol").eq(inputCol)) assert(solver.getParam("maxIter").eq(maxIter)) - intercept[NoSuchMethodException] { + intercept[NoSuchElementException] { solver.getParam("abc") } - assert(!solver.isSet(inputCol)) + intercept[IllegalArgumentException] { solver.validate() } solver.validate(ParamMap(inputCol -> "input")) solver.setInputCol("input") assert(solver.isSet(inputCol)) + assert(solver.isDefined(inputCol)) assert(solver.getInputCol === "input") solver.validate() intercept[IllegalArgumentException] { @@ -104,5 +128,8 @@ class ParamsSuite extends FunSuite { intercept[IllegalArgumentException] { solver.validate() } + + solver.clearMaxIter() + assert(!solver.isSet(maxIter)) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index 1a65883d78a7..641b64b42a5e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -17,20 +17,21 @@ package org.apache.spark.ml.param +import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter} + /** A subclass of Params for testing. */ -class TestParams extends Params { +class TestParams extends Params with HasMaxIter with HasInputCol { - val maxIter = new IntParam(this, "maxIter", "max number of iterations", Some(100)) def setMaxIter(value: Int): this.type = { set(maxIter, value); this } - def getMaxIter: Int = get(maxIter) - - val inputCol = new Param[String](this, "inputCol", "input column name") def setInputCol(value: String): this.type = { set(inputCol, value); this } - def getInputCol: String = get(inputCol) - override def validate(paramMap: ParamMap) = { - val m = this.paramMap ++ paramMap + setDefault(maxIter -> 10) + + override def validate(paramMap: ParamMap): Unit = { + val m = extractParamMap(paramMap) require(m(maxIter) >= 0) require(m.contains(inputCol)) } + + def clearMaxIter(): this.type = clear(maxIter) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index cdd4db1b5b7d..fc7349330cf8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -17,31 +17,42 @@ package org.apache.spark.ml.recommendation +import java.io.File import java.util.Random import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.scalatest.FunSuite -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkException} import org.apache.spark.ml.recommendation.ALS._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.util.Utils class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { private var sqlContext: SQLContext = _ + private var tempDir: File = _ override def beforeAll(): Unit = { super.beforeAll() + tempDir = Utils.createTempDir() + sc.setCheckpointDir(tempDir.getAbsolutePath) sqlContext = new SQLContext(sc) } + override def afterAll(): Unit = { + Utils.deleteRecursively(tempDir) + super.afterAll() + } + test("LocalIndexEncoder") { val random = new Random for (numBlocks <- Seq(1, 2, 5, 10, 20, 50, 100)) { @@ -58,39 +69,42 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { } } - test("normal equation construction with explict feedback") { + test("normal equation construction") { val k = 2 val ne0 = new NormalEquation(k) - .add(Array(1.0f, 2.0f), 3.0f) - .add(Array(4.0f, 5.0f), 6.0f) + .add(Array(1.0f, 2.0f), 3.0) + .add(Array(4.0f, 5.0f), 6.0, 2.0) // weighted assert(ne0.k === k) assert(ne0.triK === k * (k + 1) / 2) - assert(ne0.n === 2) // NumPy code that computes the expected values: // A = np.matrix("1 2; 4 5") // b = np.matrix("3; 6") - // ata = A.transpose() * A - // atb = A.transpose() * b - assert(Vectors.dense(ne0.ata) ~== Vectors.dense(17.0, 22.0, 29.0) relTol 1e-8) - assert(Vectors.dense(ne0.atb) ~== Vectors.dense(27.0, 36.0) relTol 1e-8) + // C = np.matrix(np.diag([1, 2])) + // ata = A.transpose() * C * A + // atb = A.transpose() * C * b + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(33.0, 42.0, 54.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(51.0, 66.0) relTol 1e-8) val ne1 = new NormalEquation(2) - .add(Array(7.0f, 8.0f), 9.0f) + .add(Array(7.0f, 8.0f), 9.0) ne0.merge(ne1) - assert(ne0.n === 3) // NumPy code that computes the expected values: // A = np.matrix("1 2; 4 5; 7 8") // b = np.matrix("3; 6; 9") - // ata = A.transpose() * A - // atb = A.transpose() * b - assert(Vectors.dense(ne0.ata) ~== Vectors.dense(66.0, 78.0, 93.0) relTol 1e-8) - assert(Vectors.dense(ne0.atb) ~== Vectors.dense(90.0, 108.0) relTol 1e-8) + // C = np.matrix(np.diag([1, 2, 1])) + // ata = A.transpose() * C * A + // atb = A.transpose() * C * b + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(82.0, 98.0, 118.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(114.0, 138.0) relTol 1e-8) intercept[IllegalArgumentException] { - ne0.add(Array(1.0f), 2.0f) + ne0.add(Array(1.0f), 2.0) + } + intercept[IllegalArgumentException] { + ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0) } intercept[IllegalArgumentException] { - ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0f) + ne0.add(Array(1.0f, 2.0f), 0.0, -1.0) } intercept[IllegalArgumentException] { val ne2 = new NormalEquation(3) @@ -98,41 +112,16 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { } ne0.reset() - assert(ne0.n === 0) assert(ne0.ata.forall(_ == 0.0)) assert(ne0.atb.forall(_ == 0.0)) } - test("normal equation construction with implicit feedback") { - val k = 2 - val alpha = 0.5 - val ne0 = new NormalEquation(k) - .addImplicit(Array(-5.0f, -4.0f), -3.0f, alpha) - .addImplicit(Array(-2.0f, -1.0f), 0.0f, alpha) - .addImplicit(Array(1.0f, 2.0f), 3.0f, alpha) - assert(ne0.k === k) - assert(ne0.triK === k * (k + 1) / 2) - assert(ne0.n === 0) // addImplicit doesn't increase the count. - // NumPy code that computes the expected values: - // alpha = 0.5 - // A = np.matrix("-5 -4; -2 -1; 1 2") - // b = np.matrix("-3; 0; 3") - // b1 = b > 0 - // c = 1.0 + alpha * np.abs(b) - // C = np.diag(c.A1) - // I = np.eye(3) - // ata = A.transpose() * (C - I) * A - // atb = A.transpose() * C * b1 - assert(Vectors.dense(ne0.ata) ~== Vectors.dense(39.0, 33.0, 30.0) relTol 1e-8) - assert(Vectors.dense(ne0.atb) ~== Vectors.dense(2.5, 5.0) relTol 1e-8) - } - test("CholeskySolver") { val k = 2 val ne0 = new NormalEquation(k) - .add(Array(1.0f, 2.0f), 4.0f) - .add(Array(1.0f, 3.0f), 9.0f) - .add(Array(1.0f, 4.0f), 16.0f) + .add(Array(1.0f, 2.0f), 4.0) + .add(Array(1.0f, 3.0f), 9.0) + .add(Array(1.0f, 4.0f), 16.0) val ne1 = new NormalEquation(k) .merge(ne0) @@ -144,18 +133,17 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { // x0 = np.linalg.lstsq(A, b)[0] assert(Vectors.dense(x0) ~== Vectors.dense(-8.333333, 6.0) relTol 1e-6) - assert(ne0.n === 0) assert(ne0.ata.forall(_ == 0.0)) assert(ne0.atb.forall(_ == 0.0)) - val x1 = chol.solve(ne1, 0.5).map(_.toDouble) + val x1 = chol.solve(ne1, 1.5).map(_.toDouble) // NumPy code that computes the expected solution, where lambda is scaled by n: - // x0 = np.linalg.solve(A.transpose() * A + 0.5 * 3 * np.eye(2), A.transpose() * b) + // x0 = np.linalg.solve(A.transpose() * A + 1.5 * np.eye(2), A.transpose() * b) assert(Vectors.dense(x1) ~== Vectors.dense(-0.1155556, 3.28) relTol 1e-6) } test("RatingBlockBuilder") { - val emptyBuilder = new RatingBlockBuilder() + val emptyBuilder = new RatingBlockBuilder[Int]() assert(emptyBuilder.size === 0) val emptyBlock = emptyBuilder.build() assert(emptyBlock.srcIds.isEmpty) @@ -179,12 +167,12 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { test("UncompressedInBlock") { val encoder = new LocalIndexEncoder(10) - val uncompressed = new UncompressedInBlockBuilder(encoder) + val uncompressed = new UncompressedInBlockBuilder[Int](encoder) .add(0, Array(1, 0, 2), Array(0, 1, 4), Array(1.0f, 2.0f, 3.0f)) .add(1, Array(3, 0), Array(2, 5), Array(4.0f, 5.0f)) .build() - assert(uncompressed.size === 5) - val records = Seq.tabulate(uncompressed.size) { i => + assert(uncompressed.length === 5) + val records = Seq.tabulate(uncompressed.length) { i => val dstEncodedIndex = uncompressed.dstEncodedIndices(i) val dstBlockId = encoder.blockId(dstEncodedIndex) val dstLocalIndex = encoder.localIndex(dstEncodedIndex) @@ -228,15 +216,15 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { numItems: Int, rank: Int, noiseStd: Double = 0.0, - seed: Long = 11L): (RDD[Rating], RDD[Rating]) = { + seed: Long = 11L): (RDD[Rating[Int]], RDD[Rating[Int]]) = { val trainingFraction = 0.6 val testFraction = 0.3 val totalFraction = trainingFraction + testFraction val random = new Random(seed) val userFactors = genFactors(numUsers, rank, random) val itemFactors = genFactors(numItems, rank, random) - val training = ArrayBuffer.empty[Rating] - val test = ArrayBuffer.empty[Rating] + val training = ArrayBuffer.empty[Rating[Int]] + val test = ArrayBuffer.empty[Rating[Int]] for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) { val x = random.nextDouble() if (x < totalFraction) { @@ -268,7 +256,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { numItems: Int, rank: Int, noiseStd: Double = 0.0, - seed: Long = 11L): (RDD[Rating], RDD[Rating]) = { + seed: Long = 11L): (RDD[Rating[Int]], RDD[Rating[Int]]) = { // The assumption of the implicit feedback model is that unobserved ratings are more likely to // be negatives. val positiveFraction = 0.8 @@ -279,8 +267,8 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { val random = new Random(seed) val userFactors = genFactors(numUsers, rank, random) val itemFactors = genFactors(numItems, rank, random) - val training = ArrayBuffer.empty[Rating] - val test = ArrayBuffer.empty[Rating] + val training = ArrayBuffer.empty[Rating[Int]] + val test = ArrayBuffer.empty[Rating[Int]] for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) { val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1) val threshold = if (rating > 0) positiveFraction else negativeFraction @@ -340,8 +328,8 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { * @param targetRMSE target test RMSE */ def testALS( - training: RDD[Rating], - test: RDD[Rating], + training: RDD[Rating[Int]], + test: RDD[Rating[Int]], rank: Int, maxIter: Int, regParam: Double, @@ -350,7 +338,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { numItemBlocks: Int = 3, targetRMSE: Double = 0.05): Unit = { val sqlContext = this.sqlContext - import sqlContext.{createSchemaRDD, symbolToUnresolvedAttribute} + import sqlContext.implicits._ val als = new ALS() .setRank(rank) .setRegParam(regParam) @@ -358,9 +346,9 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { .setNumUserBlocks(numUserBlocks) .setNumItemBlocks(numItemBlocks) val alpha = als.getAlpha - val model = als.fit(training) - val predictions = model.transform(test) - .select('rating, 'prediction) + val model = als.fit(training.toDF()) + val predictions = model.transform(test.toDF()) + .select("rating", "prediction") .map { case Row(rating: Float, prediction: Float) => (rating.toDouble, prediction.toDouble) } @@ -414,7 +402,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { val (training, test) = genExplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) for ((numUserBlocks, numItemBlocks) <- Seq((1, 1), (1, 2), (2, 1), (2, 2))) { - testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03, + testALS(training, test, maxIter = 4, rank = 3, regParam = 0.01, targetRMSE = 0.03, numUserBlocks = numUserBlocks, numItemBlocks = numItemBlocks) } } @@ -432,4 +420,64 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, implicitPrefs = true, targetRMSE = 0.3) } + + test("using generic ID types") { + val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) + + val longRatings = ratings.map(r => Rating(r.user.toLong, r.item.toLong, r.rating)) + val (longUserFactors, _) = ALS.train(longRatings, rank = 2, maxIter = 4) + assert(longUserFactors.first()._1.getClass === classOf[Long]) + + val strRatings = ratings.map(r => Rating(r.user.toString, r.item.toString, r.rating)) + val (strUserFactors, _) = ALS.train(strRatings, rank = 2, maxIter = 4) + assert(strUserFactors.first()._1.getClass === classOf[String]) + } + + test("nonnegative constraint") { + val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) + val (userFactors, itemFactors) = ALS.train(ratings, rank = 2, maxIter = 4, nonnegative = true) + def isNonnegative(factors: RDD[(Int, Array[Float])]): Boolean = { + factors.values.map { _.forall(_ >= 0.0) }.reduce(_ && _) + } + assert(isNonnegative(userFactors)) + assert(isNonnegative(itemFactors)) + // TODO: Validate the solution. + } + + test("als partitioner is a projection") { + for (p <- Seq(1, 10, 100, 1000)) { + val part = new ALSPartitioner(p) + var k = 0 + while (k < p) { + assert(k === part.getPartition(k)) + assert(k === part.getPartition(k.toLong)) + k += 1 + } + } + } + + test("partitioner in returned factors") { + val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) + val (userFactors, itemFactors) = ALS.train( + ratings, rank = 2, maxIter = 4, numUserBlocks = 3, numItemBlocks = 4) + for ((tpe, factors) <- Seq(("User", userFactors), ("Item", itemFactors))) { + assert(userFactors.partitioner.isDefined, s"$tpe factors should have partitioner.") + val part = userFactors.partitioner.get + userFactors.mapPartitionsWithIndex { (idx, items) => + items.foreach { case (id, _) => + if (part.getPartition(id) != idx) { + throw new SparkException(s"$tpe with ID $id should not be in partition $idx.") + } + } + Iterator.empty + }.count() + } + } + + test("als with large number of iterations") { + val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) + ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2) + ALS.train( + ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, implicitPrefs = true) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala new file mode 100644 index 000000000000..0b40fe33fae9 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.scalatest.FunSuite + +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, + DecisionTreeSuite => OldDecisionTreeSuite} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext { + + import DecisionTreeRegressorSuite.compareAPIs + + private var categoricalDataPointsRDD: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + categoricalDataPointsRDD = + sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints()) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// + + test("Regression stump with 3-ary (ordered) categorical features") { + val dt = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3, 1-> 3) + compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) + } + + test("Regression stump with binary (ordered) categorical features") { + val dt = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 2, 1-> 2) + compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + // TODO: test("model save/load") +} + +private[ml] object DecisionTreeRegressorSuite extends FunSuite { + + /** + * Train 2 decision trees on the given dataset, one using the old API and one using the new API. + * Convert the old tree to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + dt: DecisionTreeRegressor, + categoricalFeatures: Map[Int, Int]): Unit = { + val oldStrategy = dt.getOldStrategy(categoricalFeatures) + val oldTree = OldDecisionTree.train(data, oldStrategy) + val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) + val newTree = dt.fit(newData) + // Use parent, fittingParamMap from newTree since these are not checked anyways. + val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(oldTree, newTree.parent, + newTree.fittingParamMap, categoricalFeatures) + TreeTests.checkEqual(oldTreeAsNew, newTree) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala new file mode 100644 index 000000000000..bbb44c3e2dfc --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, SQLContext} + +class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { + + @transient var sqlContext: SQLContext = _ + @transient var dataset: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + dataset = sqlContext.createDataFrame( + sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 2)) + } + + test("linear regression: default params") { + val lr = new LinearRegression + assert(lr.getLabelCol == "label") + val model = lr.fit(dataset) + model.transform(dataset) + .select("label", "prediction") + .collect() + // Check defaults + assert(model.getFeaturesCol == "features") + assert(model.getPredictionCol == "prediction") + } + + test("linear regression with setters") { + // Set params, train, and check as many as we can. + val lr = new LinearRegression() + .setMaxIter(10) + .setRegParam(1.0) + val model = lr.fit(dataset) + assert(model.fittingParamMap.get(lr.maxIter).get === 10) + assert(model.fittingParamMap.get(lr.regParam).get === 1.0) + + // Call fit() with new params, and check as many as we can. + val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.predictionCol -> "thePred") + assert(model2.fittingParamMap.get(lr.maxIter).get === 5) + assert(model2.fittingParamMap.get(lr.regParam).get === 0.1) + assert(model2.getPredictionCol == "thePred") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 41cc13da4d5b..761ea821ef7c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -23,16 +23,16 @@ import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{SQLContext, SchemaRDD} +import org.apache.spark.sql.{SQLContext, DataFrame} class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { - @transient var dataset: SchemaRDD = _ + @transient var dataset: DataFrame = _ override def beforeAll(): Unit = { super.beforeAll() val sqlContext = new SQLContext(sc) - dataset = sqlContext.createSchemaRDD( + dataset = sqlContext.createDataFrame( sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala new file mode 100644 index 000000000000..c44cb61b3417 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.ml.Transformer +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.MetadataBuilder +import org.scalatest.FunSuite + +private[ml] object TestingUtils extends FunSuite { + + /** + * Test whether unrelated metadata are preserved for this transformer. + * This attaches extra metadata to a column, transforms the column, and check to ensure the + * extra metadata have not changed. + * @param data Input dataset + * @param transformer Transformer to test + * @param inputCol Unique input column for Transformer. This must be the ONLY input column. + * @param outputCol Output column to test for metadata presence. + */ + def testPreserveMetadata( + data: DataFrame, + transformer: Transformer, + inputCol: String, + outputCol: String): Unit = { + // Create some fake metadata + val origMetadata = data.schema(inputCol).metadata + val metaKey = "__testPreserveMetadata__fake_key" + val metaValue = 12345 + assert(!origMetadata.contains(metaKey), + s"Unit test with testPreserveMetadata will fail since metadata key was present: $metaKey") + val newMetadata = + new MetadataBuilder().withMetadata(origMetadata).putLong(metaKey, metaValue).build() + // Add metadata to the inputCol + val withMetadata = data.select(data(inputCol).as(inputCol, newMetadata)) + // Transform, and ensure extra metadata was not affected + val transformed = transformer.transform(withMetadata) + val transMetadata = transformed.schema(outputCol).metadata + assert(transMetadata.contains(metaKey), + "Unit test with testPreserveMetadata failed; extra metadata key was not present.") + assert(transMetadata.getLong(metaKey) === metaValue, + "Unit test with testPreserveMetadata failed; extra metadata value was wrong." + + s" Expected $metaValue but found ${transMetadata.getLong(metaKey)}") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 94b0e00f3726..a26c52852c4d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -17,16 +17,19 @@ package org.apache.spark.mllib.classification -import scala.util.Random import scala.collection.JavaConversions._ +import scala.util.Random +import scala.util.control.Breaks._ import org.scalatest.FunSuite import org.scalatest.Matchers -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils + object LogisticRegressionSuite { @@ -55,8 +58,116 @@ object LogisticRegressionSuite { val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Vectors.dense(Array(x1(i))))) testData } + + /** + * Generates `k` classes multinomial synthetic logistic input in `n` dimensional space given the + * model weights and mean/variance of the features. The synthetic data will be drawn from + * the probability distribution constructed by weights using the following formula. + * + * P(y = 0 | x) = 1 / norm + * P(y = 1 | x) = exp(x * w_1) / norm + * P(y = 2 | x) = exp(x * w_2) / norm + * ... + * P(y = k-1 | x) = exp(x * w_{k-1}) / norm + * where norm = 1 + exp(x * w_1) + exp(x * w_2) + ... + exp(x * w_{k-1}) + * + * @param weights matrix is flatten into a vector; as a result, the dimension of weights vector + * will be (k - 1) * (n + 1) if `addIntercept == true`, and + * if `addIntercept != true`, the dimension will be (k - 1) * n. + * @param xMean the mean of the generated features. Lots of time, if the features are not properly + * standardized, the algorithm with poor implementation will have difficulty + * to converge. + * @param xVariance the variance of the generated features. + * @param addIntercept whether to add intercept. + * @param nPoints the number of instance of generated data. + * @param seed the seed for random generator. For consistent testing result, it will be fixed. + */ + def generateMultinomialLogisticInput( + weights: Array[Double], + xMean: Array[Double], + xVariance: Array[Double], + addIntercept: Boolean, + nPoints: Int, + seed: Int): Seq[LabeledPoint] = { + val rnd = new Random(seed) + + val xDim = xMean.size + val xWithInterceptsDim = if (addIntercept) xDim + 1 else xDim + val nClasses = weights.size / xWithInterceptsDim + 1 + + val x = Array.fill[Vector](nPoints)(Vectors.dense(Array.fill[Double](xDim)(rnd.nextGaussian()))) + + x.map(vector => { + // This doesn't work if `vector` is a sparse vector. + val vectorArray = vector.toArray + var i = 0 + while (i < vectorArray.size) { + vectorArray(i) = vectorArray(i) * math.sqrt(xVariance(i)) + xMean(i) + i += 1 + } + }) + + val y = (0 until nPoints).map { idx => + val xArray = x(idx).toArray + val margins = Array.ofDim[Double](nClasses) + val probs = Array.ofDim[Double](nClasses) + + for (i <- 0 until nClasses - 1) { + for (j <- 0 until xDim) margins(i + 1) += weights(i * xWithInterceptsDim + j) * xArray(j) + if (addIntercept) margins(i + 1) += weights((i + 1) * xWithInterceptsDim - 1) + } + // Preventing the overflow when we compute the probability + val maxMargin = margins.max + if (maxMargin > 0) for (i <-0 until nClasses) margins(i) -= maxMargin + + // Computing the probabilities for each class from the margins. + val norm = { + var temp = 0.0 + for (i <- 0 until nClasses) { + probs(i) = math.exp(margins(i)) + temp += probs(i) + } + temp + } + for (i <-0 until nClasses) probs(i) /= norm + + // Compute the cumulative probability so we can generate a random number and assign a label. + for (i <- 1 until nClasses) probs(i) += probs(i - 1) + val p = rnd.nextDouble() + var y = 0 + breakable { + for (i <- 0 until nClasses) { + if (p < probs(i)) { + y = i + break + } + } + } + y + } + + val testData = (0 until nPoints).map(i => LabeledPoint(y(i), x(i))) + testData + } + + /** Binary labels, 3 features */ + private val binaryModel = new LogisticRegressionModel( + weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5, numFeatures = 3, numClasses = 2) + + /** 3 classes, 2 features */ + private val multiclassModel = new LogisticRegressionModel( + weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, numFeatures = 2, numClasses = 3) + + private def checkModelsEqual(a: LogisticRegressionModel, b: LogisticRegressionModel): Unit = { + assert(a.weights == b.weights) + assert(a.intercept == b.intercept) + assert(a.numClasses == b.numClasses) + assert(a.numFeatures == b.numFeatures) + assert(a.getThreshold == b.getThreshold) + } } + class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { def validatePrediction( predictions: Seq[Double], @@ -261,8 +372,12 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M testRDD2.cache() testRDD3.cache() + val numIteration = 10 + val lrA = new LogisticRegressionWithLBFGS().setIntercept(true) + lrA.optimizer.setNumIterations(numIteration) val lrB = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(false) + lrB.optimizer.setNumIterations(numIteration) val modelA1 = lrA.run(testRDD1, initialWeights) val modelA2 = lrA.run(testRDD2, initialWeights) @@ -285,6 +400,144 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M assert(modelB1.weights(0) !~== modelB3.weights(0) * 1.0E6 absTol 0.1) } + test("multinomial logistic regression with LBFGS") { + val nPoints = 10000 + + /** + * The following weights and xMean/xVariance are computed from iris dataset with lambda = 0.2. + * As a result, we are actually drawing samples from probability distribution of built model. + */ + val weights = Array( + -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, + -0.16624, -0.84355, -0.048509, -0.301789, 4.170682) + + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + + val testData = LogisticRegressionSuite.generateMultinomialLogisticInput( + weights, xMean, xVariance, true, nPoints, 42) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(3) + lr.optimizer.setConvergenceTol(1E-15).setNumIterations(200) + + val model = lr.run(testRDD) + + val numFeatures = testRDD.map(_.features.size).first() + val initialWeights = Vectors.dense(new Array[Double]((numFeatures + 1) * 2)) + val model2 = lr.run(testRDD, initialWeights) + + LogisticRegressionSuite.checkModelsEqual(model, model2) + + /** + * The following is the instruction to reproduce the model using R's glmnet package. + * + * First of all, using the following scala code to save the data into `path`. + * + * testRDD.map(x => x.label+ ", " + x.features(0) + ", " + x.features(1) + ", " + + * x.features(2) + ", " + x.features(3)).saveAsTextFile("path") + * + * Using the following R code to load the data and train the model using glmnet package. + * + * library("glmnet") + * data <- read.csv("path", header=FALSE) + * label = factor(data$V1) + * features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + * weights = coef(glmnet(features,label, family="multinomial", alpha = 0, lambda = 0)) + * + * The model weights of mutinomial logstic regression in R have `K` set of linear predictors + * for `K` classes classification problem; however, only `K-1` set is required if the first + * outcome is chosen as a "pivot", and the other `K-1` outcomes are separately regressed against + * the pivot outcome. This can be done by subtracting the first weights from those `K-1` set + * weights. The mathematical discussion and proof can be found here: + * http://en.wikipedia.org/wiki/Multinomial_logistic_regression + * + * weights1 = weights$`1` - weights$`0` + * weights2 = weights$`2` - weights$`0` + * + * > weights1 + * 5 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * 2.6228269 + * data.V2 -0.5837166 + * data.V3 0.9285260 + * data.V4 -0.3783612 + * data.V5 -0.8123411 + * > weights2 + * 5 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * 4.11197445 + * data.V2 -0.16918650 + * data.V3 -0.81104784 + * data.V4 -0.06463799 + * data.V5 -0.29198337 + */ + + val weightsR = Vectors.dense(Array( + -0.5837166, 0.9285260, -0.3783612, -0.8123411, 2.6228269, + -0.1691865, -0.811048, -0.0646380, -0.2919834, 4.1119745)) + + assert(model.weights ~== weightsR relTol 0.05) + + val validationData = LogisticRegressionSuite.generateMultinomialLogisticInput( + weights, xMean, xVariance, true, nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) + // The validation accuracy is not good since this model (even the original weights) doesn't have + // very steep curve in logistic function so that when we draw samples from distribution, it's + // very easy to assign to another labels. However, this prediction result is consistent to R. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData, 0.47) + + } + + test("model save/load: binary classification") { + // NOTE: This will need to be generalized once there are multiple model format versions. + val model = LogisticRegressionSuite.binaryModel + + model.clearThreshold() + assert(model.getThreshold.isEmpty) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = LogisticRegressionModel.load(sc, path) + LogisticRegressionSuite.checkModelsEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + + // Save model with threshold. + try { + model.setThreshold(0.7) + model.save(sc, path) + val sameModel = LogisticRegressionModel.load(sc, path) + LogisticRegressionSuite.checkModelsEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + + test("model save/load: multiclass classification") { + // NOTE: This will need to be generalized once there are multiple model format versions. + val model = LogisticRegressionSuite.multiclassModel + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = LogisticRegressionModel.load(sc, path) + LogisticRegressionSuite.checkModelsEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + } class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index e68fe89d6cce..ea89b17b7c08 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -19,12 +19,17 @@ package org.apache.spark.mllib.classification import scala.util.Random +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis} +import breeze.stats.distributions.{Multinomial => BrzMultinomial} + import org.scalatest.FunSuite import org.apache.spark.SparkException import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} +import org.apache.spark.util.Utils + object NaiveBayesSuite { @@ -39,25 +44,48 @@ object NaiveBayesSuite { // Generate input of the form Y = (theta * x).argmax() def generateNaiveBayesInput( - pi: Array[Double], // 1XC - theta: Array[Array[Double]], // CXD - nPoints: Int, - seed: Int): Seq[LabeledPoint] = { + pi: Array[Double], // 1XC + theta: Array[Array[Double]], // CXD + nPoints: Int, + seed: Int, + modelType: String = "Multinomial", + sample: Int = 10): Seq[LabeledPoint] = { val D = theta(0).length val rnd = new Random(seed) - val _pi = pi.map(math.pow(math.E, _)) val _theta = theta.map(row => row.map(math.pow(math.E, _))) for (i <- 0 until nPoints) yield { val y = calcLabel(rnd.nextDouble(), _pi) - val xi = Array.tabulate[Double](D) { j => - if (rnd.nextDouble() < _theta(y)(j)) 1 else 0 + val xi = modelType match { + case "Bernoulli" => Array.tabulate[Double] (D) { j => + if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0 + } + case "Multinomial" => + val mult = BrzMultinomial(BDV(_theta(y))) + val emptyMap = (0 until D).map(x => (x, 0.0)).toMap + val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map { + case (index, reps) => (index, reps.size.toDouble) + } + counts.toArray.sortBy(_._1).map(_._2) + case _ => + // This should never happen. + throw new UnknownError(s"NaiveBayesSuite found unknown ModelType: $modelType") } LabeledPoint(y, Vectors.dense(xi)) } } + + /** Bernoulli NaiveBayes with binary labels, 3 features */ + private val binaryBernoulliModel = new NaiveBayesModel(labels = Array(0.0, 1.0), + pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), + "Bernoulli") + + /** Multinomial NaiveBayes with binary labels, 3 features */ + private val binaryMultinomialModel = new NaiveBayesModel(labels = Array(0.0, 1.0), + pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), + "Multinomial") } class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { @@ -71,23 +99,79 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { assert(numOfPredictions < input.length / 5) } - test("Naive Bayes") { - val nPoints = 10000 + def validateModelFit( + piData: Array[Double], + thetaData: Array[Array[Double]], + model: NaiveBayesModel): Unit = { + def closeFit(d1: Double, d2: Double, precision: Double): Boolean = { + (d1 - d2).abs <= precision + } + val modelIndex = (0 until piData.length).zip(model.labels.map(_.toInt)) + for (i <- modelIndex) { + assert(closeFit(math.exp(piData(i._2)), math.exp(model.pi(i._1)), 0.05)) + } + for (i <- modelIndex) { + for (j <- 0 until thetaData(i._2).length) { + assert(closeFit(math.exp(thetaData(i._2)(j)), math.exp(model.theta(i._1)(j)), 0.05)) + } + } + } + + test("get, set params") { + val nb = new NaiveBayes() + nb.setLambda(2.0) + assert(nb.getLambda === 2.0) + nb.setLambda(3.0) + assert(nb.getLambda === 3.0) + } + + test("Naive Bayes Multinomial") { + val nPoints = 1000 + val pi = Array(0.5, 0.1, 0.4).map(math.log) + val theta = Array( + Array(0.70, 0.10, 0.10, 0.10), // label 0 + Array(0.10, 0.70, 0.10, 0.10), // label 1 + Array(0.10, 0.10, 0.70, 0.10) // label 2 + ).map(_.map(math.log)) + + val testData = NaiveBayesSuite.generateNaiveBayesInput( + pi, theta, nPoints, 42, "Multinomial") + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + val model = NaiveBayes.train(testRDD, 1.0, "Multinomial") + validateModelFit(pi, theta, model) + val validationData = NaiveBayesSuite.generateNaiveBayesInput( + pi, theta, nPoints, 17, "Multinomial") + val validationRDD = sc.parallelize(validationData, 2) + + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } + + test("Naive Bayes Bernoulli") { + val nPoints = 10000 val pi = Array(0.5, 0.3, 0.2).map(math.log) val theta = Array( - Array(0.91, 0.03, 0.03, 0.03), // label 0 - Array(0.03, 0.91, 0.03, 0.03), // label 1 - Array(0.03, 0.03, 0.91, 0.03) // label 2 + Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0 + Array(0.02, 0.70, 0.10, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02), // label 1 + Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2 ).map(_.map(math.log)) - val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42) + val testData = NaiveBayesSuite.generateNaiveBayesInput( + pi, theta, nPoints, 45, "Bernoulli") val testRDD = sc.parallelize(testData, 2) testRDD.cache() - val model = NaiveBayes.train(testRDD) + val model = NaiveBayes.train(testRDD, 1.0, "Bernoulli") + validateModelFit(pi, theta, model) - val validationData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 17) + val validationData = NaiveBayesSuite.generateNaiveBayesInput( + pi, theta, nPoints, 20, "Bernoulli") val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. @@ -123,6 +207,46 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { NaiveBayes.train(sc.makeRDD(nan, 2)) } } + + test("model save/load: 2.0 to 2.0") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + Seq(NaiveBayesSuite.binaryBernoulliModel, NaiveBayesSuite.binaryMultinomialModel).map { + model => + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = NaiveBayesModel.load(sc, path) + assert(model.labels === sameModel.labels) + assert(model.pi === sameModel.pi) + assert(model.theta === sameModel.theta) + assert(model.modelType === sameModel.modelType) + } finally { + Utils.deleteRecursively(tempDir) + } + } + } + + test("model save/load: 1.0 to 2.0") { + val model = NaiveBayesSuite.binaryMultinomialModel + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model as version 1.0, load it back, and compare. + try { + val data = NaiveBayesModel.SaveLoadV1_0.Data(model.labels, model.pi, model.theta) + NaiveBayesModel.SaveLoadV1_0.save(sc, path, data) + val sameModel = NaiveBayesModel.load(sc, path) + assert(model.labels === sameModel.labels) + assert(model.pi === sameModel.pi) + assert(model.theta === sameModel.theta) + assert(model.modelType === "Multinomial") + } finally { + Utils.deleteRecursively(tempDir) + } + } } class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext { @@ -136,8 +260,8 @@ class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext { LabeledPoint(random.nextInt(2), Vectors.dense(Array.fill(n)(random.nextDouble()))) } } - // If we serialize data directly in the task closure, the size of the serialized task would be - // greater than 1MB and hence Spark would throw an error. + // If we serialize data directly in the task closure, the size of the serialized task + // would be greater than 1MB and hence Spark would throw an error. val model = NaiveBayes.train(examples) val predictions = model.predict(examples.map(_.features)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index a2de7fbd4138..6de098b383ba 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.SparkException import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} +import org.apache.spark.util.Utils object SVMSuite { @@ -56,6 +57,9 @@ object SVMSuite { y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2))) } + /** Binary labels, 3 features */ + private val binaryModel = new SVMModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) + } class SVMSuite extends FunSuite with MLlibTestSparkContext { @@ -191,6 +195,38 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { // Turning off data validation should not throw an exception new SVMWithSGD().setValidateData(false).run(testRDDInvalid) } + + test("model save/load") { + // NOTE: This will need to be generalized once there are multiple model format versions. + val model = SVMSuite.binaryModel + + model.clearThreshold() + assert(model.getThreshold.isEmpty) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = SVMModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + assert(sameModel.getThreshold.isEmpty) + } finally { + Utils.deleteRecursively(tempDir) + } + + // Save model with threshold. + try { + model.setThreshold(0.7) + model.save(sc, path) + val sameModel2 = SVMModel.load(sc, path) + assert(model.getThreshold.get == sameModel2.getThreshold.get) + } finally { + Utils.deleteRecursively(tempDir) + } + } } class SVMClusterSuite extends FunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala new file mode 100644 index 000000000000..5683b55e8500 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.TestSuiteBase + +class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase { + + // use longer wait time to ensure job completion + override def maxWaitTimeMillis: Int = 30000 + + // Test if we can accurately learn B for Y = logistic(BX) on streaming data + test("parameter accuracy") { + + val nPoints = 100 + val B = 1.5 + + // create model + val model = new StreamingLogisticRegressionWithSGD() + .setInitialWeights(Vectors.dense(0.0)) + .setStepSize(0.2) + .setNumIterations(25) + + // generate sequence of simulated data + val numBatches = 20 + val input = (0 until numBatches).map { i => + LogisticRegressionSuite.generateLogisticInput(0.0, B, nPoints, 42 * (i + 1)) + } + + // apply model training to input stream + val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // check accuracy of final parameter estimates + assert(model.latestModel().weights(0) ~== B relTol 0.1) + + } + + // Test that parameter estimates improve when learning Y = logistic(BX) on streaming data + test("parameter convergence") { + + val B = 1.5 + val nPoints = 100 + + // create model + val model = new StreamingLogisticRegressionWithSGD() + .setInitialWeights(Vectors.dense(0.0)) + .setStepSize(0.2) + .setNumIterations(25) + + // generate sequence of simulated data + val numBatches = 20 + val input = (0 until numBatches).map { i => + LogisticRegressionSuite.generateLogisticInput(0.0, B, nPoints, 42 * (i + 1)) + } + + // create buffer to store intermediate fits + val history = new ArrayBuffer[Double](numBatches) + + // apply model training to input stream, storing the intermediate results + // (we add a count to ensure the result is a DStream) + val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - B))) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // compute change in error + val deltas = history.drop(1).zip(history.dropRight(1)) + // check error stability (it always either shrinks, or increases with small tol) + assert(deltas.forall(x => (x._1 - x._2) <= 0.1)) + // check that error shrunk on at least 2 batches + assert(deltas.map(x => if ((x._1 - x._2) < 0) 1 else 0).sum > 1) + } + + // Test predictions on a stream + test("predictions") { + + val B = 1.5 + val nPoints = 100 + + // create model initialized with true weights + val model = new StreamingLogisticRegressionWithSGD() + .setInitialWeights(Vectors.dense(1.5)) + .setStepSize(0.2) + .setNumIterations(25) + + // generate sequence of simulated data for testing + val numBatches = 10 + val testInput = (0 until numBatches).map { i => + LogisticRegressionSuite.generateLogisticInput(0.0, B, nPoints, 42 * (i + 1)) + } + + // apply model predictions to test stream + val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + }) + + // collect the output as (true, estimated) tuples + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + + // check that at least 60% of predictions are correct on all batches + val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints) + + assert(errors.forall(x => x <= 0.4)) + } + + // Test training combined with prediction + test("training and prediction") { + // create model initialized with zero weights + val model = new StreamingLogisticRegressionWithSGD() + .setInitialWeights(Vectors.dense(-0.1)) + .setStepSize(0.01) + .setNumIterations(10) + + // generate sequence of simulated data for testing + val numBatches = 10 + val nPoints = 100 + val testInput = (0 until numBatches).map { i => + LogisticRegressionSuite.generateLogisticInput(0.0, 5.0, nPoints, 42 * (i + 1)) + } + + // train and predict + val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + }) + + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + + // assert that prediction error improves, ensuring that the updated model is being used + val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList + assert(error.head > 0.8 & error.last < 0.2) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala similarity index 51% rename from mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala rename to mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala index 198997b5bb2b..f356ffa3e3a2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala @@ -23,15 +23,16 @@ import org.apache.spark.mllib.linalg.{Vectors, Matrices} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils -class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContext { +class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { test("single cluster") { val data = sc.parallelize(Array( Vectors.dense(6.0, 9.0), Vectors.dense(5.0, 10.0), Vectors.dense(4.0, 11.0) )) - + // expectations val Ew = 1.0 val Emu = Vectors.dense(5.0, 10.0) @@ -39,22 +40,17 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex val seeds = Array(314589, 29032897, 50181, 494821, 4660) seeds.foreach { seed => - val gmm = new GaussianMixtureEM().setK(1).setSeed(seed).run(data) + val gmm = new GaussianMixture().setK(1).setSeed(seed).run(data) assert(gmm.weights(0) ~== Ew absTol 1E-5) assert(gmm.gaussians(0).mu ~== Emu absTol 1E-5) assert(gmm.gaussians(0).sigma ~== Esigma absTol 1E-5) } + } test("two clusters") { - val data = sc.parallelize(Array( - Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220), - Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118), - Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322), - Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026), - Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) - )) - + val data = sc.parallelize(GaussianTestData.data) + // we set an initial gaussian to induce expected results val initialGmm = new GaussianMixtureModel( Array(0.5, 0.5), @@ -63,16 +59,16 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex new MultivariateGaussian(Vectors.dense(1.0), Matrices.dense(1, 1, Array(1.0))) ) ) - + val Ew = Array(1.0 / 3.0, 2.0 / 3.0) val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604)) val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644))) - val gmm = new GaussianMixtureEM() + val gmm = new GaussianMixture() .setK(2) .setInitialModel(initialGmm) .run(data) - + assert(gmm.weights(0) ~== Ew(0) absTol 1E-3) assert(gmm.weights(1) ~== Ew(1) absTol 1E-3) assert(gmm.gaussians(0).mu ~== Emu(0) absTol 1E-3) @@ -80,4 +76,88 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex assert(gmm.gaussians(0).sigma ~== Esigma(0) absTol 1E-3) assert(gmm.gaussians(1).sigma ~== Esigma(1) absTol 1E-3) } + + test("single cluster with sparse data") { + val data = sc.parallelize(Array( + Vectors.sparse(3, Array(0, 2), Array(4.0, 2.0)), + Vectors.sparse(3, Array(0, 2), Array(2.0, 4.0)), + Vectors.sparse(3, Array(1), Array(6.0)) + )) + + val Ew = 1.0 + val Emu = Vectors.dense(2.0, 2.0, 2.0) + val Esigma = Matrices.dense(3, 3, + Array(8.0 / 3.0, -4.0, 4.0 / 3.0, -4.0, 8.0, -4.0, 4.0 / 3.0, -4.0, 8.0 / 3.0) + ) + + val seeds = Array(42, 1994, 27, 11, 0) + seeds.foreach { seed => + val gmm = new GaussianMixture().setK(1).setSeed(seed).run(data) + assert(gmm.weights(0) ~== Ew absTol 1E-5) + assert(gmm.gaussians(0).mu ~== Emu absTol 1E-5) + assert(gmm.gaussians(0).sigma ~== Esigma absTol 1E-5) + } + } + + test("two clusters with sparse data") { + val data = sc.parallelize(GaussianTestData.data) + val sparseData = data.map(point => Vectors.sparse(1, Array(0), point.toArray)) + // we set an initial gaussian to induce expected results + val initialGmm = new GaussianMixtureModel( + Array(0.5, 0.5), + Array( + new MultivariateGaussian(Vectors.dense(-1.0), Matrices.dense(1, 1, Array(1.0))), + new MultivariateGaussian(Vectors.dense(1.0), Matrices.dense(1, 1, Array(1.0))) + ) + ) + val Ew = Array(1.0 / 3.0, 2.0 / 3.0) + val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604)) + val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644))) + + val sparseGMM = new GaussianMixture() + .setK(2) + .setInitialModel(initialGmm) + .run(data) + + assert(sparseGMM.weights(0) ~== Ew(0) absTol 1E-3) + assert(sparseGMM.weights(1) ~== Ew(1) absTol 1E-3) + assert(sparseGMM.gaussians(0).mu ~== Emu(0) absTol 1E-3) + assert(sparseGMM.gaussians(1).mu ~== Emu(1) absTol 1E-3) + assert(sparseGMM.gaussians(0).sigma ~== Esigma(0) absTol 1E-3) + assert(sparseGMM.gaussians(1).sigma ~== Esigma(1) absTol 1E-3) + } + + test("model save / load") { + val data = sc.parallelize(GaussianTestData.data) + + val gmm = new GaussianMixture().setK(2).setSeed(0).run(data) + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + try { + gmm.save(sc, path) + + // TODO: GaussianMixtureModel should implement equals/hashcode directly. + val sameModel = GaussianMixtureModel.load(sc, path) + assert(sameModel.k === gmm.k) + (0 until sameModel.k).foreach { i => + assert(sameModel.gaussians(i).mu === gmm.gaussians(i).mu) + assert(sameModel.gaussians(i).sigma === gmm.gaussians(i).sigma) + } + } finally { + Utils.deleteRecursively(tempDir) + } + } + + object GaussianTestData { + + val data = Array( + Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220), + Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118), + Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322), + Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026), + Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) + ) + + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index caee5917000a..0f2b26d462ad 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -21,9 +21,10 @@ import scala.util.Random import org.scalatest.FunSuite -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils class KMeansSuite extends FunSuite with MLlibTestSparkContext { @@ -198,9 +199,13 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { test("k-means|| initialization") { case class VectorWithCompare(x: Vector) extends Ordered[VectorWithCompare] { - @Override def compare(that: VectorWithCompare): Int = { - if(this.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x) > - that.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x)) -1 else 1 + override def compare(that: VectorWithCompare): Int = { + if (this.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x) > + that.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x)) { + -1 + } else { + 1 + } } } @@ -257,6 +262,47 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { assert(predicts(0) != predicts(3)) } } + + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + Array(true, false).foreach { case selector => + val model = KMeansSuite.createModel(10, 3, selector) + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = KMeansModel.load(sc, path) + KMeansSuite.checkEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + } +} + +object KMeansSuite extends FunSuite { + def createModel(dim: Int, k: Int, isSparse: Boolean): KMeansModel = { + val singlePoint = isSparse match { + case true => + Vectors.sparse(dim, Array.empty[Int], Array.empty[Double]) + case _ => + Vectors.dense(Array.fill[Double](dim)(0.0)) + } + new KMeansModel(Array.fill[Vector](k)(singlePoint)) + } + + def checkEqual(a: KMeansModel, b: KMeansModel): Unit = { + assert(a.k === b.k) + a.clusterCenters.zip(b.clusterCenters).foreach { + case (ca: SparseVector, cb: SparseVector) => + assert(ca === cb) + case (ca: DenseVector, cb: DenseVector) => + assert(ca === cb) + case _ => + throw new AssertionError("checkEqual failed since the two clusters were not identical.\n") + } + } } class KMeansClusterSuite extends FunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala new file mode 100644 index 000000000000..cc747dabb996 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class LDASuite extends FunSuite with MLlibTestSparkContext { + + import LDASuite._ + + test("LocalLDAModel") { + val model = new LocalLDAModel(tinyTopics) + + // Check: basic parameters + assert(model.k === tinyK) + assert(model.vocabSize === tinyVocabSize) + assert(model.topicsMatrix === tinyTopics) + + // Check: describeTopics() with all terms + val fullTopicSummary = model.describeTopics() + assert(fullTopicSummary.size === tinyK) + fullTopicSummary.zip(tinyTopicDescription).foreach { + case ((algTerms, algTermWeights), (terms, termWeights)) => + assert(algTerms === terms) + assert(algTermWeights === termWeights) + } + + // Check: describeTopics() with some terms + val smallNumTerms = 3 + val smallTopicSummary = model.describeTopics(maxTermsPerTopic = smallNumTerms) + smallTopicSummary.zip(tinyTopicDescription).foreach { + case ((algTerms, algTermWeights), (terms, termWeights)) => + assert(algTerms === terms.slice(0, smallNumTerms)) + assert(algTermWeights === termWeights.slice(0, smallNumTerms)) + } + } + + test("running and DistributedLDAModel") { + val k = 3 + val topicSmoothing = 1.2 + val termSmoothing = 1.2 + + // Train a model + val lda = new LDA() + lda.setK(k) + .setDocConcentration(topicSmoothing) + .setTopicConcentration(termSmoothing) + .setMaxIterations(5) + .setSeed(12345) + val corpus = sc.parallelize(tinyCorpus, 2) + + val model: DistributedLDAModel = lda.run(corpus) + + // Check: basic parameters + val localModel = model.toLocal + assert(model.k === k) + assert(localModel.k === k) + assert(model.vocabSize === tinyVocabSize) + assert(localModel.vocabSize === tinyVocabSize) + assert(model.topicsMatrix === localModel.topicsMatrix) + + // Check: topic summaries + // The odd decimal formatting and sorting is a hack to do a robust comparison. + val roundedTopicSummary = model.describeTopics().map { case (terms, termWeights) => + // cut values to 3 digits after the decimal place + terms.zip(termWeights).map { case (term, weight) => + ("%.3f".format(weight).toDouble, term.toInt) + } + }.sortBy(_.mkString("")) + val roundedLocalTopicSummary = localModel.describeTopics().map { case (terms, termWeights) => + // cut values to 3 digits after the decimal place + terms.zip(termWeights).map { case (term, weight) => + ("%.3f".format(weight).toDouble, term.toInt) + } + }.sortBy(_.mkString("")) + roundedTopicSummary.zip(roundedLocalTopicSummary).foreach { case (t1, t2) => + assert(t1 === t2) + } + + // Check: per-doc topic distributions + val topicDistributions = model.topicDistributions.collect() + // Ensure all documents are covered. + assert(topicDistributions.size === tinyCorpus.size) + assert(tinyCorpus.map(_._1).toSet === topicDistributions.map(_._1).toSet) + // Ensure we have proper distributions + topicDistributions.foreach { case (docId, topicDistribution) => + assert(topicDistribution.size === tinyK) + assert(topicDistribution.toArray.sum ~== 1.0 absTol 1e-5) + } + + // Check: log probabilities + assert(model.logLikelihood < 0.0) + assert(model.logPrior < 0.0) + } + + test("vertex indexing") { + // Check vertex ID indexing and conversions. + val docIds = Array(0, 1, 2) + val docVertexIds = docIds + val termIds = Array(0, 1, 2) + val termVertexIds = Array(-1, -2, -3) + assert(docVertexIds.forall(i => !LDA.isTermVertex((i.toLong, 0)))) + assert(termIds.map(LDA.term2index) === termVertexIds) + assert(termVertexIds.map(i => LDA.index2term(i.toLong)) === termIds) + assert(termVertexIds.forall(i => LDA.isTermVertex((i.toLong, 0)))) + } + + test("setter alias") { + val lda = new LDA().setAlpha(2.0).setBeta(3.0) + assert(lda.getAlpha === 2.0) + assert(lda.getDocConcentration === 2.0) + assert(lda.getBeta === 3.0) + assert(lda.getTopicConcentration === 3.0) + } +} + +private[clustering] object LDASuite { + + def tinyK: Int = 3 + def tinyVocabSize: Int = 5 + def tinyTopicsAsArray: Array[Array[Double]] = Array( + Array[Double](0.1, 0.2, 0.3, 0.4, 0.0), // topic 0 + Array[Double](0.5, 0.05, 0.05, 0.1, 0.3), // topic 1 + Array[Double](0.2, 0.2, 0.05, 0.05, 0.5) // topic 2 + ) + def tinyTopics: Matrix = new DenseMatrix(numRows = tinyVocabSize, numCols = tinyK, + values = tinyTopicsAsArray.fold(Array.empty[Double])(_ ++ _)) + def tinyTopicDescription: Array[(Array[Int], Array[Double])] = tinyTopicsAsArray.map { topic => + val (termWeights, terms) = topic.zipWithIndex.sortBy(-_._1).unzip + (terms.toArray, termWeights.toArray) + } + + def tinyCorpus: Array[(Long, Vector)] = Array( + Vectors.dense(1, 3, 0, 2, 8), + Vectors.dense(0, 2, 1, 0, 4), + Vectors.dense(2, 3, 12, 3, 1), + Vectors.dense(0, 3, 1, 9, 8), + Vectors.dense(1, 1, 4, 2, 6) + ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } + assert(tinyCorpus.forall(_._2.size == tinyVocabSize)) // sanity check for test data + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala new file mode 100644 index 000000000000..6d6fe6fe46ba --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import scala.collection.mutable +import scala.util.Random + +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext +import org.apache.spark.graphx.{Edge, Graph} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils + +class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext { + + import org.apache.spark.mllib.clustering.PowerIterationClustering._ + + test("power iteration clustering") { + /* + We use the following graph to test PIC. All edges are assigned similarity 1.0 except 0.1 for + edge (3, 4). + + 15-14 -13 -12 + | | + 4 . 3 - 2 11 + | | x | | + 5 0 - 1 10 + | | + 6 - 7 - 8 - 9 + */ + + val similarities = Seq[(Long, Long, Double)]((0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), + (1, 3, 1.0), (2, 3, 1.0), (3, 4, 0.1), // (3, 4) is a weak edge + (4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0), (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0), + (10, 11, 1.0), (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0)) + val model = new PowerIterationClustering() + .setK(2) + .run(sc.parallelize(similarities, 2)) + val predictions = Array.fill(2)(mutable.Set.empty[Long]) + model.assignments.collect().foreach { a => + predictions(a.cluster) += a.id + } + assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) + + val model2 = new PowerIterationClustering() + .setK(2) + .setInitializationMode("degree") + .run(sc.parallelize(similarities, 2)) + val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) + model2.assignments.collect().foreach { a => + predictions2(a.cluster) += a.id + } + assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) + } + + test("normalize and powerIter") { + /* + Test normalize() with the following graph: + + 0 - 3 + | \ | + 1 - 2 + + The affinity matrix (A) is + + 0 1 1 1 + 1 0 1 0 + 1 1 0 1 + 1 0 1 0 + + D is diag(3, 2, 3, 2) and hence W is + + 0 1/3 1/3 1/3 + 1/2 0 1/2 0 + 1/3 1/3 0 1/3 + 1/2 0 1/2 0 + */ + val similarities = Seq[(Long, Long, Double)]( + (0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), (2, 3, 1.0)) + val expected = Array( + Array(0.0, 1.0/3.0, 1.0/3.0, 1.0/3.0), + Array(1.0/2.0, 0.0, 1.0/2.0, 0.0), + Array(1.0/3.0, 1.0/3.0, 0.0, 1.0/3.0), + Array(1.0/2.0, 0.0, 1.0/2.0, 0.0)) + val w = normalize(sc.parallelize(similarities, 2)) + w.edges.collect().foreach { case Edge(i, j, x) => + assert(x ~== expected(i.toInt)(j.toInt) absTol 1e-14) + } + val v0 = sc.parallelize(Seq[(Long, Double)]((0, 0.1), (1, 0.2), (2, 0.3), (3, 0.4)), 2) + val w0 = Graph(v0, w.edges) + val v1 = powerIter(w0, maxIterations = 1).collect() + val u = Array(0.3, 0.2, 0.7/3.0, 0.2) + val norm = u.sum + val u1 = u.map(x => x / norm) + v1.foreach { case (i, x) => + assert(x ~== u1(i.toInt) absTol 1e-14) + } + } + + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + val model = PowerIterationClusteringSuite.createModel(sc, 3, 10) + try { + model.save(sc, path) + val sameModel = PowerIterationClusteringModel.load(sc, path) + PowerIterationClusteringSuite.checkEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } +} + +object PowerIterationClusteringSuite extends FunSuite { + def createModel(sc: SparkContext, k: Int, nPoints: Int): PowerIterationClusteringModel = { + val assignments = sc.parallelize( + (0 until nPoints).map(p => PowerIterationClustering.Assignment(p, Random.nextInt(k)))) + new PowerIterationClusteringModel(k, assignments) + } + + def checkEqual(a: PowerIterationClusteringModel, b: PowerIterationClusteringModel): Unit = { + assert(a.k === b.k) + + val aAssignments = a.assignments.map(x => (x.id, x.cluster)) + val bAssignments = b.assignments.map(x => (x.id, x.cluster)) + val unequalElements = aAssignments.join(bAssignments).filter { + case (id, (c1, c2)) => c1 != c2 }.count() + assert(unequalElements === 0L) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala index 850c9fce507c..f90025d535e4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.util.random.XORShiftRandom class StreamingKMeansSuite extends FunSuite with TestSuiteBase { - override def maxWaitTimeMillis = 30000 + override def maxWaitTimeMillis: Int = 30000 test("accuracy for single center and equivalence to grand average") { // set parameters @@ -59,7 +59,7 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase { // estimated center from streaming should exactly match the arithmetic mean of all data points // because the decay factor is set to 1.0 val grandMean = - input.flatten.map(x => x.toBreeze).reduce(_+_) / (numBatches * numPoints).toDouble + input.flatten.map(x => x.toBreeze).reduce(_ + _) / (numBatches * numPoints).toDouble assert(model.latestModel().clusterCenters(0) ~== Vectors.dense(grandMean.toArray) absTol 1E-5) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala new file mode 100644 index 000000000000..747f5914598e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class ChiSqSelectorSuite extends FunSuite with MLlibTestSparkContext { + + /* + * Contingency tables + * feature0 = {8.0, 0.0} + * class 0 1 2 + * 8.0||1|0|1| + * 0.0||0|2|0| + * + * feature1 = {7.0, 9.0} + * class 0 1 2 + * 7.0||1|0|0| + * 9.0||0|2|1| + * + * feature2 = {0.0, 6.0, 8.0, 5.0} + * class 0 1 2 + * 0.0||1|0|0| + * 6.0||0|1|0| + * 8.0||0|1|0| + * 5.0||0|0|1| + * + * Use chi-squared calculator from Internet + */ + + test("ChiSqSelector transform test (sparse & dense vector)") { + val labeledDiscreteData = sc.parallelize( + Seq(LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))), + LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))), + LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))), + LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))), 2) + val preFilteredData = + Set(LabeledPoint(0.0, Vectors.dense(Array(0.0))), + LabeledPoint(1.0, Vectors.dense(Array(6.0))), + LabeledPoint(1.0, Vectors.dense(Array(8.0))), + LabeledPoint(2.0, Vectors.dense(Array(5.0)))) + val model = new ChiSqSelector(1).fit(labeledDiscreteData) + val filteredData = labeledDiscreteData.map { lp => + LabeledPoint(lp.label, model.transform(lp.features)) + }.collect().toSet + assert(filteredData == preFilteredData) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala index 4c93c0ca4f86..7f94564b2a3a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala @@ -22,29 +22,114 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} import org.apache.spark.rdd.RDD class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { + // When the input data is all constant, the variance is zero. The standardization against + // zero variance is not well-defined, but we decide to just set it into zero here. + val constantData = Array( + Vectors.dense(2.0), + Vectors.dense(2.0), + Vectors.dense(2.0) + ) + + val sparseData = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.sparse(3, Seq((1, -1.0), (2, -3.0))), + Vectors.sparse(3, Seq((1, -5.1))), + Vectors.sparse(3, Seq((0, 3.8), (2, 1.9))), + Vectors.sparse(3, Seq((0, 1.7), (1, -0.6))), + Vectors.sparse(3, Seq((1, 1.9))) + ) + + val denseData = Array( + Vectors.dense(-2.0, 2.3, 0), + Vectors.dense(0.0, -1.0, -3.0), + Vectors.dense(0.0, -5.1, 0.0), + Vectors.dense(3.8, 0.0, 1.9), + Vectors.dense(1.7, -0.6, 0.0), + Vectors.dense(0.0, 1.9, 0.0) + ) + private def computeSummary(data: RDD[Vector]): MultivariateStatisticalSummary = { data.treeAggregate(new MultivariateOnlineSummarizer)( (aggregator, data) => aggregator.add(data), (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) } + test("Standardization with dense input when means and stds are provided") { + + val dataRDD = sc.parallelize(denseData, 3) + + val standardizer1 = new StandardScaler(withMean = true, withStd = true) + val standardizer2 = new StandardScaler() + val standardizer3 = new StandardScaler(withMean = true, withStd = false) + + val model1 = standardizer1.fit(dataRDD) + val model2 = standardizer2.fit(dataRDD) + val model3 = standardizer3.fit(dataRDD) + + val equivalentModel1 = new StandardScalerModel(model1.std, model1.mean) + val equivalentModel2 = new StandardScalerModel(model2.std, model2.mean, true, false) + val equivalentModel3 = new StandardScalerModel(model3.std, model3.mean, false, true) + + val data1 = denseData.map(equivalentModel1.transform) + val data2 = denseData.map(equivalentModel2.transform) + val data3 = denseData.map(equivalentModel3.transform) + + val data1RDD = equivalentModel1.transform(dataRDD) + val data2RDD = equivalentModel2.transform(dataRDD) + val data3RDD = equivalentModel3.transform(dataRDD) + + val summary = computeSummary(dataRDD) + val summary1 = computeSummary(data1RDD) + val summary2 = computeSummary(data2RDD) + val summary3 = computeSummary(data3RDD) + + assert((denseData, data1, data1RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((denseData, data2, data2RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((denseData, data3, data3RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + assert((data3, data3RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + + assert(summary1.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary1.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) + + assert(summary2.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary2.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) + + assert(summary3.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary3.variance ~== summary.variance absTol 1E-5) + + assert(data1(0) ~== Vectors.dense(-1.31527964, 1.023470449, 0.11637768424) absTol 1E-5) + assert(data1(3) ~== Vectors.dense(1.637735298, 0.156973995, 1.32247368462) absTol 1E-5) + assert(data2(4) ~== Vectors.dense(0.865538862, -0.22604255, 0.0) absTol 1E-5) + assert(data2(5) ~== Vectors.dense(0.0, 0.71580142, 0.0) absTol 1E-5) + assert(data3(1) ~== Vectors.dense(-0.58333333, -0.58333333, -2.8166666666) absTol 1E-5) + assert(data3(5) ~== Vectors.dense(-0.58333333, 2.316666666, 0.18333333333) absTol 1E-5) + } + test("Standardization with dense input") { - val data = Array( - Vectors.dense(-2.0, 2.3, 0), - Vectors.dense(0.0, -1.0, -3.0), - Vectors.dense(0.0, -5.1, 0.0), - Vectors.dense(3.8, 0.0, 1.9), - Vectors.dense(1.7, -0.6, 0.0), - Vectors.dense(0.0, 1.9, 0.0) - ) - val dataRDD = sc.parallelize(data, 3) + val dataRDD = sc.parallelize(denseData, 3) val standardizer1 = new StandardScaler(withMean = true, withStd = true) val standardizer2 = new StandardScaler() @@ -54,9 +139,9 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { val model2 = standardizer2.fit(dataRDD) val model3 = standardizer3.fit(dataRDD) - val data1 = data.map(model1.transform) - val data2 = data.map(model2.transform) - val data3 = data.map(model3.transform) + val data1 = denseData.map(model1.transform) + val data2 = denseData.map(model2.transform) + val data3 = denseData.map(model3.transform) val data1RDD = model1.transform(dataRDD) val data2RDD = model2.transform(dataRDD) @@ -67,19 +152,19 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { val summary2 = computeSummary(data2RDD) val summary3 = computeSummary(data3RDD) - assert((data, data1, data1RDD.collect()).zipped.forall { + assert((denseData, data1, data1RDD.collect()).zipped.forall { case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true case _ => false }, "The vector type should be preserved after standardization.") - assert((data, data2, data2RDD.collect()).zipped.forall { + assert((denseData, data2, data2RDD.collect()).zipped.forall { case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true case _ => false }, "The vector type should be preserved after standardization.") - assert((data, data3, data3RDD.collect()).zipped.forall { + assert((denseData, data3, data3RDD.collect()).zipped.forall { case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true case _ => false @@ -107,17 +192,58 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { } + test("Standardization with sparse input when means and stds are provided") { + + val dataRDD = sc.parallelize(sparseData, 3) + + val standardizer1 = new StandardScaler(withMean = true, withStd = true) + val standardizer2 = new StandardScaler() + val standardizer3 = new StandardScaler(withMean = true, withStd = false) + + val model1 = standardizer1.fit(dataRDD) + val model2 = standardizer2.fit(dataRDD) + val model3 = standardizer3.fit(dataRDD) + + val equivalentModel1 = new StandardScalerModel(model1.std, model1.mean) + val equivalentModel2 = new StandardScalerModel(model2.std, model2.mean, true, false) + val equivalentModel3 = new StandardScalerModel(model3.std, model3.mean, false, true) + + val data2 = sparseData.map(equivalentModel2.transform) + + withClue("Standardization with mean can not be applied on sparse input.") { + intercept[IllegalArgumentException] { + sparseData.map(equivalentModel1.transform) + } + } + + withClue("Standardization with mean can not be applied on sparse input.") { + intercept[IllegalArgumentException] { + sparseData.map(equivalentModel3.transform) + } + } + + val data2RDD = equivalentModel2.transform(dataRDD) + + val summary = computeSummary(data2RDD) + + assert((sparseData, data2, data2RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + + assert(summary.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) + + assert(data2(4) ~== Vectors.sparse(3, Seq((0, 0.865538862), (1, -0.22604255))) absTol 1E-5) + assert(data2(5) ~== Vectors.sparse(3, Seq((1, 0.71580142))) absTol 1E-5) + } + test("Standardization with sparse input") { - val data = Array( - Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), - Vectors.sparse(3, Seq((1, -1.0), (2, -3.0))), - Vectors.sparse(3, Seq((1, -5.1))), - Vectors.sparse(3, Seq((0, 3.8), (2, 1.9))), - Vectors.sparse(3, Seq((0, 1.7), (1, -0.6))), - Vectors.sparse(3, Seq((1, 1.9))) - ) - val dataRDD = sc.parallelize(data, 3) + val dataRDD = sc.parallelize(sparseData, 3) val standardizer1 = new StandardScaler(withMean = true, withStd = true) val standardizer2 = new StandardScaler() @@ -127,25 +253,26 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { val model2 = standardizer2.fit(dataRDD) val model3 = standardizer3.fit(dataRDD) - val data2 = data.map(model2.transform) + val data2 = sparseData.map(model2.transform) withClue("Standardization with mean can not be applied on sparse input.") { intercept[IllegalArgumentException] { - data.map(model1.transform) + sparseData.map(model1.transform) } } withClue("Standardization with mean can not be applied on sparse input.") { intercept[IllegalArgumentException] { - data.map(model3.transform) + sparseData.map(model3.transform) } } val data2RDD = model2.transform(dataRDD) - val summary2 = computeSummary(data2RDD) - assert((data, data2, data2RDD.collect()).zipped.forall { + val summary = computeSummary(data2RDD) + + assert((sparseData, data2, data2RDD.collect()).zipped.forall { case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true case _ => false @@ -153,23 +280,44 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) - assert(summary2.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) - assert(summary2.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) + assert(summary.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) assert(data2(4) ~== Vectors.sparse(3, Seq((0, 0.865538862), (1, -0.22604255))) absTol 1E-5) assert(data2(5) ~== Vectors.sparse(3, Seq((1, 0.71580142))) absTol 1E-5) } + test("Standardization with constant input when means and stds are provided") { + + val dataRDD = sc.parallelize(constantData, 2) + + val standardizer1 = new StandardScaler(withMean = true, withStd = true) + val standardizer2 = new StandardScaler(withMean = true, withStd = false) + val standardizer3 = new StandardScaler(withMean = false, withStd = true) + + val model1 = standardizer1.fit(dataRDD) + val model2 = standardizer2.fit(dataRDD) + val model3 = standardizer3.fit(dataRDD) + + val equivalentModel1 = new StandardScalerModel(model1.std, model1.mean) + val equivalentModel2 = new StandardScalerModel(model2.std, model2.mean, true, false) + val equivalentModel3 = new StandardScalerModel(model3.std, model3.mean, false, true) + + val data1 = constantData.map(equivalentModel1.transform) + val data2 = constantData.map(equivalentModel2.transform) + val data3 = constantData.map(equivalentModel3.transform) + + assert(data1.forall(_.toArray.forall(_ == 0.0)), + "The variance is zero, so the transformed result should be 0.0") + assert(data2.forall(_.toArray.forall(_ == 0.0)), + "The variance is zero, so the transformed result should be 0.0") + assert(data3.forall(_.toArray.forall(_ == 0.0)), + "The variance is zero, so the transformed result should be 0.0") + } + test("Standardization with constant input") { - // When the input data is all constant, the variance is zero. The standardization against - // zero variance is not well-defined, but we decide to just set it into zero here. - val data = Array( - Vectors.dense(2.0), - Vectors.dense(2.0), - Vectors.dense(2.0) - ) - val dataRDD = sc.parallelize(data, 2) + val dataRDD = sc.parallelize(constantData, 2) val standardizer1 = new StandardScaler(withMean = true, withStd = true) val standardizer2 = new StandardScaler(withMean = true, withStd = false) @@ -179,9 +327,9 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { val model2 = standardizer2.fit(dataRDD) val model3 = standardizer3.fit(dataRDD) - val data1 = data.map(model1.transform) - val data2 = data.map(model2.transform) - val data3 = data.map(model3.transform) + val data1 = constantData.map(model1.transform) + val data2 = constantData.map(model2.transform) + val data3 = constantData.map(model3.transform) assert(data1.forall(_.toArray.forall(_ == 0.0)), "The variance is zero, so the transformed result should be 0.0") @@ -191,4 +339,29 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { "The variance is zero, so the transformed result should be 0.0") } + test("StandardScalerModel argument nulls are properly handled") { + + withClue("model needs at least one of std or mean vectors") { + intercept[IllegalArgumentException] { + val model = new StandardScalerModel(null, null) + } + } + withClue("model needs std to set withStd to true") { + intercept[IllegalArgumentException] { + val model = new StandardScalerModel(null, Vectors.dense(0.0)) + model.setWithStd(true) + } + } + withClue("model needs mean to set withMean to true") { + intercept[IllegalArgumentException] { + val model = new StandardScalerModel(Vectors.dense(0.0), null) + model.setWithMean(true) + } + } + withClue("model needs std and mean vectors to be equal size when both are provided") { + intercept[IllegalArgumentException] { + val model = new StandardScalerModel(Vectors.dense(0.0), Vectors.dense(0.0,1.0)) + } + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index 52278690dbd8..98a98a7599bc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -21,6 +21,9 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils + class Word2VecSuite extends FunSuite with MLlibTestSparkContext { // TODO: add more tests @@ -51,4 +54,27 @@ class Word2VecSuite extends FunSuite with MLlibTestSparkContext { assert(syms(0)._1 == "taiwan") assert(syms(1)._1 == "japan") } + + test("model load / save") { + + val word2VecMap = Map( + ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)), + ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)), + ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), + ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f)) + ) + val model = new Word2VecModel(word2VecMap) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + try { + model.save(sc, path) + val sameModel = Word2VecModel.load(sc, path) + assert(sameModel.getVectors.mapValues(_.toSeq) === model.getVectors.mapValues(_.toSeq)) + } finally { + Utils.deleteRecursively(tempDir) + } + + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala new file mode 100644 index 000000000000..bd5b9cc3afa1 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.fpm + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { + + + test("FP-Growth using String type") { + val transactions = Seq( + "r z h k p", + "z y x w v u t s", + "s x o n r", + "x z y m t s q e", + "z", + "x z y r q t p") + .map(_.split(" ")) + val rdd = sc.parallelize(transactions, 2).cache() + + val fpg = new FPGrowth() + + val model6 = fpg + .setMinSupport(0.9) + .setNumPartitions(1) + .run(rdd) + assert(model6.freqItemsets.count() === 0) + + val model3 = fpg + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd) + val freqItemsets3 = model3.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) + } + val expected = Set( + (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L), + (Set("r"), 3L), + (Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", "x"), 3L), + (Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L), + (Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 3L), + (Set("t", "y", "x"), 3L), + (Set("t", "y", "x", "z"), 3L)) + assert(freqItemsets3.toSet === expected) + + val model2 = fpg + .setMinSupport(0.3) + .setNumPartitions(4) + .run(rdd) + assert(model2.freqItemsets.count() === 54) + + val model1 = fpg + .setMinSupport(0.1) + .setNumPartitions(8) + .run(rdd) + assert(model1.freqItemsets.count() === 625) + } + + test("FP-Growth using Int type") { + val transactions = Seq( + "1 2 3", + "1 2 3 4", + "5 4 3 2 1", + "6 5 4 3 2 1", + "2 4", + "1 3", + "1 7") + .map(_.split(" ").map(_.toInt).toArray) + val rdd = sc.parallelize(transactions, 2).cache() + + val fpg = new FPGrowth() + + val model6 = fpg + .setMinSupport(0.9) + .setNumPartitions(1) + .run(rdd) + assert(model6.freqItemsets.count() === 0) + + val model3 = fpg + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd) + assert(model3.freqItemsets.first().items.getClass === Array(1).getClass, + "frequent itemsets should use primitive arrays") + val freqItemsets3 = model3.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) + } + val expected = Set( + (Set(1), 6L), (Set(2), 5L), (Set(3), 5L), (Set(4), 4L), + (Set(1, 2), 4L), (Set(1, 3), 5L), (Set(2, 3), 4L), + (Set(2, 4), 4L), (Set(1, 2, 3), 4L)) + assert(freqItemsets3.toSet === expected) + + val model2 = fpg + .setMinSupport(0.3) + .setNumPartitions(4) + .run(rdd) + assert(model2.freqItemsets.count() === 15) + + val model1 = fpg + .setMinSupport(0.1) + .setNumPartitions(8) + .run(rdd) + assert(model1.freqItemsets.count() === 65) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala new file mode 100644 index 000000000000..04017f67c311 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.fpm + +import scala.language.existentials + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class FPTreeSuite extends FunSuite with MLlibTestSparkContext { + + test("add transaction") { + val tree = new FPTree[String] + .add(Seq("a", "b", "c")) + .add(Seq("a", "b", "y")) + .add(Seq("b")) + + assert(tree.root.children.size == 2) + assert(tree.root.children.contains("a")) + assert(tree.root.children("a").item.equals("a")) + assert(tree.root.children("a").count == 2) + assert(tree.root.children.contains("b")) + assert(tree.root.children("b").item.equals("b")) + assert(tree.root.children("b").count == 1) + var child = tree.root.children("a") + assert(child.children.size == 1) + assert(child.children.contains("b")) + assert(child.children("b").item.equals("b")) + assert(child.children("b").count == 2) + child = child.children("b") + assert(child.children.size == 2) + assert(child.children.contains("c")) + assert(child.children.contains("y")) + assert(child.children("c").item.equals("c")) + assert(child.children("y").item.equals("y")) + assert(child.children("c").count == 1) + assert(child.children("y").count == 1) + } + + test("merge tree") { + val tree1 = new FPTree[String] + .add(Seq("a", "b", "c")) + .add(Seq("a", "b", "y")) + .add(Seq("b")) + + val tree2 = new FPTree[String] + .add(Seq("a", "b")) + .add(Seq("a", "b", "c")) + .add(Seq("a", "b", "c", "d")) + .add(Seq("a", "x")) + .add(Seq("a", "x", "y")) + .add(Seq("c", "n")) + .add(Seq("c", "m")) + + val tree3 = tree1.merge(tree2) + + assert(tree3.root.children.size == 3) + assert(tree3.root.children("a").count == 7) + assert(tree3.root.children("b").count == 1) + assert(tree3.root.children("c").count == 2) + val child1 = tree3.root.children("a") + assert(child1.children.size == 2) + assert(child1.children("b").count == 5) + assert(child1.children("x").count == 2) + val child2 = child1.children("b") + assert(child2.children.size == 2) + assert(child2.children("y").count == 1) + assert(child2.children("c").count == 3) + val child3 = child2.children("c") + assert(child3.children.size == 1) + assert(child3.children("d").count == 1) + val child4 = child1.children("x") + assert(child4.children.size == 1) + assert(child4.children("y").count == 1) + val child5 = tree3.root.children("c") + assert(child5.children.size == 2) + assert(child5.children("n").count == 1) + assert(child5.children("m").count == 1) + } + + test("extract freq itemsets") { + val tree = new FPTree[String] + .add(Seq("a", "b", "c")) + .add(Seq("a", "b", "y")) + .add(Seq("a", "b")) + .add(Seq("a")) + .add(Seq("b")) + .add(Seq("b", "n")) + + val freqItemsets = tree.extract(3L).map { case (items, count) => + (items.toSet, count) + }.toSet + val expected = Set( + (Set("a"), 4L), + (Set("b"), 5L), + (Set("a", "b"), 3L)) + assert(freqItemsets === expected) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala new file mode 100644 index 000000000000..699f009f0f2e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import org.scalatest.FunSuite + +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.SparkContext +import org.apache.spark.graphx.{Edge, Graph} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils + + +class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext { + + import PeriodicGraphCheckpointerSuite._ + + // TODO: Do I need to call count() on the graphs' RDDs? + + test("Persisting") { + var graphsToCheck = Seq.empty[GraphToCheck] + + val graph1 = createGraph(sc) + val checkpointer = new PeriodicGraphCheckpointer(graph1, 10) + graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) + checkPersistence(graphsToCheck, 1) + + var iteration = 2 + while (iteration < 9) { + val graph = createGraph(sc) + checkpointer.updateGraph(graph) + graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) + checkPersistence(graphsToCheck, iteration) + iteration += 1 + } + } + + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + val checkpointInterval = 2 + var graphsToCheck = Seq.empty[GraphToCheck] + sc.setCheckpointDir(path) + val graph1 = createGraph(sc) + val checkpointer = new PeriodicGraphCheckpointer(graph1, checkpointInterval) + graph1.edges.count() + graph1.vertices.count() + graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) + checkCheckpoint(graphsToCheck, 1, checkpointInterval) + + var iteration = 2 + while (iteration < 9) { + val graph = createGraph(sc) + checkpointer.updateGraph(graph) + graph.vertices.count() + graph.edges.count() + graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) + checkCheckpoint(graphsToCheck, iteration, checkpointInterval) + iteration += 1 + } + + checkpointer.deleteAllCheckpoints() + graphsToCheck.foreach { graph => + confirmCheckpointRemoved(graph.graph) + } + + Utils.deleteRecursively(tempDir) + } +} + +private object PeriodicGraphCheckpointerSuite { + + case class GraphToCheck(graph: Graph[Double, Double], gIndex: Int) + + val edges = Seq( + Edge[Double](0, 1, 0), + Edge[Double](1, 2, 0), + Edge[Double](2, 3, 0), + Edge[Double](3, 4, 0)) + + def createGraph(sc: SparkContext): Graph[Double, Double] = { + Graph.fromEdges[Double, Double](sc.parallelize(edges), 0) + } + + def checkPersistence(graphs: Seq[GraphToCheck], iteration: Int): Unit = { + graphs.foreach { g => + checkPersistence(g.graph, g.gIndex, iteration) + } + } + + /** + * Check storage level of graph. + * @param gIndex Index of graph in order inserted into checkpointer (from 1). + * @param iteration Total number of graphs inserted into checkpointer. + */ + def checkPersistence(graph: Graph[_, _], gIndex: Int, iteration: Int): Unit = { + try { + if (gIndex + 2 < iteration) { + assert(graph.vertices.getStorageLevel == StorageLevel.NONE) + assert(graph.edges.getStorageLevel == StorageLevel.NONE) + } else { + assert(graph.vertices.getStorageLevel != StorageLevel.NONE) + assert(graph.edges.getStorageLevel != StorageLevel.NONE) + } + } catch { + case _: AssertionError => + throw new Exception(s"PeriodicGraphCheckpointerSuite.checkPersistence failed with:\n" + + s"\t gIndex = $gIndex\n" + + s"\t iteration = $iteration\n" + + s"\t graph.vertices.getStorageLevel = ${graph.vertices.getStorageLevel}\n" + + s"\t graph.edges.getStorageLevel = ${graph.edges.getStorageLevel}\n") + } + } + + def checkCheckpoint(graphs: Seq[GraphToCheck], iteration: Int, checkpointInterval: Int): Unit = { + graphs.reverse.foreach { g => + checkCheckpoint(g.graph, g.gIndex, iteration, checkpointInterval) + } + } + + def confirmCheckpointRemoved(graph: Graph[_, _]): Unit = { + // Note: We cannot check graph.isCheckpointed since that value is never updated. + // Instead, we check for the presence of the checkpoint files. + // This test should continue to work even after this graph.isCheckpointed issue + // is fixed (though it can then be simplified and not look for the files). + val fs = FileSystem.get(graph.vertices.sparkContext.hadoopConfiguration) + graph.getCheckpointFiles.foreach { checkpointFile => + assert(!fs.exists(new Path(checkpointFile)), + "Graph checkpoint file should have been removed") + } + } + + /** + * Check checkpointed status of graph. + * @param gIndex Index of graph in order inserted into checkpointer (from 1). + * @param iteration Total number of graphs inserted into checkpointer. + */ + def checkCheckpoint( + graph: Graph[_, _], + gIndex: Int, + iteration: Int, + checkpointInterval: Int): Unit = { + try { + if (gIndex % checkpointInterval == 0) { + // We allow 2 checkpoint intervals since we perform an action (checkpointing a second graph) + // only AFTER PeriodicGraphCheckpointer decides whether to remove the previous checkpoint. + if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) { + assert(graph.isCheckpointed, "Graph should be checkpointed") + assert(graph.getCheckpointFiles.length == 2, "Graph should have 2 checkpoint files") + } else { + confirmCheckpointRemoved(graph) + } + } else { + // Graph should never be checkpointed + assert(!graph.isCheckpointed, "Graph should never have been checkpointed") + assert(graph.getCheckpointFiles.length == 0, "Graph should not have any checkpoint files") + } + } catch { + case e: AssertionError => + throw new Exception(s"PeriodicGraphCheckpointerSuite.checkCheckpoint failed with:\n" + + s"\t gIndex = $gIndex\n" + + s"\t iteration = $iteration\n" + + s"\t checkpointInterval = $checkpointInterval\n" + + s"\t graph.isCheckpointed = ${graph.isCheckpointed}\n" + + s"\t graph.getCheckpointFiles = ${graph.getCheckpointFiles.mkString(", ")}\n" + + s" AssertionError message: ${e.getMessage}") + } + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 771878e925ea..002cb253862b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -166,19 +166,28 @@ class BLASSuite extends FunSuite { syr(alpha, y, dA) } } + + val xSparse = new SparseVector(4, Array(0, 2, 3), Array(1.0, 3.0, 4.0)) + val dD = new DenseMatrix(4, 4, + Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0, 3.1, 4.6, 3.0, 0.8)) + syr(0.1, xSparse, dD) + val expectedSparse = new DenseMatrix(4, 4, + Array(0.1, 1.2, 2.5, 3.5, 1.2, 3.2, 5.3, 4.6, 2.5, 5.3, 2.7, 4.2, 3.5, 4.6, 4.2, 2.4)) + assert(dD ~== expectedSparse absTol 1e-15) } test("gemm") { - val dA = new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0)) val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0)) val B = new DenseMatrix(3, 2, Array(1.0, 0.0, 0.0, 0.0, 2.0, 1.0)) val expected = new DenseMatrix(4, 2, Array(0.0, 1.0, 0.0, 0.0, 4.0, 0.0, 2.0, 3.0)) + val BTman = new DenseMatrix(2, 3, Array(1.0, 0.0, 0.0, 2.0, 0.0, 1.0)) + val BT = B.transpose - assert(dA multiply B ~== expected absTol 1e-15) - assert(sA multiply B ~== expected absTol 1e-15) + assert(dA.multiply(B) ~== expected absTol 1e-15) + assert(sA.multiply(B) ~== expected absTol 1e-15) val C1 = new DenseMatrix(4, 2, Array(1.0, 0.0, 2.0, 1.0, 0.0, 0.0, 1.0, 0.0)) val C2 = C1.copy @@ -188,6 +197,10 @@ class BLASSuite extends FunSuite { val C6 = C1.copy val C7 = C1.copy val C8 = C1.copy + val C9 = C1.copy + val C10 = C1.copy + val C11 = C1.copy + val C12 = C1.copy val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0)) val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0)) @@ -202,26 +215,40 @@ class BLASSuite extends FunSuite { withClue("columns of A don't match the rows of B") { intercept[Exception] { - gemm(true, false, 1.0, dA, B, 2.0, C1) + gemm(1.0, dA.transpose, B, 2.0, C1) } } - val dAT = + val dATman = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) - val sAT = + val sATman = new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) - assert(dAT transposeMultiply B ~== expected absTol 1e-15) - assert(sAT transposeMultiply B ~== expected absTol 1e-15) - - gemm(true, false, 1.0, dAT, B, 2.0, C5) - gemm(true, false, 1.0, sAT, B, 2.0, C6) - gemm(true, false, 2.0, dAT, B, 2.0, C7) - gemm(true, false, 2.0, sAT, B, 2.0, C8) + val dATT = dATman.transpose + val sATT = sATman.transpose + val BTT = BTman.transpose.asInstanceOf[DenseMatrix] + + assert(dATT.multiply(B) ~== expected absTol 1e-15) + assert(sATT.multiply(B) ~== expected absTol 1e-15) + assert(dATT.multiply(BTT) ~== expected absTol 1e-15) + assert(sATT.multiply(BTT) ~== expected absTol 1e-15) + + gemm(1.0, dATT, BTT, 2.0, C5) + gemm(1.0, sATT, BTT, 2.0, C6) + gemm(2.0, dATT, BTT, 2.0, C7) + gemm(2.0, sATT, BTT, 2.0, C8) + gemm(1.0, dA, BTT, 2.0, C9) + gemm(1.0, sA, BTT, 2.0, C10) + gemm(2.0, dA, BTT, 2.0, C11) + gemm(2.0, sA, BTT, 2.0, C12) assert(C5 ~== expected2 absTol 1e-15) assert(C6 ~== expected2 absTol 1e-15) assert(C7 ~== expected3 absTol 1e-15) assert(C8 ~== expected3 absTol 1e-15) + assert(C9 ~== expected2 absTol 1e-15) + assert(C10 ~== expected2 absTol 1e-15) + assert(C11 ~== expected3 absTol 1e-15) + assert(C12 ~== expected3 absTol 1e-15) } test("gemv") { @@ -233,17 +260,13 @@ class BLASSuite extends FunSuite { val x = new DenseVector(Array(1.0, 2.0, 3.0)) val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0)) - assert(dA multiply x ~== expected absTol 1e-15) - assert(sA multiply x ~== expected absTol 1e-15) + assert(dA.multiply(x) ~== expected absTol 1e-15) + assert(sA.multiply(x) ~== expected absTol 1e-15) val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0)) val y2 = y1.copy val y3 = y1.copy val y4 = y1.copy - val y5 = y1.copy - val y6 = y1.copy - val y7 = y1.copy - val y8 = y1.copy val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0)) val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0)) @@ -257,25 +280,18 @@ class BLASSuite extends FunSuite { assert(y4 ~== expected3 absTol 1e-15) withClue("columns of A don't match the rows of B") { intercept[Exception] { - gemv(true, 1.0, dA, x, 2.0, y1) + gemv(1.0, dA.transpose, x, 2.0, y1) } } - val dAT = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) val sAT = new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) - assert(dAT transposeMultiply x ~== expected absTol 1e-15) - assert(sAT transposeMultiply x ~== expected absTol 1e-15) - - gemv(true, 1.0, dAT, x, 2.0, y5) - gemv(true, 1.0, sAT, x, 2.0, y6) - gemv(true, 2.0, dAT, x, 2.0, y7) - gemv(true, 2.0, sAT, x, 2.0, y8) - assert(y5 ~== expected2 absTol 1e-15) - assert(y6 ~== expected2 absTol 1e-15) - assert(y7 ~== expected3 absTol 1e-15) - assert(y8 ~== expected3 absTol 1e-15) + val dATT = dAT.transpose + val sATT = sAT.transpose + + assert(dATT.multiply(x) ~== expected absTol 1e-15) + assert(sATT.multiply(x) ~== expected absTol 1e-15) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala index 73a6d3a27d86..203103237397 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala @@ -36,6 +36,11 @@ class BreezeMatrixConversionSuite extends FunSuite { assert(mat.numRows === breeze.rows) assert(mat.numCols === breeze.cols) assert(mat.values.eq(breeze.data), "should not copy data") + // transposed matrix + val matTransposed = Matrices.fromBreeze(breeze.t).asInstanceOf[DenseMatrix] + assert(matTransposed.numRows === breeze.cols) + assert(matTransposed.numCols === breeze.rows) + assert(matTransposed.values.eq(breeze.data), "should not copy data") } test("sparse matrix to breeze") { @@ -58,5 +63,9 @@ class BreezeMatrixConversionSuite extends FunSuite { assert(mat.numRows === breeze.rows) assert(mat.numCols === breeze.cols) assert(mat.values.eq(breeze.data), "should not copy data") + val matTransposed = Matrices.fromBreeze(breeze.t).asInstanceOf[SparseMatrix] + assert(matTransposed.numRows === breeze.cols) + assert(matTransposed.numCols === breeze.rows) + assert(!matTransposed.values.eq(breeze.data), "has to copy data") } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index a35d0fe389fd..86119ec38101 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -22,6 +22,9 @@ import java.util.Random import org.mockito.Mockito.when import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar._ +import scala.collection.mutable.{Map => MutableMap} + +import org.apache.spark.mllib.util.TestingUtils._ class MatricesSuite extends FunSuite { test("dense matrix construction") { @@ -32,7 +35,6 @@ class MatricesSuite extends FunSuite { assert(mat.numRows === m) assert(mat.numCols === n) assert(mat.values.eq(values), "should not copy data") - assert(mat.toArray.eq(values), "toArray should not copy data") } test("dense matrix construction with wrong dimension") { @@ -135,8 +137,8 @@ class MatricesSuite extends FunSuite { val spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values) val deMat1 = new DenseMatrix(m, n, allValues) - val spMat2 = deMat1.toSparse() - val deMat2 = spMat1.toDense() + val spMat2 = deMat1.toSparse + val deMat2 = spMat1.toDense assert(spMat1.toBreeze === spMat2.toBreeze) assert(deMat1.toBreeze === deMat2.toBreeze) @@ -161,6 +163,66 @@ class MatricesSuite extends FunSuite { assert(deMat1.toArray === deMat2.toArray) } + test("transpose") { + val dA = + new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0)) + val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0)) + + val dAT = dA.transpose.asInstanceOf[DenseMatrix] + val sAT = sA.transpose.asInstanceOf[SparseMatrix] + val dATexpected = + new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) + val sATexpected = + new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) + + assert(dAT.toBreeze === dATexpected.toBreeze) + assert(sAT.toBreeze === sATexpected.toBreeze) + assert(dA(1, 0) === dAT(0, 1)) + assert(dA(2, 1) === dAT(1, 2)) + assert(sA(1, 0) === sAT(0, 1)) + assert(sA(2, 1) === sAT(1, 2)) + + assert(!dA.toArray.eq(dAT.toArray), "has to have a new array") + assert(dA.values.eq(dAT.transpose.asInstanceOf[DenseMatrix].values), "should not copy array") + + assert(dAT.toSparse.toBreeze === sATexpected.toBreeze) + assert(sAT.toDense.toBreeze === dATexpected.toBreeze) + } + + test("foreachActive") { + val m = 3 + val n = 2 + val values = Array(1.0, 2.0, 4.0, 5.0) + val allValues = Array(1.0, 2.0, 0.0, 0.0, 4.0, 5.0) + val colPtrs = Array(0, 2, 4) + val rowIndices = Array(0, 1, 1, 2) + + val sp = new SparseMatrix(m, n, colPtrs, rowIndices, values) + val dn = new DenseMatrix(m, n, allValues) + + val dnMap = MutableMap[(Int, Int), Double]() + dn.foreachActive { (i, j, value) => + dnMap.put((i, j), value) + } + assert(dnMap.size === 6) + assert(dnMap(0, 0) === 1.0) + assert(dnMap(1, 0) === 2.0) + assert(dnMap(2, 0) === 0.0) + assert(dnMap(0, 1) === 0.0) + assert(dnMap(1, 1) === 4.0) + assert(dnMap(2, 1) === 5.0) + + val spMap = MutableMap[(Int, Int), Double]() + sp.foreachActive { (i, j, value) => + spMap.put((i, j), value) + } + assert(spMap.size === 4) + assert(spMap(0, 0) === 1.0) + assert(spMap(1, 0) === 2.0) + assert(spMap(1, 1) === 4.0) + assert(spMap(2, 1) === 5.0) + } + test("horzcat, vertcat, eye, speye") { val m = 3 val n = 2 @@ -168,9 +230,20 @@ class MatricesSuite extends FunSuite { val allValues = Array(1.0, 2.0, 0.0, 0.0, 4.0, 5.0) val colPtrs = Array(0, 2, 4) val rowIndices = Array(0, 1, 1, 2) + // transposed versions + val allValuesT = Array(1.0, 0.0, 2.0, 4.0, 0.0, 5.0) + val colPtrsT = Array(0, 1, 3, 4) + val rowIndicesT = Array(0, 0, 1, 1) val spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values) val deMat1 = new DenseMatrix(m, n, allValues) + val spMat1T = new SparseMatrix(n, m, colPtrsT, rowIndicesT, values) + val deMat1T = new DenseMatrix(n, m, allValuesT) + + // should equal spMat1 & deMat1 respectively + val spMat1TT = spMat1T.transpose + val deMat1TT = deMat1T.transpose + val deMat2 = Matrices.eye(3) val spMat2 = Matrices.speye(3) val deMat3 = Matrices.eye(2) @@ -180,7 +253,6 @@ class MatricesSuite extends FunSuite { val spHorz2 = Matrices.horzcat(Array(spMat1, deMat2)) val spHorz3 = Matrices.horzcat(Array(deMat1, spMat2)) val deHorz1 = Matrices.horzcat(Array(deMat1, deMat2)) - val deHorz2 = Matrices.horzcat(Array[Matrix]()) assert(deHorz1.numRows === 3) @@ -195,8 +267,8 @@ class MatricesSuite extends FunSuite { assert(deHorz2.numCols === 0) assert(deHorz2.toArray.length === 0) - assert(deHorz1.toBreeze.toDenseMatrix === spHorz2.toBreeze.toDenseMatrix) - assert(spHorz2.toBreeze === spHorz3.toBreeze) + assert(deHorz1 ~== spHorz2.asInstanceOf[SparseMatrix].toDense absTol 1e-15) + assert(spHorz2 ~== spHorz3 absTol 1e-15) assert(spHorz(0, 0) === 1.0) assert(spHorz(2, 1) === 5.0) assert(spHorz(0, 2) === 1.0) @@ -212,6 +284,17 @@ class MatricesSuite extends FunSuite { assert(deHorz1(2, 4) === 1.0) assert(deHorz1(1, 4) === 0.0) + // containing transposed matrices + val spHorzT = Matrices.horzcat(Array(spMat1TT, spMat2)) + val spHorz2T = Matrices.horzcat(Array(spMat1TT, deMat2)) + val spHorz3T = Matrices.horzcat(Array(deMat1TT, spMat2)) + val deHorz1T = Matrices.horzcat(Array(deMat1TT, deMat2)) + + assert(deHorz1T ~== deHorz1 absTol 1e-15) + assert(spHorzT ~== spHorz absTol 1e-15) + assert(spHorz2T ~== spHorz2 absTol 1e-15) + assert(spHorz3T ~== spHorz3 absTol 1e-15) + intercept[IllegalArgumentException] { Matrices.horzcat(Array(spMat1, spMat3)) } @@ -238,8 +321,8 @@ class MatricesSuite extends FunSuite { assert(deVert2.numCols === 0) assert(deVert2.toArray.length === 0) - assert(deVert1.toBreeze.toDenseMatrix === spVert2.toBreeze.toDenseMatrix) - assert(spVert2.toBreeze === spVert3.toBreeze) + assert(deVert1 ~== spVert2.asInstanceOf[SparseMatrix].toDense absTol 1e-15) + assert(spVert2 ~== spVert3 absTol 1e-15) assert(spVert(0, 0) === 1.0) assert(spVert(2, 1) === 5.0) assert(spVert(3, 0) === 1.0) @@ -251,6 +334,17 @@ class MatricesSuite extends FunSuite { assert(deVert1(3, 1) === 0.0) assert(deVert1(4, 1) === 1.0) + // containing transposed matrices + val spVertT = Matrices.vertcat(Array(spMat1TT, spMat3)) + val deVert1T = Matrices.vertcat(Array(deMat1TT, deMat3)) + val spVert2T = Matrices.vertcat(Array(spMat1TT, deMat3)) + val spVert3T = Matrices.vertcat(Array(deMat1TT, spMat3)) + + assert(deVert1T ~== deVert1 absTol 1e-15) + assert(spVertT ~== spVert absTol 1e-15) + assert(spVert2T ~== spVert2 absTol 1e-15) + assert(spVert3T ~== spVert3 absTol 1e-15) + intercept[IllegalArgumentException] { Matrices.vertcat(Array(spMat1, spMat2)) } @@ -330,4 +424,35 @@ class MatricesSuite extends FunSuite { assert(mat.rowIndices.toSeq === Seq(3, 0, 2, 1)) assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0)) } + + test("MatrixUDT") { + val dm1 = new DenseMatrix(2, 2, Array(0.9, 1.2, 2.3, 9.8)) + val dm2 = new DenseMatrix(3, 2, Array(0.0, 1.21, 2.3, 9.8, 9.0, 0.0)) + val dm3 = new DenseMatrix(0, 0, Array()) + val sm1 = dm1.toSparse + val sm2 = dm2.toSparse + val sm3 = dm3.toSparse + val mUDT = new MatrixUDT() + Seq(dm1, dm2, dm3, sm1, sm2, sm3).foreach { + mat => assert(mat.toArray === mUDT.deserialize(mUDT.serialize(mat)).toArray) + } + assert(mUDT.typeName == "matrix") + assert(mUDT.simpleString == "matrix") + } + + test("toString") { + val empty = Matrices.ones(0, 0) + empty.toString(0, 0) + + val mat = Matrices.rand(5, 10, new Random()) + mat.toString(-1, -5) + mat.toString(0, 0) + mat.toString(Int.MinValue, Int.MinValue) + mat.toString(Int.MaxValue, Int.MaxValue) + var lines = mat.toString(6, 50).lines.toArray + assert(lines.size == 5 && lines.forall(_.size <= 50)) + + lines = mat.toString(5, 100).lines.toArray + assert(lines.size == 5 && lines.forall(_.size <= 100)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 5def899cea11..2839c4c289b2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -187,6 +187,8 @@ class VectorsSuite extends FunSuite { for (v <- Seq(dv0, dv1, sv0, sv1)) { assert(v === udt.deserialize(udt.serialize(v))) } + assert(udt.typeName == "vector") + assert(udt.simpleString == "vector") } test("fromBreeze") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala new file mode 100644 index 000000000000..949d1c993957 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg.distributed + +import java.{util => ju} + +import breeze.linalg.{DenseMatrix => BDM} +import org.scalatest.FunSuite + +import org.apache.spark.SparkException +import org.apache.spark.mllib.linalg.{SparseMatrix, DenseMatrix, Matrices, Matrix} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext { + + val m = 5 + val n = 4 + val rowPerPart = 2 + val colPerPart = 2 + val numPartitions = 3 + var gridBasedMat: BlockMatrix = _ + + override def beforeAll() { + super.beforeAll() + val blocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 2.0))), + ((0, 1), new DenseMatrix(2, 2, Array(0.0, 1.0, 0.0, 0.0))), + ((1, 0), new DenseMatrix(2, 2, Array(3.0, 0.0, 1.0, 1.0))), + ((1, 1), new DenseMatrix(2, 2, Array(1.0, 2.0, 0.0, 1.0))), + ((2, 1), new DenseMatrix(1, 2, Array(1.0, 5.0)))) + + gridBasedMat = new BlockMatrix(sc.parallelize(blocks, numPartitions), rowPerPart, colPerPart) + } + + test("size") { + assert(gridBasedMat.numRows() === m) + assert(gridBasedMat.numCols() === n) + } + + test("grid partitioner") { + val random = new ju.Random() + // This should generate a 4x4 grid of 1x2 blocks. + val part0 = GridPartitioner(4, 7, suggestedNumPartitions = 12) + val expected0 = Array( + Array(0, 0, 4, 4, 8, 8, 12), + Array(1, 1, 5, 5, 9, 9, 13), + Array(2, 2, 6, 6, 10, 10, 14), + Array(3, 3, 7, 7, 11, 11, 15)) + for (i <- 0 until 4; j <- 0 until 7) { + assert(part0.getPartition((i, j)) === expected0(i)(j)) + assert(part0.getPartition((i, j, random.nextInt())) === expected0(i)(j)) + } + + intercept[IllegalArgumentException] { + part0.getPartition((-1, 0)) + } + + intercept[IllegalArgumentException] { + part0.getPartition((4, 0)) + } + + intercept[IllegalArgumentException] { + part0.getPartition((0, -1)) + } + + intercept[IllegalArgumentException] { + part0.getPartition((0, 7)) + } + + val part1 = GridPartitioner(2, 2, suggestedNumPartitions = 5) + val expected1 = Array( + Array(0, 2), + Array(1, 3)) + for (i <- 0 until 2; j <- 0 until 2) { + assert(part1.getPartition((i, j)) === expected1(i)(j)) + assert(part1.getPartition((i, j, random.nextInt())) === expected1(i)(j)) + } + + val part2 = GridPartitioner(2, 2, suggestedNumPartitions = 5) + assert(part0 !== part2) + assert(part1 === part2) + + val part3 = new GridPartitioner(2, 3, rowsPerPart = 1, colsPerPart = 2) + val expected3 = Array( + Array(0, 0, 2), + Array(1, 1, 3)) + for (i <- 0 until 2; j <- 0 until 3) { + assert(part3.getPartition((i, j)) === expected3(i)(j)) + assert(part3.getPartition((i, j, random.nextInt())) === expected3(i)(j)) + } + + val part4 = GridPartitioner(2, 3, rowsPerPart = 1, colsPerPart = 2) + assert(part3 === part4) + + intercept[IllegalArgumentException] { + new GridPartitioner(2, 2, rowsPerPart = 0, colsPerPart = 1) + } + + intercept[IllegalArgumentException] { + GridPartitioner(2, 2, rowsPerPart = 1, colsPerPart = 0) + } + + intercept[IllegalArgumentException] { + GridPartitioner(2, 2, suggestedNumPartitions = 0) + } + } + + test("toCoordinateMatrix") { + val coordMat = gridBasedMat.toCoordinateMatrix() + assert(coordMat.numRows() === m) + assert(coordMat.numCols() === n) + assert(coordMat.toBreeze() === gridBasedMat.toBreeze()) + } + + test("toIndexedRowMatrix") { + val rowMat = gridBasedMat.toIndexedRowMatrix() + assert(rowMat.numRows() === m) + assert(rowMat.numCols() === n) + assert(rowMat.toBreeze() === gridBasedMat.toBreeze()) + } + + test("toBreeze and toLocalMatrix") { + val expected = BDM( + (1.0, 0.0, 0.0, 0.0), + (0.0, 2.0, 1.0, 0.0), + (3.0, 1.0, 1.0, 0.0), + (0.0, 1.0, 2.0, 1.0), + (0.0, 0.0, 1.0, 5.0)) + + val dense = Matrices.fromBreeze(expected).asInstanceOf[DenseMatrix] + assert(gridBasedMat.toLocalMatrix() === dense) + assert(gridBasedMat.toBreeze() === expected) + } + + test("add") { + val blocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 2.0))), + ((0, 1), new DenseMatrix(2, 2, Array(0.0, 1.0, 0.0, 0.0))), + ((1, 0), new DenseMatrix(2, 2, Array(3.0, 0.0, 1.0, 1.0))), + ((1, 1), new DenseMatrix(2, 2, Array(1.0, 2.0, 0.0, 1.0))), + ((2, 0), new DenseMatrix(1, 2, Array(1.0, 0.0))), // Added block that doesn't exist in A + ((2, 1), new DenseMatrix(1, 2, Array(1.0, 5.0)))) + val rdd = sc.parallelize(blocks, numPartitions) + val B = new BlockMatrix(rdd, rowPerPart, colPerPart) + + val expected = BDM( + (2.0, 0.0, 0.0, 0.0), + (0.0, 4.0, 2.0, 0.0), + (6.0, 2.0, 2.0, 0.0), + (0.0, 2.0, 4.0, 2.0), + (1.0, 0.0, 2.0, 10.0)) + + val AplusB = gridBasedMat.add(B) + assert(AplusB.numRows() === m) + assert(AplusB.numCols() === B.numCols()) + assert(AplusB.toBreeze() === expected) + + val C = new BlockMatrix(rdd, rowPerPart, colPerPart, m, n + 1) // columns don't match + intercept[IllegalArgumentException] { + gridBasedMat.add(C) + } + val largerBlocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(4, 4, new Array[Double](16))), + ((1, 0), new DenseMatrix(1, 4, Array(1.0, 0.0, 1.0, 5.0)))) + val C2 = new BlockMatrix(sc.parallelize(largerBlocks, numPartitions), 4, 4, m, n) + intercept[SparkException] { // partitioning doesn't match + gridBasedMat.add(C2) + } + // adding BlockMatrices composed of SparseMatrices + val sparseBlocks = for (i <- 0 until 4) yield ((i / 2, i % 2), SparseMatrix.speye(4)) + val denseBlocks = for (i <- 0 until 4) yield ((i / 2, i % 2), DenseMatrix.eye(4)) + val sparseBM = new BlockMatrix(sc.makeRDD(sparseBlocks, 4), 4, 4, 8, 8) + val denseBM = new BlockMatrix(sc.makeRDD(denseBlocks, 4), 4, 4, 8, 8) + + assert(sparseBM.add(sparseBM).toBreeze() === sparseBM.add(denseBM).toBreeze()) + } + + test("multiply") { + // identity matrix + val blocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 1.0))), + ((1, 1), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 1.0)))) + val rdd = sc.parallelize(blocks, 2) + val B = new BlockMatrix(rdd, colPerPart, rowPerPart) + val expected = BDM( + (1.0, 0.0, 0.0, 0.0), + (0.0, 2.0, 1.0, 0.0), + (3.0, 1.0, 1.0, 0.0), + (0.0, 1.0, 2.0, 1.0), + (0.0, 0.0, 1.0, 5.0)) + + val AtimesB = gridBasedMat.multiply(B) + assert(AtimesB.numRows() === m) + assert(AtimesB.numCols() === n) + assert(AtimesB.toBreeze() === expected) + val C = new BlockMatrix(rdd, rowPerPart, colPerPart, m + 1, n) // dimensions don't match + intercept[IllegalArgumentException] { + gridBasedMat.multiply(C) + } + val largerBlocks = Seq(((0, 0), DenseMatrix.eye(4))) + val C2 = new BlockMatrix(sc.parallelize(largerBlocks, numPartitions), 4, 4) + intercept[SparkException] { + // partitioning doesn't match + gridBasedMat.multiply(C2) + } + val rand = new ju.Random(42) + val largerAblocks = for (i <- 0 until 20) yield ((i % 5, i / 5), DenseMatrix.rand(6, 4, rand)) + val largerBblocks = for (i <- 0 until 16) yield ((i % 4, i / 4), DenseMatrix.rand(4, 4, rand)) + + // Try it with increased number of partitions + val largeA = new BlockMatrix(sc.parallelize(largerAblocks, 10), 6, 4) + val largeB = new BlockMatrix(sc.parallelize(largerBblocks, 8), 4, 4) + val largeC = largeA.multiply(largeB) + val localC = largeC.toLocalMatrix() + val result = largeA.toLocalMatrix().multiply(largeB.toLocalMatrix().asInstanceOf[DenseMatrix]) + assert(largeC.numRows() === largeA.numRows()) + assert(largeC.numCols() === largeB.numCols()) + assert(localC ~== result absTol 1e-8) + } + + test("validate") { + // No error + gridBasedMat.validate() + // Wrong MatrixBlock dimensions + val blocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 2.0))), + ((0, 1), new DenseMatrix(2, 2, Array(0.0, 1.0, 0.0, 0.0))), + ((1, 0), new DenseMatrix(2, 2, Array(3.0, 0.0, 1.0, 1.0))), + ((1, 1), new DenseMatrix(2, 2, Array(1.0, 2.0, 0.0, 1.0))), + ((2, 1), new DenseMatrix(1, 2, Array(1.0, 5.0)))) + val rdd = sc.parallelize(blocks, numPartitions) + val wrongRowPerParts = new BlockMatrix(rdd, rowPerPart + 1, colPerPart) + val wrongColPerParts = new BlockMatrix(rdd, rowPerPart, colPerPart + 1) + intercept[SparkException] { + wrongRowPerParts.validate() + } + intercept[SparkException] { + wrongColPerParts.validate() + } + // Wrong BlockMatrix dimensions + val wrongRowSize = new BlockMatrix(rdd, rowPerPart, colPerPart, 4, 4) + intercept[AssertionError] { + wrongRowSize.validate() + } + val wrongColSize = new BlockMatrix(rdd, rowPerPart, colPerPart, 5, 2) + intercept[AssertionError] { + wrongColSize.validate() + } + // Duplicate indices + val duplicateBlocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 2.0))), + ((0, 0), new DenseMatrix(2, 2, Array(0.0, 1.0, 0.0, 0.0))), + ((1, 1), new DenseMatrix(2, 2, Array(3.0, 0.0, 1.0, 1.0))), + ((1, 1), new DenseMatrix(2, 2, Array(1.0, 2.0, 0.0, 1.0))), + ((2, 1), new DenseMatrix(1, 2, Array(1.0, 5.0)))) + val dupMatrix = new BlockMatrix(sc.parallelize(duplicateBlocks, numPartitions), 2, 2) + intercept[SparkException] { + dupMatrix.validate() + } + } + + test("transpose") { + val expected = BDM( + (1.0, 0.0, 3.0, 0.0, 0.0), + (0.0, 2.0, 1.0, 1.0, 0.0), + (0.0, 1.0, 1.0, 2.0, 1.0), + (0.0, 0.0, 0.0, 1.0, 5.0)) + + val AT = gridBasedMat.transpose + assert(AT.numRows() === gridBasedMat.numCols()) + assert(AT.numCols() === gridBasedMat.numRows()) + assert(AT.toBreeze() === expected) + + // make sure it works when matrices are cached as well + gridBasedMat.cache() + val AT2 = gridBasedMat.transpose + AT2.cache() + assert(AT2.toBreeze() === AT.toBreeze()) + val A = AT2.transpose + assert(A.toBreeze() === gridBasedMat.toBreeze()) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala index 80bef814ce50..04b36a9ef999 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala @@ -100,4 +100,18 @@ class CoordinateMatrixSuite extends FunSuite with MLlibTestSparkContext { Vectors.dense(0.0, 9.0, 0.0, 0.0)) assert(rows === expected) } + + test("toBlockMatrix") { + val blockMat = mat.toBlockMatrix(2, 2) + assert(blockMat.numRows() === m) + assert(blockMat.numCols() === n) + assert(blockMat.toBreeze() === mat.toBreeze()) + + intercept[IllegalArgumentException] { + mat.toBlockMatrix(-1, 2) + } + intercept[IllegalArgumentException] { + mat.toBlockMatrix(2, 0) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index b86c2ca5ff13..2ab53cc13db7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -88,6 +88,21 @@ class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext { assert(coordMat.toBreeze() === idxRowMat.toBreeze()) } + test("toBlockMatrix") { + val idxRowMat = new IndexedRowMatrix(indexedRows) + val blockMat = idxRowMat.toBlockMatrix(2, 2) + assert(blockMat.numRows() === m) + assert(blockMat.numCols() === n) + assert(blockMat.toBreeze() === idxRowMat.toBreeze()) + + intercept[IllegalArgumentException] { + idxRowMat.toBlockMatrix(-1, 2) + } + intercept[IllegalArgumentException] { + idxRowMat.toBlockMatrix(2, 0) + } + } + test("multiply a local matrix") { val A = new IndexedRowMatrix(indexedRows) val B = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala index 82c327bd49fc..22855e4e8f24 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala @@ -55,7 +55,7 @@ class NNLSSuite extends FunSuite { for (k <- 0 until 100) { val (ata, atb) = genOnesData(n, rand) - val x = new DoubleMatrix(NNLS.solve(ata, atb, ws)) + val x = new DoubleMatrix(NNLS.solve(ata.data, atb.data, ws)) assert(x.length === n) val answer = DoubleMatrix.ones(n, 1) SimpleBlas.axpy(-1.0, answer, x) @@ -79,7 +79,7 @@ class NNLSSuite extends FunSuite { val goodx = Array(0.13025, 0.54506, 0.2874, 0.0, 0.028628) val ws = NNLS.createWorkspace(n) - val x = NNLS.solve(ata, atb, ws) + val x = NNLS.solve(ata.data, atb.data, ws) for (i <- 0 until n) { assert(x(i) ~== goodx(i) absTol 1E-3) assert(x(i) >= 0) @@ -104,7 +104,7 @@ class NNLSSuite extends FunSuite { val ws = NNLS.createWorkspace(n) - val x = new DoubleMatrix(NNLS.solve(ata, atb, ws)) + val x = new DoubleMatrix(NNLS.solve(ata.data, atb.data, ws)) val obj = computeObjectiveValue(ata, atb, x) assert(obj < refObj + 1E-5) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala index 6395188a0842..63f2ea916d45 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala @@ -181,7 +181,8 @@ class RandomRDDsSuite extends FunSuite with MLlibTestSparkContext with Serializa val poisson = RandomRDDs.poissonVectorRDD(sc, poissonMean, rows, cols, parts, seed) testGeneratedVectorRDD(poisson, rows, cols, parts, poissonMean, math.sqrt(poissonMean), 0.1) - val exponential = RandomRDDs.exponentialVectorRDD(sc, exponentialMean, rows, cols, parts, seed) + val exponential = + RandomRDDs.exponentialVectorRDD(sc, exponentialMean, rows, cols, parts, seed) testGeneratedVectorRDD(exponential, rows, cols, parts, exponentialMean, exponentialMean, 0.1) val gamma = RandomRDDs.gammaVectorRDD(sc, gammaShape, gammaScale, rows, cols, parts, seed) @@ -197,7 +198,7 @@ private[random] class MockDistro extends RandomDataGenerator[Double] { // This allows us to check that each partition has a different seed override def nextValue(): Double = seed.toDouble - override def setSeed(seed: Long) = this.seed = seed + override def setSeed(seed: Long): Unit = this.seed = seed override def copy(): MockDistro = new MockDistro } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala new file mode 100644 index 000000000000..1ac7c12c4e8e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.rdd + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.rdd.MLPairRDDFunctions._ + +class MLPairRDDFunctionsSuite extends FunSuite with MLlibTestSparkContext { + test("topByKey") { + val topMap = sc.parallelize(Array((1, 1), (1, 2), (3, 2), (3, 7), (3, 5), (5, 1), (5, 3)), 2) + .topByKey(2) + .collectAsMap() + + assert(topMap.size === 3) + assert(topMap(1) === Array(2, 1)) + assert(topMap(3) === Array(7, 5)) + assert(topMap(5) === Array(3, 1)) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index 681ce9263933..6d6c0aa5be81 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -46,22 +46,4 @@ class RDDFunctionsSuite extends FunSuite with MLlibTestSparkContext { val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq) assert(sliding === expected) } - - test("treeAggregate") { - val rdd = sc.makeRDD(-1000 until 1000, 10) - def seqOp = (c: Long, x: Int) => c + x - def combOp = (c1: Long, c2: Long) => c1 + c2 - for (depth <- 1 until 10) { - val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth) - assert(sum === -1000L) - } - } - - test("treeReduce") { - val rdd = sc.makeRDD(-1000 until 1000, 10) - for (depth <- 1 until 10) { - val sum = rdd.treeReduce(_ + _, depth) - assert(sum === -1000) - } - } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index e9fc37e00052..b3798940ddc3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -24,9 +24,7 @@ import scala.util.Random import org.scalatest.FunSuite import org.jblas.DoubleMatrix -import org.apache.spark.SparkContext._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.recommendation.ALS.BlockStats import org.apache.spark.storage.StorageLevel object ALSSuite { @@ -189,22 +187,6 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext { testALS(100, 200, 2, 15, 0.7, 0.4, false, false, false, -1, -1, false) } - test("analyze one user block and one product block") { - val localRatings = Seq( - Rating(0, 100, 1.0), - Rating(0, 101, 2.0), - Rating(0, 102, 3.0), - Rating(1, 102, 4.0), - Rating(2, 103, 5.0)) - val ratings = sc.makeRDD(localRatings, 2) - val stats = ALS.analyzeBlocks(ratings, 1, 1) - assert(stats.size === 2) - assert(stats(0) === BlockStats("user", 0, 3, 5, 4, 3)) - assert(stats(1) === BlockStats("product", 0, 4, 5, 3, 4)) - } - - // TODO: add tests for analyzing multiple user/product blocks - /** * Test if we can correctly factorize R = U * P where U and P are of known rank. * @@ -221,6 +203,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext { * @param numProductBlocks number of product blocks to partition products into * @param negativeFactors whether the generated user/product factors can have negative entries */ + // scalastyle:off def testALS( users: Int, products: Int, @@ -234,6 +217,8 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext { numUserBlocks: Int = -1, numProductBlocks: Int = -1, negativeFactors: Boolean = true) { + // scalastyle:on + val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products, features, samplingRate, implicitPrefs, negativeWeights, negativeFactors) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala index b9caecc904a2..9801e8757674 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext { @@ -53,4 +54,22 @@ class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext new MatrixFactorizationModel(rank, userFeatures, prodFeatures1) } } + + test("save/load") { + val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + def collect(features: RDD[(Int, Array[Double])]): Set[(Int, Seq[Double])] = { + features.mapValues(_.toSeq).collect().toSet + } + try { + model.save(sc, path) + val newModel = MatrixFactorizationModel.load(sc, path) + assert(newModel.rank === rank) + assert(collect(newModel.userFeatures) === collect(userFeatures)) + assert(collect(newModel.productFeatures) === collect(prodFeatures)) + } finally { + Utils.deleteRecursively(tempDir) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala new file mode 100644 index 000000000000..8e12340bbd9d --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import org.scalatest.{Matchers, FunSuite} + +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils + +class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { + + private def round(d: Double) = { + Math.round(d * 100).toDouble / 100 + } + + private def generateIsotonicInput(labels: Seq[Double]): Seq[(Double, Double, Double)] = { + Seq.tabulate(labels.size)(i => (labels(i), i.toDouble, 1d)) + } + + private def generateIsotonicInput( + labels: Seq[Double], + weights: Seq[Double]): Seq[(Double, Double, Double)] = { + Seq.tabulate(labels.size)(i => (labels(i), i.toDouble, weights(i))) + } + + private def runIsotonicRegression( + labels: Seq[Double], + weights: Seq[Double], + isotonic: Boolean): IsotonicRegressionModel = { + val trainRDD = sc.parallelize(generateIsotonicInput(labels, weights)).cache() + new IsotonicRegression().setIsotonic(isotonic).run(trainRDD) + } + + private def runIsotonicRegression( + labels: Seq[Double], + isotonic: Boolean): IsotonicRegressionModel = { + runIsotonicRegression(labels, Array.fill(labels.size)(1d), isotonic) + } + + test("increasing isotonic regression") { + /* + The following result could be re-produced with sklearn. + + > from sklearn.isotonic import IsotonicRegression + > x = range(9) + > y = [1, 2, 3, 1, 6, 17, 16, 17, 18] + > ir = IsotonicRegression(x, y) + > print ir.predict(x) + + array([ 1. , 2. , 2. , 2. , 6. , 16.5, 16.5, 17. , 18. ]) + */ + val model = runIsotonicRegression(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18), true) + + assert(Array.tabulate(9)(x => model.predict(x)) === Array(1, 2, 2, 2, 6, 16.5, 16.5, 17, 18)) + + assert(model.boundaries === Array(0, 1, 3, 4, 5, 6, 7, 8)) + assert(model.predictions === Array(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0)) + assert(model.isotonic) + } + + test("model save/load") { + val boundaries = Array(0.0, 1.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0) + val predictions = Array(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0) + val model = new IsotonicRegressionModel(boundaries, predictions, true) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = IsotonicRegressionModel.load(sc, path) + assert(model.boundaries === sameModel.boundaries) + assert(model.predictions === sameModel.predictions) + assert(model.isotonic === model.isotonic) + } finally { + Utils.deleteRecursively(tempDir) + } + } + + test("isotonic regression with size 0") { + val model = runIsotonicRegression(Seq(), true) + + assert(model.predictions === Array()) + } + + test("isotonic regression with size 1") { + val model = runIsotonicRegression(Seq(1), true) + + assert(model.predictions === Array(1.0)) + } + + test("isotonic regression strictly increasing sequence") { + val model = runIsotonicRegression(Seq(1, 2, 3, 4, 5), true) + + assert(model.predictions === Array(1, 2, 3, 4, 5)) + } + + test("isotonic regression strictly decreasing sequence") { + val model = runIsotonicRegression(Seq(5, 4, 3, 2, 1), true) + + assert(model.boundaries === Array(0, 4)) + assert(model.predictions === Array(3, 3)) + } + + test("isotonic regression with last element violating monotonicity") { + val model = runIsotonicRegression(Seq(1, 2, 3, 4, 2), true) + + assert(model.boundaries === Array(0, 1, 2, 4)) + assert(model.predictions === Array(1, 2, 3, 3)) + } + + test("isotonic regression with first element violating monotonicity") { + val model = runIsotonicRegression(Seq(4, 2, 3, 4, 5), true) + + assert(model.boundaries === Array(0, 2, 3, 4)) + assert(model.predictions === Array(3, 3, 4, 5)) + } + + test("isotonic regression with negative labels") { + val model = runIsotonicRegression(Seq(-1, -2, 0, 1, -1), true) + + assert(model.boundaries === Array(0, 1, 2, 4)) + assert(model.predictions === Array(-1.5, -1.5, 0, 0)) + } + + test("isotonic regression with unordered input") { + val trainRDD = sc.parallelize(generateIsotonicInput(Seq(1, 2, 3, 4, 5)).reverse, 2).cache() + + val model = new IsotonicRegression().run(trainRDD) + assert(model.predictions === Array(1, 2, 3, 4, 5)) + } + + test("weighted isotonic regression") { + val model = runIsotonicRegression(Seq(1, 2, 3, 4, 2), Seq(1, 1, 1, 1, 2), true) + + assert(model.boundaries === Array(0, 1, 2, 4)) + assert(model.predictions === Array(1, 2, 2.75, 2.75)) + } + + test("weighted isotonic regression with weights lower than 1") { + val model = runIsotonicRegression(Seq(1, 2, 3, 2, 1), Seq(1, 1, 1, 0.1, 0.1), true) + + assert(model.boundaries === Array(0, 1, 2, 4)) + assert(model.predictions.map(round) === Array(1, 2, 3.3/1.2, 3.3/1.2)) + } + + test("weighted isotonic regression with negative weights") { + val model = runIsotonicRegression(Seq(1, 2, 3, 2, 1), Seq(-1, 1, -3, 1, -5), true) + + assert(model.boundaries === Array(0.0, 1.0, 4.0)) + assert(model.predictions === Array(1.0, 10.0/6, 10.0/6)) + } + + test("weighted isotonic regression with zero weights") { + val model = runIsotonicRegression(Seq[Double](1, 2, 3, 2, 1), Seq[Double](0, 0, 0, 1, 0), true) + + assert(model.boundaries === Array(0.0, 1.0, 4.0)) + assert(model.predictions === Array(1, 2, 2)) + } + + test("isotonic regression prediction") { + val model = runIsotonicRegression(Seq(1, 2, 7, 1, 2), true) + + assert(model.predict(-2) === 1) + assert(model.predict(-1) === 1) + assert(model.predict(0.5) === 1.5) + assert(model.predict(0.75) === 1.75) + assert(model.predict(1) === 2) + assert(model.predict(2) === 10d/3) + assert(model.predict(9) === 10d/3) + } + + test("isotonic regression prediction with duplicate features") { + val trainRDD = sc.parallelize( + Seq[(Double, Double, Double)]( + (2, 1, 1), (1, 1, 1), (4, 2, 1), (2, 2, 1), (6, 3, 1), (5, 3, 1)), 2).cache() + val model = new IsotonicRegression().run(trainRDD) + + assert(model.predict(0) === 1) + assert(model.predict(1.5) === 2) + assert(model.predict(2.5) === 4.5) + assert(model.predict(4) === 6) + } + + test("antitonic regression prediction with duplicate features") { + val trainRDD = sc.parallelize( + Seq[(Double, Double, Double)]( + (5, 1, 1), (6, 1, 1), (2, 2, 1), (4, 2, 1), (1, 3, 1), (2, 3, 1)), 2).cache() + val model = new IsotonicRegression().setIsotonic(false).run(trainRDD) + + assert(model.predict(0) === 6) + assert(model.predict(1.5) === 4.5) + assert(model.predict(2.5) === 2) + assert(model.predict(4) === 1) + } + + test("isotonic regression RDD prediction") { + val model = runIsotonicRegression(Seq(1, 2, 7, 1, 2), true) + + val testRDD = sc.parallelize(List(-2.0, -1.0, 0.5, 0.75, 1.0, 2.0, 9.0), 2).cache() + val predictions = testRDD.map(x => (x, model.predict(x))).collect().sortBy(_._1).map(_._2) + assert(predictions === Array(1, 1, 1.5, 1.75, 2, 10.0/3, 10.0/3)) + } + + test("antitonic regression prediction") { + val model = runIsotonicRegression(Seq(7, 5, 3, 5, 1), false) + + assert(model.predict(-2) === 7) + assert(model.predict(-1) === 7) + assert(model.predict(0.5) === 6) + assert(model.predict(0.75) === 5.5) + assert(model.predict(1) === 5) + assert(model.predict(2) === 4) + assert(model.predict(9) === 1) + } + + test("model construction") { + val model = new IsotonicRegressionModel(Array(0.0, 1.0), Array(1.0, 2.0), isotonic = true) + assert(model.predict(-0.5) === 1.0) + assert(model.predict(0.0) === 1.0) + assert(model.predict(0.5) ~== 1.5 absTol 1e-14) + assert(model.predict(1.0) === 2.0) + assert(model.predict(1.5) === 2.0) + + intercept[IllegalArgumentException] { + // different array sizes. + new IsotonicRegressionModel(Array(0.0, 1.0), Array(1.0), isotonic = true) + } + + intercept[IllegalArgumentException] { + // unordered boundaries + new IsotonicRegressionModel(Array(1.0, 0.0), Array(1.0, 2.0), isotonic = true) + } + + intercept[IllegalArgumentException] { + // unordered predictions (isotonic) + new IsotonicRegressionModel(Array(0.0, 1.0), Array(2.0, 1.0), isotonic = true) + } + + intercept[IllegalArgumentException] { + // unordered predictions (antitonic) + new IsotonicRegressionModel(Array(0.0, 1.0), Array(1.0, 2.0), isotonic = false) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index 2668dcc14a84..c9f5dc069ef2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -24,6 +24,13 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.util.Utils + +private object LassoSuite { + + /** 3 features */ + val model = new LassoModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) +} class LassoSuite extends FunSuite with MLlibTestSparkContext { @@ -115,6 +122,23 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext { // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + + test("model save/load") { + val model = LassoSuite.model + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = LassoModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + } finally { + Utils.deleteRecursively(tempDir) + } + } } class LassoClusterSuite extends FunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index 864622a9296a..3781931c2f81 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -24,6 +24,13 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.util.Utils + +private object LinearRegressionSuite { + + /** 3 features */ + val model = new LinearRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) +} class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { @@ -124,6 +131,23 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { validatePrediction( sparseValidationData.map(row => model.predict(row.features)), sparseValidationData) } + + test("model save/load") { + val model = LinearRegressionSuite.model + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = LinearRegressionModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + } finally { + Utils.deleteRecursively(tempDir) + } + } } class LinearRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index 18d3bf5ea4ec..d6c93cc0e49c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -25,10 +25,17 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.util.Utils + +private object RidgeRegressionSuite { + + /** 3 features */ + val model = new RidgeRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) +} class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { - def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = { + def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]): Double = { predictions.zip(input).map { case (prediction, expected) => (prediction - expected.label) * (prediction - expected.label) }.reduceLeft(_ + _) / predictions.size @@ -75,6 +82,23 @@ class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(ridgeErr < linearErr, "ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")") } + + test("model save/load") { + val model = RidgeRegressionSuite.model + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = RidgeRegressionModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + } finally { + Utils.deleteRecursively(tempDir) + } + } } class RidgeRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index 70b43ddb7daf..26604dbe6c1e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.streaming.TestSuiteBase class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { // use longer wait time to ensure job completion - override def maxWaitTimeMillis = 20000 + override def maxWaitTimeMillis: Int = 20000 // Assert that two values are equal within tolerance epsilon def assertEqual(v1: Double, v2: Double, epsilon: Double) { @@ -139,4 +139,32 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints) assert(errors.forall(x => x <= 0.1)) } + + // Test training combined with prediction + test("training and prediction") { + // create model initialized with zero weights + val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.dense(0.0, 0.0)) + .setStepSize(0.2) + .setNumIterations(25) + + // generate sequence of simulated data for testing + val numBatches = 10 + val nPoints = 100 + val testInput = (0 until numBatches).map { i => + LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), nPoints, 42 * (i + 1)) + } + + // train and predict + val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + }) + + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + + // assert that prediction error improves, ensuring that the updated model is being used + val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList + assert((error.head - error.last) > 2) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala new file mode 100644 index 000000000000..16ecae23dd9d --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import org.scalatest.FunSuite + +import org.apache.commons.math3.distribution.NormalDistribution + +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class KernelDensitySuite extends FunSuite with MLlibTestSparkContext { + test("kernel density single sample") { + val rdd = sc.parallelize(Array(5.0)) + val evaluationPoints = Array(5.0, 6.0) + val densities = KernelDensity.estimate(rdd, 3.0, evaluationPoints) + val normal = new NormalDistribution(5.0, 3.0) + val acceptableErr = 1e-6 + assert(densities(0) - normal.density(5.0) < acceptableErr) + assert(densities(0) - normal.density(6.0) < acceptableErr) + } + + test("kernel density multiple samples") { + val rdd = sc.parallelize(Array(5.0, 10.0)) + val evaluationPoints = Array(5.0, 6.0) + val densities = KernelDensity.estimate(rdd, 3.0, evaluationPoints) + val normal1 = new NormalDistribution(5.0, 3.0) + val normal2 = new NormalDistribution(10.0, 3.0) + val acceptableErr = 1e-6 + assert(densities(0) - (normal1.density(5.0) + normal2.density(5.0)) / 2 < acceptableErr) + assert(densities(0) - (normal1.density(6.0) + normal2.density(6.0)) / 2 < acceptableErr) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 9347eaf9221a..249b8eae19b1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -29,11 +29,17 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy} import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} -import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node} +import org.apache.spark.mllib.tree.model._ import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils + class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { + ///////////////////////////////////////////////////////////////////////////// + // Tests examining individual elements of training + ///////////////////////////////////////////////////////////////////////////// + test("Binary classification with continuous features: split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) @@ -188,7 +194,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { assert(splits.length === 2) assert(bins.length === 2) assert(splits(0).length === 3) - assert(bins(0).length === 6) + assert(bins(0).length === 0) // Expecting 2^2 - 1 = 3 bins/splits assert(splits(0)(0).feature === 0) @@ -226,41 +232,6 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { assert(splits(1)(2).categories.contains(0.0)) assert(splits(1)(2).categories.contains(1.0)) - // Check bins. - - assert(bins(0)(0).category === Double.MinValue) - assert(bins(0)(0).lowSplit.categories.length === 0) - assert(bins(0)(0).highSplit.categories.length === 1) - assert(bins(0)(0).highSplit.categories.contains(0.0)) - assert(bins(1)(0).category === Double.MinValue) - assert(bins(1)(0).lowSplit.categories.length === 0) - assert(bins(1)(0).highSplit.categories.length === 1) - assert(bins(1)(0).highSplit.categories.contains(0.0)) - - assert(bins(0)(1).category === Double.MinValue) - assert(bins(0)(1).lowSplit.categories.length === 1) - assert(bins(0)(1).lowSplit.categories.contains(0.0)) - assert(bins(0)(1).highSplit.categories.length === 1) - assert(bins(0)(1).highSplit.categories.contains(1.0)) - assert(bins(1)(1).category === Double.MinValue) - assert(bins(1)(1).lowSplit.categories.length === 1) - assert(bins(1)(1).lowSplit.categories.contains(0.0)) - assert(bins(1)(1).highSplit.categories.length === 1) - assert(bins(1)(1).highSplit.categories.contains(1.0)) - - assert(bins(0)(2).category === Double.MinValue) - assert(bins(0)(2).lowSplit.categories.length === 1) - assert(bins(0)(2).lowSplit.categories.contains(1.0)) - assert(bins(0)(2).highSplit.categories.length === 2) - assert(bins(0)(2).highSplit.categories.contains(1.0)) - assert(bins(0)(2).highSplit.categories.contains(0.0)) - assert(bins(1)(2).category === Double.MinValue) - assert(bins(1)(2).lowSplit.categories.length === 1) - assert(bins(1)(2).lowSplit.categories.contains(1.0)) - assert(bins(1)(2).highSplit.categories.length === 2) - assert(bins(1)(2).highSplit.categories.contains(1.0)) - assert(bins(1)(2).highSplit.categories.contains(0.0)) - } test("Multiclass classification with ordered categorical features: split and bin calculations") { @@ -287,6 +258,165 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { assert(bins(0).length === 0) } + test("Avoid aggregation on the last level") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1, + numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) + + val topNode = Node.emptyNode(nodeIndex = 1) + assert(topNode.predict.predict === Double.MinValue) + assert(topNode.impurity === -1.0) + assert(topNode.isLeaf === false) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + + // don't enqueue leaf nodes into node queue + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.predict.predict !== Double.MinValue) + assert(topNode.impurity !== -1.0) + + // set impurity and predict for child nodes + assert(topNode.leftNode.get.predict.predict === 0.0) + assert(topNode.rightNode.get.predict.predict === 1.0) + assert(topNode.leftNode.get.impurity === 0.0) + assert(topNode.rightNode.get.impurity === 0.0) + } + + test("Avoid aggregation if impurity is 0.0") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) + + val topNode = Node.emptyNode(nodeIndex = 1) + assert(topNode.predict.predict === Double.MinValue) + assert(topNode.impurity === -1.0) + assert(topNode.isLeaf === false) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + + // don't enqueue a node into node queue if its impurity is 0.0 + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.predict.predict !== Double.MinValue) + assert(topNode.impurity !== -1.0) + + // set impurity and predict for child nodes + assert(topNode.leftNode.get.predict.predict === 0.0) + assert(topNode.rightNode.get.predict.predict === 1.0) + assert(topNode.leftNode.get.impurity === 0.0) + assert(topNode.rightNode.get.impurity === 0.0) + } + + test("Second level node building with vs. without groups") { + val arr = DecisionTreeSuite.generateOrderedLabeledPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Entropy, 3, 2, 100) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + + // Train a 1-node model + val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1, + numClasses = 2, maxBins = 100) + val modelOneNode = DecisionTree.train(rdd, strategyOneNode) + val rootNode1 = modelOneNode.topNode.deepCopy() + val rootNode2 = modelOneNode.topNode.deepCopy() + assert(rootNode1.leftNode.nonEmpty) + assert(rootNode1.rightNode.nonEmpty) + + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) + + // Single group second level tree construction. + val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get))) + val treeToNodeToIndexInfo = Map((0, Map( + (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)), + (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None))))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + val children1 = new Array[Node](2) + children1(0) = rootNode1.leftNode.get + children1(1) = rootNode1.rightNode.get + + // Train one second-level node at a time. + val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get))) + val treeToNodeToIndexInfoA = Map((0, Map( + (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) + nodeQueue.clear() + DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), + nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue) + val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get))) + val treeToNodeToIndexInfoB = Map((0, Map( + (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) + nodeQueue.clear() + DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), + nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue) + val children2 = new Array[Node](2) + children2(0) = rootNode2.leftNode.get + children2(1) = rootNode2.rightNode.get + + // Verify whether the splits obtained using single group and multiple group level + // construction strategies are the same. + for (i <- 0 until 2) { + assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0) + assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0) + assert(children1(i).split === children2(i).split) + assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty) + val stats1 = children1(i).stats.get + val stats2 = children2(i).stats.get + assert(stats1.gain === stats2.gain) + assert(stats1.impurity === stats2.impurity) + assert(stats1.leftImpurity === stats2.leftImpurity) + assert(stats1.rightImpurity === stats2.rightImpurity) + assert(children1(i).predict.predict === children2(i).predict.predict) + } + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// test("Binary classification stump with ordered categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() @@ -471,76 +601,6 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { assert(rootNode.predict.predict === 1) } - test("Second level node building with vs. without groups") { - val arr = DecisionTreeSuite.generateOrderedLabeledPoints() - assert(arr.length === 1000) - val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(splits.length === 2) - assert(splits(0).length === 99) - assert(bins.length === 2) - assert(bins(0).length === 100) - - // Train a 1-node model - val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1, - numClasses = 2, maxBins = 100) - val modelOneNode = DecisionTree.train(rdd, strategyOneNode) - val rootNode1 = modelOneNode.topNode.deepCopy() - val rootNode2 = modelOneNode.topNode.deepCopy() - assert(rootNode1.leftNode.nonEmpty) - assert(rootNode1.rightNode.nonEmpty) - - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) - - // Single group second level tree construction. - val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get))) - val treeToNodeToIndexInfo = Map((0, Map( - (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)), - (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None))))) - val nodeQueue = new mutable.Queue[(Int, Node)]() - DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1), - nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) - val children1 = new Array[Node](2) - children1(0) = rootNode1.leftNode.get - children1(1) = rootNode1.rightNode.get - - // Train one second-level node at a time. - val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get))) - val treeToNodeToIndexInfoA = Map((0, Map( - (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) - nodeQueue.clear() - DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), - nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue) - val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get))) - val treeToNodeToIndexInfoB = Map((0, Map( - (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) - nodeQueue.clear() - DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), - nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue) - val children2 = new Array[Node](2) - children2(0) = rootNode2.leftNode.get - children2(1) = rootNode2.rightNode.get - - // Verify whether the splits obtained using single group and multiple group level - // construction strategies are the same. - for (i <- 0 until 2) { - assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0) - assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0) - assert(children1(i).split === children2(i).split) - assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty) - val stats1 = children1(i).stats.get - val stats2 = children2(i).stats.get - assert(stats1.gain === stats2.gain) - assert(stats1.impurity === stats2.impurity) - assert(stats1.leftImpurity === stats2.leftImpurity) - assert(stats1.rightImpurity === stats2.rightImpurity) - assert(children1(i).predict.predict === children2(i).predict.predict) - } - } - test("Multiclass classification stump with 3-ary (unordered) categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() val rdd = sc.parallelize(arr) @@ -561,11 +621,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { } test("Binary classification stump with 1 continuous feature, to check off-by-1 error") { - val arr = new Array[LabeledPoint](4) - arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0)) - arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0)) - arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0)) - arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0)) + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(3.0))) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClasses = 2) @@ -577,11 +637,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { } test("Binary classification stump with 2 continuous features") { - val arr = new Array[LabeledPoint](4) - arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) - arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) - arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) - arr(3) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0)))) + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0))))) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, @@ -701,11 +761,10 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { } test("split must satisfy min instances per node requirements") { - val arr = new Array[LabeledPoint](3) - arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) - arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) - arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))) - + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, numClasses = 2, minInstancesPerNode = 2) @@ -728,11 +787,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { test("do not choose split that does not satisfy min instance per node requirements") { // if a split does not satisfy min instances per node requirements, // this split is invalid, even though the information gain of split is large. - val arr = new Array[LabeledPoint](4) - arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0)) - arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0, 1.0)) - arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)) - arr(3) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)) + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0))) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, @@ -748,10 +807,10 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { } test("split must satisfy min info gain requirements") { - val arr = new Array[LabeledPoint](3) - arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) - arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) - arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))) + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) val input = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, @@ -772,94 +831,35 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { assert(gain == InformationGainStats.invalidInformationGainStats) } - test("Avoid aggregation on the last level") { - val arr = new Array[LabeledPoint](4) - arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)) - arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)) - arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)) - arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)) - val input = sc.parallelize(arr) - - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1, - numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) - - val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// - val topNode = Node.emptyNode(nodeIndex = 1) - assert(topNode.predict.predict === Double.MinValue) - assert(topNode.impurity === -1.0) - assert(topNode.isLeaf === false) - - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) - val nodeQueue = new mutable.Queue[(Int, Node)]() - DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), - nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) - - // don't enqueue leaf nodes into node queue - assert(nodeQueue.isEmpty) - - // set impurity and predict for topNode - assert(topNode.predict.predict !== Double.MinValue) - assert(topNode.impurity !== -1.0) - - // set impurity and predict for child nodes - assert(topNode.leftNode.get.predict.predict === 0.0) - assert(topNode.rightNode.get.predict.predict === 1.0) - assert(topNode.leftNode.get.impurity === 0.0) - assert(topNode.rightNode.get.impurity === 0.0) + test("Node.subtreeIterator") { + val model = DecisionTreeSuite.createModel(Classification) + val nodeIds = model.topNode.subtreeIterator.map(_.id).toArray.sorted + assert(nodeIds === DecisionTreeSuite.createdModelNodeIds) } - test("Avoid aggregation if impurity is 0.0") { - val arr = new Array[LabeledPoint](4) - arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)) - arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)) - arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)) - arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)) - val input = sc.parallelize(arr) - - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, - numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) - - val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) - - val topNode = Node.emptyNode(nodeIndex = 1) - assert(topNode.predict.predict === Double.MinValue) - assert(topNode.impurity === -1.0) - assert(topNode.isLeaf === false) - - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) - val nodeQueue = new mutable.Queue[(Int, Node)]() - DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), - nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) - - // don't enqueue a node into node queue if its impurity is 0.0 - assert(nodeQueue.isEmpty) - - // set impurity and predict for topNode - assert(topNode.predict.predict !== Double.MinValue) - assert(topNode.impurity !== -1.0) - - // set impurity and predict for child nodes - assert(topNode.leftNode.get.predict.predict === 0.0) - assert(topNode.rightNode.get.predict.predict === 1.0) - assert(topNode.leftNode.get.impurity === 0.0) - assert(topNode.rightNode.get.impurity === 0.0) + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + Array(Classification, Regression).foreach { algo => + val model = DecisionTreeSuite.createModel(algo) + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = DecisionTreeModel.load(sc, path) + DecisionTreeSuite.checkEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } } } -object DecisionTreeSuite { +object DecisionTreeSuite extends FunSuite { def validateClassifier( model: DecisionTreeModel, @@ -979,4 +979,96 @@ object DecisionTreeSuite { arr } + /** Create a leaf node with the given node ID */ + private def createLeafNode(id: Int): Node = { + Node(nodeIndex = id, new Predict(0.0, 1.0), impurity = 0.5, isLeaf = true) + } + + /** + * Create an internal node with the given node ID and feature type. + * Note: This does NOT set the child nodes. + */ + private def createInternalNode(id: Int, featureType: FeatureType): Node = { + val node = Node(nodeIndex = id, new Predict(0.0, 1.0), impurity = 0.5, isLeaf = false) + featureType match { + case Continuous => + node.split = Some(new Split(feature = 0, threshold = 0.5, Continuous, + categories = List.empty[Double])) + case Categorical => + node.split = Some(new Split(feature = 1, threshold = 0.0, Categorical, + categories = List(0.0, 1.0))) + } + // TODO: The information gain stats should be consistent with the same info stored in children. + node.stats = Some(new InformationGainStats(gain = 0.1, impurity = 0.2, + leftImpurity = 0.3, rightImpurity = 0.4, new Predict(1.0, 0.4), new Predict(0.0, 0.6))) + node + } + + /** + * Create a tree model. This is deterministic and contains a variety of node and feature types. + * TODO: Update this to be a correct tree (with matching probabilities, impurities, etc.) + */ + private[mllib] def createModel(algo: Algo): DecisionTreeModel = { + val topNode = createInternalNode(id = 1, Continuous) + val (node2, node3) = (createLeafNode(id = 2), createInternalNode(id = 3, Categorical)) + val (node6, node7) = (createLeafNode(id = 6), createLeafNode(id = 7)) + topNode.leftNode = Some(node2) + topNode.rightNode = Some(node3) + node3.leftNode = Some(node6) + node3.rightNode = Some(node7) + new DecisionTreeModel(topNode, algo) + } + + /** Sorted Node IDs matching the model returned by [[createModel()]] */ + private val createdModelNodeIds = Array(1, 2, 3, 6, 7) + + /** + * Check if the two trees are exactly the same. + * Note: I hesitate to override Node.equals since it could cause problems if users + * make mistakes such as creating loops of Nodes. + * If the trees are not equal, this prints the two trees and throws an exception. + */ + private[mllib] def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = { + try { + assert(a.algo === b.algo) + checkEqual(a.topNode, b.topNode) + } catch { + case ex: Exception => + throw new AssertionError("checkEqual failed since the two trees were not identical.\n" + + "TREE A:\n" + a.toDebugString + "\n" + + "TREE B:\n" + b.toDebugString + "\n", ex) + } + } + + /** + * Return true iff the two nodes and their descendents are exactly the same. + * Note: I hesitate to override Node.equals since it could cause problems if users + * make mistakes such as creating loops of Nodes. + */ + private def checkEqual(a: Node, b: Node): Unit = { + assert(a.id === b.id) + assert(a.predict === b.predict) + assert(a.impurity === b.impurity) + assert(a.isLeaf === b.isLeaf) + assert(a.split === b.split) + (a.stats, b.stats) match { + // TODO: Check other fields besides the infomation gain. + case (Some(aStats), Some(bStats)) => assert(aStats.gain === bStats.gain) + case (None, None) => + case _ => throw new AssertionError( + s"Only one instance has stats defined. (a.stats: ${a.stats}, b.stats: ${b.stats})") + } + (a.leftNode, b.leftNode) match { + case (Some(aNode), Some(bNode)) => checkEqual(aNode, bNode) + case (None, None) => + case _ => throw new AssertionError("Only one instance has leftNode defined. " + + s"(a.leftNode: ${a.leftNode}, b.leftNode: ${b.leftNode})") + } + (a.rightNode, b.rightNode) match { + case (Some(aNode: Node), Some(bNode: Node)) => checkEqual(aNode, bNode) + case (None, None) => + case _ => throw new AssertionError("Only one instance has rightNode defined. " + + s"(a.rightNode: ${a.rightNode}, b.rightNode: ${b.rightNode})") + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index 3aa97e544680..55b0bac7d49f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -24,8 +24,10 @@ import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} import org.apache.spark.mllib.tree.impurity.Variance import org.apache.spark.mllib.tree.loss.{AbsoluteError, SquaredError, LogLoss} - +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils + /** * Test suite for [[GradientBoostedTrees]]. @@ -35,32 +37,30 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { test("Regression with continuous features: SquaredError") { GradientBoostedTreesSuite.testCombinations.foreach { case (numIterations, learningRate, subsamplingRate) => - GradientBoostedTreesSuite.randomSeeds.foreach { randomSeed => - val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) - - val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, - categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate) - val boostingStrategy = - new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate) - - val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) - - assert(gbt.trees.size === numIterations) - try { - EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06) - } catch { - case e: java.lang.AssertionError => - println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + - s" subsamplingRate=$subsamplingRate") - throw e - } + val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) + + val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate) + val boostingStrategy = + new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate) - val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - val dt = DecisionTree.train(remappedInput, treeStrategy) + val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) - // Make sure trees are the same. - assert(gbt.trees.head.toString == dt.toString) + assert(gbt.trees.size === numIterations) + try { + EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06) + } catch { + case e: java.lang.AssertionError => + println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + + s" subsamplingRate=$subsamplingRate") + throw e } + + val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + val dt = DecisionTree.train(remappedInput, treeStrategy) + + // Make sure trees are the same. + assert(gbt.trees.head.toString == dt.toString) } } @@ -128,14 +128,90 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { } } + test("SPARK-5496: BoostingStrategy.defaultParams should recognize Classification") { + for (algo <- Seq("classification", "Classification", "regression", "Regression")) { + BoostingStrategy.defaultParams(algo) + } + } + + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val trees = Range(0, 3).map(_ => DecisionTreeSuite.createModel(Regression)).toArray + val treeWeights = Array(0.1, 0.3, 1.1) + + Array(Classification, Regression).foreach { algo => + val model = new GradientBoostedTreesModel(algo, trees, treeWeights) + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = GradientBoostedTreesModel.load(sc, path) + assert(model.algo == sameModel.algo) + model.trees.zip(sameModel.trees).foreach { case (treeA, treeB) => + DecisionTreeSuite.checkEqual(treeA, treeB) + } + assert(model.treeWeights === sameModel.treeWeights) + } finally { + Utils.deleteRecursively(tempDir) + } + } + } + + test("runWithValidation stops early and performs better on a validation dataset") { + // Set numIterations large enough so that it stops early. + val numIterations = 20 + val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2) + val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2) + + val algos = Array(Regression, Regression, Classification) + val losses = Array(SquaredError, AbsoluteError, LogLoss) + (algos zip losses) map { + case (algo, loss) => { + val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty) + val boostingStrategy = + new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) + val gbtValidate = new GradientBoostedTrees(boostingStrategy) + .runWithValidation(trainRdd, validateRdd) + val numTrees = gbtValidate.numTrees + assert(numTrees !== numIterations) + + // Test that it performs better on the validation dataset. + val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd) + val (errorWithoutValidation, errorWithValidation) = { + if (algo == Classification) { + val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) + (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd)) + } else { + (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd)) + } + } + assert(errorWithValidation <= errorWithoutValidation) + + // Test that results from evaluateEachIteration comply with runWithValidation. + // Note that convergenceTol is set to 0.0 + val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss) + assert(evaluationArray.length === numIterations) + assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1)) + var i = 1 + while (i < numTrees) { + assert(evaluationArray(i) <= evaluationArray(i - 1)) + i += 1 + } + } + } + } + } -object GradientBoostedTreesSuite { +private object GradientBoostedTreesSuite { // Combinations for estimators, learning rates and subsamplingRate val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75)) - val randomSeeds = Array(681283, 4398) - val data = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100) + val trainData = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120) + val validateData = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala new file mode 100644 index 000000000000..92b498580af0 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator} +import org.apache.spark.mllib.util.MLlibTestSparkContext + +/** + * Test suites for [[GiniAggregator]] and [[EntropyAggregator]]. + */ +class ImpuritySuite extends FunSuite with MLlibTestSparkContext { + test("Gini impurity does not support negative labels") { + val gini = new GiniAggregator(2) + intercept[IllegalArgumentException] { + gini.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0) + } + } + + test("Entropy does not support negative labels") { + val entropy = new EntropyAggregator(2) + intercept[IllegalArgumentException] { + entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index f7f0f20c6c12..ee3bc9848686 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -27,8 +27,10 @@ import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata import org.apache.spark.mllib.tree.impurity.{Gini, Variance} -import org.apache.spark.mllib.tree.model.Node +import org.apache.spark.mllib.tree.model.{Node, RandomForestModel} import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils + /** * Test suite for [[RandomForest]]. @@ -196,6 +198,42 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext { featureSubsetStrategy = "sqrt", seed = 12345) EnsembleTestHelper.validateClassifier(model, arr, 1.0) } -} + test("subsampling rate in RandomForest"){ + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(5, 20) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClasses = 2, categoricalFeaturesInfo = Map.empty[Int, Int], + useNodeIdCache = true) + val rf1 = RandomForest.trainClassifier(rdd, strategy, numTrees = 3, + featureSubsetStrategy = "auto", seed = 123) + strategy.subsamplingRate = 0.5 + val rf2 = RandomForest.trainClassifier(rdd, strategy, numTrees = 3, + featureSubsetStrategy = "auto", seed = 123) + assert(rf1.toDebugString != rf2.toDebugString) + } + + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + Array(Classification, Regression).foreach { algo => + val trees = Range(0, 3).map(_ => DecisionTreeSuite.createModel(algo)).toArray + val model = new RandomForestModel(algo, trees) + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = RandomForestModel.load(sc, path) + assert(model.algo == sameModel.algo) + model.trees.zip(sameModel.trees).foreach { case (treeA, treeB) => + DecisionTreeSuite.checkEqual(treeA, treeB) + } + } finally { + Utils.deleteRecursively(tempDir) + } + } + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala index e957fa5d25f4..352193a67860 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala @@ -95,16 +95,16 @@ object TestingUtils { /** * Comparison using absolute tolerance. */ - def absTol(eps: Double): CompareDoubleRightSide = CompareDoubleRightSide(AbsoluteErrorComparison, - x, eps, ABS_TOL_MSG) + def absTol(eps: Double): CompareDoubleRightSide = + CompareDoubleRightSide(AbsoluteErrorComparison, x, eps, ABS_TOL_MSG) /** * Comparison using relative tolerance. */ - def relTol(eps: Double): CompareDoubleRightSide = CompareDoubleRightSide(RelativeErrorComparison, - x, eps, REL_TOL_MSG) + def relTol(eps: Double): CompareDoubleRightSide = + CompareDoubleRightSide(RelativeErrorComparison, x, eps, REL_TOL_MSG) - override def toString = x.toString + override def toString: String = x.toString } case class CompareVectorRightSide( @@ -166,7 +166,7 @@ object TestingUtils { x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 relTol eps) }, x, eps, REL_TOL_MSG) - override def toString = x.toString + override def toString: String = x.toString } case class CompareMatrixRightSide( @@ -229,7 +229,7 @@ object TestingUtils { x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 relTol eps) }, x, eps, REL_TOL_MSG) - override def toString = x.toString + override def toString: String = x.toString } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala index b0ecb33c2848..59e6c778806f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala @@ -88,16 +88,20 @@ class TestingUtilsSuite extends FunSuite { assert(!(17.8 ~= 17.59 absTol 0.2)) // Comparisons of numbers very close to zero, and both side of zeros - assert(Double.MinPositiveValue ~== 4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) - assert(Double.MinPositiveValue !~== 6 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) - - assert(-Double.MinPositiveValue ~== 3 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) - assert(Double.MinPositiveValue !~== -4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + assert( + Double.MinPositiveValue ~== 4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + assert( + Double.MinPositiveValue !~== 6 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + + assert( + -Double.MinPositiveValue ~== 3 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + assert( + Double.MinPositiveValue !~== -4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) } test("Comparing vectors using relative error.") { - //Comparisons of two dense vectors + // Comparisons of two dense vectors assert(Vectors.dense(Array(3.1, 3.5)) ~== Vectors.dense(Array(3.130, 3.534)) relTol 0.01) assert(Vectors.dense(Array(3.1, 3.5)) !~== Vectors.dense(Array(3.135, 3.534)) relTol 0.01) assert(Vectors.dense(Array(3.1, 3.5)) ~= Vectors.dense(Array(3.130, 3.534)) relTol 0.01) @@ -130,7 +134,7 @@ class TestingUtilsSuite extends FunSuite { test("Comparing vectors using absolute error.") { - //Comparisons of two dense vectors + // Comparisons of two dense vectors assert(Vectors.dense(Array(3.1, 3.5, 0.0)) ~== Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6) diff --git a/network/common/pom.xml b/network/common/pom.xml index 245a96b8c403..22c738bde6d4 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -21,8 +21,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT ../../pom.xml @@ -48,10 +48,15 @@ slf4j-api provided
    + com.google.guava guava - provided + compile @@ -75,6 +80,11 @@ mockito-all test + + org.slf4j + slf4j-log4j12 + test + @@ -87,11 +97,6 @@ maven-jar-plugin 2.2 - - - test-jar - - test-jar-on-test-compile test-compile diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index 5bc6e5a2418a..3fe69b1bd885 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -22,6 +22,7 @@ import com.google.common.collect.Lists; import io.netty.channel.Channel; import io.netty.channel.socket.SocketChannel; +import io.netty.handler.timeout.IdleStateHandler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,7 +36,6 @@ import org.apache.spark.network.server.TransportChannelHandler; import org.apache.spark.network.server.TransportRequestHandler; import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; @@ -107,6 +107,7 @@ public TransportChannelHandler initializePipeline(SocketChannel channel) { .addLast("encoder", encoder) .addLast("frameDecoder", NettyUtils.createFrameDecoder()) .addLast("decoder", decoder) + .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000)) // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this // would require more logic to guarantee if this were not part of the same event loop. .addLast("handler", channelHandler); @@ -127,7 +128,8 @@ private TransportChannelHandler createChannelHandler(Channel channel) { TransportClient client = new TransportClient(channel, responseHandler); TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, rpcHandler); - return new TransportChannelHandler(client, responseHandler, requestHandler); + return new TransportChannelHandler(client, responseHandler, requestHandler, + conf.connectionTimeoutMs()); } public TransportConf getConf() { return conf; } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 2044afb0d85d..94fc21af5e60 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -20,8 +20,8 @@ import java.io.IOException; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; -import com.google.common.annotations.VisibleForTesting; import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -50,13 +50,18 @@ public class TransportResponseHandler extends MessageHandler { private final Map outstandingRpcs; + /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */ + private final AtomicLong timeOfLastRequestNs; + public TransportResponseHandler(Channel channel) { this.channel = channel; this.outstandingFetches = new ConcurrentHashMap(); this.outstandingRpcs = new ConcurrentHashMap(); + this.timeOfLastRequestNs = new AtomicLong(0); } public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) { + timeOfLastRequestNs.set(System.nanoTime()); outstandingFetches.put(streamChunkId, callback); } @@ -65,6 +70,7 @@ public void removeFetchRequest(StreamChunkId streamChunkId) { } public void addRpcRequest(long requestId, RpcResponseCallback callback) { + timeOfLastRequestNs.set(System.nanoTime()); outstandingRpcs.put(requestId, callback); } @@ -161,8 +167,12 @@ public void handle(ResponseMessage message) { } /** Returns total number of outstanding requests (fetch requests + rpcs) */ - @VisibleForTesting public int numOutstandingRequests() { return outstandingFetches.size() + outstandingRpcs.size(); } + + /** Returns the time in nanoseconds of when the last request was sent out. */ + public long getTimeOfLastRequestNs() { + return timeOfLastRequestNs.get(); + } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java index 986957c1509f..f76bb49e874f 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java @@ -17,7 +17,6 @@ package org.apache.spark.network.protocol; -import com.google.common.base.Charsets; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java index 873c69425094..9162d0b977f8 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java @@ -20,7 +20,6 @@ import com.google.common.base.Charsets; import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; /** Provides a canonical set of Encoders for simple types. */ public class Encoders { diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 91d1e8a538a7..0f999f5dfe8d 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -72,9 +72,11 @@ public void encode(ChannelHandlerContext ctx, Message in, List out) { in.encode(header); assert header.writableBytes() == 0; - out.add(header); if (body != null && bodyLength > 0) { - out.add(body); + out.add(new MessageWithHeader(header, body, bodyLength)); + } else { + out.add(header); } } + } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java new file mode 100644 index 000000000000..d686a951467c --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.io.IOException; +import java.nio.channels.WritableByteChannel; + +import com.google.common.base.Preconditions; +import io.netty.buffer.ByteBuf; +import io.netty.channel.FileRegion; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.ReferenceCountUtil; + +/** + * A wrapper message that holds two separate pieces (a header and a body). + * + * The header must be a ByteBuf, while the body can be a ByteBuf or a FileRegion. + */ +class MessageWithHeader extends AbstractReferenceCounted implements FileRegion { + + private final ByteBuf header; + private final int headerLength; + private final Object body; + private final long bodyLength; + private long totalBytesTransferred; + + MessageWithHeader(ByteBuf header, Object body, long bodyLength) { + Preconditions.checkArgument(body instanceof ByteBuf || body instanceof FileRegion, + "Body must be a ByteBuf or a FileRegion."); + this.header = header; + this.headerLength = header.readableBytes(); + this.body = body; + this.bodyLength = bodyLength; + } + + @Override + public long count() { + return headerLength + bodyLength; + } + + @Override + public long position() { + return 0; + } + + @Override + public long transfered() { + return totalBytesTransferred; + } + + /** + * This code is more complicated than you would think because we might require multiple + * transferTo invocations in order to transfer a single MessageWithHeader to avoid busy waiting. + * + * The contract is that the caller will ensure position is properly set to the total number + * of bytes transferred so far (i.e. value returned by transfered()). + */ + @Override + public long transferTo(final WritableByteChannel target, final long position) throws IOException { + Preconditions.checkArgument(position == totalBytesTransferred, "Invalid position."); + // Bytes written for header in this call. + long writtenHeader = 0; + if (header.readableBytes() > 0) { + writtenHeader = copyByteBuf(header, target); + totalBytesTransferred += writtenHeader; + if (header.readableBytes() > 0) { + return writtenHeader; + } + } + + // Bytes written for body in this call. + long writtenBody = 0; + if (body instanceof FileRegion) { + writtenBody = ((FileRegion) body).transferTo(target, totalBytesTransferred - headerLength); + } else if (body instanceof ByteBuf) { + writtenBody = copyByteBuf((ByteBuf) body, target); + } + totalBytesTransferred += writtenBody; + + return writtenHeader + writtenBody; + } + + @Override + protected void deallocate() { + header.release(); + ReferenceCountUtil.release(body); + } + + private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException { + int written = target.write(buf.nioBuffer()); + buf.skipBytes(written); + return written; + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java index ebd764eb5eb5..6b991375fc48 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java @@ -17,7 +17,6 @@ package org.apache.spark.network.protocol; -import com.google.common.base.Charsets; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java rename to network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java rename to network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java similarity index 96% rename from network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java rename to network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 3777a18e33f7..026cbd260d16 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -19,16 +19,13 @@ import java.util.concurrent.ConcurrentMap; -import com.google.common.base.Charsets; import com.google.common.collect.Maps; -import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.protocol.Encodable; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java b/network/common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java rename to network/common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java rename to network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java rename to network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index e491367fa452..8e0ee709e38e 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -19,6 +19,8 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.handler.timeout.IdleState; +import io.netty.handler.timeout.IdleStateEvent; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -40,6 +42,11 @@ * Client. * This means that the Client also needs a RequestHandler and the Server needs a ResponseHandler, * for the Client's responses to the Server's requests. + * + * This class also handles timeouts from a {@link io.netty.handler.timeout.IdleStateHandler}. + * We consider a connection timed out if there are outstanding fetch or RPC requests but no traffic + * on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not + * timeout if the client is continuously sending but getting no responses, for simplicity. */ public class TransportChannelHandler extends SimpleChannelInboundHandler { private final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class); @@ -47,14 +54,17 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler 0; + boolean isActuallyOverdue = + System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; + if (e.state() == IdleState.ALL_IDLE && hasInFlightRequests && isActuallyOverdue) { + String address = NettyUtils.getRemoteAddress(ctx.channel()); + logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + + "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + + "is wrong.", address, requestTimeoutNs / 1000 / 1000); + ctx.close(); + } + } + } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java index 625c3257d764..b7ce8541e565 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -28,7 +28,6 @@ import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.socket.SocketChannel; -import io.netty.util.internal.PlatformDependent; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -100,8 +99,7 @@ protected void initChannel(SocketChannel ch) throws Exception { } }); - channelFuture = bootstrap.bind(new InetSocketAddress(portToBind)); - channelFuture.syncUninterruptibly(); + bindRightPort(portToBind); port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort(); logger.debug("Shuffle server started on port :" + port); @@ -123,4 +121,37 @@ public void close() { bootstrap = null; } + /** + * Attempt to bind to the specified port up to a fixed number of retries. + * If all attempts fail after the max number of retries, exit. + */ + private void bindRightPort(int portToBind) { + int maxPortRetries = conf.portMaxRetries(); + + for (int i = 0; i <= maxPortRetries; i++) { + int tryPort = -1; + if (0 == portToBind) { + // Do not increment port if tryPort is 0, which is treated as a special port + tryPort = 0; + } else { + // If the new port wraps around, do not try a privilege port + tryPort = ((portToBind + i - 1024) % (65536 - 1024)) + 1024; + } + try { + channelFuture = bootstrap.bind(new InetSocketAddress(tryPort)); + channelFuture.syncUninterruptibly(); + return; + } catch (Exception e) { + logger.warn("Netty service could not bind on port " + tryPort + + ". Attempting the next port."); + if (i >= maxPortRetries) { + logger.error(e.getMessage() + ": Netty server failed after " + + maxPortRetries + " retries."); + + // If it can't find a right port, it should exit directly. + System.exit(-1); + } + } + } + } } diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index bf8a1fc42fc6..b6fbace509a0 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -17,19 +17,17 @@ package org.apache.spark.network.util; -import java.nio.ByteBuffer; - -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; import java.io.Closeable; import java.io.File; import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; +import java.nio.ByteBuffer; +import java.util.concurrent.TimeUnit; +import java.util.regex.Matcher; +import java.util.regex.Pattern; -import com.google.common.base.Preconditions; -import com.google.common.io.Closeables; import com.google.common.base.Charsets; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableMap; import io.netty.buffer.Unpooled; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -127,4 +125,66 @@ private static boolean isSymlink(File file) throws IOException { } return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()); } + + private static ImmutableMap timeSuffixes = + ImmutableMap.builder() + .put("us", TimeUnit.MICROSECONDS) + .put("ms", TimeUnit.MILLISECONDS) + .put("s", TimeUnit.SECONDS) + .put("m", TimeUnit.MINUTES) + .put("min", TimeUnit.MINUTES) + .put("h", TimeUnit.HOURS) + .put("d", TimeUnit.DAYS) + .build(); + + /** + * Convert a passed time string (e.g. 50s, 100ms, or 250us) to a time count for + * internal use. If no suffix is provided a direct conversion is attempted. + */ + private static long parseTimeString(String str, TimeUnit unit) { + String lower = str.toLowerCase().trim(); + + try { + String suffix; + long val; + Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower); + if (m.matches()) { + val = Long.parseLong(m.group(1)); + suffix = m.group(2); + } else { + throw new NumberFormatException("Failed to parse time string: " + str); + } + + // Check for invalid suffixes + if (suffix != null && !timeSuffixes.containsKey(suffix)) { + throw new NumberFormatException("Invalid suffix: \"" + suffix + "\""); + } + + // If suffix is valid use that, otherwise none was provided and use the default passed + return unit.convert(val, suffix != null ? timeSuffixes.get(suffix) : unit); + } catch (NumberFormatException e) { + String timeError = "Time must be specified as seconds (s), " + + "milliseconds (ms), microseconds (us), minutes (m or min) hour (h), or day (d). " + + "E.g. 50s, 100ms, or 250us."; + + throw new NumberFormatException(timeError + "\n" + e.getMessage()); + } + } + + /** + * Convert a time parameter such as (50s, 100ms, or 250us) to milliseconds for internal use. If + * no suffix is provided, the passed number is assumed to be in ms. + */ + public static long timeStringAsMs(String str) { + return parseTimeString(str, TimeUnit.MILLISECONDS); + } + + /** + * Convert a time parameter such as (50s, 100ms, or 250us) to seconds for internal use. If + * no suffix is provided, the passed number is assumed to be in seconds. + */ + public static long timeStringAsSec(String str) { + return parseTimeString(str, TimeUnit.SECONDS); + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java new file mode 100644 index 000000000000..668d2356b955 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import com.google.common.collect.Maps; + +import java.util.Map; +import java.util.NoSuchElementException; + +/** ConfigProvider based on a Map (copied in the constructor). */ +public class MapConfigProvider extends ConfigProvider { + private final Map config; + + public MapConfigProvider(Map config) { + this.config = Maps.newHashMap(config); + } + + @Override + public String get(String name) { + String value = config.get(name); + if (value == null) { + throw new NoSuchElementException(name); + } + return value; + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java index 2a4b88b64cdc..26c6399ce7db 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -25,7 +25,6 @@ import io.netty.channel.Channel; import io.netty.channel.EventLoopGroup; import io.netty.channel.ServerChannel; -import io.netty.channel.epoll.Epoll; import io.netty.channel.epoll.EpollEventLoopGroup; import io.netty.channel.epoll.EpollServerSocketChannel; import io.netty.channel.epoll.EpollSocketChannel; @@ -99,7 +98,7 @@ public static ByteToMessageDecoder createFrameDecoder() { return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 8, -8, 8); } - /** Returns the remote address on the channel or "<remote address>" if none exists. */ + /** Returns the remote address on the channel or "<unknown remote>" if none exists. */ public static String getRemoteAddress(Channel channel) { if (channel != null && channel.remoteAddress() != null) { return channel.remoteAddress().toString(); diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index 6c9178688693..0aef7f198731 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -37,8 +37,11 @@ public boolean preferDirectBufs() { /** Connect timeout in milliseconds. Default 120 secs. */ public int connectionTimeoutMs() { - int defaultTimeout = conf.getInt("spark.network.timeout", 120); - return conf.getInt("spark.shuffle.io.connectionTimeout", defaultTimeout) * 1000; + long defaultNetworkTimeoutS = JavaUtils.timeStringAsSec( + conf.get("spark.network.timeout", "120s")); + long defaultTimeoutMs = JavaUtils.timeStringAsSec( + conf.get("spark.shuffle.io.connectionTimeout", defaultNetworkTimeoutS + "s")) * 1000; + return (int) defaultTimeoutMs; } /** Number of concurrent connections between two nodes for fetching data. */ @@ -68,7 +71,9 @@ public int numConnectionsPerPeer() { public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); } /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ - public int saslRTTimeoutMs() { return conf.getInt("spark.shuffle.sasl.timeout", 30) * 1000; } + public int saslRTTimeoutMs() { + return (int) JavaUtils.timeStringAsSec(conf.get("spark.shuffle.sasl.timeout", "30s")) * 1000; + } /** * Max number of times we will try IO exceptions (such as connection timeouts) per request. @@ -80,7 +85,9 @@ public int numConnectionsPerPeer() { * Time (in milliseconds) that we will wait in order to perform a retry after an IOException. * Only relevant if maxIORetries > 0. */ - public int ioRetryWaitTimeMs() { return conf.getInt("spark.shuffle.io.retryWait", 5) * 1000; } + public int ioRetryWaitTimeMs() { + return (int) JavaUtils.timeStringAsSec(conf.get("spark.shuffle.io.retryWait", "5s")) * 1000; + } /** * Minimum size of a block that we should start using memory map rather than reading in through @@ -98,4 +105,11 @@ public int memoryMapBytes() { public boolean lazyFileDescriptor() { return conf.getBoolean("spark.shuffle.io.lazyFD", true); } + + /** + * Maximum number of retries when binding to a port before giving up. + */ + public int portMaxRetries() { + return conf.getInt("spark.port.maxRetries", 16); + } } diff --git a/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java b/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java new file mode 100644 index 000000000000..b525ed69fc9f --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; + +public class ByteArrayWritableChannel implements WritableByteChannel { + + private final byte[] data; + private int offset; + + public ByteArrayWritableChannel(int size) { + this.data = new byte[size]; + this.offset = 0; + } + + public byte[] getData() { + return data; + } + + @Override + public int write(ByteBuffer src) { + int available = src.remaining(); + src.get(data, offset, available); + offset += available; + return available; + } + + @Override + public void close() { + + } + + @Override + public boolean isOpen() { + return true; + } + +} diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java index 43dc0cf8c719..860dd6d9b391 100644 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -17,26 +17,34 @@ package org.apache.spark.network; +import java.util.List; + +import com.google.common.primitives.Ints; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.FileRegion; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.MessageToMessageEncoder; import org.junit.Test; import static org.junit.Assert.assertEquals; -import org.apache.spark.network.protocol.Message; -import org.apache.spark.network.protocol.StreamChunkId; -import org.apache.spark.network.protocol.ChunkFetchRequest; import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchRequest; import org.apache.spark.network.protocol.ChunkFetchSuccess; -import org.apache.spark.network.protocol.RpcRequest; -import org.apache.spark.network.protocol.RpcFailure; -import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.MessageDecoder; import org.apache.spark.network.protocol.MessageEncoder; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcRequest; +import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.util.NettyUtils; public class ProtocolSuite { private void testServerToClient(Message msg) { - EmbeddedChannel serverChannel = new EmbeddedChannel(new MessageEncoder()); + EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(), + new MessageEncoder()); serverChannel.writeOutbound(msg); EmbeddedChannel clientChannel = new EmbeddedChannel( @@ -51,7 +59,8 @@ private void testServerToClient(Message msg) { } private void testClientToServer(Message msg) { - EmbeddedChannel clientChannel = new EmbeddedChannel(new MessageEncoder()); + EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(), + new MessageEncoder()); clientChannel.writeOutbound(msg); EmbeddedChannel serverChannel = new EmbeddedChannel( @@ -83,4 +92,25 @@ public void responses() { testServerToClient(new RpcFailure(0, "this is an error")); testServerToClient(new RpcFailure(0, "")); } + + /** + * Handler to transform a FileRegion into a byte buffer. EmbeddedChannel doesn't actually transfer + * bytes, but messages, so this is needed so that the frame decoder on the receiving side can + * understand what MessageWithHeader actually contains. + */ + private static class FileRegionEncoder extends MessageToMessageEncoder { + + @Override + public void encode(ChannelHandlerContext ctx, FileRegion in, List out) + throws Exception { + + ByteArrayWritableChannel channel = new ByteArrayWritableChannel(Ints.checkedCast(in.count())); + while (in.transfered() < in.count()) { + in.transferTo(channel, in.transfered()); + } + out.add(Unpooled.wrappedBuffer(channel.getData())); + } + + } + } diff --git a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java new file mode 100644 index 000000000000..84ebb337e6d5 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import com.google.common.collect.Maps; +import com.google.common.util.concurrent.Uninterruptibles; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.MapConfigProvider; +import org.apache.spark.network.util.TransportConf; +import org.junit.*; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.*; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +/** + * Suite which ensures that requests that go without a response for the network timeout period are + * failed, and the connection closed. + * + * In this suite, we use 2 seconds as the connection timeout, with some slack given in the tests, + * to ensure stability in different test environments. + */ +public class RequestTimeoutIntegrationSuite { + + private TransportServer server; + private TransportClientFactory clientFactory; + + private StreamManager defaultManager; + private TransportConf conf; + + // A large timeout that "shouldn't happen", for the sake of faulty tests not hanging forever. + private final int FOREVER = 60 * 1000; + + @Before + public void setUp() throws Exception { + Map configMap = Maps.newHashMap(); + configMap.put("spark.shuffle.io.connectionTimeout", "2s"); + conf = new TransportConf(new MapConfigProvider(configMap)); + + defaultManager = new StreamManager() { + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + throw new UnsupportedOperationException(); + } + }; + } + + @After + public void tearDown() { + if (server != null) { + server.close(); + } + if (clientFactory != null) { + clientFactory.close(); + } + } + + // Basic suite: First request completes quickly, and second waits for longer than network timeout. + @Test + public void timeoutInactiveRequests() throws Exception { + final Semaphore semaphore = new Semaphore(1); + final byte[] response = new byte[16]; + RpcHandler handler = new RpcHandler() { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + try { + semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS); + callback.onSuccess(response); + } catch (InterruptedException e) { + // do nothing + } + } + + @Override + public StreamManager getStreamManager() { + return defaultManager; + } + }; + + TransportContext context = new TransportContext(conf, handler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + + // First completes quickly (semaphore starts at 1). + TestCallback callback0 = new TestCallback(); + synchronized (callback0) { + client.sendRpc(new byte[0], callback0); + callback0.wait(FOREVER); + assert (callback0.success.length == response.length); + } + + // Second times out after 2 seconds, with slack. Must be IOException. + TestCallback callback1 = new TestCallback(); + synchronized (callback1) { + client.sendRpc(new byte[0], callback1); + callback1.wait(4 * 1000); + assert (callback1.failure != null); + assert (callback1.failure instanceof IOException); + } + semaphore.release(); + } + + // A timeout will cause the connection to be closed, invalidating the current TransportClient. + // It should be the case that requesting a client from the factory produces a new, valid one. + @Test + public void timeoutCleanlyClosesClient() throws Exception { + final Semaphore semaphore = new Semaphore(0); + final byte[] response = new byte[16]; + RpcHandler handler = new RpcHandler() { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + try { + semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS); + callback.onSuccess(response); + } catch (InterruptedException e) { + // do nothing + } + } + + @Override + public StreamManager getStreamManager() { + return defaultManager; + } + }; + + TransportContext context = new TransportContext(conf, handler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + + // First request should eventually fail. + TransportClient client0 = + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + TestCallback callback0 = new TestCallback(); + synchronized (callback0) { + client0.sendRpc(new byte[0], callback0); + callback0.wait(FOREVER); + assert (callback0.failure instanceof IOException); + assert (!client0.isActive()); + } + + // Increment the semaphore and the second request should succeed quickly. + semaphore.release(2); + TransportClient client1 = + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + TestCallback callback1 = new TestCallback(); + synchronized (callback1) { + client1.sendRpc(new byte[0], callback1); + callback1.wait(FOREVER); + assert (callback1.success.length == response.length); + assert (callback1.failure == null); + } + } + + // The timeout is relative to the LAST request sent, which is kinda weird, but still. + // This test also makes sure the timeout works for Fetch requests as well as RPCs. + @Test + public void furtherRequestsDelay() throws Exception { + final byte[] response = new byte[16]; + final StreamManager manager = new StreamManager() { + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + Uninterruptibles.sleepUninterruptibly(FOREVER, TimeUnit.MILLISECONDS); + return new NioManagedBuffer(ByteBuffer.wrap(response)); + } + }; + RpcHandler handler = new RpcHandler() { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + throw new UnsupportedOperationException(); + } + + @Override + public StreamManager getStreamManager() { + return manager; + } + }; + + TransportContext context = new TransportContext(conf, handler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + + // Send one request, which will eventually fail. + TestCallback callback0 = new TestCallback(); + client.fetchChunk(0, 0, callback0); + Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS); + + // Send a second request before the first has failed. + TestCallback callback1 = new TestCallback(); + client.fetchChunk(0, 1, callback1); + Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS); + + synchronized (callback0) { + // not complete yet, but should complete soon + assert (callback0.success == null && callback0.failure == null); + callback0.wait(2 * 1000); + assert (callback0.failure instanceof IOException); + } + + synchronized (callback1) { + // failed at same time as previous + assert (callback0.failure instanceof IOException); + } + } + + /** + * Callback which sets 'success' or 'failure' on completion. + * Additionally notifies all waiters on this callback when invoked. + */ + class TestCallback implements RpcResponseCallback, ChunkReceivedCallback { + + byte[] success; + Throwable failure; + + @Override + public void onSuccess(byte[] response) { + synchronized(this) { + success = response; + this.notifyAll(); + } + } + + @Override + public void onFailure(Throwable e) { + synchronized(this) { + failure = e; + this.notifyAll(); + } + } + + @Override + public void onSuccess(int chunkIndex, ManagedBuffer buffer) { + synchronized(this) { + try { + success = buffer.nioByteBuffer().array(); + this.notifyAll(); + } catch (IOException e) { + // weird + } + } + } + + @Override + public void onFailure(int chunkIndex, Throwable e) { + synchronized(this) { + failure = e; + this.notifyAll(); + } + } + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index 416dc1b969fa..35de5e57ccb9 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -20,10 +20,11 @@ import java.io.IOException; import java.util.Collections; import java.util.HashSet; -import java.util.NoSuchElementException; +import java.util.Map; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; +import com.google.common.collect.Maps; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -36,9 +37,9 @@ import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.util.ConfigProvider; -import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; public class TransportClientFactorySuite { @@ -70,16 +71,10 @@ public void tearDown() { */ private void testClientReuse(final int maxConnections, boolean concurrent) throws IOException, InterruptedException { - TransportConf conf = new TransportConf(new ConfigProvider() { - @Override - public String get(String name) { - if (name.equals("spark.shuffle.io.numConnectionsPerPeer")) { - return Integer.toString(maxConnections); - } else { - throw new NoSuchElementException(); - } - } - }); + + Map configMap = Maps.newHashMap(); + configMap.put("spark.shuffle.io.numConnectionsPerPeer", Integer.toString(maxConnections)); + TransportConf conf = new TransportConf(new MapConfigProvider(configMap)); RpcHandler rpcHandler = new NoOpRpcHandler(); TransportContext context = new TransportContext(conf, rpcHandler); diff --git a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java new file mode 100644 index 000000000000..ff985096d72d --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.FileRegion; +import io.netty.util.AbstractReferenceCounted; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.ByteArrayWritableChannel; + +public class MessageWithHeaderSuite { + + @Test + public void testSingleWrite() throws Exception { + testFileRegionBody(8, 8); + } + + @Test + public void testShortWrite() throws Exception { + testFileRegionBody(8, 1); + } + + @Test + public void testByteBufBody() throws Exception { + ByteBuf header = Unpooled.copyLong(42); + ByteBuf body = Unpooled.copyLong(84); + MessageWithHeader msg = new MessageWithHeader(header, body, body.readableBytes()); + + ByteBuf result = doWrite(msg, 1); + assertEquals(msg.count(), result.readableBytes()); + assertEquals(42, result.readLong()); + assertEquals(84, result.readLong()); + } + + private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception { + ByteBuf header = Unpooled.copyLong(42); + int headerLength = header.readableBytes(); + TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall); + MessageWithHeader msg = new MessageWithHeader(header, region, region.count()); + + ByteBuf result = doWrite(msg, totalWrites / writesPerCall); + assertEquals(headerLength + region.count(), result.readableBytes()); + assertEquals(42, result.readLong()); + for (long i = 0; i < 8; i++) { + assertEquals(i, result.readLong()); + } + } + + private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception { + int writes = 0; + ByteArrayWritableChannel channel = new ByteArrayWritableChannel((int) msg.count()); + while (msg.transfered() < msg.count()) { + msg.transferTo(channel, msg.transfered()); + writes++; + } + assertTrue("Not enough writes!", minExpectedWrites <= writes); + return Unpooled.wrappedBuffer(channel.getData()); + } + + private static class TestFileRegion extends AbstractReferenceCounted implements FileRegion { + + private final int writeCount; + private final int writesPerCall; + private int written; + + TestFileRegion(int totalWrites, int writesPerCall) { + this.writeCount = totalWrites; + this.writesPerCall = writesPerCall; + } + + @Override + public long count() { + return 8 * writeCount; + } + + @Override + public long position() { + return 0; + } + + @Override + public long transfered() { + return 8 * written; + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + for (int i = 0; i < writesPerCall; i++) { + ByteBuf buf = Unpooled.copyLong((position / 8) + i); + ByteBuffer nio = buf.nioBuffer(); + while (nio.remaining() > 0) { + target.write(nio); + } + buf.release(); + written++; + } + return 8 * writesPerCall; + } + + @Override + protected void deallocate() { + } + + } + +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java similarity index 95% rename from network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java rename to network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 67a07f38eb5a..23b4e06f064e 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -17,12 +17,12 @@ package org.apache.spark.network.sasl; -import java.util.Map; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; -import com.google.common.collect.ImmutableMap; import org.junit.Test; -import static org.junit.Assert.*; /** * Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes. diff --git a/network/common/src/test/resources/log4j.properties b/network/common/src/test/resources/log4j.properties new file mode 100644 index 000000000000..e8da774f7ca9 --- /dev/null +++ b/network/common/src/test/resources/log4j.properties @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=DEBUG, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Silence verbose logs from 3rd-party libraries. +log4j.logger.io.netty=INFO diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index 5bfa1ac9c373..7dc7c65825e3 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -21,8 +21,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT ../../pom.xml @@ -52,7 +52,6 @@ com.google.guava guava - provided diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 8ed2e0b39ad2..e653f5cb147e 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -29,7 +29,6 @@ import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.OpenBlocks; import org.apache.spark.network.shuffle.protocol.StreamHandle; -import org.apache.spark.network.util.JavaUtils; /** * Simple wrapper on top of a TransportClient which interprets each chunk as a whole block, and diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java index 62fce9b0d16c..60485bace643 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java @@ -23,7 +23,6 @@ import io.netty.buffer.ByteBuf; import org.apache.spark.network.protocol.Encoders; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; /** Request to read a set of blocks. Returns {@link StreamHandle}. */ public class OpenBlocks extends BlockTransferMessage { diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java index 7eb438504407..38acae3b31d6 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java @@ -21,7 +21,6 @@ import io.netty.buffer.ByteBuf; import org.apache.spark.network.protocol.Encoders; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; /** * Initial registration message between an executor and its local shuffle server. diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java index bc9daa6158ba..9a9220211a50 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java @@ -20,8 +20,6 @@ import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; - /** * Identifier for a fixed number of chunks to read from a stream created by an "open blocks" * message. This is used by {@link org.apache.spark.network.shuffle.OneForOneBlockFetcher}. diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java index 0b23e112bd51..2ff9aaa650f9 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java @@ -23,7 +23,6 @@ import io.netty.buffer.ByteBuf; import org.apache.spark.network.protocol.Encoders; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; /** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */ diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index 842741e3d354..b35a6d685dd0 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -28,11 +28,16 @@ import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; -import static org.junit.Assert.*; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; +import static org.mockito.Matchers.anyLong; import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NettyManagedBuffer; diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml index acec8f18f2b5..1e2e9c80af6c 100644 --- a/network/yarn/pom.xml +++ b/network/yarn/pom.xml @@ -21,8 +21,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT ../../pom.xml @@ -33,6 +33,8 @@ http://spark.apache.org/ network-yarn + + provided @@ -47,7 +49,6 @@ org.apache.hadoop hadoop-client - provided diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index a34aabe9e78a..63b21222e7b7 100644 --- a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -76,6 +76,9 @@ public class YarnShuffleService extends AuxiliaryService { // The actual server that serves shuffle files private TransportServer shuffleServer = null; + // Handles registering executors and opening shuffle blocks + private ExternalShuffleBlockHandler blockHandler; + public YarnShuffleService() { super("spark_shuffle"); logger.info("Initializing YARN shuffle service for Spark"); @@ -99,7 +102,8 @@ protected void serviceInit(Configuration conf) { // If authentication is enabled, set up the shuffle server to use a // special RPC handler that filters out unauthenticated fetch requests boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); - RpcHandler rpcHandler = new ExternalShuffleBlockHandler(transportConf); + blockHandler = new ExternalShuffleBlockHandler(transportConf); + RpcHandler rpcHandler = blockHandler; if (authEnabled) { secretManager = new ShuffleSecretManager(); rpcHandler = new SaslRpcHandler(rpcHandler, secretManager); @@ -136,6 +140,7 @@ public void stopApplication(ApplicationTerminationContext context) { if (isAuthenticationEnabled()) { secretManager.unregisterApp(appId); } + blockHandler.applicationRemoved(appId, false /* clean up local dirs */); } catch (Exception e) { logger.error("Exception when stopping application {}", appId, e); } diff --git a/pom.xml b/pom.xml index b993391b1504..95ccbf5abcd4 100644 --- a/pom.xml +++ b/pom.xml @@ -25,8 +25,8 @@ 14 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ @@ -105,6 +105,7 @@ external/zeromq examples repl + launcher @@ -117,9 +118,9 @@ 2.0.1 0.21.0 shaded-protobuf - 1.7.5 + 1.7.10 1.2.17 - 1.0.4 + 2.2.0 2.4.1 ${hadoop.version} 0.98.7-hadoop1 @@ -135,22 +136,30 @@ 1.6.0rc3 1.2.3 8.1.14.v20131031 + 3.0.0.v201112011016 0.5.0 - 3.0.0 - 1.7.6 + 2.4.0 + 2.0.8 + 3.1.0 + 1.7.7 0.7.1 1.8.3 1.1.0 - 4.2.6 - 3.1.1 + 4.3.2 + 3.4.1 ${project.build.directory}/spark-test-classpath.txt 2.10.4 2.10 ${scala.version} org.scala-lang + 3.6.3 1.8.8 - 1.1.1.6 + 2.4.4 + 1.1.1.7 + 1.1.2 + + ${java.home} Maven Repository - https://repo1.maven.org/maven2 + http://repo1.maven.org/maven2 true @@ -193,7 +202,7 @@ apache-repo Apache Repository - https://repository.apache.org/content/repositories/releases + http://repository.apache.org/content/repositories/releases true @@ -204,7 +213,7 @@ jboss-repo JBoss Repository - https://repository.jboss.org/nexus/content/repositories/releases + http://repository.jboss.org/nexus/content/repositories/releases true @@ -215,7 +224,7 @@ mqtt-repo MQTT Repository - https://repo.eclipse.org/content/repositories/paho-releases + http://repo.eclipse.org/content/repositories/paho-releases true @@ -226,7 +235,7 @@ cloudera-repo Cloudera Repository - https://repository.cloudera.com/artifactory/cloudera-repos + http://repository.cloudera.com/artifactory/cloudera-repos true @@ -248,7 +257,7 @@ spring-releases Spring Release Repository - https://repo.spring.io/libs-release + http://repo.spring.io/libs-release true @@ -260,7 +269,7 @@ central - https://repo1.maven.org/maven2 + http://repo1.maven.org/maven2 true @@ -337,25 +346,51 @@ + + + + org.eclipse.jetty + jetty-http + ${jetty.version} + provided + + + org.eclipse.jetty + jetty-continuation + ${jetty.version} + provided + + + org.eclipse.jetty + jetty-servlet + ${jetty.version} + provided + org.eclipse.jetty jetty-util ${jetty.version} + provided org.eclipse.jetty jetty-security ${jetty.version} + provided org.eclipse.jetty jetty-plus ${jetty.version} + provided org.eclipse.jetty jetty-server ${jetty.version} + provided com.google.guava @@ -363,6 +398,8 @@ 14.0.1 provided + + org.apache.commons commons-lang3 @@ -371,7 +408,7 @@ commons-codec commons-codec - 1.5 + 1.10 org.apache.commons @@ -383,12 +420,29 @@ jsr305 1.3.9 + + org.apache.httpcomponents + httpclient + ${commons.httpclient.version} + + + org.apache.httpcomponents + httpcore + ${commons.httpclient.version} + org.seleniumhq.selenium selenium-java 2.42.2 test + + + xml-apis + xml-apis + 1.4.01 + test + org.slf4j slf4j-api @@ -521,30 +575,48 @@ ${derby.version} - com.codahale.metrics + io.dropwizard.metrics metrics-core ${codahale.metrics.version} - com.codahale.metrics + io.dropwizard.metrics metrics-jvm ${codahale.metrics.version} - com.codahale.metrics + io.dropwizard.metrics metrics-json ${codahale.metrics.version} - com.codahale.metrics + io.dropwizard.metrics metrics-ganglia ${codahale.metrics.version} - com.codahale.metrics + io.dropwizard.metrics metrics-graphite ${codahale.metrics.version} + + com.fasterxml.jackson.core + jackson-databind + ${fasterxml.jackson.version} + + + + com.fasterxml.jackson.module + jackson-module-scala_2.10 + ${fasterxml.jackson.version} + + + com.google.guava + guava + + + org.scala-lang scala-compiler @@ -576,19 +648,6 @@ 2.2.1 test - - org.easymock - easymockclassextension - 3.1 - test - - - - asm - asm - 3.3.1 - test - org.mockito mockito-all @@ -896,6 +955,16 @@ ${codehaus.jackson.version} ${hadoop.deps.scope} + + org.codehaus.jackson + jackson-xc + ${codehaus.jackson.version} + + + org.codehaus.jackson + jackson-jaxrs + ${codehaus.jackson.version} + ${hive.group} hive-beeline @@ -922,6 +991,10 @@ com.esotericsoftware.kryo kryo + + org.apache.avro + avro-mapred + @@ -1039,6 +1112,12 @@ scala-maven-plugin 3.2.0 + + eclipse-add-source + + add-source + + scala-compile-first process-resources @@ -1121,13 +1200,20 @@ ${project.build.directory}/surefire-reports -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + + + ${test_classpath} + ${test.java.home} + true - ${session.executionRootDirectory} + ${spark.test.home} 1 false false - ${test_classpath} true false @@ -1151,6 +1237,7 @@ launched by the tests have access to the correct test-time classpath. --> ${test_classpath} + ${test.java.home} true @@ -1192,6 +1279,7 @@ create-source-jar jar-no-fork + test-jar-no-fork @@ -1264,7 +1352,10 @@ - + org.apache.maven.plugins maven-shade-plugin @@ -1273,9 +1364,44 @@ false + org.spark-project.spark:unused + + org.eclipse.jetty:jetty-io + org.eclipse.jetty:jetty-http + org.eclipse.jetty:jetty-continuation + org.eclipse.jetty:jetty-servlet + org.eclipse.jetty:jetty-plus + org.eclipse.jetty:jetty-security + org.eclipse.jetty:jetty-util + org.eclipse.jetty:jetty-server + com.google.guava:guava + + + org.eclipse.jetty + org.spark-project.jetty + + org.eclipse.jetty.** + + + + com.google.common + org.spark-project.guava + + + com/google/common/base/Absent* + com/google/common/base/Function + com/google/common/base/Optional* + com/google/common/base/Present* + com/google/common/base/Supplier + + + @@ -1331,7 +1457,7 @@ org.scalastyle scalastyle-maven-plugin - 0.4.0 + 0.7.0 false true @@ -1340,12 +1466,12 @@ ${basedir}/src/main/scala ${basedir}/src/test/scala scalastyle-config.xml - scalastyle-output.xml - UTF-8 + ${basedir}/target/scalastyle-output.xml + ${project.build.sourceEncoding} + ${project.reporting.outputEncoding} - package check @@ -1361,6 +1487,25 @@ org.scalatest scalatest-maven-plugin + + + org.apache.maven.plugins + maven-jar-plugin + + + prepare-test-jar + prepare-package + + test-jar + + + + log4j.properties + + + + + @@ -1468,6 +1613,7 @@ 2.5.0 0.98.7-hadoop2 hadoop2 + 1.9.13 @@ -1476,10 +1622,11 @@ 2.3.0 2.5.0 - 0.9.0 + 0.9.3 0.98.7-hadoop2 3.1.1 hadoop2 + 1.9.13 @@ -1488,10 +1635,11 @@ 2.4.0 2.5.0 - 0.9.0 + 0.9.3 0.98.7-hadoop2 3.1.1 hadoop2 + 1.9.13 @@ -1508,7 +1656,7 @@ 1.0.3-mapr-3.0.3 2.4.1-mapr-1408 - 0.94.17-mapr-1405 + 0.98.4-mapr-1408 3.4.5-mapr-1406 @@ -1518,7 +1666,7 @@ 2.4.1-mapr-1408 2.4.1-mapr-1408 - 0.94.17-mapr-1405-4.0.0-FCS + 0.98.4-mapr-1408 3.4.5-mapr-1406 @@ -1577,9 +1725,20 @@ external/kafka + external/kafka-assembly + + test-java-home + + env.JAVA_HOME + + + ${env.JAVA_HOME} + + + scala-2.11 @@ -1613,5 +1772,8 @@ parquet-provided + + sparkr + diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index f0cbf4e57b8c..dde92949fa17 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -91,7 +91,7 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - val previousSparkVersion = "1.2.0" + val previousSparkVersion = "1.3.0" val fullId = "spark-" + projectRef.project + "_2.10" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index bc5d81f12d74..7ef363a2f07a 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -16,6 +16,7 @@ */ import com.typesafe.tools.mima.core._ +import com.typesafe.tools.mima.core.ProblemFilters._ /** * Additional excludes for checking of Spark's binary compatibility. @@ -33,9 +34,50 @@ import com.typesafe.tools.mima.core._ object MimaExcludes { def excludes(version: String) = version match { + case v if v.startsWith("1.4") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("ml"), + // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff"), + // These are needed if checking against the sbt build, since they are part of + // the maven-generated artifacts in 1.3. + excludePackage("org.spark-project.jetty"), + MimaBuild.excludeSparkPackage("unused"), + ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.rdd.JdbcRDD.compute"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.broadcast.HttpBroadcastFactory.newBroadcast"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.scheduler.OutputCommitCoordinator$OutputCommitCoordinatorActor") + ) ++ Seq( + // SPARK-4655 - Making Stage an Abstract class broke binary compatility even though + // the stage class is defined as private[spark] + ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.scheduler.Stage") + ) ++ Seq( + // SPARK-6510 Add a Graph#minus method acting as Set#difference + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.minus") + ) ++ Seq( + // SPARK-6492 Fix deadlock in SparkContext.stop() + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.org$" + + "apache$spark$SparkContext$$SPARK_CONTEXT_CONSTRUCTOR_LOCK") + )++ Seq( + // SPARK-6693 add tostring with max lines and width for matrix + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.toString") + )++ Seq( + // SPARK-6703 Add getOrCreate method to SparkContext + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.SparkContext.org$apache$spark$SparkContext$$activeContext") + ) + case v if v.startsWith("1.3") => Seq( MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("ml"), // These are needed if checking against the sbt build, since they are part of // the maven-generated artifacts in the 1.2 build. MimaBuild.excludeSparkPackage("unused"), @@ -52,6 +94,29 @@ object MimaExcludes { "org.apache.spark.mllib.linalg.Matrices.randn"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.linalg.Matrices.rand") + ) ++ Seq( + // SPARK-5321 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.SparseMatrix.transposeMultiply"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.transpose"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.DenseMatrix.transposeMultiply"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix." + + "org$apache$spark$mllib$linalg$Matrix$_setter_$isTransposed_="), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.isTransposed"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.foreachActive") + ) ++ Seq( + // SPARK-5540 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.solveLeastSquares"), + // SPARK-5536 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateBlock") ) ++ Seq( // SPARK-3325 ProblemFilters.exclude[MissingMethodProblem]( @@ -81,11 +146,30 @@ object MimaExcludes { ) ++ Seq( // SPARK-5166 Spark SQL API stabilization ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Transformer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit") + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Transformer.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Pipeline.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PipelineModel.transform"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Estimator.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Evaluator.evaluate"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Evaluator.evaluate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidator.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidatorModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScaler.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegression.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate") ) ++ Seq( // SPARK-5270 ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.api.java.JavaRDDLike.isEmpty") + ) ++ Seq( + // SPARK-5430 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.treeReduce"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.treeAggregate") ) ++ Seq( // SPARK-5297 Java FileStream do not work with custom key/values ProblemFilters.exclude[MissingMethodProblem]( @@ -94,6 +178,53 @@ object MimaExcludes { // SPARK-5315 Spark Streaming Java API returns Scala DStream ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.streaming.api.java.JavaDStreamLike.reduceByWindow") + ) ++ Seq( + // SPARK-5461 Graph should have isCheckpointed, getCheckpointFiles methods + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.graphx.Graph.getCheckpointFiles"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.graphx.Graph.isCheckpointed") + ) ++ Seq( + // SPARK-4789 Standardize ML Prediction APIs + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.linalg.VectorUDT"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.sqlType") + ) ++ Seq( + // SPARK-5814 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$wrapDoubleArray"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$fillFullMatrix"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$iterations"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeOutLinkBlock"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$computeYtY"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeLinkRDDs"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$alpha"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$randomFactor"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeInLinkBlock"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$dspr"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$lambda"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$implicitPrefs"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$rank") + ) ++ Seq( + // SPARK-4682 + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.RealClock"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Clock"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.TestClock") + ) ++ Seq( + // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff") ) case v if v.startsWith("1.2") => diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index ded4b5443a90..09b4976d10c2 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -import java.io.File +import java.io._ import scala.util.Properties import scala.collection.JavaConversions._ @@ -34,18 +34,19 @@ object BuildCommons { val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka, - streamingMqtt, streamingTwitter, streamingZeromq) = + streamingMqtt, streamingTwitter, streamingZeromq, launcher) = Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", "sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", - "streaming-zeromq").map(ProjectRef(buildLocation, _)) + "streaming-zeromq", "launcher").map(ProjectRef(buildLocation, _)) val optionallyEnabledProjects@Seq(yarn, yarnStable, java8Tests, sparkGangliaLgpl, sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", "kinesis-asl").map(ProjectRef(buildLocation, _)) - val assemblyProjects@Seq(assembly, examples, networkYarn) = - Seq("assembly", "examples", "network-yarn").map(ProjectRef(buildLocation, _)) + val assemblyProjects@Seq(assembly, examples, networkYarn, streamingKafkaAssembly) = + Seq("assembly", "examples", "network-yarn", "streaming-kafka-assembly") + .map(ProjectRef(buildLocation, _)) val tools = ProjectRef(buildLocation, "tools") // Root project. @@ -118,7 +119,9 @@ object SparkBuild extends PomBuild { lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy") lazy val sharedSettings = graphSettings ++ genjavadocSettings ++ Seq ( - javaHome := Properties.envOrNone("JAVA_HOME").map(file), + javaHome := sys.env.get("JAVA_HOME") + .orElse(sys.props.get("java.home").map { p => new File(p).getParentFile().getAbsolutePath() }) + .map(file), incOptions := incOptions.value.withNameHashing(true), retrieveManaged := true, retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", @@ -154,14 +157,18 @@ object SparkBuild extends PomBuild { (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) // TODO: Add Sql to mima checks + // TODO: remove launcher from this list after 1.3. allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl, - networkCommon, networkShuffle, networkYarn).contains(x)).foreach { + networkCommon, networkShuffle, networkYarn, launcher).contains(x)).foreach { x => enable(MimaBuild.mimaSettings(sparkHome, x))(x) } /* Enable Assembly for all assembly projects */ assemblyProjects.foreach(enable(Assembly.settings)) + /* Package pyspark artifacts in the main assembly. */ + enable(PySparkAssembly.settings)(assembly) + /* Enable unidoc only for the root spark project */ enable(Unidoc.settings)(spark) @@ -176,6 +183,29 @@ object SparkBuild extends PomBuild { enable(Flume.settings)(streamingFlumeSink) + + /** + * Adds the ability to run the spark shell directly from SBT without building an assembly + * jar. + * + * Usage: `build/sbt sparkShell` + */ + val sparkShell = taskKey[Unit]("start a spark-shell.") + + enable(Seq( + connectInput in run := true, + fork := true, + outputStrategy in run := Some (StdoutOutput), + + javaOptions ++= Seq("-Xmx2G", "-XX:MaxPermSize=1g"), + + sparkShell := { + (runMain in Compile).toTask(" org.apache.spark.repl.Main -usejavacp").value + } + ))(assembly) + + enable(Seq(sparkShell := sparkShell in "assembly"))(spark) + // TODO: move this to its upstream project. override def projectDefinitions(baseDirectory: File): Seq[Project] = { super.projectDefinitions(baseDirectory).map { x => @@ -245,9 +275,9 @@ object SQL { |import org.apache.spark.sql.catalyst.rules._ |import org.apache.spark.sql.catalyst.util._ |import org.apache.spark.sql.execution + |import org.apache.spark.sql.functions._ |import org.apache.spark.sql.test.TestSQLContext._ - |import org.apache.spark.sql.types._ - |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin, + |import org.apache.spark.sql.types._""".stripMargin, cleanupCommands in console := "sparkContext.stop()" ) } @@ -275,10 +305,10 @@ object Hive { |import org.apache.spark.sql.catalyst.rules._ |import org.apache.spark.sql.catalyst.util._ |import org.apache.spark.sql.execution + |import org.apache.spark.sql.functions._ |import org.apache.spark.sql.hive._ |import org.apache.spark.sql.hive.test.TestHive._ - |import org.apache.spark.sql.types._ - |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin, + |import org.apache.spark.sql.types._""".stripMargin, cleanupCommands in console := "sparkContext.stop()", // Some of our log4j jars make it impossible to submit jobs from this JVM to Hive Map/Reduce // in order to generate golden files. This is only required for developers who are adding new @@ -289,6 +319,7 @@ object Hive { } object Assembly { + import sbtassembly.AssemblyUtils._ import sbtassembly.Plugin._ import AssemblyKeys._ @@ -300,7 +331,14 @@ object Assembly { sys.props.get("hadoop.version") .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String]) }, - jarName in assembly := s"${moduleName.value}-${version.value}-hadoop${hadoopVersion.value}.jar", + jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) => + if (mName.contains("streaming-kafka-assembly")) { + // This must match the same name used in maven (see external/kafka-assembly/pom.xml) + s"${mName}-${v}.jar" + } else { + s"${mName}-${v}-hadoop${hv}.jar" + } + }, mergeStrategy in assembly := { case PathList("org", "datanucleus", xs @ _*) => MergeStrategy.discard case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard @@ -313,6 +351,60 @@ object Assembly { ) } +object PySparkAssembly { + import sbtassembly.Plugin._ + import AssemblyKeys._ + + lazy val settings = Seq( + unmanagedJars in Compile += { BuildCommons.sparkHome / "python/lib/py4j-0.8.2.1-src.zip" }, + // Use a resource generator to copy all .py files from python/pyspark into a managed directory + // to be included in the assembly. We can't just add "python/" to the assembly's resource dir + // list since that will copy unneeded / unwanted files. + resourceGenerators in Compile <+= resourceManaged in Compile map { outDir: File => + val dst = new File(outDir, "pyspark") + if (!dst.isDirectory()) { + require(dst.mkdirs()) + } + + val src = new File(BuildCommons.sparkHome, "python/pyspark") + copy(src, dst) + } + ) + + private def copy(src: File, dst: File): Seq[File] = { + src.listFiles().flatMap { f => + val child = new File(dst, f.getName()) + if (f.isDirectory()) { + child.mkdir() + copy(f, child) + } else if (f.getName().endsWith(".py")) { + var in: Option[FileInputStream] = None + var out: Option[FileOutputStream] = None + try { + in = Some(new FileInputStream(f)) + out = Some(new FileOutputStream(child)) + + val bytes = new Array[Byte](1024) + var read = 0 + while (read >= 0) { + read = in.get.read(bytes) + if (read > 0) { + out.get.write(bytes, 0, read) + } + } + + Some(child) + } finally { + in.foreach(_.close()) + out.foreach(_.close()) + } + } else { + None + } + } + } +} + object Unidoc { import BuildCommons._ @@ -324,25 +416,38 @@ object Unidoc { names.map(s => "org.apache.spark." + s).mkString(":") } + private def ignoreUndocumentedPackages(packages: Seq[Seq[File]]): Seq[Seq[File]] = { + packages + .map(_.filterNot(_.getName.contains("$"))) + .map(_.filterNot(_.getCanonicalPath.contains("akka"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/deploy"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/network"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/shuffle"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/executor"))) + .map(_.filterNot(_.getCanonicalPath.contains("python"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/collection"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/catalyst"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/execution"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/hive/test"))) + } + lazy val settings = scalaJavaUnidocSettings ++ Seq ( publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, catalyst, streamingFlumeSink, yarn), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, bagel, examples, tools, catalyst, streamingFlumeSink, yarn), + inAnyProject -- inProjects(OldDeps.project, repl, bagel, examples, tools, streamingFlumeSink, yarn), + + // Skip actual catalyst, but include the subproject. + // Catalyst is not public API and contains quasiquotes which break scaladoc. + unidocAllSources in (ScalaUnidoc, unidoc) := { + ignoreUndocumentedPackages((unidocAllSources in (ScalaUnidoc, unidoc)).value) + }, // Skip class names containing $ and some internal packages in Javadocs unidocAllSources in (JavaUnidoc, unidoc) := { - (unidocAllSources in (JavaUnidoc, unidoc)).value - .map(_.filterNot(_.getName.contains("$"))) - .map(_.filterNot(_.getCanonicalPath.contains("akka"))) - .map(_.filterNot(_.getCanonicalPath.contains("deploy"))) - .map(_.filterNot(_.getCanonicalPath.contains("network"))) - .map(_.filterNot(_.getCanonicalPath.contains("shuffle"))) - .map(_.filterNot(_.getCanonicalPath.contains("executor"))) - .map(_.filterNot(_.getCanonicalPath.contains("python"))) - .map(_.filterNot(_.getCanonicalPath.contains("collection"))) + ignoreUndocumentedPackages((unidocAllSources in (JavaUnidoc, unidoc)).value) }, // Javadoc options: create a window title, and group key packages on index page @@ -361,11 +466,15 @@ object Unidoc { "mllib.tree.impurity", "mllib.tree.model", "mllib.util", "mllib.evaluation", "mllib.feature", "mllib.random", "mllib.stat.correlation", "mllib.stat.test", "mllib.tree.impl", "mllib.tree.loss", - "ml", "ml.classification", "ml.evaluation", "ml.feature", "ml.param", "ml.tuning" + "ml", "ml.attribute", "ml.classification", "ml.evaluation", "ml.feature", "ml.param", + "ml.tuning" ), "-group", "Spark SQL", packageList("sql.api.java", "sql.api.java.types", "sql.hive.api.java"), "-noqualifier", "java.lang" - ) + ), + + // Group similar methods together based on the @group annotation. + scalacOptions in (ScalaUnidoc, unidoc) ++= Seq("-groups") ) } @@ -375,6 +484,12 @@ object TestSettings { lazy val settings = Seq ( // Fork new JVMs for tests and set Java options for those fork := true, + // Setting SPARK_DIST_CLASSPATH is a simple way to make sure any child processes + // launched by the tests have access to the correct test-time classpath. + envVars in Test ++= Map( + "SPARK_DIST_CLASSPATH" -> + (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":"), + "JAVA_HOME" -> sys.env.get("JAVA_HOME").getOrElse(sys.props("java.home"))), javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dspark.port.maxRetries=100", @@ -387,10 +502,6 @@ object TestSettings { javaOptions in Test += "-ea", javaOptions in Test ++= "-Xmx3g -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g" .split(" ").toSeq, - // This places test scope jars on the classpath of executors during tests. - javaOptions in Test += - "-Dspark.executor.extraClassPath=" + (fullClasspath in Test).value.files. - map(_.getAbsolutePath).mkString(":").stripSuffix(":"), javaOptions += "-Xmx3g", // Show full stack trace and duration in test cases. testOptions in Test += Tests.Argument("-oDF"), diff --git a/project/build.properties b/project/build.properties index 32a3aeefaf9f..064ec843da9e 100644 --- a/project/build.properties +++ b/project/build.properties @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -sbt.version=0.13.6 +sbt.version=0.13.7 diff --git a/project/plugins.sbt b/project/plugins.sbt index ee45b6a51905..7096b0d3ee7d 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -19,7 +19,7 @@ addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0") addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.4") -addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.6.0") +addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.7.0") addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") diff --git a/project/project/SparkPluginBuild.scala b/project/project/SparkPluginBuild.scala index 8863f272da41..471d00bd8223 100644 --- a/project/project/SparkPluginBuild.scala +++ b/project/project/SparkPluginBuild.scala @@ -24,20 +24,6 @@ import sbt.Keys._ * becomes available for scalastyle sbt plugin. */ object SparkPluginDef extends Build { - lazy val root = Project("plugins", file(".")) dependsOn(sparkStyle, sbtPomReader) - lazy val sparkStyle = Project("spark-style", file("spark-style"), settings = styleSettings) + lazy val root = Project("plugins", file(".")) dependsOn(sbtPomReader) lazy val sbtPomReader = uri("https://github.com/ScrapCodes/sbt-pom-reader.git#ignore_artifact_id") - - // There is actually no need to publish this artifact. - def styleSettings = Defaults.defaultSettings ++ Seq ( - name := "spark-style", - organization := "org.apache.spark", - scalaVersion := "2.10.4", - scalacOptions := Seq("-unchecked", "-deprecation"), - libraryDependencies ++= Dependencies.scalaStyle - ) - - object Dependencies { - val scalaStyle = Seq("org.scalastyle" %% "scalastyle" % "0.4.0") - } } diff --git a/python/docs/conf.py b/python/docs/conf.py index e58d97ae6a74..163987dd8e5f 100644 --- a/python/docs/conf.py +++ b/python/docs/conf.py @@ -48,16 +48,16 @@ # General information about the project. project = u'PySpark' -copyright = u'2014, Author' +copyright = u'' # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = '1.2-SNAPSHOT' +version = 'master' # The full version, including alpha/beta/rc tags. -release = '1.2-SNAPSHOT' +release = os.environ.get('RELEASE_VERSION', version) # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -97,6 +97,10 @@ # If true, keep warnings as "system message" paragraphs in the built documents. #keep_warnings = False +# -- Options for autodoc -------------------------------------------------- + +# Look at the first line of the docstring for function and method signatures. +autodoc_docstring_signature = True # -- Options for HTML output ---------------------------------------------- diff --git a/python/docs/index.rst b/python/docs/index.rst index 703bef644de2..f7eede9c3c82 100644 --- a/python/docs/index.rst +++ b/python/docs/index.rst @@ -14,6 +14,7 @@ Contents: pyspark pyspark.sql pyspark.streaming + pyspark.ml pyspark.mllib @@ -28,6 +29,14 @@ Core classes: A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. + :class:`pyspark.sql.SQLContext` + + Main entry point for DataFrame and SQL functionality. + + :class:`pyspark.sql.DataFrame` + + A distributed collection of data grouped into named columns. + Indices and tables ================== diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst new file mode 100644 index 000000000000..4da6d4a74a29 --- /dev/null +++ b/python/docs/pyspark.ml.rst @@ -0,0 +1,26 @@ +pyspark.ml package +===================== + +Module Context +-------------- + +.. automodule:: pyspark.ml + :members: + :undoc-members: + :inherited-members: + +pyspark.ml.feature module +------------------------- + +.. automodule:: pyspark.ml.feature + :members: + :undoc-members: + :inherited-members: + +pyspark.ml.classification module +-------------------------------- + +.. automodule:: pyspark.ml.classification + :members: + :undoc-members: + :inherited-members: diff --git a/python/docs/pyspark.mllib.rst b/python/docs/pyspark.mllib.rst index 4548b8739ed9..26ece4c2c389 100644 --- a/python/docs/pyspark.mllib.rst +++ b/python/docs/pyspark.mllib.rst @@ -1,16 +1,13 @@ pyspark.mllib package ===================== -Submodules ----------- - pyspark.mllib.classification module ----------------------------------- .. automodule:: pyspark.mllib.classification :members: :undoc-members: - :show-inheritance: + :inherited-members: pyspark.mllib.clustering module ------------------------------- @@ -18,7 +15,13 @@ pyspark.mllib.clustering module .. automodule:: pyspark.mllib.clustering :members: :undoc-members: - :show-inheritance: + +pyspark.mllib.evaluation module +------------------------------- + +.. automodule:: pyspark.mllib.evaluation + :members: + :undoc-members: pyspark.mllib.feature module ------------------------------- @@ -28,6 +31,13 @@ pyspark.mllib.feature module :undoc-members: :show-inheritance: +pyspark.mllib.fpm module +------------------------ + +.. automodule:: pyspark.mllib.fpm + :members: + :undoc-members: + pyspark.mllib.linalg module --------------------------- @@ -42,7 +52,6 @@ pyspark.mllib.random module .. automodule:: pyspark.mllib.random :members: :undoc-members: - :show-inheritance: pyspark.mllib.recommendation module ----------------------------------- @@ -50,7 +59,6 @@ pyspark.mllib.recommendation module .. automodule:: pyspark.mllib.recommendation :members: :undoc-members: - :show-inheritance: pyspark.mllib.regression module ------------------------------- @@ -58,7 +66,7 @@ pyspark.mllib.regression module .. automodule:: pyspark.mllib.regression :members: :undoc-members: - :show-inheritance: + :inherited-members: pyspark.mllib.stat module ------------------------- @@ -66,7 +74,6 @@ pyspark.mllib.stat module .. automodule:: pyspark.mllib.stat :members: :undoc-members: - :show-inheritance: pyspark.mllib.tree module ------------------------- @@ -74,7 +81,7 @@ pyspark.mllib.tree module .. automodule:: pyspark.mllib.tree :members: :undoc-members: - :show-inheritance: + :inherited-members: pyspark.mllib.util module ------------------------- @@ -82,4 +89,3 @@ pyspark.mllib.util module .. automodule:: pyspark.mllib.util :members: :undoc-members: - :show-inheritance: diff --git a/python/docs/pyspark.rst b/python/docs/pyspark.rst index e81be3b6cb79..0df12c49ad03 100644 --- a/python/docs/pyspark.rst +++ b/python/docs/pyspark.rst @@ -9,6 +9,7 @@ Subpackages pyspark.sql pyspark.streaming + pyspark.ml pyspark.mllib Contents diff --git a/python/docs/pyspark.sql.rst b/python/docs/pyspark.sql.rst index 65b3650ae10a..6259379ed05b 100644 --- a/python/docs/pyspark.sql.rst +++ b/python/docs/pyspark.sql.rst @@ -1,10 +1,23 @@ pyspark.sql module ================== -Module contents ---------------- +Module Context +-------------- .. automodule:: pyspark.sql :members: :undoc-members: - :show-inheritance: + + +pyspark.sql.types module +------------------------ +.. automodule:: pyspark.sql.types + :members: + :undoc-members: + + +pyspark.sql.functions module +---------------------------- +.. automodule:: pyspark.sql.functions + :members: + :undoc-members: diff --git a/python/docs/pyspark.streaming.rst b/python/docs/pyspark.streaming.rst index f08185627d0b..50822c93faba 100644 --- a/python/docs/pyspark.streaming.rst +++ b/python/docs/pyspark.streaming.rst @@ -8,3 +8,10 @@ Module contents :members: :undoc-members: :show-inheritance: + +pyspark.streaming.kafka module +------------------------------ +.. automodule:: pyspark.streaming.kafka + :members: + :undoc-members: + :show-inheritance: diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 9556e4718e58..5f70ac6ed8fe 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -22,17 +22,17 @@ - :class:`SparkContext`: Main entry point for Spark functionality. - - L{RDD} + - :class:`RDD`: A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. - - L{Broadcast} + - :class:`Broadcast`: A broadcast variable that gets reused across tasks. - - L{Accumulator} + - :class:`Accumulator`: An "add-only" shared variable that tasks can only add values to. - - L{SparkConf} + - :class:`SparkConf`: For configuring Spark. - - L{SparkFiles} + - :class:`SparkFiles`: Access files shipped with jobs. - - L{StorageLevel} + - :class:`StorageLevel`: Finer-grained cache persistence levels. """ @@ -45,6 +45,8 @@ from pyspark.accumulators import Accumulator, AccumulatorParam from pyspark.broadcast import Broadcast from pyspark.serializers import MarshalSerializer, PickleSerializer +from pyspark.status import * +from pyspark.profiler import Profiler, BasicProfiler # for back compatibility from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row @@ -52,4 +54,5 @@ __all__ = [ "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast", "Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer", + "StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler", ] diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index b8cdbbe3cf2b..0d21a132048a 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -54,7 +54,7 @@ ... def zero(self, value): ... return [0.0] * len(value) ... def addInPlace(self, val1, val2): -... for i in xrange(len(val1)): +... for i in range(len(val1)): ... val1[i] += val2[i] ... return val1 >>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam()) @@ -83,12 +83,16 @@ >>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... -Exception:... +TypeError:... """ +import sys import select import struct -import SocketServer +if sys.version < '3': + import SocketServer +else: + import socketserver as SocketServer import threading from pyspark.cloudpickle import CloudPickler from pyspark.serializers import read_int, PickleSerializer @@ -215,21 +219,6 @@ def addInPlace(self, value1, value2): COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) -class PStatsParam(AccumulatorParam): - """PStatsParam is used to merge pstats.Stats""" - - @staticmethod - def zero(value): - return None - - @staticmethod - def addInPlace(value1, value2): - if value1 is None: - return value2 - value1.add(value2) - return value1 - - class _UpdateRequestHandler(SocketServer.StreamRequestHandler): """ @@ -262,6 +251,7 @@ class AccumulatorServer(SocketServer.TCPServer): def shutdown(self): self.server_shutdown = True SocketServer.TCPServer.shutdown(self) + self.server_close() def _start_update_server(): diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 6b8a8b256a89..3de4615428bb 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -16,10 +16,15 @@ # import os -import cPickle +import sys import gc from tempfile import NamedTemporaryFile +if sys.version < '3': + import cPickle as pickle +else: + import pickle + unicode = str __all__ = ['Broadcast'] @@ -70,33 +75,19 @@ def __init__(self, sc=None, value=None, pickle_registry=None, path=None): self._path = path def dump(self, value, f): - if isinstance(value, basestring): - if isinstance(value, unicode): - f.write('U') - value = value.encode('utf8') - else: - f.write('S') - f.write(value) - else: - f.write('P') - cPickle.dump(value, f, 2) + pickle.dump(value, f, 2) f.close() return f.name def load(self, path): with open(path, 'rb', 1 << 20) as f: - flag = f.read(1) - data = f.read() - if flag == 'P': - # cPickle.loads() may create lots of objects, disable GC - # temporary for better performance - gc.disable() - try: - return cPickle.loads(data) - finally: - gc.enable() - else: - return data.decode('utf8') if flag == 'U' else data + # pickle.load() may create lots of objects, disable GC + # temporary for better performance + gc.disable() + try: + return pickle.load(f) + finally: + gc.enable() @property def value(self): diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index bb0783555aa7..9ef93071d2e7 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -40,164 +40,126 @@ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ - +from __future__ import print_function import operator import os +import io import pickle import struct import sys import types from functools import partial import itertools -from copy_reg import _extension_registry, _inverted_registry, _extension_cache -import new import dis import traceback -import platform - -PyImp = platform.python_implementation() - -import logging -cloudLog = logging.getLogger("Cloud.Transport") +if sys.version < '3': + from pickle import Pickler + try: + from cStringIO import StringIO + except ImportError: + from StringIO import StringIO + PY3 = False +else: + types.ClassType = type + from pickle import _Pickler as Pickler + from io import BytesIO as StringIO + PY3 = True #relevant opcodes -STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL')) -DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL')) -LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL')) +STORE_GLOBAL = dis.opname.index('STORE_GLOBAL') +DELETE_GLOBAL = dis.opname.index('DELETE_GLOBAL') +LOAD_GLOBAL = dis.opname.index('LOAD_GLOBAL') GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL] +HAVE_ARGUMENT = dis.HAVE_ARGUMENT +EXTENDED_ARG = dis.EXTENDED_ARG -HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT) -EXTENDED_ARG = chr(dis.EXTENDED_ARG) - -if PyImp == "PyPy": - # register builtin type in `new` - new.method = types.MethodType - -try: - from cStringIO import StringIO -except ImportError: - from StringIO import StringIO -# These helper functions were copied from PiCloud's util module. def islambda(func): - return getattr(func,'func_name') == '' + return getattr(func,'__name__') == '' -def xrange_params(xrangeobj): - """Returns a 3 element tuple describing the xrange start, step, and len - respectively - Note: Only guarentees that elements of xrange are the same. parameters may - be different. - e.g. xrange(1,1) is interpretted as xrange(0,0); both behave the same - though w/ iteration - """ - - xrange_len = len(xrangeobj) - if not xrange_len: #empty - return (0,1,0) - start = xrangeobj[0] - if xrange_len == 1: #one element - return start, 1, 1 - return (start, xrangeobj[1] - xrangeobj[0], xrange_len) - -#debug variables intended for developer use: -printSerialization = False -printMemoization = False +_BUILTIN_TYPE_NAMES = {} +for k, v in types.__dict__.items(): + if type(v) is type: + _BUILTIN_TYPE_NAMES[v] = k -useForcedImports = True #Should I use forced imports for tracking? +def _builtin_type(name): + return getattr(types, name) -class CloudPickler(pickle.Pickler): +class CloudPickler(Pickler): - dispatch = pickle.Pickler.dispatch.copy() - savedForceImports = False - savedDjangoEnv = False #hack tro transport django environment + dispatch = Pickler.dispatch.copy() - def __init__(self, file, protocol=None, min_size_to_save= 0): - pickle.Pickler.__init__(self,file,protocol) - self.modules = set() #set of modules needed to depickle - self.globals_ref = {} # map ids to dictionary. used to ensure that functions can share global env + def __init__(self, file, protocol=None): + Pickler.__init__(self, file, protocol) + # set of modules to unpickle + self.modules = set() + # map ids to dictionary. used to ensure that functions can share global env + self.globals_ref = {} def dump(self, obj): - # note: not thread safe - # minimal side-effects, so not fixing - recurse_limit = 3000 - base_recurse = sys.getrecursionlimit() - if base_recurse < recurse_limit: - sys.setrecursionlimit(recurse_limit) self.inject_addons() try: - return pickle.Pickler.dump(self, obj) - except RuntimeError, e: + return Pickler.dump(self, obj) + except RuntimeError as e: if 'recursion' in e.args[0]: - msg = """Could not pickle object as excessively deep recursion required. - Try _fast_serialization=2 or contact PiCloud support""" + msg = """Could not pickle object as excessively deep recursion required.""" raise pickle.PicklingError(msg) - finally: - new_recurse = sys.getrecursionlimit() - if new_recurse == recurse_limit: - sys.setrecursionlimit(base_recurse) + + def save_memoryview(self, obj): + """Fallback to save_string""" + Pickler.save_string(self, str(obj)) def save_buffer(self, obj): """Fallback to save_string""" - pickle.Pickler.save_string(self,str(obj)) - dispatch[buffer] = save_buffer + Pickler.save_string(self,str(obj)) + if PY3: + dispatch[memoryview] = save_memoryview + else: + dispatch[buffer] = save_buffer - #block broken objects - def save_unsupported(self, obj, pack=None): + def save_unsupported(self, obj): raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj)) dispatch[types.GeneratorType] = save_unsupported - #python2.6+ supports slice pickling. some py2.5 extensions might as well. We just test it - try: - slice(0,1).__reduce__() - except TypeError: #can't pickle - - dispatch[slice] = save_unsupported - - #itertools objects do not pickle! + # itertools objects do not pickle! for v in itertools.__dict__.values(): if type(v) is type: dispatch[v] = save_unsupported - - def save_dict(self, obj): - """hack fix - If the dict is a global, deal with it in a special way - """ - #print 'saving', obj - if obj is __builtins__: - self.save_reduce(_get_module_builtins, (), obj=obj) - else: - pickle.Pickler.save_dict(self, obj) - dispatch[pickle.DictionaryType] = save_dict - - - def save_module(self, obj, pack=struct.pack): + def save_module(self, obj): """ Save a module as an import """ - #print 'try save import', obj.__name__ self.modules.add(obj) - self.save_reduce(subimport,(obj.__name__,), obj=obj) - dispatch[types.ModuleType] = save_module #new type + self.save_reduce(subimport, (obj.__name__,), obj=obj) + dispatch[types.ModuleType] = save_module - def save_codeobject(self, obj, pack=struct.pack): + def save_codeobject(self, obj): """ Save a code object """ - #print 'try to save codeobj: ', obj - args = ( - obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code, - obj.co_consts, obj.co_names, obj.co_varnames, obj.co_filename, obj.co_name, - obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, obj.co_cellvars - ) + if PY3: + args = ( + obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize, + obj.co_flags, obj.co_code, obj.co_consts, obj.co_names, obj.co_varnames, + obj.co_filename, obj.co_name, obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, + obj.co_cellvars + ) + else: + args = ( + obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code, + obj.co_consts, obj.co_names, obj.co_varnames, obj.co_filename, obj.co_name, + obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, obj.co_cellvars + ) self.save_reduce(types.CodeType, args, obj=obj) - dispatch[types.CodeType] = save_codeobject #new type + dispatch[types.CodeType] = save_codeobject - def save_function(self, obj, name=None, pack=struct.pack): + def save_function(self, obj, name=None): """ Registered with the dispatch to handle all function types. Determines what kind of function obj is (e.g. lambda, defined at @@ -205,12 +167,14 @@ def save_function(self, obj, name=None, pack=struct.pack): """ write = self.write - name = obj.__name__ + if name is None: + name = obj.__name__ modname = pickle.whichmodule(obj, name) - #print 'which gives %s %s %s' % (modname, obj, name) + # print('which gives %s %s %s' % (modname, obj, name)) try: themodule = sys.modules[modname] - except KeyError: # eval'd items such as namedtuple give invalid items for their function __module__ + except KeyError: + # eval'd items such as namedtuple give invalid items for their function __module__ modname = '__main__' if modname == '__main__': @@ -221,37 +185,18 @@ def save_function(self, obj, name=None, pack=struct.pack): if getattr(themodule, name, None) is obj: return self.save_global(obj, name) - if not self.savedDjangoEnv: - #hack for django - if we detect the settings module, we transport it - django_settings = os.environ.get('DJANGO_SETTINGS_MODULE', '') - if django_settings: - django_mod = sys.modules.get(django_settings) - if django_mod: - cloudLog.debug('Transporting django settings %s during save of %s', django_mod, name) - self.savedDjangoEnv = True - self.modules.add(django_mod) - write(pickle.MARK) - self.save_reduce(django_settings_load, (django_mod.__name__,), obj=django_mod) - write(pickle.POP_MARK) - - # if func is lambda, def'ed at prompt, is in main, or is nested, then # we'll pickle the actual function object rather than simply saving a # reference (as is done in default pickler), via save_function_tuple. - if islambda(obj) or obj.func_code.co_filename == '' or themodule is None: - #Force server to import modules that have been imported in main - modList = None - if themodule is None and not self.savedForceImports: - mainmod = sys.modules['__main__'] - if useForcedImports and hasattr(mainmod,'___pyc_forcedImports__'): - modList = list(mainmod.___pyc_forcedImports__) - self.savedForceImports = True - self.save_function_tuple(obj, modList) + if islambda(obj) or obj.__code__.co_filename == '' or themodule is None: + #print("save global", islambda(obj), obj.__code__.co_filename, modname, themodule) + self.save_function_tuple(obj) return - else: # func is nested + else: + # func is nested klass = getattr(themodule, name, None) if klass is None or klass is not obj: - self.save_function_tuple(obj, [themodule]) + self.save_function_tuple(obj) return if obj.__dict__: @@ -266,7 +211,7 @@ def save_function(self, obj, name=None, pack=struct.pack): self.memoize(obj) dispatch[types.FunctionType] = save_function - def save_function_tuple(self, func, forced_imports): + def save_function_tuple(self, func): """ Pickles an actual func object. A func comprises: code, globals, defaults, closure, and dict. We @@ -281,19 +226,6 @@ def save_function_tuple(self, func, forced_imports): save = self.save write = self.write - # save the modules (if any) - if forced_imports: - write(pickle.MARK) - save(_modules_to_main) - #print 'forced imports are', forced_imports - - forced_names = map(lambda m: m.__name__, forced_imports) - save((forced_names,)) - - #save((forced_imports,)) - write(pickle.REDUCE) - write(pickle.POP_MARK) - code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func) save(_fill_function) # skeleton function updater @@ -318,6 +250,8 @@ def extract_code_globals(co): Find all globals names read or written to by codeblock co """ code = co.co_code + if not PY3: + code = [ord(c) for c in code] names = co.co_names out_names = set() @@ -327,18 +261,18 @@ def extract_code_globals(co): while i < n: op = code[i] - i = i+1 + i += 1 if op >= HAVE_ARGUMENT: - oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg + oparg = code[i] + code[i+1] * 256 + extended_arg extended_arg = 0 - i = i+2 + i += 2 if op == EXTENDED_ARG: - extended_arg = oparg*65536L + extended_arg = oparg*65536 if op in GLOBAL_OPS: out_names.add(names[oparg]) - #print 'extracted', out_names, ' from ', names - if co.co_consts: # see if nested function have any global refs + # see if nested function have any global refs + if co.co_consts: for const in co.co_consts: if type(const) is types.CodeType: out_names |= CloudPickler.extract_code_globals(const) @@ -350,46 +284,28 @@ def extract_func_data(self, func): Turn the function into a tuple of data necessary to recreate it: code, globals, defaults, closure, dict """ - code = func.func_code + code = func.__code__ # extract all global ref's - func_global_refs = CloudPickler.extract_code_globals(code) + func_global_refs = self.extract_code_globals(code) # process all variables referenced by global environment f_globals = {} for var in func_global_refs: - #Some names, such as class functions are not global - we don't need them - if func.func_globals.has_key(var): - f_globals[var] = func.func_globals[var] + if var in func.__globals__: + f_globals[var] = func.__globals__[var] # defaults requires no processing - defaults = func.func_defaults - - def get_contents(cell): - try: - return cell.cell_contents - except ValueError, e: #cell is empty error on not yet assigned - raise pickle.PicklingError('Function to be pickled has free variables that are referenced before assignment in enclosing scope') - + defaults = func.__defaults__ # process closure - if func.func_closure: - closure = map(get_contents, func.func_closure) - else: - closure = [] + closure = [c.cell_contents for c in func.__closure__] if func.__closure__ else [] # save the dict - dct = func.func_dict - - if printSerialization: - outvars = ['code: ' + str(code) ] - outvars.append('globals: ' + str(f_globals)) - outvars.append('defaults: ' + str(defaults)) - outvars.append('closure: ' + str(closure)) - print 'function ', func, 'is extracted to: ', ', '.join(outvars) + dct = func.__dict__ - base_globals = self.globals_ref.get(id(func.func_globals), {}) - self.globals_ref[id(func.func_globals)] = base_globals + base_globals = self.globals_ref.get(id(func.__globals__), {}) + self.globals_ref[id(func.__globals__)] = base_globals return (code, f_globals, defaults, closure, dct, base_globals) @@ -400,8 +316,9 @@ def save_builtin_function(self, obj): dispatch[types.BuiltinFunctionType] = save_builtin_function def save_global(self, obj, name=None, pack=struct.pack): - write = self.write - memo = self.memo + if obj.__module__ == "__builtin__" or obj.__module__ == "builtins": + if obj in _BUILTIN_TYPE_NAMES: + return self.save_reduce(_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj) if name is None: name = obj.__name__ @@ -410,98 +327,57 @@ def save_global(self, obj, name=None, pack=struct.pack): if modname is None: modname = pickle.whichmodule(obj, name) - try: - __import__(modname) - themodule = sys.modules[modname] - except (ImportError, KeyError, AttributeError): #should never occur - raise pickle.PicklingError( - "Can't pickle %r: Module %s cannot be found" % - (obj, modname)) - if modname == '__main__': themodule = None - - if themodule: + else: + __import__(modname) + themodule = sys.modules[modname] self.modules.add(themodule) - sendRef = True - typ = type(obj) - #print 'saving', obj, typ - try: - try: #Deal with case when getattribute fails with exceptions - klass = getattr(themodule, name) - except (AttributeError): - if modname == '__builtin__': #new.* are misrepeported - modname = 'new' - __import__(modname) - themodule = sys.modules[modname] - try: - klass = getattr(themodule, name) - except AttributeError, a: - # print themodule, name, obj, type(obj) - raise pickle.PicklingError("Can't pickle builtin %s" % obj) - else: - raise + if hasattr(themodule, name) and getattr(themodule, name) is obj: + return Pickler.save_global(self, obj, name) - except (ImportError, KeyError, AttributeError): - if typ == types.TypeType or typ == types.ClassType: - sendRef = False - else: #we can't deal with this - raise - else: - if klass is not obj and (typ == types.TypeType or typ == types.ClassType): - sendRef = False - if not sendRef: - #note: Third party types might crash this - add better checks! - d = dict(obj.__dict__) #copy dict proxy to a dict - if not isinstance(d.get('__dict__', None), property): # don't extract dict that are properties - d.pop('__dict__',None) - d.pop('__weakref__',None) + typ = type(obj) + if typ is not obj and isinstance(obj, (type, types.ClassType)): + d = dict(obj.__dict__) # copy dict proxy to a dict + if not isinstance(d.get('__dict__', None), property): + # don't extract dict that are properties + d.pop('__dict__', None) + d.pop('__weakref__', None) # hack as __new__ is stored differently in the __dict__ new_override = d.get('__new__', None) if new_override: d['__new__'] = obj.__new__ - self.save_reduce(type(obj),(obj.__name__,obj.__bases__, - d),obj=obj) - #print 'internal reduce dask %s %s' % (obj, d) - return - - if self.proto >= 2: - code = _extension_registry.get((modname, name)) - if code: - assert code > 0 - if code <= 0xff: - write(pickle.EXT1 + chr(code)) - elif code <= 0xffff: - write("%c%c%c" % (pickle.EXT2, code&0xff, code>>8)) - else: - write(pickle.EXT4 + pack("> sys.stderr, 'Cloud not import django settings %s:' % (name) - print_exec(sys.stderr) - if modified_env: - del os.environ['DJANGO_SETTINGS_MODULE'] - else: - #add project directory to sys,path: - if hasattr(module,'__file__'): - dirname = os.path.split(module.__file__)[0] + '/' - sys.path.append(dirname) # restores function attributes def _restore_attr(obj, attr): @@ -851,13 +636,16 @@ def _restore_attr(obj, attr): setattr(obj, key, val) return obj + def _get_module_builtins(): return pickle.__builtins__ + def print_exec(stream): ei = sys.exc_info() traceback.print_exception(ei[0], ei[1], ei[2], None, stream) + def _modules_to_main(modList): """Force every module in modList to be placed into main""" if not modList: @@ -868,22 +656,16 @@ def _modules_to_main(modList): if type(modname) is str: try: mod = __import__(modname) - except Exception, i: #catch all... - sys.stderr.write('warning: could not import %s\n. Your function may unexpectedly error due to this import failing; \ -A version mismatch is likely. Specific error was:\n' % modname) + except Exception as e: + sys.stderr.write('warning: could not import %s\n. ' + 'Your function may unexpectedly error due to this import failing;' + 'A version mismatch is likely. Specific error was:\n' % modname) print_exec(sys.stderr) else: - setattr(main,mod.__name__, mod) - else: - #REVERSE COMPATIBILITY FOR CLOUD CLIENT 1.5 (WITH EPD) - #In old version actual module was sent - setattr(main,modname.__name__, modname) + setattr(main, mod.__name__, mod) -#object generators: -def _build_xrange(start, step, len): - """Built xrange explicitly""" - return xrange(start, start + step*len, step) +#object generators: def _genpartial(func, args, kwds): if not args: args = () @@ -891,22 +673,26 @@ def _genpartial(func, args, kwds): kwds = {} return partial(func, *args, **kwds) + def _fill_function(func, globals, defaults, dict): """ Fills in the rest of function data into the skeleton function object that were created via _make_skel_func(). """ - func.func_globals.update(globals) - func.func_defaults = defaults - func.func_dict = dict + func.__globals__.update(globals) + func.__defaults__ = defaults + func.__dict__ = dict return func + def _make_cell(value): - return (lambda: value).func_closure[0] + return (lambda: value).__closure__[0] + def _reconstruct_closure(values): return tuple([_make_cell(v) for v in values]) + def _make_skel_func(code, closures, base_globals = None): """ Creates a skeleton function object that contains just the provided code and the correct number of cells in func_closure. All other @@ -928,40 +714,3 @@ def _make_skel_func(code, closures, base_globals = None): def _getobject(modname, attribute): mod = __import__(modname, fromlist=[attribute]) return mod.__dict__[attribute] - -def _generateImage(size, mode, str_rep): - """Generate image from string representation""" - import Image - i = Image.new(mode, size) - i.fromstring(str_rep) - return i - -def _lazyloadImage(fp): - import Image - fp.seek(0) #works in almost any case - return Image.open(fp) - -"""Timeseries""" -def _genTimeSeries(reduce_args, state): - import scikits.timeseries.tseries as ts - from numpy import ndarray - from numpy.ma import MaskedArray - - - time_series = ts._tsreconstruct(*reduce_args) - - #from setstate modified - (ver, shp, typ, isf, raw, msk, flv, dsh, dtm, dtyp, frq, infodict) = state - #print 'regenerating %s' % dtyp - - MaskedArray.__setstate__(time_series, (ver, shp, typ, isf, raw, msk, flv)) - _dates = time_series._dates - #_dates.__setstate__((ver, dsh, typ, isf, dtm, frq)) #use remote typ - ndarray.__setstate__(_dates,(dsh,dtyp, isf, dtm)) - _dates.freq = frq - _dates._cachedinfo.update(dict(full=None, hasdups=None, steps=None, - toobj=None, toord=None, tostr=None)) - # Update the _optinfo dictionary - time_series._optinfo.update(infodict) - return time_series - diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index dc7cd0bce56f..924da3eecf21 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -44,7 +44,7 @@ >>> conf.get("spark.executorEnv.VAR1") u'value1' ->>> print conf.toDebugString() +>>> print(conf.toDebugString()) spark.executorEnv.VAR1=value1 spark.executorEnv.VAR3=value3 spark.executorEnv.VAR4=value4 @@ -56,6 +56,13 @@ __all__ = ['SparkConf'] +import sys +import re + +if sys.version > '3': + unicode = str + __doc__ = re.sub(r"(\W|^)[uU](['])", r'\1\2', __doc__) + class SparkConf(object): diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 64f6a3ca6bf4..b006120eb266 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -15,12 +15,13 @@ # limitations under the License. # +from __future__ import print_function + import os import shutil import sys from threading import Lock from tempfile import NamedTemporaryFile -import atexit from pyspark import accumulators from pyspark.accumulators import Accumulator @@ -31,10 +32,13 @@ from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ PairDeserializer, AutoBatchedSerializer, NoOpSerializer from pyspark.storagelevel import StorageLevel -from pyspark.rdd import RDD +from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.traceback_utils import CallSite, first_spark_call +from pyspark.status import StatusTracker +from pyspark.profiler import ProfilerCollector, BasicProfiler -from py4j.java_collections import ListConverter +if sys.version > '3': + xrange = range __all__ = ['SparkContext'] @@ -58,15 +62,16 @@ class SparkContext(object): _gateway = None _jvm = None - _writeToFile = None _next_accum_id = 0 _active_spark_context = None _lock = Lock() _python_includes = None # zip and egg files that need to be added to PYTHONPATH + PACKAGE_EXTENSIONS = ('.zip', '.egg', '.jar') + def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, environment=None, batchSize=0, serializer=PickleSerializer(), conf=None, - gateway=None, jsc=None): + gateway=None, jsc=None, profiler_cls=BasicProfiler): """ Create a new SparkContext. At least the master and app name should be set, either through the named parameters here or through C{conf}. @@ -88,6 +93,9 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, :param conf: A L{SparkConf} object setting Spark properties. :param gateway: Use an existing gateway and JVM, otherwise a new JVM will be instantiated. + :param jsc: The JavaSparkContext instance (optional). + :param profiler_cls: A class of custom Profiler used to do profiling + (default is pyspark.profiler.BasicProfiler). >>> from pyspark.context import SparkContext @@ -102,14 +110,14 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, SparkContext._ensure_initialized(self, gateway=gateway) try: self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf, jsc) + conf, jsc, profiler_cls) except: # If an error occurs, clean up in order to allow future SparkContext creation: self.stop() raise def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf, jsc): + conf, jsc, profiler_cls): self.environment = environment or {} self._conf = conf or SparkConf(_jvm=self._jvm) self._batchSize = batchSize # -1 represents an unlimited batch size @@ -128,7 +136,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, if sparkHome: self._conf.setSparkHome(sparkHome) if environment: - for key, value in environment.iteritems(): + for key, value in environment.items(): self._conf.setExecutorEnv(key, value) for key, value in DEFAULT_CONFIGS.items(): self._conf.setIfMissing(key, value) @@ -148,6 +156,10 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, if k.startswith("spark.executorEnv."): varName = k[len("spark.executorEnv."):] self.environment[varName] = v + if sys.version >= '3.3' and 'PYTHONHASHSEED' not in os.environ: + # disable randomness of hash of string in worker, if this is not + # launched by spark-submit + self.environment["PYTHONHASHSEED"] = "0" # Create the Java SparkContext through Py4J self._jsc = jsc or self._initialize_context(self._conf._jconf) @@ -182,17 +194,22 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, for path in self._conf.get("spark.submit.pyFiles", "").split(","): if path != "": (dirname, filename) = os.path.split(path) - if filename.lower().endswith("zip") or filename.lower().endswith("egg"): + if filename[-4:].lower() in self.PACKAGE_EXTENSIONS: self._python_includes.append(filename) sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) # Create a temporary directory inside spark.local.dir: local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf()) self._temp_dir = \ - self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath() + self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir, "pyspark") \ + .getAbsolutePath() # profiling stats collected for each PythonRDD - self._profile_stats = [] + if self._conf.get("spark.python.profile", "false") == "true": + dump_path = self._conf.get("spark.python.profile.dump", None) + self.profiler_collector = ProfilerCollector(profiler_cls, dump_path) + else: + self.profiler_collector = None def _initialize_context(self, jconf): """ @@ -210,7 +227,6 @@ def _ensure_initialized(cls, instance=None, gateway=None): if not SparkContext._gateway: SparkContext._gateway = gateway or launch_gateway() SparkContext._jvm = SparkContext._gateway.jvm - SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile if instance: if (SparkContext._active_spark_context and @@ -229,6 +245,14 @@ def _ensure_initialized(cls, instance=None, gateway=None): else: SparkContext._active_spark_context = instance + def __getnewargs__(self): + # This method is called when attempting to pickle SparkContext, which is always an error: + raise Exception( + "It appears that you are attempting to reference SparkContext from a broadcast " + "variable, action, or transforamtion. SparkContext can only be used on the driver, " + "not in code that it run on workers. For more information, see SPARK-5063." + ) + def __enter__(self): """ Enable 'with SparkContext(...) as sc: app(sc)' syntax. @@ -306,7 +330,7 @@ def parallelize(self, c, numSlices=None): start0 = c[0] def getStart(split): - return start0 + (split * size / numSlices) * step + return start0 + int((split * size / numSlices)) * step def f(split, iterator): return xrange(getStart(split), getStart(split + 1), step) @@ -340,6 +364,7 @@ def pickleFile(self, name, minPartitions=None): minPartitions = minPartitions or self.defaultMinPartitions return RDD(self._jsc.objectFile(name, minPartitions), self) + @ignore_unicode_prefix def textFile(self, name, minPartitions=None, use_unicode=True): """ Read a text file from HDFS, a local file system (available on all @@ -352,7 +377,7 @@ def textFile(self, name, minPartitions=None, use_unicode=True): >>> path = os.path.join(tempdir, "sample-text.txt") >>> with open(path, "w") as testFile: - ... testFile.write("Hello world!") + ... _ = testFile.write("Hello world!") >>> textFile = sc.textFile(path) >>> textFile.collect() [u'Hello world!'] @@ -361,6 +386,7 @@ def textFile(self, name, minPartitions=None, use_unicode=True): return RDD(self._jsc.textFile(name, minPartitions), self, UTF8Deserializer(use_unicode)) + @ignore_unicode_prefix def wholeTextFiles(self, path, minPartitions=None, use_unicode=True): """ Read a directory of text files from HDFS, a local file system @@ -394,9 +420,9 @@ def wholeTextFiles(self, path, minPartitions=None, use_unicode=True): >>> dirPath = os.path.join(tempdir, "files") >>> os.mkdir(dirPath) >>> with open(os.path.join(dirPath, "1.txt"), "w") as file1: - ... file1.write("1") + ... _ = file1.write("1") >>> with open(os.path.join(dirPath, "2.txt"), "w") as file2: - ... file2.write("2") + ... _ = file2.write("2") >>> textFiles = sc.wholeTextFiles(dirPath) >>> sorted(textFiles.collect()) [(u'.../1.txt', u'1'), (u'.../2.txt', u'2')] @@ -439,7 +465,7 @@ def _dictToJavaMap(self, d): jm = self._jvm.java.util.HashMap() if not d: d = {} - for k, v in d.iteritems(): + for k, v in d.items(): jm[k] = v return jm @@ -591,6 +617,7 @@ def _checkpointFile(self, name, input_deserializer): jrdd = self._jsc.checkpointFile(name) return RDD(jrdd, self, input_deserializer) + @ignore_unicode_prefix def union(self, rdds): """ Build the union of a list of RDDs. @@ -601,7 +628,7 @@ def union(self, rdds): >>> path = os.path.join(tempdir, "union-text.txt") >>> with open(path, "w") as testFile: - ... testFile.write("Hello") + ... _ = testFile.write("Hello") >>> textFile = sc.textFile(path) >>> textFile.collect() [u'Hello'] @@ -614,7 +641,6 @@ def union(self, rdds): rdds = [x._reserialize() for x in rdds] first = rdds[0]._jrdd rest = [x._jrdd for x in rdds[1:]] - rest = ListConverter().convert(rest, self._gateway._gateway_client) return RDD(self._jsc.union(first, rest), self, rdds[0]._jrdd_deserializer) def broadcast(self, value): @@ -642,7 +668,7 @@ def accumulator(self, value, accum_param=None): elif isinstance(value, complex): accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM else: - raise Exception("No default accumulator param for type %s" % type(value)) + raise TypeError("No default accumulator param for type %s" % type(value)) SparkContext._next_accum_id += 1 return Accumulator(SparkContext._next_accum_id - 1, value, accum_param) @@ -660,7 +686,7 @@ def addFile(self, path): >>> from pyspark import SparkFiles >>> path = os.path.join(tempdir, "test.txt") >>> with open(path, "w") as testFile: - ... testFile.write("100") + ... _ = testFile.write("100") >>> sc.addFile(path) >>> def func(iterator): ... with open(SparkFiles.get("test.txt")) as testFile: @@ -688,11 +714,13 @@ def addPyFile(self, path): """ self.addFile(path) (dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix - - if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'): + if filename[-4:].lower() in self.PACKAGE_EXTENSIONS: self._python_includes.append(filename) # for tests in local mode sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) + if sys.version > '3': + import importlib + importlib.invalidate_caches() def setCheckpointDir(self, dirName): """ @@ -727,7 +755,7 @@ def setJobGroup(self, groupId, description, interruptOnCancel=False): The application can use L{SparkContext.cancelJobGroup} to cancel all running jobs in this group. - >>> import thread, threading + >>> import threading >>> from time import sleep >>> result = "Not Set" >>> lock = threading.Lock() @@ -746,10 +774,10 @@ def setJobGroup(self, groupId, description, interruptOnCancel=False): ... sleep(5) ... sc.cancelJobGroup("job_to_cancel") >>> supress = lock.acquire() - >>> supress = thread.start_new_thread(start_job, (10,)) - >>> supress = thread.start_new_thread(stop_job, tuple()) + >>> supress = threading.Thread(target=start_job, args=(10,)).start() + >>> supress = threading.Thread(target=stop_job).start() >>> supress = lock.acquire() - >>> print result + >>> print(result) Cancelled If interruptOnCancel is set to true for the job group, then job cancellation will result @@ -792,6 +820,12 @@ def cancelAllJobs(self): """ self._jsc.sc().cancelAllJobs() + def statusTracker(self): + """ + Return :class:`StatusTracker` object + """ + return StatusTracker(self._jsc.statusTracker()) + def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): """ Executes the given partitionFunc on the specified set of partitions, @@ -809,48 +843,23 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): """ if partitions is None: partitions = range(rdd._jrdd.partitions().size()) - javaPartitions = ListConverter().convert(partitions, self._gateway._gateway_client) # Implementation note: This is implemented as a mapPartitions followed # by runJob() in order to avoid having to pass a Python lambda into # SparkContext#runJob. mappedRDD = rdd.mapPartitions(partitionFunc) - it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal) - return list(mappedRDD._collect_iterator_through_file(it)) - - def _add_profile(self, id, profileAcc): - if not self._profile_stats: - dump_path = self._conf.get("spark.python.profile.dump") - if dump_path: - atexit.register(self.dump_profiles, dump_path) - else: - atexit.register(self.show_profiles) - - self._profile_stats.append([id, profileAcc, False]) + port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions, + allowLocal) + return list(_load_from_socket(port, mappedRDD._jrdd_deserializer)) def show_profiles(self): """ Print the profile stats to stdout """ - for i, (id, acc, showed) in enumerate(self._profile_stats): - stats = acc.value - if not showed and stats: - print "=" * 60 - print "Profile of RDD" % id - print "=" * 60 - stats.sort_stats("time", "cumulative").print_stats() - # mark it as showed - self._profile_stats[i][2] = True + self.profiler_collector.show_profiles() def dump_profiles(self, path): """ Dump the profile stats into directory `path` """ - if not os.path.exists(path): - os.makedirs(path) - for id, acc, _ in self._profile_stats: - stats = acc.value - if stats: - p = os.path.join(path, "rdd_%d.pstats" % id) - stats.dump_stats(p) - self._profile_stats = [] + self.profiler_collector.dump_profiles(path) def _test(): diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index f09587f21170..7f06d4288c87 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -24,9 +24,10 @@ import traceback import time import gc -from errno import EINTR, ECHILD, EAGAIN +from errno import EINTR, EAGAIN from socket import AF_INET, SOCK_STREAM, SOMAXCONN from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT + from pyspark.worker import main as worker_main from pyspark.serializers import read_int, write_int @@ -53,29 +54,21 @@ def worker(sock): # Read the socket using fdopen instead of socket.makefile() because the latter # seems to be very slow; note that we need to dup() the file descriptor because # otherwise writes also cause a seek that makes us miss data on the read side. - infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) - outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) + infile = os.fdopen(os.dup(sock.fileno()), "rb", 65536) + outfile = os.fdopen(os.dup(sock.fileno()), "wb", 65536) exit_code = 0 try: worker_main(infile, outfile) except SystemExit as exc: exit_code = compute_real_exit_code(exc.code) finally: - outfile.flush() + try: + outfile.flush() + except Exception: + pass return exit_code -# Cleanup zombie children -def cleanup_dead_children(): - try: - while True: - pid, _ = os.waitpid(0, os.WNOHANG) - if not pid: - break - except: - pass - - def manager(): # Create a new process group to corral our children os.setpgid(0, 0) @@ -85,8 +78,12 @@ def manager(): listen_sock.bind(('127.0.0.1', 0)) listen_sock.listen(max(1024, SOMAXCONN)) listen_host, listen_port = listen_sock.getsockname() - write_int(listen_port, sys.stdout) - sys.stdout.flush() + + # re-open stdin/stdout in 'wb' mode + stdin_bin = os.fdopen(sys.stdin.fileno(), 'rb', 4) + stdout_bin = os.fdopen(sys.stdout.fileno(), 'wb', 4) + write_int(listen_port, stdout_bin) + stdout_bin.flush() def shutdown(code): signal.signal(SIGTERM, SIG_DFL) @@ -98,6 +95,7 @@ def handle_sigterm(*args): shutdown(1) signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP + signal.signal(SIGCHLD, SIG_IGN) reuse = os.environ.get("SPARK_REUSE_WORKER") @@ -112,12 +110,9 @@ def handle_sigterm(*args): else: raise - # cleanup in signal handler will cause deadlock - cleanup_dead_children() - if 0 in ready_fds: try: - worker_pid = read_int(sys.stdin) + worker_pid = read_int(stdin_bin) except EOFError: # Spark told us to exit by closing stdin shutdown(0) @@ -142,7 +137,7 @@ def handle_sigterm(*args): time.sleep(1) pid = os.fork() # error here will shutdown daemon else: - outfile = sock.makefile('w') + outfile = sock.makefile(mode='wb') write_int(e.errno, outfile) # Signal that the fork failed outfile.flush() outfile.close() @@ -154,7 +149,7 @@ def handle_sigterm(*args): listen_sock.close() try: # Acknowledge that the fork was successful - outfile = sock.makefile("w") + outfile = sock.makefile(mode="wb") write_int(os.getpid(), outfile) outfile.flush() outfile.close() diff --git a/python/pyspark/heapq3.py b/python/pyspark/heapq3.py index bc441f138f7f..4ef2afe03544 100644 --- a/python/pyspark/heapq3.py +++ b/python/pyspark/heapq3.py @@ -627,51 +627,49 @@ def merge(iterables, key=None, reverse=False): if key is None: for order, it in enumerate(map(iter, iterables)): try: - next = it.next - h_append([next(), order * direction, next]) + h_append([next(it), order * direction, it]) except StopIteration: pass _heapify(h) while len(h) > 1: try: while True: - value, order, next = s = h[0] + value, order, it = s = h[0] yield value - s[0] = next() # raises StopIteration when exhausted + s[0] = next(it) # raises StopIteration when exhausted _heapreplace(h, s) # restore heap condition except StopIteration: _heappop(h) # remove empty iterator if h: # fast case when only a single iterator remains - value, order, next = h[0] + value, order, it = h[0] yield value - for value in next.__self__: + for value in it: yield value return for order, it in enumerate(map(iter, iterables)): try: - next = it.next - value = next() - h_append([key(value), order * direction, value, next]) + value = next(it) + h_append([key(value), order * direction, value, it]) except StopIteration: pass _heapify(h) while len(h) > 1: try: while True: - key_value, order, value, next = s = h[0] + key_value, order, value, it = s = h[0] yield value - value = next() + value = next(it) s[0] = key(value) s[2] = value _heapreplace(h, s) except StopIteration: _heappop(h) if h: - key_value, order, value, next = h[0] + key_value, order, value, it = h[0] yield value - for value in next.__self__: + for value in it: yield value diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index a975dc19cb78..3cee4ea6e3a3 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -18,59 +18,75 @@ import atexit import os import sys +import select import signal import shlex +import socket import platform from subprocess import Popen, PIPE -from threading import Thread + +if sys.version >= '3': + xrange = range + from py4j.java_gateway import java_import, JavaGateway, GatewayClient +from py4j.java_collections import ListConverter +from pyspark.serializers import read_int -def launch_gateway(): - SPARK_HOME = os.environ["SPARK_HOME"] - gateway_port = -1 +# patching ListConverter, or it will convert bytearray into Java ArrayList +def can_convert_list(self, obj): + return isinstance(obj, (list, tuple, xrange)) + +ListConverter.can_convert = can_convert_list + + +def launch_gateway(): if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) else: + SPARK_HOME = os.environ["SPARK_HOME"] # Launch the Py4j gateway using Spark's run command so that we pick up the # proper classpath and settings from spark-env.sh on_windows = platform.system() == "Windows" script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit" - submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS") - submit_args = submit_args if submit_args is not None else "" - submit_args = shlex.split(submit_args) - command = [os.path.join(SPARK_HOME, script)] + submit_args + ["pyspark-shell"] + submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") + command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args) + + # Start a socket that will be used by PythonGatewayServer to communicate its port to us + callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + callback_socket.bind(('127.0.0.1', 0)) + callback_socket.listen(1) + callback_host, callback_port = callback_socket.getsockname() + env = dict(os.environ) + env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host + env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port) + + # Launch the Java gateway. + # We open a pipe to stdin so that the Java gateway can die when the pipe is broken if not on_windows: # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): signal.signal(signal.SIGINT, signal.SIG_IGN) - env = dict(os.environ) - env["IS_SUBPROCESS"] = "1" # tell JVM to exit after python exits - proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func, env=env) + proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env) else: # preexec_fn not supported on Windows - proc = Popen(command, stdout=PIPE, stdin=PIPE) - - try: - # Determine which ephemeral port the server started on: - gateway_port = proc.stdout.readline() - gateway_port = int(gateway_port) - except ValueError: - # Grab the remaining lines of stdout - (stdout, _) = proc.communicate() - exit_code = proc.poll() - error_msg = "Launching GatewayServer failed" - error_msg += " with exit code %d!\n" % exit_code if exit_code else "!\n" - error_msg += "Warning: Expected GatewayServer to output a port, but found " - if gateway_port == "" and stdout == "": - error_msg += "no output.\n" - else: - error_msg += "the following:\n\n" - error_msg += "--------------------------------------------------------------\n" - error_msg += gateway_port + stdout - error_msg += "--------------------------------------------------------------\n" - raise Exception(error_msg) + proc = Popen(command, stdin=PIPE, env=env) + + gateway_port = None + # We use select() here in order to avoid blocking indefinitely if the subprocess dies + # before connecting + while gateway_port is None and proc.poll() is None: + timeout = 1 # (seconds) + readable, _, _ = select.select([callback_socket], [], [], timeout) + if callback_socket in readable: + gateway_connection = callback_socket.accept()[0] + # Determine which ephemeral port the server started on: + gateway_port = read_int(gateway_connection.makefile(mode="rb")) + gateway_connection.close() + callback_socket.close() + if gateway_port is None: + raise Exception("Java gateway process exited before sending the driver its port number") # In Windows, ensure the Java child processes do not linger after Python has exited. # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when @@ -88,33 +104,17 @@ def killChild(): Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)]) atexit.register(killChild) - # Create a thread to echo output from the GatewayServer, which is required - # for Java log output to show up: - class EchoOutputThread(Thread): - - def __init__(self, stream): - Thread.__init__(self) - self.daemon = True - self.stream = stream - - def run(self): - while True: - line = self.stream.readline() - sys.stderr.write(line) - EchoOutputThread(proc.stdout).start() - # Connect to the gateway - gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False) + gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True) # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") - java_import(gateway.jvm, "org.apache.spark.sql.SQLContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.HiveContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.LocalHiveContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.TestHiveContext") + # TODO(davies): move into sql + java_import(gateway.jvm, "org.apache.spark.sql.*") + java_import(gateway.jvm, "org.apache.spark.sql.hive.*") java_import(gateway.jvm, "scala.Tuple2") return gateway diff --git a/python/pyspark/join.py b/python/pyspark/join.py index b4a844713745..94df3990164d 100644 --- a/python/pyspark/join.py +++ b/python/pyspark/join.py @@ -32,11 +32,12 @@ """ from pyspark.resultiterable import ResultIterable +from functools import reduce def _do_python_join(rdd, other, numPartitions, dispatch): - vs = rdd.map(lambda (k, v): (k, (1, v))) - ws = other.map(lambda (k, v): (k, (2, v))) + vs = rdd.mapValues(lambda v: (1, v)) + ws = other.mapValues(lambda v: (2, v)) return vs.union(ws).groupByKey(numPartitions).flatMapValues(lambda x: dispatch(x.__iter__())) @@ -48,7 +49,7 @@ def dispatch(seq): vbuf.append(v) elif n == 2: wbuf.append(v) - return [(v, w) for v in vbuf for w in wbuf] + return ((v, w) for v in vbuf for w in wbuf) return _do_python_join(rdd, other, numPartitions, dispatch) @@ -62,7 +63,7 @@ def dispatch(seq): wbuf.append(v) if not vbuf: vbuf.append(None) - return [(v, w) for v in vbuf for w in wbuf] + return ((v, w) for v in vbuf for w in wbuf) return _do_python_join(rdd, other, numPartitions, dispatch) @@ -76,7 +77,7 @@ def dispatch(seq): wbuf.append(v) if not wbuf: wbuf.append(None) - return [(v, w) for v in vbuf for w in wbuf] + return ((v, w) for v in vbuf for w in wbuf) return _do_python_join(rdd, other, numPartitions, dispatch) @@ -98,14 +99,15 @@ def dispatch(seq): def python_cogroup(rdds, numPartitions): def make_mapper(i): - return lambda (k, v): (k, (i, v)) - vrdds = [rdd.map(make_mapper(i)) for i, rdd in enumerate(rdds)] + return lambda v: (i, v) + vrdds = [rdd.mapValues(make_mapper(i)) for i, rdd in enumerate(rdds)] union_vrdds = reduce(lambda acc, other: acc.union(other), vrdds) rdd_len = len(vrdds) def dispatch(seq): - bufs = [[] for i in range(rdd_len)] - for (n, v) in seq: + bufs = [[] for _ in range(rdd_len)] + for n, v in seq: bufs[n].append(v) - return tuple(map(ResultIterable, bufs)) + return tuple(ResultIterable(vs) for vs in bufs) + return union_vrdds.groupByKey(numPartitions).mapValues(dispatch) diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py new file mode 100644 index 000000000000..47fed80f42e1 --- /dev/null +++ b/python/pyspark/ml/__init__.py @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.ml.param import * +from pyspark.ml.pipeline import * + +__all__ = ["Param", "Params", "Transformer", "Estimator", "Pipeline"] diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py new file mode 100644 index 000000000000..45754bc9d4b1 --- /dev/null +++ b/python/pyspark/ml/classification.py @@ -0,0 +1,102 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.ml.util import keyword_only +from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\ + HasRegParam +from pyspark.mllib.common import inherit_doc + + +__all__ = ['LogisticRegression', 'LogisticRegressionModel'] + + +@inherit_doc +class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, + HasRegParam): + """ + Logistic regression. + + >>> from pyspark.sql import Row + >>> from pyspark.mllib.linalg import Vectors + >>> df = sc.parallelize([ + ... Row(label=1.0, features=Vectors.dense(1.0)), + ... Row(label=0.0, features=Vectors.sparse(1, [], []))]).toDF() + >>> lr = LogisticRegression(maxIter=5, regParam=0.01) + >>> model = lr.fit(df) + >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF() + >>> model.transform(test0).head().prediction + 0.0 + >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF() + >>> model.transform(test1).head().prediction + 1.0 + >>> lr.setParams("vector") + Traceback (most recent call last): + ... + TypeError: Method setParams forces keyword arguments. + """ + _java_class = "org.apache.spark.ml.classification.LogisticRegression" + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxIter=100, regParam=0.1): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxIter=100, regParam=0.1) + """ + super(LogisticRegression, self).__init__() + self._setDefault(maxIter=100, regParam=0.1) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxIter=100, regParam=0.1): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxIter=100, regParam=0.1) + Sets params for logistic regression. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return LogisticRegressionModel(java_model) + + +class LogisticRegressionModel(JavaModel): + """ + Model fitted by LogisticRegression. + """ + + +if __name__ == "__main__": + import doctest + from pyspark.context import SparkContext + from pyspark.sql import SQLContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + sc = SparkContext("local[2]", "ml.feature tests") + sqlContext = SQLContext(sc) + globs['sc'] = sc + globs['sqlContext'] = sqlContext + (failure_count, test_count) = doctest.testmod( + globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + if failure_count: + exit(-1) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py new file mode 100644 index 000000000000..4e4614b859ac --- /dev/null +++ b/python/pyspark/ml/feature.py @@ -0,0 +1,130 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.rdd import ignore_unicode_prefix +from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures +from pyspark.ml.util import keyword_only +from pyspark.ml.wrapper import JavaTransformer +from pyspark.mllib.common import inherit_doc + +__all__ = ['Tokenizer', 'HashingTF'] + + +@inherit_doc +@ignore_unicode_prefix +class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): + """ + A tokenizer that converts the input string to lowercase and then + splits it by white spaces. + + >>> from pyspark.sql import Row + >>> df = sc.parallelize([Row(text="a b c")]).toDF() + >>> tokenizer = Tokenizer(inputCol="text", outputCol="words") + >>> tokenizer.transform(df).head() + Row(text=u'a b c', words=[u'a', u'b', u'c']) + >>> # Change a parameter. + >>> tokenizer.setParams(outputCol="tokens").transform(df).head() + Row(text=u'a b c', tokens=[u'a', u'b', u'c']) + >>> # Temporarily modify a parameter. + >>> tokenizer.transform(df, {tokenizer.outputCol: "words"}).head() + Row(text=u'a b c', words=[u'a', u'b', u'c']) + >>> tokenizer.transform(df).head() + Row(text=u'a b c', tokens=[u'a', u'b', u'c']) + >>> # Must use keyword arguments to specify params. + >>> tokenizer.setParams("text") + Traceback (most recent call last): + ... + TypeError: Method setParams forces keyword arguments. + """ + + _java_class = "org.apache.spark.ml.feature.Tokenizer" + + @keyword_only + def __init__(self, inputCol=None, outputCol=None): + """ + __init__(self, inputCol=None, outputCol=None) + """ + super(Tokenizer, self).__init__() + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, inputCol=None, outputCol=None): + """ + setParams(self, inputCol="input", outputCol="output") + Sets params for this Tokenizer. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + +@inherit_doc +class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): + """ + Maps a sequence of terms to their term frequencies using the + hashing trick. + + >>> from pyspark.sql import Row + >>> df = sc.parallelize([Row(words=["a", "b", "c"])]).toDF() + >>> hashingTF = HashingTF(numFeatures=10, inputCol="words", outputCol="features") + >>> hashingTF.transform(df).head().features + SparseVector(10, {7: 1.0, 8: 1.0, 9: 1.0}) + >>> hashingTF.setParams(outputCol="freqs").transform(df).head().freqs + SparseVector(10, {7: 1.0, 8: 1.0, 9: 1.0}) + >>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"} + >>> hashingTF.transform(df, params).head().vector + SparseVector(5, {2: 1.0, 3: 1.0, 4: 1.0}) + """ + + _java_class = "org.apache.spark.ml.feature.HashingTF" + + @keyword_only + def __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None): + """ + __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None) + """ + super(HashingTF, self).__init__() + self._setDefault(numFeatures=1 << 18) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None): + """ + setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None) + Sets params for this HashingTF. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + +if __name__ == "__main__": + import doctest + from pyspark.context import SparkContext + from pyspark.sql import SQLContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + sc = SparkContext("local[2]", "ml.feature tests") + sqlContext = SQLContext(sc) + globs['sc'] = sc + globs['sqlContext'] = sqlContext + (failure_count, test_count) = doctest.testmod( + globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + if failure_count: + exit(-1) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py new file mode 100644 index 000000000000..49c20b4cf70c --- /dev/null +++ b/python/pyspark/ml/param/__init__.py @@ -0,0 +1,198 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABCMeta + +from pyspark.ml.util import Identifiable + + +__all__ = ['Param', 'Params'] + + +class Param(object): + """ + A param with self-contained documentation. + """ + + def __init__(self, parent, name, doc): + if not isinstance(parent, Params): + raise TypeError("Parent must be a Params but got type %s." % type(parent)) + self.parent = parent + self.name = str(name) + self.doc = str(doc) + + def __str__(self): + return str(self.parent) + "__" + self.name + + def __repr__(self): + return "Param(parent=%r, name=%r, doc=%r)" % (self.parent, self.name, self.doc) + + +class Params(Identifiable): + """ + Components that take parameters. This also provides an internal + param map to store parameter values attached to the instance. + """ + + __metaclass__ = ABCMeta + + #: internal param map for user-supplied values param map + paramMap = {} + + #: internal param map for default values + defaultParamMap = {} + + @property + def params(self): + """ + Returns all params ordered by name. The default implementation + uses :py:func:`dir` to get all attributes of type + :py:class:`Param`. + """ + return list(filter(lambda attr: isinstance(attr, Param), + [getattr(self, x) for x in dir(self) if x != "params"])) + + def _explain(self, param): + """ + Explains a single param and returns its name, doc, and optional + default value and user-supplied value in a string. + """ + param = self._resolveParam(param) + values = [] + if self.isDefined(param): + if param in self.defaultParamMap: + values.append("default: %s" % self.defaultParamMap[param]) + if param in self.paramMap: + values.append("current: %s" % self.paramMap[param]) + else: + values.append("undefined") + valueStr = "(" + ", ".join(values) + ")" + return "%s: %s %s" % (param.name, param.doc, valueStr) + + def explainParams(self): + """ + Returns the documentation of all params with their optionally + default values and user-supplied values. + """ + return "\n".join([self._explain(param) for param in self.params]) + + def getParam(self, paramName): + """ + Gets a param by its name. + """ + param = getattr(self, paramName) + if isinstance(param, Param): + return param + else: + raise ValueError("Cannot find param with name %s." % paramName) + + def isSet(self, param): + """ + Checks whether a param is explicitly set by user. + """ + param = self._resolveParam(param) + return param in self.paramMap + + def hasDefault(self, param): + """ + Checks whether a param has a default value. + """ + param = self._resolveParam(param) + return param in self.defaultParamMap + + def isDefined(self, param): + """ + Checks whether a param is explicitly set by user or has a default value. + """ + return self.isSet(param) or self.hasDefault(param) + + def getOrDefault(self, param): + """ + Gets the value of a param in the user-supplied param map or its + default value. Raises an error if either is set. + """ + if isinstance(param, Param): + if param in self.paramMap: + return self.paramMap[param] + else: + return self.defaultParamMap[param] + elif isinstance(param, str): + return self.getOrDefault(self.getParam(param)) + else: + raise KeyError("Cannot recognize %r as a param." % param) + + def extractParamMap(self, extraParamMap={}): + """ + Extracts the embedded default param values and user-supplied + values, and then merges them with extra values from input into + a flat param map, where the latter value is used if there exist + conflicts, i.e., with ordering: default param values < + user-supplied values < extraParamMap. + :param extraParamMap: extra param values + :return: merged param map + """ + paramMap = self.defaultParamMap.copy() + paramMap.update(self.paramMap) + paramMap.update(extraParamMap) + return paramMap + + def _shouldOwn(self, param): + """ + Validates that the input param belongs to this Params instance. + """ + if param.parent is not self: + raise ValueError("Param %r does not belong to %r." % (param, self)) + + def _resolveParam(self, param): + """ + Resolves a param and validates the ownership. + :param param: param name or the param instance, which must + belong to this Params instance + :return: resolved param instance + """ + if isinstance(param, Param): + self._shouldOwn(param) + return param + elif isinstance(param, str): + return self.getParam(param) + else: + raise ValueError("Cannot resolve %r as a param." % param) + + @staticmethod + def _dummy(): + """ + Returns a dummy Params instance used as a placeholder to generate docs. + """ + dummy = Params() + dummy.uid = "undefined" + return dummy + + def _set(self, **kwargs): + """ + Sets user-supplied params. + """ + for param, value in kwargs.items(): + self.paramMap[getattr(self, param)] = value + return self + + def _setDefault(self, **kwargs): + """ + Sets default params. + """ + for param, value in kwargs.items(): + self.defaultParamMap[getattr(self, param)] = value + return self diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py new file mode 100644 index 000000000000..6a3192465d66 --- /dev/null +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -0,0 +1,102 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +header = """# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#""" + +# Code generator for shared params (shared.py). Run under this folder with: +# python _shared_params_code_gen.py > shared.py + + +def _gen_param_code(name, doc, defaultValueStr): + """ + Generates Python code for a shared param class. + + :param name: param name + :param doc: param doc + :param defaultValueStr: string representation of the default value + :return: code string + """ + # TODO: How to correctly inherit instance attributes? + template = '''class Has$Name(Params): + """ + Mixin for param $name: $doc. + """ + + # a placeholder to make it appear in the generated doc + $name = Param(Params._dummy(), "$name", "$doc") + + def __init__(self): + super(Has$Name, self).__init__() + #: param for $doc + self.$name = Param(self, "$name", "$doc") + if $defaultValueStr is not None: + self._setDefault($name=$defaultValueStr) + + def set$Name(self, value): + """ + Sets the value of :py:attr:`$name`. + """ + self.paramMap[self.$name] = value + return self + + def get$Name(self): + """ + Gets the value of $name or its default value. + """ + return self.getOrDefault(self.$name)''' + + Name = name[0].upper() + name[1:] + return template \ + .replace("$name", name) \ + .replace("$Name", Name) \ + .replace("$doc", doc) \ + .replace("$defaultValueStr", str(defaultValueStr)) + +if __name__ == "__main__": + print(header) + print("\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n") + print("from pyspark.ml.param import Param, Params\n\n") + shared = [ + ("maxIter", "max number of iterations", None), + ("regParam", "regularization constant", None), + ("featuresCol", "features column name", "'features'"), + ("labelCol", "label column name", "'label'"), + ("predictionCol", "prediction column name", "'prediction'"), + ("inputCol", "input column name", None), + ("outputCol", "output column name", None), + ("numFeatures", "number of features", None)] + code = [] + for name, doc, defaultValueStr in shared: + code.append(_gen_param_code(name, doc, defaultValueStr)) + print("\n\n\n".join(code)) diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py new file mode 100644 index 000000000000..13b6749998ad --- /dev/null +++ b/python/pyspark/ml/param/shared.py @@ -0,0 +1,252 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py. + +from pyspark.ml.param import Param, Params + + +class HasMaxIter(Params): + """ + Mixin for param maxIter: max number of iterations. + """ + + # a placeholder to make it appear in the generated doc + maxIter = Param(Params._dummy(), "maxIter", "max number of iterations") + + def __init__(self): + super(HasMaxIter, self).__init__() + #: param for max number of iterations + self.maxIter = Param(self, "maxIter", "max number of iterations") + if None is not None: + self._setDefault(maxIter=None) + + def setMaxIter(self, value): + """ + Sets the value of :py:attr:`maxIter`. + """ + self.paramMap[self.maxIter] = value + return self + + def getMaxIter(self): + """ + Gets the value of maxIter or its default value. + """ + return self.getOrDefault(self.maxIter) + + +class HasRegParam(Params): + """ + Mixin for param regParam: regularization constant. + """ + + # a placeholder to make it appear in the generated doc + regParam = Param(Params._dummy(), "regParam", "regularization constant") + + def __init__(self): + super(HasRegParam, self).__init__() + #: param for regularization constant + self.regParam = Param(self, "regParam", "regularization constant") + if None is not None: + self._setDefault(regParam=None) + + def setRegParam(self, value): + """ + Sets the value of :py:attr:`regParam`. + """ + self.paramMap[self.regParam] = value + return self + + def getRegParam(self): + """ + Gets the value of regParam or its default value. + """ + return self.getOrDefault(self.regParam) + + +class HasFeaturesCol(Params): + """ + Mixin for param featuresCol: features column name. + """ + + # a placeholder to make it appear in the generated doc + featuresCol = Param(Params._dummy(), "featuresCol", "features column name") + + def __init__(self): + super(HasFeaturesCol, self).__init__() + #: param for features column name + self.featuresCol = Param(self, "featuresCol", "features column name") + if 'features' is not None: + self._setDefault(featuresCol='features') + + def setFeaturesCol(self, value): + """ + Sets the value of :py:attr:`featuresCol`. + """ + self.paramMap[self.featuresCol] = value + return self + + def getFeaturesCol(self): + """ + Gets the value of featuresCol or its default value. + """ + return self.getOrDefault(self.featuresCol) + + +class HasLabelCol(Params): + """ + Mixin for param labelCol: label column name. + """ + + # a placeholder to make it appear in the generated doc + labelCol = Param(Params._dummy(), "labelCol", "label column name") + + def __init__(self): + super(HasLabelCol, self).__init__() + #: param for label column name + self.labelCol = Param(self, "labelCol", "label column name") + if 'label' is not None: + self._setDefault(labelCol='label') + + def setLabelCol(self, value): + """ + Sets the value of :py:attr:`labelCol`. + """ + self.paramMap[self.labelCol] = value + return self + + def getLabelCol(self): + """ + Gets the value of labelCol or its default value. + """ + return self.getOrDefault(self.labelCol) + + +class HasPredictionCol(Params): + """ + Mixin for param predictionCol: prediction column name. + """ + + # a placeholder to make it appear in the generated doc + predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name") + + def __init__(self): + super(HasPredictionCol, self).__init__() + #: param for prediction column name + self.predictionCol = Param(self, "predictionCol", "prediction column name") + if 'prediction' is not None: + self._setDefault(predictionCol='prediction') + + def setPredictionCol(self, value): + """ + Sets the value of :py:attr:`predictionCol`. + """ + self.paramMap[self.predictionCol] = value + return self + + def getPredictionCol(self): + """ + Gets the value of predictionCol or its default value. + """ + return self.getOrDefault(self.predictionCol) + + +class HasInputCol(Params): + """ + Mixin for param inputCol: input column name. + """ + + # a placeholder to make it appear in the generated doc + inputCol = Param(Params._dummy(), "inputCol", "input column name") + + def __init__(self): + super(HasInputCol, self).__init__() + #: param for input column name + self.inputCol = Param(self, "inputCol", "input column name") + if None is not None: + self._setDefault(inputCol=None) + + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + self.paramMap[self.inputCol] = value + return self + + def getInputCol(self): + """ + Gets the value of inputCol or its default value. + """ + return self.getOrDefault(self.inputCol) + + +class HasOutputCol(Params): + """ + Mixin for param outputCol: output column name. + """ + + # a placeholder to make it appear in the generated doc + outputCol = Param(Params._dummy(), "outputCol", "output column name") + + def __init__(self): + super(HasOutputCol, self).__init__() + #: param for output column name + self.outputCol = Param(self, "outputCol", "output column name") + if None is not None: + self._setDefault(outputCol=None) + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + self.paramMap[self.outputCol] = value + return self + + def getOutputCol(self): + """ + Gets the value of outputCol or its default value. + """ + return self.getOrDefault(self.outputCol) + + +class HasNumFeatures(Params): + """ + Mixin for param numFeatures: number of features. + """ + + # a placeholder to make it appear in the generated doc + numFeatures = Param(Params._dummy(), "numFeatures", "number of features") + + def __init__(self): + super(HasNumFeatures, self).__init__() + #: param for number of features + self.numFeatures = Param(self, "numFeatures", "number of features") + if None is not None: + self._setDefault(numFeatures=None) + + def setNumFeatures(self, value): + """ + Sets the value of :py:attr:`numFeatures`. + """ + self.paramMap[self.numFeatures] = value + return self + + def getNumFeatures(self): + """ + Gets the value of numFeatures or its default value. + """ + return self.getOrDefault(self.numFeatures) diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py new file mode 100644 index 000000000000..7c1ec3026da6 --- /dev/null +++ b/python/pyspark/ml/pipeline.py @@ -0,0 +1,170 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABCMeta, abstractmethod + +from pyspark.ml.param import Param, Params +from pyspark.ml.util import keyword_only +from pyspark.mllib.common import inherit_doc + + +__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel'] + + +@inherit_doc +class Estimator(Params): + """ + Abstract class for estimators that fit models to data. + """ + + __metaclass__ = ABCMeta + + @abstractmethod + def fit(self, dataset, params={}): + """ + Fits a model to the input dataset with optional parameters. + + :param dataset: input dataset, which is an instance of + :py:class:`pyspark.sql.DataFrame` + :param params: an optional param map that overwrites embedded + params + :returns: fitted model + """ + raise NotImplementedError() + + +@inherit_doc +class Transformer(Params): + """ + Abstract class for transformers that transform one dataset into + another. + """ + + __metaclass__ = ABCMeta + + @abstractmethod + def transform(self, dataset, params={}): + """ + Transforms the input dataset with optional parameters. + + :param dataset: input dataset, which is an instance of + :py:class:`pyspark.sql.DataFrame` + :param params: an optional param map that overwrites embedded + params + :returns: transformed dataset + """ + raise NotImplementedError() + + +@inherit_doc +class Pipeline(Estimator): + """ + A simple pipeline, which acts as an estimator. A Pipeline consists + of a sequence of stages, each of which is either an + :py:class:`Estimator` or a :py:class:`Transformer`. When + :py:meth:`Pipeline.fit` is called, the stages are executed in + order. If a stage is an :py:class:`Estimator`, its + :py:meth:`Estimator.fit` method will be called on the input + dataset to fit a model. Then the model, which is a transformer, + will be used to transform the dataset as the input to the next + stage. If a stage is a :py:class:`Transformer`, its + :py:meth:`Transformer.transform` method will be called to produce + the dataset for the next stage. The fitted model from a + :py:class:`Pipeline` is an :py:class:`PipelineModel`, which + consists of fitted models and transformers, corresponding to the + pipeline stages. If there are no stages, the pipeline acts as an + identity transformer. + """ + + @keyword_only + def __init__(self, stages=[]): + """ + __init__(self, stages=[]) + """ + super(Pipeline, self).__init__() + #: Param for pipeline stages. + self.stages = Param(self, "stages", "pipeline stages") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + def setStages(self, value): + """ + Set pipeline stages. + :param value: a list of transformers or estimators + :return: the pipeline instance + """ + self.paramMap[self.stages] = value + return self + + def getStages(self): + """ + Get pipeline stages. + """ + if self.stages in self.paramMap: + return self.paramMap[self.stages] + + @keyword_only + def setParams(self, stages=[]): + """ + setParams(self, stages=[]) + Sets params for Pipeline. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def fit(self, dataset, params={}): + paramMap = self.extractParamMap(params) + stages = paramMap[self.stages] + for stage in stages: + if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)): + raise TypeError( + "Cannot recognize a pipeline stage of type %s." % type(stage)) + indexOfLastEstimator = -1 + for i, stage in enumerate(stages): + if isinstance(stage, Estimator): + indexOfLastEstimator = i + transformers = [] + for i, stage in enumerate(stages): + if i <= indexOfLastEstimator: + if isinstance(stage, Transformer): + transformers.append(stage) + dataset = stage.transform(dataset, paramMap) + else: # must be an Estimator + model = stage.fit(dataset, paramMap) + transformers.append(model) + if i < indexOfLastEstimator: + dataset = model.transform(dataset, paramMap) + else: + transformers.append(stage) + return PipelineModel(transformers) + + +@inherit_doc +class PipelineModel(Transformer): + """ + Represents a compiled pipeline with transformers and fitted models. + """ + + def __init__(self, transformers): + super(PipelineModel, self).__init__() + self.transformers = transformers + + def transform(self, dataset, params={}): + paramMap = self.extractParamMap(params) + for t in self.transformers: + dataset = t.transform(dataset, paramMap) + return dataset diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py new file mode 100644 index 000000000000..3a42bcf72389 --- /dev/null +++ b/python/pyspark/ml/tests.py @@ -0,0 +1,163 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Unit tests for Spark ML Python APIs. +""" + +import sys + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase +from pyspark.sql import DataFrame +from pyspark.ml.param import Param +from pyspark.ml.param.shared import HasMaxIter, HasInputCol +from pyspark.ml.pipeline import Transformer, Estimator, Pipeline + + +class MockDataset(DataFrame): + + def __init__(self): + self.index = 0 + + +class MockTransformer(Transformer): + + def __init__(self): + super(MockTransformer, self).__init__() + self.fake = Param(self, "fake", "fake") + self.dataset_index = None + self.fake_param_value = None + + def transform(self, dataset, params={}): + self.dataset_index = dataset.index + if self.fake in params: + self.fake_param_value = params[self.fake] + dataset.index += 1 + return dataset + + +class MockEstimator(Estimator): + + def __init__(self): + super(MockEstimator, self).__init__() + self.fake = Param(self, "fake", "fake") + self.dataset_index = None + self.fake_param_value = None + self.model = None + + def fit(self, dataset, params={}): + self.dataset_index = dataset.index + if self.fake in params: + self.fake_param_value = params[self.fake] + model = MockModel() + self.model = model + return model + + +class MockModel(MockTransformer, Transformer): + + def __init__(self): + super(MockModel, self).__init__() + + +class PipelineTests(PySparkTestCase): + + def test_pipeline(self): + dataset = MockDataset() + estimator0 = MockEstimator() + transformer1 = MockTransformer() + estimator2 = MockEstimator() + transformer3 = MockTransformer() + pipeline = Pipeline() \ + .setStages([estimator0, transformer1, estimator2, transformer3]) + pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1}) + self.assertEqual(0, estimator0.dataset_index) + self.assertEqual(0, estimator0.fake_param_value) + model0 = estimator0.model + self.assertEqual(0, model0.dataset_index) + self.assertEqual(1, transformer1.dataset_index) + self.assertEqual(1, transformer1.fake_param_value) + self.assertEqual(2, estimator2.dataset_index) + model2 = estimator2.model + self.assertIsNone(model2.dataset_index, "The model produced by the last estimator should " + "not be called during fit.") + dataset = pipeline_model.transform(dataset) + self.assertEqual(2, model0.dataset_index) + self.assertEqual(3, transformer1.dataset_index) + self.assertEqual(4, model2.dataset_index) + self.assertEqual(5, transformer3.dataset_index) + self.assertEqual(6, dataset.index) + + +class TestParams(HasMaxIter, HasInputCol): + """ + A subclass of Params mixed with HasMaxIter and HasInputCol. + """ + + def __init__(self): + super(TestParams, self).__init__() + self._setDefault(maxIter=10) + + +class ParamTests(PySparkTestCase): + + def test_param(self): + testParams = TestParams() + maxIter = testParams.maxIter + self.assertEqual(maxIter.name, "maxIter") + self.assertEqual(maxIter.doc, "max number of iterations") + self.assertTrue(maxIter.parent is testParams) + + def test_params(self): + testParams = TestParams() + maxIter = testParams.maxIter + inputCol = testParams.inputCol + + params = testParams.params + self.assertEqual(params, [inputCol, maxIter]) + + self.assertTrue(testParams.hasDefault(maxIter)) + self.assertFalse(testParams.isSet(maxIter)) + self.assertTrue(testParams.isDefined(maxIter)) + self.assertEqual(testParams.getMaxIter(), 10) + testParams.setMaxIter(100) + self.assertTrue(testParams.isSet(maxIter)) + self.assertEquals(testParams.getMaxIter(), 100) + + self.assertFalse(testParams.hasDefault(inputCol)) + self.assertFalse(testParams.isSet(inputCol)) + self.assertFalse(testParams.isDefined(inputCol)) + with self.assertRaises(KeyError): + testParams.getInputCol() + + self.assertEquals( + testParams.explainParams(), + "\n".join(["inputCol: input column name (undefined)", + "maxIter: max number of iterations (default: 10, current: 100)"])) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py new file mode 100644 index 000000000000..d3cb100a9efa --- /dev/null +++ b/python/pyspark/ml/util.py @@ -0,0 +1,47 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from functools import wraps +import uuid + + +def keyword_only(func): + """ + A decorator that forces keyword arguments in the wrapped method + and saves actual input keyword arguments in `_input_kwargs`. + """ + @wraps(func) + def wrapper(*args, **kwargs): + if len(args) > 1: + raise TypeError("Method %s forces keyword arguments." % func.__name__) + wrapper._input_kwargs = kwargs + return func(*args, **kwargs) + return wrapper + + +class Identifiable(object): + """ + Object with a unique ID. + """ + + def __init__(self): + #: A unique id for the object. The default implementation + #: concatenates the class name, "_", and 8 random hex chars. + self.uid = type(self).__name__ + "_" + uuid.uuid4().hex[:8] + + def __repr__(self): + return self.uid diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py new file mode 100644 index 000000000000..394f23c5e9b1 --- /dev/null +++ b/python/pyspark/ml/wrapper.py @@ -0,0 +1,149 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABCMeta + +from pyspark import SparkContext +from pyspark.sql import DataFrame +from pyspark.ml.param import Params +from pyspark.ml.pipeline import Estimator, Transformer +from pyspark.mllib.common import inherit_doc + + +def _jvm(): + """ + Returns the JVM view associated with SparkContext. Must be called + after SparkContext is initialized. + """ + jvm = SparkContext._jvm + if jvm: + return jvm + else: + raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?") + + +@inherit_doc +class JavaWrapper(Params): + """ + Utility class to help create wrapper classes from Java/Scala + implementations of pipeline components. + """ + + __metaclass__ = ABCMeta + + #: Fully-qualified class name of the wrapped Java component. + _java_class = None + + def _java_obj(self): + """ + Returns or creates a Java object. + """ + java_obj = _jvm() + for name in self._java_class.split("."): + java_obj = getattr(java_obj, name) + return java_obj() + + def _transfer_params_to_java(self, params, java_obj): + """ + Transforms the embedded params and additional params to the + input Java object. + :param params: additional params (overwriting embedded values) + :param java_obj: Java object to receive the params + """ + paramMap = self.extractParamMap(params) + for param in self.params: + if param in paramMap: + java_obj.set(param.name, paramMap[param]) + + def _empty_java_param_map(self): + """ + Returns an empty Java ParamMap reference. + """ + return _jvm().org.apache.spark.ml.param.ParamMap() + + def _create_java_param_map(self, params, java_obj): + paramMap = self._empty_java_param_map() + for param, value in params.items(): + if param.parent is self: + paramMap.put(java_obj.getParam(param.name), value) + return paramMap + + +@inherit_doc +class JavaEstimator(Estimator, JavaWrapper): + """ + Base class for :py:class:`Estimator`s that wrap Java/Scala + implementations. + """ + + __metaclass__ = ABCMeta + + def _create_model(self, java_model): + """ + Creates a model from the input Java model reference. + """ + return JavaModel(java_model) + + def _fit_java(self, dataset, params={}): + """ + Fits a Java model to the input dataset. + :param dataset: input dataset, which is an instance of + :py:class:`pyspark.sql.DataFrame` + :param params: additional params (overwriting embedded values) + :return: fitted Java model + """ + java_obj = self._java_obj() + self._transfer_params_to_java(params, java_obj) + return java_obj.fit(dataset._jdf, self._empty_java_param_map()) + + def fit(self, dataset, params={}): + java_model = self._fit_java(dataset, params) + return self._create_model(java_model) + + +@inherit_doc +class JavaTransformer(Transformer, JavaWrapper): + """ + Base class for :py:class:`Transformer`s that wrap Java/Scala + implementations. + """ + + __metaclass__ = ABCMeta + + def transform(self, dataset, params={}): + java_obj = self._java_obj() + self._transfer_params_to_java({}, java_obj) + java_param_map = self._create_java_param_map(params, java_obj) + return DataFrame(java_obj.transform(dataset._jdf, java_param_map), + dataset.sql_ctx) + + +@inherit_doc +class JavaModel(JavaTransformer): + """ + Base class for :py:class:`Model`s that wrap Java/Scala + implementations. + """ + + __metaclass__ = ABCMeta + + def __init__(self, java_model): + super(JavaTransformer, self).__init__() + self._java_model = java_model + + def _java_obj(self): + return self._java_model diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py index c3217620e3c4..07507b2ad0d0 100644 --- a/python/pyspark/mllib/__init__.py +++ b/python/pyspark/mllib/__init__.py @@ -18,18 +18,21 @@ """ Python bindings for MLlib. """ +from __future__ import absolute_import -# MLlib currently needs and NumPy 1.4+, so complain if lower +# MLlib currently needs NumPy 1.4+, so complain if lower import numpy if numpy.version.version < '1.4': raise Exception("MLlib requires NumPy 1.4+") -__all__ = ['classification', 'clustering', 'feature', 'linalg', 'random', +__all__ = ['classification', 'clustering', 'feature', 'fpm', 'linalg', 'random', 'recommendation', 'regression', 'stat', 'tree', 'util'] import sys -import rand as random -random.__name__ = 'random' -random.RandomRDDs.__module__ = __name__ + '.random' -sys.modules[__name__ + '.random'] = random +from . import rand as random +modname = __name__ + '.random' +random.__name__ = modname +random.RandomRDDs.__module__ = modname +sys.modules[modname] = random +del modname, sys diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 00e2e76711e8..a70c664a71fd 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -21,22 +21,23 @@ from numpy import array from pyspark import RDD -from pyspark.mllib.common import callMLlibFunc -from pyspark.mllib.linalg import SparseVector, _convert_to_vector +from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py +from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper +from pyspark.mllib.util import Saveable, Loader, inherit_doc __all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'LogisticRegressionWithLBFGS', 'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes'] -class LinearBinaryClassificationModel(LinearModel): +class LinearClassificationModel(LinearModel): """ - Represents a linear binary classification model that predicts to whether an - example is positive (1.0) or negative (0.0). + A private abstract class representing a multiclass classification model. + The categories are represented by int values: 0, 1, 2, etc. """ def __init__(self, weights, intercept): - super(LinearBinaryClassificationModel, self).__init__(weights, intercept) + super(LinearClassificationModel, self).__init__(weights, intercept) self._threshold = None def setThreshold(self, value): @@ -46,14 +47,26 @@ def setThreshold(self, value): Sets the threshold that separates positive predictions from negative predictions. An example with prediction score greater than or equal to this threshold is identified as an positive, and negative otherwise. + It is used for binary classification only. """ self._threshold = value + @property + def threshold(self): + """ + .. note:: Experimental + + Returns the threshold (if any) used for converting raw prediction scores + into 0/1 predictions. It is used for binary classification only. + """ + return self._threshold + def clearThreshold(self): """ .. note:: Experimental Clears the threshold so that `predict` will output raw prediction scores. + It is used for binary classification only. """ self._threshold = None @@ -65,7 +78,7 @@ def predict(self, test): raise NotImplementedError -class LogisticRegressionModel(LinearBinaryClassificationModel): +class LogisticRegressionModel(LinearClassificationModel): """A linear binary classification model derived from logistic regression. @@ -73,7 +86,7 @@ class LogisticRegressionModel(LinearBinaryClassificationModel): ... LabeledPoint(0.0, [0.0, 1.0]), ... LabeledPoint(1.0, [1.0, 0.0]), ... ] - >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(data)) + >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(data), iterations=10) >>> lrm.predict([1.0, 0.0]) 1 >>> lrm.predict([0.0, 1.0]) @@ -82,7 +95,7 @@ class LogisticRegressionModel(LinearBinaryClassificationModel): [1, 0] >>> lrm.clearThreshold() >>> lrm.predict([0.0, 1.0]) - 0.123... + 0.279... >>> sparse_data = [ ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), @@ -90,7 +103,7 @@ class LogisticRegressionModel(LinearBinaryClassificationModel): ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})), ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) ... ] - >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(sparse_data)) + >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(sparse_data), iterations=10) >>> lrm.predict(array([0.0, 1.0])) 1 >>> lrm.predict(array([1.0, 0.0])) @@ -99,10 +112,52 @@ class LogisticRegressionModel(LinearBinaryClassificationModel): 1 >>> lrm.predict(SparseVector(2, {0: 1.0})) 0 + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> lrm.save(sc, path) + >>> sameModel = LogisticRegressionModel.load(sc, path) + >>> sameModel.predict(array([0.0, 1.0])) + 1 + >>> sameModel.predict(SparseVector(2, {0: 1.0})) + 0 + >>> try: + ... os.removedirs(path) + ... except: + ... pass + >>> multi_class_data = [ + ... LabeledPoint(0.0, [0.0, 1.0, 0.0]), + ... LabeledPoint(1.0, [1.0, 0.0, 0.0]), + ... LabeledPoint(2.0, [0.0, 0.0, 1.0]) + ... ] + >>> data = sc.parallelize(multi_class_data) + >>> mcm = LogisticRegressionWithLBFGS.train(data, iterations=10, numClasses=3) + >>> mcm.predict([0.0, 0.5, 0.0]) + 0 + >>> mcm.predict([0.8, 0.0, 0.0]) + 1 + >>> mcm.predict([0.0, 0.0, 0.3]) + 2 """ - def __init__(self, weights, intercept): + def __init__(self, weights, intercept, numFeatures, numClasses): super(LogisticRegressionModel, self).__init__(weights, intercept) + self._numFeatures = int(numFeatures) + self._numClasses = int(numClasses) self._threshold = 0.5 + if self._numClasses == 2: + self._dataWithBiasSize = None + self._weightsMatrix = None + else: + self._dataWithBiasSize = self._coeff.size / (self._numClasses - 1) + self._weightsMatrix = self._coeff.toArray().reshape(self._numClasses - 1, + self._dataWithBiasSize) + + @property + def numFeatures(self): + return self._numFeatures + + @property + def numClasses(self): + return self._numClasses def predict(self, x): """ @@ -113,23 +168,60 @@ def predict(self, x): return x.map(lambda v: self.predict(v)) x = _convert_to_vector(x) - margin = self.weights.dot(x) + self._intercept - if margin > 0: - prob = 1 / (1 + exp(-margin)) + if self.numClasses == 2: + margin = self.weights.dot(x) + self._intercept + if margin > 0: + prob = 1 / (1 + exp(-margin)) + else: + exp_margin = exp(margin) + prob = exp_margin / (1 + exp_margin) + if self._threshold is None: + return prob + else: + return 1 if prob > self._threshold else 0 else: - exp_margin = exp(margin) - prob = exp_margin / (1 + exp_margin) - if self._threshold is None: - return prob - else: - return 1 if prob > self._threshold else 0 + best_class = 0 + max_margin = 0.0 + if x.size + 1 == self._dataWithBiasSize: + for i in range(0, self._numClasses - 1): + margin = x.dot(self._weightsMatrix[i][0:x.size]) + \ + self._weightsMatrix[i][x.size] + if margin > max_margin: + max_margin = margin + best_class = i + 1 + else: + for i in range(0, self._numClasses - 1): + margin = x.dot(self._weightsMatrix[i]) + if margin > max_margin: + max_margin = margin + best_class = i + 1 + return best_class + + def save(self, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel( + _py2java(sc, self._coeff), self.intercept, self.numFeatures, self.numClasses) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel.load( + sc._jsc.sc(), path) + weights = _java2py(sc, java_model.weights()) + intercept = java_model.intercept() + numFeatures = java_model.numFeatures() + numClasses = java_model.numClasses() + threshold = java_model.getThreshold().get() + model = LogisticRegressionModel(weights, intercept, numFeatures, numClasses) + model.setThreshold(threshold) + return model class LogisticRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, - initialWeights=None, regParam=0.01, regType="l2", intercept=False): + initialWeights=None, regParam=0.01, regType="l2", intercept=False, + validateData=True): """ Train a logistic regression model on the given data. @@ -155,11 +247,14 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, or not of the augmented representation for training data (i.e. whether bias features are activated or not). + :param validateData: Boolean parameter which indicates if the + algorithm should validate data before training. + (default: True) """ def train(rdd, i): return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations), float(step), float(miniBatchFraction), i, float(regParam), regType, - bool(intercept)) + bool(intercept), bool(validateData)) return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights) @@ -168,7 +263,7 @@ class LogisticRegressionWithLBFGS(object): @classmethod def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType="l2", - intercept=False, corrections=10, tolerance=1e-4): + intercept=False, corrections=10, tolerance=1e-4, validateData=True, numClasses=2): """ Train a logistic regression model on the given data. @@ -194,12 +289,17 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType update (default: 10). :param tolerance: The convergence tolerance of iterations for L-BFGS (default: 1e-4). + :param validateData: Boolean parameter which indicates if the + algorithm should validate data before training. + (default: True) + :param numClasses: The number of classes (i.e., outcomes) a label can take + in Multinomial Logistic Regression (default: 2). >>> data = [ ... LabeledPoint(0.0, [0.0, 1.0]), ... LabeledPoint(1.0, [1.0, 0.0]), ... ] - >>> lrm = LogisticRegressionWithLBFGS.train(sc.parallelize(data)) + >>> lrm = LogisticRegressionWithLBFGS.train(sc.parallelize(data), iterations=10) >>> lrm.predict([1.0, 0.0]) 1 >>> lrm.predict([0.0, 1.0]) @@ -207,13 +307,21 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType """ def train(rdd, i): return callMLlibFunc("trainLogisticRegressionModelWithLBFGS", rdd, int(iterations), i, - float(regParam), str(regType), bool(intercept), int(corrections), - float(tolerance)) - + float(regParam), regType, bool(intercept), int(corrections), + float(tolerance), bool(validateData), int(numClasses)) + + if initialWeights is None: + if numClasses == 2: + initialWeights = [0.0] * len(data.first().features) + else: + if intercept: + initialWeights = [0.0] * (len(data.first().features) + 1) * (numClasses - 1) + else: + initialWeights = [0.0] * len(data.first().features) * (numClasses - 1) return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights) -class SVMModel(LinearBinaryClassificationModel): +class SVMModel(LinearClassificationModel): """A support vector machine. @@ -223,14 +331,14 @@ class SVMModel(LinearBinaryClassificationModel): ... LabeledPoint(1.0, [2.0]), ... LabeledPoint(1.0, [3.0]) ... ] - >>> svm = SVMWithSGD.train(sc.parallelize(data)) + >>> svm = SVMWithSGD.train(sc.parallelize(data), iterations=10) >>> svm.predict([1.0]) 1 >>> svm.predict(sc.parallelize([[1.0]])).collect() [1] >>> svm.clearThreshold() >>> svm.predict(array([1.0])) - 1.25... + 1.44... >>> sparse_data = [ ... LabeledPoint(0.0, SparseVector(2, {0: -1.0})), @@ -238,11 +346,23 @@ class SVMModel(LinearBinaryClassificationModel): ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) ... ] - >>> svm = SVMWithSGD.train(sc.parallelize(sparse_data)) + >>> svm = SVMWithSGD.train(sc.parallelize(sparse_data), iterations=10) >>> svm.predict(SparseVector(2, {1: 1.0})) 1 >>> svm.predict(SparseVector(2, {0: -1.0})) 0 + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> svm.save(sc, path) + >>> sameModel = SVMModel.load(sc, path) + >>> sameModel.predict(SparseVector(2, {1: 1.0})) + 1 + >>> sameModel.predict(SparseVector(2, {0: -1.0})) + 0 + >>> try: + ... os.removedirs(path) + ... except: + ... pass """ def __init__(self, weights, intercept): super(SVMModel, self).__init__(weights, intercept) @@ -263,12 +383,29 @@ def predict(self, x): else: return 1 if margin > self._threshold else 0 + def save(self, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.classification.SVMModel( + _py2java(sc, self._coeff), self.intercept) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.classification.SVMModel.load( + sc._jsc.sc(), path) + weights = _java2py(sc, java_model.weights()) + intercept = java_model.intercept() + threshold = java_model.getThreshold().get() + model = SVMModel(weights, intercept) + model.setThreshold(threshold) + return model + class SVMWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=0.01, - miniBatchFraction=1.0, initialWeights=None, regType="l2", intercept=False): + miniBatchFraction=1.0, initialWeights=None, regType="l2", + intercept=False, validateData=True): """ Train a support vector machine on the given data. @@ -294,16 +431,20 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, or not of the augmented representation for training data (i.e. whether bias features are activated or not). + :param validateData: Boolean parameter which indicates if the + algorithm should validate data before training. + (default: True) """ def train(rdd, i): return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step), float(regParam), float(miniBatchFraction), i, regType, - bool(intercept)) + bool(intercept), bool(validateData)) return _regression_train_wrapper(train, SVMModel, data, initialWeights) -class NaiveBayesModel(object): +@inherit_doc +class NaiveBayesModel(Saveable, Loader): """ Model for Naive Bayes classifiers. @@ -334,6 +475,16 @@ class NaiveBayesModel(object): 0.0 >>> model.predict(SparseVector(2, {0: 1.0})) 1.0 + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = NaiveBayesModel.load(sc, path) + >>> sameModel.predict(SparseVector(2, {0: 1.0})) == model.predict(SparseVector(2, {0: 1.0})) + True + >>> try: + ... os.removedirs(path) + ... except OSError: + ... pass """ def __init__(self, labels, pi, theta): @@ -348,6 +499,24 @@ def predict(self, x): x = _convert_to_vector(x) return self.labels[numpy.argmax(self.pi + x.dot(self.theta.transpose()))] + def save(self, sc, path): + java_labels = _py2java(sc, self.labels.tolist()) + java_pi = _py2java(sc, self.pi.tolist()) + java_theta = _py2java(sc, self.theta.tolist()) + java_model = sc._jvm.org.apache.spark.mllib.classification.NaiveBayesModel( + java_labels, java_pi, java_theta) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.classification.NaiveBayesModel.load( + sc._jsc.sc(), path) + # Can not unpickle array.array from Pyrolite in Python3 with "bytes" + py_labels = _java2py(sc, java_model.labels(), "latin1") + py_pi = _java2py(sc, java_model.pi(), "latin1") + py_theta = _java2py(sc, java_model.theta(), "latin1") + return NaiveBayesModel(py_labels, py_pi, numpy.array(py_theta)) + class NaiveBayes(object): diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 6b713aa39374..abbb7cf60eec 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -15,19 +15,30 @@ # limitations under the License. # +import sys +import array as pyarray + +if sys.version > '3': + xrange = range + +from numpy import array + +from pyspark import RDD from pyspark import SparkContext -from pyspark.mllib.common import callMLlibFunc, callJavaFunc +from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _py2java, _java2py from pyspark.mllib.linalg import SparseVector, _convert_to_vector +from pyspark.mllib.stat.distribution import MultivariateGaussian +from pyspark.mllib.util import Saveable, Loader, inherit_doc -__all__ = ['KMeansModel', 'KMeans'] +__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture'] -class KMeansModel(object): +@inherit_doc +class KMeansModel(Saveable, Loader): """A clustering model derived from the k-means method. - >>> from numpy import array - >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4,2) + >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4, 2) >>> model = KMeans.train( ... sc.parallelize(data), 2, maxIterations=10, runs=30, initializationMode="random") >>> model.predict(array([0.0, 0.0])) == model.predict(array([1.0, 1.0])) @@ -50,8 +61,18 @@ class KMeansModel(object): True >>> model.predict(sparse_data[2]) == model.predict(sparse_data[3]) True - >>> type(model.clusterCenters) - + >>> isinstance(model.clusterCenters, list) + True + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = KMeansModel.load(sc, path) + >>> sameModel.predict(sparse_data[0]) == model.predict(sparse_data[0]) + True + >>> try: + ... os.removedirs(path) + ... except OSError: + ... pass """ def __init__(self, centers): @@ -74,6 +95,16 @@ def predict(self, x): best_distance = distance return best + def save(self, sc, path): + java_centers = _py2java(sc, [_convert_to_vector(c) for c in self.centers]) + java_model = sc._jvm.org.apache.spark.mllib.clustering.KMeansModel(java_centers) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.clustering.KMeansModel.load(sc._jsc.sc(), path) + return KMeansModel(_java2py(sc, java_model.clusterCenters())) + class KMeans(object): @@ -86,6 +117,87 @@ def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||" return KMeansModel([c.toArray() for c in centers]) +class GaussianMixtureModel(object): + + """A clustering model derived from the Gaussian Mixture Model method. + + >>> clusterdata_1 = sc.parallelize(array([-0.1,-0.05,-0.01,-0.1, + ... 0.9,0.8,0.75,0.935, + ... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2)) + >>> model = GaussianMixture.train(clusterdata_1, 3, convergenceTol=0.0001, + ... maxIterations=50, seed=10) + >>> labels = model.predict(clusterdata_1).collect() + >>> labels[0]==labels[1] + False + >>> labels[1]==labels[2] + True + >>> labels[4]==labels[5] + True + >>> clusterdata_2 = sc.parallelize(array([-5.1971, -2.5359, -3.8220, + ... -5.2211, -5.0602, 4.7118, + ... 6.8989, 3.4592, 4.6322, + ... 5.7048, 4.6567, 5.5026, + ... 4.5605, 5.2043, 6.2734]).reshape(5, 3)) + >>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001, + ... maxIterations=150, seed=10) + >>> labels = model.predict(clusterdata_2).collect() + >>> labels[0]==labels[1]==labels[2] + True + >>> labels[3]==labels[4] + True + """ + + def __init__(self, weights, gaussians): + self.weights = weights + self.gaussians = gaussians + self.k = len(self.weights) + + def predict(self, x): + """ + Find the cluster to which the points in 'x' has maximum membership + in this model. + + :param x: RDD of data points. + :return: cluster_labels. RDD of cluster labels. + """ + if isinstance(x, RDD): + cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z))) + return cluster_labels + + def predictSoft(self, x): + """ + Find the membership of each point in 'x' to all mixture components. + + :param x: RDD of data points. + :return: membership_matrix. RDD of array of double values. + """ + if isinstance(x, RDD): + means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians]) + membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector), + _convert_to_vector(self.weights), means, sigmas) + return membership_matrix.map(lambda x: pyarray.array('d', x)) + + +class GaussianMixture(object): + """ + Learning algorithm for Gaussian Mixtures using the expectation-maximization algorithm. + + :param data: RDD of data points + :param k: Number of components + :param convergenceTol: Threshold value to check the convergence criteria. Defaults to 1e-3 + :param maxIterations: Number of iterations. Default to 100 + :param seed: Random Seed + """ + @classmethod + def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None): + """Train a Gaussian Mixture clustering model.""" + weight, mu, sigma = callMLlibFunc("trainGaussianMixture", + rdd.map(_convert_to_vector), k, + convergenceTol, maxIterations, seed) + mvg_obj = [MultivariateGaussian(mu[i], sigma[i]) for i in range(k)] + return GaussianMixtureModel(weight, mvg_obj) + + def _test(): import doctest globs = globals().copy() diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index 3c5ee66cd8b6..ba6058978880 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -15,6 +15,11 @@ # limitations under the License. # +import sys +if sys.version >= '3': + long = int + unicode = str + import py4j.protocol from py4j.protocol import Py4JJavaError from py4j.java_gateway import JavaObject @@ -36,7 +41,7 @@ def _new_smart_decode(obj): if isinstance(obj, float): - s = unicode(obj) + s = str(obj) return _float_str_mapping.get(s, s) return _old_smart_decode(obj) @@ -70,19 +75,19 @@ def _py2java(sc, obj): obj = _to_java_object_rdd(obj) elif isinstance(obj, SparkContext): obj = obj._jsc - elif isinstance(obj, list) and (obj or isinstance(obj[0], JavaObject)): - obj = ListConverter().convert(obj, sc._gateway._gateway_client) + elif isinstance(obj, list): + obj = ListConverter().convert([_py2java(sc, x) for x in obj], sc._gateway._gateway_client) elif isinstance(obj, JavaObject): pass - elif isinstance(obj, (int, long, float, bool, basestring)): + elif isinstance(obj, (int, long, float, bool, bytes, unicode)): pass else: - bytes = bytearray(PickleSerializer().dumps(obj)) - obj = sc._jvm.SerDe.loads(bytes) + data = bytearray(PickleSerializer().dumps(obj)) + obj = sc._jvm.SerDe.loads(data) return obj -def _java2py(sc, r): +def _java2py(sc, r, encoding="bytes"): if isinstance(r, JavaObject): clsName = r.getClass().getSimpleName() # convert RDD into JavaRDD @@ -102,8 +107,8 @@ def _java2py(sc, r): except Py4JJavaError: pass # not pickable - if isinstance(r, bytearray): - r = PickleSerializer().loads(str(r)) + if isinstance(r, (bytearray, bytes)): + r = PickleSerializer().loads(bytes(r), encoding=encoding) return r @@ -134,3 +139,20 @@ def __del__(self): def call(self, name, *a): """Call method of java_model""" return callJavaFunc(self._sc, getattr(self._java_model, name), *a) + + +def inherit_doc(cls): + """ + A decorator that makes a class inherit documentation from its parents. + """ + for name, func in vars(cls).items(): + # only inherit docstring for public functions + if name.startswith("_"): + continue + if not func.__doc__: + for parent in cls.__bases__: + parent_func = getattr(parent, name, None) + if parent_func and getattr(parent_func, "__doc__", None): + func.__doc__ = parent_func.__doc__ + break + return cls diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py new file mode 100644 index 000000000000..16cb49cc0cff --- /dev/null +++ b/python/pyspark/mllib/evaluation.py @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.mllib.common import JavaModelWrapper +from pyspark.sql import SQLContext +from pyspark.sql.types import StructField, StructType, DoubleType + + +class BinaryClassificationMetrics(JavaModelWrapper): + """ + Evaluator for binary classification. + + >>> scoreAndLabels = sc.parallelize([ + ... (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2) + >>> metrics = BinaryClassificationMetrics(scoreAndLabels) + >>> metrics.areaUnderROC() + 0.70... + >>> metrics.areaUnderPR() + 0.83... + >>> metrics.unpersist() + """ + + def __init__(self, scoreAndLabels): + """ + :param scoreAndLabels: an RDD of (score, label) pairs + """ + sc = scoreAndLabels.ctx + sql_ctx = SQLContext(sc) + df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([ + StructField("score", DoubleType(), nullable=False), + StructField("label", DoubleType(), nullable=False)])) + java_class = sc._jvm.org.apache.spark.mllib.evaluation.BinaryClassificationMetrics + java_model = java_class(df._jdf) + super(BinaryClassificationMetrics, self).__init__(java_model) + + def areaUnderROC(self): + """ + Computes the area under the receiver operating characteristic + (ROC) curve. + """ + return self.call("areaUnderROC") + + def areaUnderPR(self): + """ + Computes the area under the precision-recall curve. + """ + return self.call("areaUnderPR") + + def unpersist(self): + """ + Unpersists intermediate RDDs used in the computation. + """ + self.call("unpersist") + + +def _test(): + import doctest + from pyspark import SparkContext + import pyspark.mllib.evaluation + globs = pyspark.mllib.evaluation.__dict__.copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest') + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 10df6288065b..1140539a24e9 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -23,12 +23,17 @@ import sys import warnings import random +import binascii +if sys.version >= '3': + basestring = str + unicode = str from py4j.protocol import Py4JJavaError -from pyspark import RDD, SparkContext +from pyspark import SparkContext +from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper -from pyspark.mllib.linalg import Vectors, Vector, _convert_to_vector +from pyspark.mllib.linalg import Vectors, _convert_to_vector __all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler', 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel'] @@ -58,7 +63,8 @@ class Normalizer(VectorTransformer): For any 1 <= `p` < float('inf'), normalizes samples using sum(abs(vector) :sup:`p`) :sup:`(1/p)` as norm. - For `p` = float('inf'), max(abs(vector)) will be used as norm for normalization. + For `p` = float('inf'), max(abs(vector)) will be used as norm for + normalization. >>> v = Vectors.dense(range(3)) >>> nor = Normalizer(1) @@ -120,12 +126,33 @@ def transform(self, vector): """ Applies standardization transformation on a vector. + Note: In Python, transform cannot currently be used within + an RDD transformation or action. + Call transform directly on the RDD instead. + :param vector: Vector or RDD of Vector to be standardized. - :return: Standardized vector. If the variance of a column is zero, - it will return default `0.0` for the column with zero variance. + :return: Standardized vector. If the variance of a column is + zero, it will return default `0.0` for the column with + zero variance. """ return JavaVectorTransformer.transform(self, vector) + def setWithMean(self, withMean): + """ + Setter of the boolean which decides + whether it uses mean or not + """ + self.call("setWithMean", withMean) + return self + + def setWithStd(self, withStd): + """ + Setter of the boolean which decides + whether it uses std or not + """ + self.call("setWithStd", withStd) + return self + class StandardScaler(object): """ @@ -148,9 +175,10 @@ def __init__(self, withMean=False, withStd=True): """ :param withMean: False by default. Centers the data with mean before scaling. It will build a dense output, so this - does not work on sparse input and will raise an exception. - :param withStd: True by default. Scales the data to unit standard - deviation. + does not work on sparse input and will raise an + exception. + :param withStd: True by default. Scales the data to unit + standard deviation. """ if not (withMean or withStd): warnings.warn("Both withMean and withStd are false. The model does nothing.") @@ -159,10 +187,11 @@ def __init__(self, withMean=False, withStd=True): def fit(self, dataset): """ - Computes the mean and variance and stores as a model to be used for later scaling. + Computes the mean and variance and stores as a model to be used + for later scaling. - :param data: The data used to compute the mean and variance to build - the transformation model. + :param data: The data used to compute the mean and variance + to build the transformation model. :return: a StandardScalarModel """ dataset = dataset.map(_convert_to_vector) @@ -174,14 +203,15 @@ class HashingTF(object): """ .. note:: Experimental - Maps a sequence of terms to their term frequencies using the hashing trick. + Maps a sequence of terms to their term frequencies using the hashing + trick. Note: the terms must be hashable (can not be dict/set/list...). >>> htf = HashingTF(100) >>> doc = "a a b b c d".split(" ") >>> htf.transform(doc) - SparseVector(100, {1: 1.0, 14: 1.0, 31: 2.0, 44: 2.0}) + SparseVector(100, {...}) """ def __init__(self, numFeatures=1 << 20): """ @@ -195,8 +225,9 @@ def indexOf(self, term): def transform(self, document): """ - Transforms the input document (list of terms) to term frequency vectors, - or transform the RDD of document to RDD of term frequency vectors. + Transforms the input document (list of terms) to term frequency + vectors, or transform the RDD of document to RDD of term + frequency vectors. """ if isinstance(document, RDD): return document.map(self.transform) @@ -220,7 +251,12 @@ def transform(self, x): the terms which occur in fewer than `minDocFreq` documents will have an entry of 0. - :param x: an RDD of term frequency vectors or a term frequency vector + Note: In Python, transform cannot currently be used within + an RDD transformation or action. + Call transform directly on the RDD instead. + + :param x: an RDD of term frequency vectors or a term frequency + vector :return: an RDD of TF-IDF vectors or a TF-IDF vector """ if isinstance(x, RDD): @@ -229,6 +265,12 @@ def transform(self, x): x = _convert_to_vector(x) return JavaVectorTransformer.transform(self, x) + def idf(self): + """ + Returns the current IDF vector. + """ + return self.call('idf') + class IDF(object): """ @@ -241,9 +283,9 @@ class IDF(object): of documents that contain term `t`. This implementation supports filtering out terms which do not appear - in a minimum number of documents (controlled by the variable `minDocFreq`). - For terms that are not in at least `minDocFreq` documents, the IDF is - found as 0, resulting in TF-IDFs of 0. + in a minimum number of documents (controlled by the variable + `minDocFreq`). For terms that are not in at least `minDocFreq` + documents, the IDF is found as 0, resulting in TF-IDFs of 0. >>> n = 4 >>> freqs = [Vectors.sparse(n, (1, 3), (1.0, 2.0)), @@ -316,7 +358,14 @@ def findSynonyms(self, word, num): words, similarity = self.call("findSynonyms", word, num) return zip(words, similarity) + def getVectors(self): + """ + Returns a map of words to their vector representations. + """ + return self.call("getVectors") + +@ignore_unicode_prefix class Word2Vec(object): """ Word2Vec creates vector representation of words in a text corpus. @@ -325,20 +374,21 @@ class Word2Vec(object): The vector representation can be used as features in natural language processing and machine learning algorithms. - We used skip-gram model in our implementation and hierarchical softmax - method to train the model. The variable names in the implementation - matches the original C implementation. + We used skip-gram model in our implementation and hierarchical + softmax method to train the model. The variable names in the + implementation matches the original C implementation. - For original C implementation, see https://code.google.com/p/word2vec/ + For original C implementation, + see https://code.google.com/p/word2vec/ For research papers, see Efficient Estimation of Word Representations in Vector Space - and - Distributed Representations of Words and Phrases and their Compositionality. + and Distributed Representations of Words and Phrases and their + Compositionality. >>> sentence = "a b " * 100 + "a c " * 10 >>> localDoc = [sentence, sentence] >>> doc = sc.parallelize(localDoc).map(lambda line: line.split(" ")) - >>> model = Word2Vec().setVectorSize(10).setSeed(42L).fit(doc) + >>> model = Word2Vec().setVectorSize(10).setSeed(42).fit(doc) >>> syms = model.findSynonyms("a", 2) >>> [s[0] for s in syms] @@ -356,7 +406,8 @@ def __init__(self): self.learningRate = 0.025 self.numPartitions = 1 self.numIterations = 1 - self.seed = random.randint(0, sys.maxint) + self.seed = random.randint(0, sys.maxsize) + self.minCount = 5 def setVectorSize(self, vectorSize): """ @@ -374,15 +425,16 @@ def setLearningRate(self, learningRate): def setNumPartitions(self, numPartitions): """ - Sets number of partitions (default: 1). Use a small number for accuracy. + Sets number of partitions (default: 1). Use a small number for + accuracy. """ self.numPartitions = numPartitions return self def setNumIterations(self, numIterations): """ - Sets number of iterations (default: 1), which should be smaller than or equal to number of - partitions. + Sets number of iterations (default: 1), which should be smaller + than or equal to number of partitions. """ self.numIterations = numIterations return self @@ -394,6 +446,14 @@ def setSeed(self, seed): self.seed = seed return self + def setMinCount(self, minCount): + """ + Sets minCount, the minimum number of times a token must appear + to be included in the word2vec model's vocabulary (default: 5). + """ + self.minCount = minCount + return self + def fit(self, data): """ Computes the vector representation of each word in vocabulary. @@ -405,7 +465,8 @@ def fit(self, data): raise TypeError("data should be an RDD of list of string") jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize), float(self.learningRate), int(self.numPartitions), - int(self.numIterations), long(self.seed)) + int(self.numIterations), int(self.seed), + int(self.minCount)) return Word2VecModel(jmodel) diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py new file mode 100644 index 000000000000..d8df02bdbaba --- /dev/null +++ b/python/pyspark/mllib/fpm.py @@ -0,0 +1,92 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy +from numpy import array +from collections import namedtuple + +from pyspark import SparkContext +from pyspark.rdd import ignore_unicode_prefix +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc + +__all__ = ['FPGrowth', 'FPGrowthModel'] + + +@inherit_doc +@ignore_unicode_prefix +class FPGrowthModel(JavaModelWrapper): + + """ + .. note:: Experimental + + A FP-Growth model for mining frequent itemsets + using the Parallel FP-Growth algorithm. + + >>> data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]] + >>> rdd = sc.parallelize(data, 2) + >>> model = FPGrowth.train(rdd, 0.6, 2) + >>> sorted(model.freqItemsets().collect()) + [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ... + """ + + def freqItemsets(self): + """ + Returns the frequent itemsets of this model. + """ + return self.call("getFreqItemsets").map(lambda x: (FPGrowth.FreqItemset(x[0], x[1]))) + + +class FPGrowth(object): + """ + .. note:: Experimental + + A Parallel FP-growth algorithm to mine frequent itemsets. + """ + + @classmethod + def train(cls, data, minSupport=0.3, numPartitions=-1): + """ + Computes an FP-Growth model that contains frequent itemsets. + :param data: The input data set, each element + contains a transaction. + :param minSupport: The minimal support level + (default: `0.3`). + :param numPartitions: The number of partitions used by parallel + FP-growth (default: same as input data). + """ + model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartitions)) + return FPGrowthModel(model) + + class FreqItemset(namedtuple("FreqItemset", ["items", "freq"])): + """ + Represents an (items, freq) tuple. + """ + + +def _test(): + import doctest + import pyspark.mllib.fpm + globs = pyspark.mllib.fpm.__dict__.copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest') + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 7f21190ed8c2..cc9a4cf8ba17 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -25,11 +25,17 @@ import sys import array -import copy_reg + +if sys.version >= '3': + basestring = str + xrange = range + import copyreg as copy_reg +else: + import copy_reg import numpy as np -from pyspark.sql import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ +from pyspark.sql.types import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ IntegerType, ByteType @@ -57,7 +63,7 @@ def fast_pickle_array(ar): def _convert_to_vector(l): if isinstance(l, Vector): return l - elif type(l) in (array.array, np.array, np.ndarray, list, tuple): + elif type(l) in (array.array, np.array, np.ndarray, list, tuple, xrange): return DenseVector(l) elif _have_scipy and scipy.sparse.issparse(l): assert l.shape[1] == 1, "Expected column vector" @@ -88,7 +94,7 @@ def _vector_size(v): """ if isinstance(v, Vector): return len(v) - elif type(v) in (array.array, list, tuple): + elif type(v) in (array.array, list, tuple, xrange): return len(v) elif type(v) == np.ndarray: if v.ndim == 1 or (v.ndim == 2 and v.shape[1] == 1): @@ -139,7 +145,7 @@ def serialize(self, obj): values = [float(v) for v in obj] return (1, None, None, values) else: - raise ValueError("cannot serialize %r of type %r" % (obj, type(obj))) + raise TypeError("cannot serialize %r of type %r" % (obj, type(obj))) def deserialize(self, datum): assert len(datum) == 4, \ @@ -152,6 +158,9 @@ def deserialize(self, datum): else: raise ValueError("do not recognize type %r" % tpe) + def simpleString(self): + return "vector" + class Vector(object): @@ -170,10 +179,27 @@ def toArray(self): class DenseVector(Vector): """ - A dense vector represented by a value array. + A dense vector represented by a value array. We use numpy array for + storage and arithmetics will be delegated to the underlying numpy + array. + + >>> v = Vectors.dense([1.0, 2.0]) + >>> u = Vectors.dense([3.0, 4.0]) + >>> v + u + DenseVector([4.0, 6.0]) + >>> 2 - v + DenseVector([1.0, 0.0]) + >>> v / 2 + DenseVector([0.5, 1.0]) + >>> v * u + DenseVector([3.0, 8.0]) + >>> u / v + DenseVector([3.0, 2.0]) + >>> u % 2 + DenseVector([1.0, 0.0]) """ def __init__(self, ar): - if isinstance(ar, basestring): + if isinstance(ar, bytes): ar = np.frombuffer(ar, dtype=np.float64) elif not isinstance(ar, np.ndarray): ar = np.array(ar, dtype=np.float64) @@ -289,6 +315,27 @@ def __ne__(self, other): def __getattr__(self, item): return getattr(self.array, item) + def _delegate(op): + def func(self, other): + if isinstance(other, DenseVector): + other = other.array + return DenseVector(getattr(self.array, op)(other)) + return func + + __neg__ = _delegate("__neg__") + __add__ = _delegate("__add__") + __sub__ = _delegate("__sub__") + __mul__ = _delegate("__mul__") + __div__ = _delegate("__div__") + __truediv__ = _delegate("__truediv__") + __mod__ = _delegate("__mod__") + __radd__ = _delegate("__radd__") + __rsub__ = _delegate("__rsub__") + __rmul__ = _delegate("__rmul__") + __rdiv__ = _delegate("__rdiv__") + __rtruediv__ = _delegate("__rtruediv__") + __rmod__ = _delegate("__rmod__") + class SparseVector(Vector): """ @@ -305,12 +352,12 @@ def __init__(self, size, *args): :param args: Non-zero entries, as a dictionary, list of tupes, or two sorted lists containing indices and values. - >>> print SparseVector(4, {1: 1.0, 3: 5.5}) - (4,[1,3],[1.0,5.5]) - >>> print SparseVector(4, [(1, 1.0), (3, 5.5)]) - (4,[1,3],[1.0,5.5]) - >>> print SparseVector(4, [1, 3], [1.0, 5.5]) - (4,[1,3],[1.0,5.5]) + >>> SparseVector(4, {1: 1.0, 3: 5.5}) + SparseVector(4, {1: 1.0, 3: 5.5}) + >>> SparseVector(4, [(1, 1.0), (3, 5.5)]) + SparseVector(4, {1: 1.0, 3: 5.5}) + >>> SparseVector(4, [1, 3], [1.0, 5.5]) + SparseVector(4, {1: 1.0, 3: 5.5}) """ self.size = int(size) assert 1 <= len(args) <= 2, "must pass either 2 or 3 arguments" @@ -322,8 +369,8 @@ def __init__(self, size, *args): self.indices = np.array([p[0] for p in pairs], dtype=np.int32) self.values = np.array([p[1] for p in pairs], dtype=np.float64) else: - if isinstance(args[0], basestring): - assert isinstance(args[1], str), "values should be string too" + if isinstance(args[0], bytes): + assert isinstance(args[1], bytes), "values should be string too" if args[0]: self.indices = np.frombuffer(args[0], np.int32) self.values = np.frombuffer(args[1], np.float64) @@ -514,7 +561,7 @@ def __getitem__(self, index): inds = self.indices vals = self.values if not isinstance(index, int): - raise ValueError( + raise TypeError( "Indices must be of type integer, got type %s" % type(index)) if index < 0: index += self.size @@ -552,12 +599,12 @@ def sparse(size, *args): :param args: Non-zero entries, as a dictionary, list of tupes, or two sorted lists containing indices and values. - >>> print Vectors.sparse(4, {1: 1.0, 3: 5.5}) - (4,[1,3],[1.0,5.5]) - >>> print Vectors.sparse(4, [(1, 1.0), (3, 5.5)]) - (4,[1,3],[1.0,5.5]) - >>> print Vectors.sparse(4, [1, 3], [1.0, 5.5]) - (4,[1,3],[1.0,5.5]) + >>> Vectors.sparse(4, {1: 1.0, 3: 5.5}) + SparseVector(4, {1: 1.0, 3: 5.5}) + >>> Vectors.sparse(4, [(1, 1.0), (3, 5.5)]) + SparseVector(4, {1: 1.0, 3: 5.5}) + >>> Vectors.sparse(4, [1, 3], [1.0, 5.5]) + SparseVector(4, {1: 1.0, 3: 5.5}) """ return SparseVector(size, *args) @@ -591,9 +638,10 @@ class Matrix(object): Represents a local matrix. """ - def __init__(self, numRows, numCols): + def __init__(self, numRows, numCols, isTransposed=False): self.numRows = numRows self.numCols = numCols + self.isTransposed = isTransposed def toArray(self): """ @@ -601,24 +649,30 @@ def toArray(self): """ raise NotImplementedError + @staticmethod + def _convert_to_array(array_like, dtype): + """ + Convert Matrix attributes which are array-like or buffer to array. + """ + if isinstance(array_like, bytes): + return np.frombuffer(array_like, dtype=dtype) + return np.asarray(array_like, dtype=dtype) + class DenseMatrix(Matrix): """ Column-major dense matrix. """ - def __init__(self, numRows, numCols, values): - Matrix.__init__(self, numRows, numCols) - if isinstance(values, basestring): - values = np.frombuffer(values, dtype=np.float64) - elif not isinstance(values, np.ndarray): - values = np.array(values, dtype=np.float64) + def __init__(self, numRows, numCols, values, isTransposed=False): + Matrix.__init__(self, numRows, numCols, isTransposed) + values = self._convert_to_array(values, np.float64) assert len(values) == numRows * numCols - if values.dtype != np.float64: - values.astype(np.float64) self.values = values def __reduce__(self): - return DenseMatrix, (self.numRows, self.numCols, self.values.tostring()) + return DenseMatrix, ( + self.numRows, self.numCols, self.values.tostring(), + int(self.isTransposed)) def toArray(self): """ @@ -629,13 +683,124 @@ def toArray(self): array([[ 0., 2.], [ 1., 3.]]) """ - return self.values.reshape((self.numRows, self.numCols), order='F') + if self.isTransposed: + return np.asfortranarray( + self.values.reshape((self.numRows, self.numCols))) + else: + return self.values.reshape((self.numRows, self.numCols), order='F') + + def toSparse(self): + """Convert to SparseMatrix""" + if self.isTransposed: + values = np.ravel(self.toArray(), order='F') + else: + values = self.values + indices = np.nonzero(values)[0] + colCounts = np.bincount(indices // self.numRows) + colPtrs = np.cumsum(np.hstack( + (0, colCounts, np.zeros(self.numCols - colCounts.size)))) + values = values[indices] + rowIndices = indices % self.numRows + + return SparseMatrix(self.numRows, self.numCols, colPtrs, rowIndices, values) + + def __getitem__(self, indices): + i, j = indices + if i < 0 or i >= self.numRows: + raise ValueError("Row index %d is out of range [0, %d)" + % (i, self.numRows)) + if j >= self.numCols or j < 0: + raise ValueError("Column index %d is out of range [0, %d)" + % (j, self.numCols)) + + if self.isTransposed: + return self.values[i * self.numCols + j] + else: + return self.values[i + j * self.numRows] + + def __eq__(self, other): + if (not isinstance(other, DenseMatrix) or + self.numRows != other.numRows or + self.numCols != other.numCols): + return False + + self_values = np.ravel(self.toArray(), order='F') + other_values = np.ravel(other.toArray(), order='F') + return all(self_values == other_values) + + +class SparseMatrix(Matrix): + """Sparse Matrix stored in CSC format.""" + def __init__(self, numRows, numCols, colPtrs, rowIndices, values, + isTransposed=False): + Matrix.__init__(self, numRows, numCols, isTransposed) + self.colPtrs = self._convert_to_array(colPtrs, np.int32) + self.rowIndices = self._convert_to_array(rowIndices, np.int32) + self.values = self._convert_to_array(values, np.float64) + + if self.isTransposed: + if self.colPtrs.size != numRows + 1: + raise ValueError("Expected colPtrs of size %d, got %d." + % (numRows + 1, self.colPtrs.size)) + else: + if self.colPtrs.size != numCols + 1: + raise ValueError("Expected colPtrs of size %d, got %d." + % (numCols + 1, self.colPtrs.size)) + if self.rowIndices.size != self.values.size: + raise ValueError("Expected rowIndices of length %d, got %d." + % (self.rowIndices.size, self.values.size)) + + def __reduce__(self): + return SparseMatrix, ( + self.numRows, self.numCols, self.colPtrs.tostring(), + self.rowIndices.tostring(), self.values.tostring(), + self.isTransposed) + + def __getitem__(self, indices): + i, j = indices + if i < 0 or i >= self.numRows: + raise ValueError("Row index %d is out of range [0, %d)" + % (i, self.numRows)) + if j < 0 or j >= self.numCols: + raise ValueError("Column index %d is out of range [0, %d)" + % (j, self.numCols)) + + # If a CSR matrix is given, then the row index should be searched + # for in ColPtrs, and the column index should be searched for in the + # corresponding slice obtained from rowIndices. + if self.isTransposed: + j, i = i, j + + colStart = self.colPtrs[j] + colEnd = self.colPtrs[j + 1] + nz = self.rowIndices[colStart: colEnd] + ind = np.searchsorted(nz, i) + colStart + if ind < colEnd and self.rowIndices[ind] == i: + return self.values[ind] + else: + return 0.0 + + def toArray(self): + """ + Return an numpy.ndarray + """ + A = np.zeros((self.numRows, self.numCols), dtype=np.float64, order='F') + for k in xrange(self.colPtrs.size - 1): + startptr = self.colPtrs[k] + endptr = self.colPtrs[k + 1] + if self.isTransposed: + A[k, self.rowIndices[startptr:endptr]] = self.values[startptr:endptr] + else: + A[self.rowIndices[startptr:endptr], k] = self.values[startptr:endptr] + return A + def toDense(self): + densevals = np.ravel(self.toArray(), order='F') + return DenseMatrix(self.numRows, self.numCols, densevals) + + # TODO: More efficient implementation: def __eq__(self, other): - return (isinstance(other, DenseMatrix) and - self.numRows == other.numRows and - self.numCols == other.numCols and - all(self.values == other.values)) + return np.all(self.toArray == other.toArray) class Matrices(object): @@ -646,6 +811,13 @@ def dense(numRows, numCols, values): """ return DenseMatrix(numRows, numCols, values) + @staticmethod + def sparse(numRows, numCols, colPtrs, rowIndices, values): + """ + Create a SparseMatrix + """ + return SparseMatrix(numRows, numCols, colPtrs, rowIndices, values) + def _test(): import doctest diff --git a/python/pyspark/mllib/rand.py b/python/pyspark/mllib/rand.py index 20ee9d78bf5b..06fbc0eb6aef 100644 --- a/python/pyspark/mllib/rand.py +++ b/python/pyspark/mllib/rand.py @@ -88,10 +88,10 @@ def normalRDD(sc, size, numPartitions=None, seed=None): :param seed: Random seed (default: a random long integer). :return: RDD of float comprised of i.i.d. samples ~ N(0.0, 1.0). - >>> x = RandomRDDs.normalRDD(sc, 1000, seed=1L) + >>> x = RandomRDDs.normalRDD(sc, 1000, seed=1) >>> stats = x.stats() >>> stats.count() - 1000L + 1000 >>> abs(stats.mean() - 0.0) < 0.1 True >>> abs(stats.stdev() - 1.0) < 0.1 @@ -118,10 +118,10 @@ def logNormalRDD(sc, mean, std, size, numPartitions=None, seed=None): >>> std = 1.0 >>> expMean = exp(mean + 0.5 * std * std) >>> expStd = sqrt((exp(std * std) - 1.0) * exp(2.0 * mean + std * std)) - >>> x = RandomRDDs.logNormalRDD(sc, mean, std, 1000, seed=2L) + >>> x = RandomRDDs.logNormalRDD(sc, mean, std, 1000, seed=2) >>> stats = x.stats() >>> stats.count() - 1000L + 1000 >>> abs(stats.mean() - expMean) < 0.5 True >>> from math import sqrt @@ -145,10 +145,10 @@ def poissonRDD(sc, mean, size, numPartitions=None, seed=None): :return: RDD of float comprised of i.i.d. samples ~ Pois(mean). >>> mean = 100.0 - >>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=2L) + >>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=2) >>> stats = x.stats() >>> stats.count() - 1000L + 1000 >>> abs(stats.mean() - mean) < 0.5 True >>> from math import sqrt @@ -171,10 +171,10 @@ def exponentialRDD(sc, mean, size, numPartitions=None, seed=None): :return: RDD of float comprised of i.i.d. samples ~ Exp(mean). >>> mean = 2.0 - >>> x = RandomRDDs.exponentialRDD(sc, mean, 1000, seed=2L) + >>> x = RandomRDDs.exponentialRDD(sc, mean, 1000, seed=2) >>> stats = x.stats() >>> stats.count() - 1000L + 1000 >>> abs(stats.mean() - mean) < 0.5 True >>> from math import sqrt @@ -202,10 +202,10 @@ def gammaRDD(sc, shape, scale, size, numPartitions=None, seed=None): >>> scale = 2.0 >>> expMean = shape * scale >>> expStd = sqrt(shape * scale * scale) - >>> x = RandomRDDs.gammaRDD(sc, shape, scale, 1000, seed=2L) + >>> x = RandomRDDs.gammaRDD(sc, shape, scale, 1000, seed=2) >>> stats = x.stats() >>> stats.count() - 1000L + 1000 >>> abs(stats.mean() - expMean) < 0.5 True >>> abs(stats.stdev() - expStd) < 0.5 @@ -254,7 +254,7 @@ def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): :return: RDD of Vector with vectors containing i.i.d. samples ~ `N(0.0, 1.0)`. >>> import numpy as np - >>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1L).collect()) + >>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1).collect()) >>> mat.shape (100, 100) >>> abs(mat.mean() - 0.0) < 0.1 @@ -286,8 +286,8 @@ def logNormalVectorRDD(sc, mean, std, numRows, numCols, numPartitions=None, seed >>> std = 1.0 >>> expMean = exp(mean + 0.5 * std * std) >>> expStd = sqrt((exp(std * std) - 1.0) * exp(2.0 * mean + std * std)) - >>> mat = np.matrix(RandomRDDs.logNormalVectorRDD(sc, mean, std, \ - 100, 100, seed=1L).collect()) + >>> m = RandomRDDs.logNormalVectorRDD(sc, mean, std, 100, 100, seed=1).collect() + >>> mat = np.matrix(m) >>> mat.shape (100, 100) >>> abs(mat.mean() - expMean) < 0.1 @@ -315,7 +315,7 @@ def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): >>> import numpy as np >>> mean = 100.0 - >>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1L) + >>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1) >>> mat = np.mat(rdd.collect()) >>> mat.shape (100, 100) @@ -345,7 +345,7 @@ def exponentialVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=No >>> import numpy as np >>> mean = 0.5 - >>> rdd = RandomRDDs.exponentialVectorRDD(sc, mean, 100, 100, seed=1L) + >>> rdd = RandomRDDs.exponentialVectorRDD(sc, mean, 100, 100, seed=1) >>> mat = np.mat(rdd.collect()) >>> mat.shape (100, 100) @@ -380,8 +380,7 @@ def gammaVectorRDD(sc, shape, scale, numRows, numCols, numPartitions=None, seed= >>> scale = 2.0 >>> expMean = shape * scale >>> expStd = sqrt(shape * scale * scale) - >>> mat = np.matrix(RandomRDDs.gammaVectorRDD(sc, shape, scale, \ - 100, 100, seed=1L).collect()) + >>> mat = np.matrix(RandomRDDs.gammaVectorRDD(sc, shape, scale, 100, 100, seed=1).collect()) >>> mat.shape (100, 100) >>> abs(mat.mean() - expMean) < 0.1 diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 97ec74eda0b7..4b7d17d64e94 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -15,11 +15,14 @@ # limitations under the License. # +import array from collections import namedtuple from pyspark import SparkContext from pyspark.rdd import RDD -from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc +from pyspark.mllib.util import JavaLoader, JavaSaveable +from pyspark.sql import DataFrame __all__ = ['MatrixFactorizationModel', 'ALS', 'Rating'] @@ -39,7 +42,8 @@ def __reduce__(self): return Rating, (int(self.user), int(self.product), float(self.rating)) -class MatrixFactorizationModel(JavaModelWrapper): +@inherit_doc +class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): """A matrix factorisation model trained by regularized alternating least-squares. @@ -49,17 +53,17 @@ class MatrixFactorizationModel(JavaModelWrapper): >>> r3 = (2, 1, 2.0) >>> ratings = sc.parallelize([r1, r2, r3]) >>> model = ALS.trainImplicit(ratings, 1, seed=10) - >>> model.predict(2,2) - 0.4473... + >>> model.predict(2, 2) + 0.4... >>> testset = sc.parallelize([(1, 2), (1, 1)]) - >>> model = ALS.train(ratings, 1, seed=10) + >>> model = ALS.train(ratings, 2, seed=0) >>> model.predictAll(testset).collect() - [Rating(user=1, product=1, rating=1.0471...), Rating(user=1, product=2, rating=1.9679...)] + [Rating(user=1, product=1, rating=1.0...), Rating(user=1, product=2, rating=1.9...)] >>> model = ALS.train(ratings, 4, seed=10) >>> model.userFeatures().collect() - [(2, array('d', [...])), (1, array('d', [...]))] + [(1, array('d', [...])), (2, array('d', [...]))] >>> first_user = model.userFeatures().take(1)[0] >>> latents = first_user[1] @@ -67,7 +71,7 @@ class MatrixFactorizationModel(JavaModelWrapper): True >>> model.productFeatures().collect() - [(2, array('d', [...])), (1, array('d', [...]))] + [(1, array('d', [...])), (2, array('d', [...]))] >>> first_product = model.productFeatures().take(1)[0] >>> latents = first_product[1] @@ -75,12 +79,30 @@ class MatrixFactorizationModel(JavaModelWrapper): True >>> model = ALS.train(ratings, 1, nonnegative=True, seed=10) - >>> model.predict(2,2) - 3.735... + >>> model.predict(2, 2) + 3.8... + + >>> df = sqlContext.createDataFrame([Rating(1, 1, 1.0), Rating(1, 2, 2.0), Rating(2, 1, 2.0)]) + >>> model = ALS.train(df, 1, nonnegative=True, seed=10) + >>> model.predict(2, 2) + 3.8... >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10) - >>> model.predict(2,2) - 0.4473... + >>> model.predict(2, 2) + 0.4... + + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = MatrixFactorizationModel.load(sc, path) + >>> sameModel.predict(2, 2) + 0.4... + >>> sameModel.predictAll(testset).collect() + [Rating(... + >>> try: + ... os.removedirs(path) + ... except OSError: + ... pass """ def predict(self, user, product): return self._java_model.predict(int(user), int(product)) @@ -89,27 +111,40 @@ def predictAll(self, user_product): assert isinstance(user_product, RDD), "user_product should be RDD of (user, product)" first = user_product.first() assert len(first) == 2, "user_product should be RDD of (user, product)" - user_product = user_product.map(lambda (u, p): (int(u), int(p))) + user_product = user_product.map(lambda u_p: (int(u_p[0]), int(u_p[1]))) return self.call("predict", user_product) def userFeatures(self): - return self.call("getUserFeatures") + return self.call("getUserFeatures").mapValues(lambda v: array.array('d', v)) def productFeatures(self): - return self.call("getProductFeatures") + return self.call("getProductFeatures").mapValues(lambda v: array.array('d', v)) + + @classmethod + def load(cls, sc, path): + model = cls._load_java(sc, path) + wrapper = sc._jvm.MatrixFactorizationModelWrapper(model) + return MatrixFactorizationModel(wrapper) class ALS(object): @classmethod def _prepare(cls, ratings): - assert isinstance(ratings, RDD), "ratings should be RDD" + if isinstance(ratings, RDD): + pass + elif isinstance(ratings, DataFrame): + ratings = ratings.rdd + else: + raise TypeError("Ratings should be represented by either an RDD or a DataFrame, " + "but got %s." % type(ratings)) first = ratings.first() - if not isinstance(first, Rating): - if isinstance(first, (tuple, list)): - ratings = ratings.map(lambda x: Rating(*x)) - else: - raise ValueError("rating should be RDD of Rating or tuple/list") + if isinstance(first, Rating): + pass + elif isinstance(first, (tuple, list)): + ratings = ratings.map(lambda x: Rating(*x)) + else: + raise TypeError("Expect a Rating or a tuple/list, but got %s." % type(first)) return ratings @classmethod @@ -130,8 +165,11 @@ def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alp def _test(): import doctest import pyspark.mllib.recommendation + from pyspark.sql import SQLContext globs = pyspark.mllib.recommendation.__dict__.copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest') + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 210060140fd9..4bc6351bdf02 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -18,11 +18,14 @@ import numpy as np from numpy import array -from pyspark.mllib.common import callMLlibFunc +from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py, inherit_doc from pyspark.mllib.linalg import SparseVector, _convert_to_vector +from pyspark.mllib.util import Saveable, Loader -__all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel', - 'LinearRegressionWithSGD', 'LassoWithSGD', 'RidgeRegressionWithSGD'] +__all__ = ['LabeledPoint', 'LinearModel', + 'LinearRegressionModel', 'LinearRegressionWithSGD', + 'RidgeRegressionModel', 'RidgeRegressionWithSGD', + 'LassoModel', 'LassoWithSGD'] class LabeledPoint(object): @@ -31,8 +34,11 @@ class LabeledPoint(object): The features and labels of a data point. :param label: Label for this data point. - :param features: Vector of features for this point (NumPy array, list, - pyspark.mllib.linalg.SparseVector, or scipy.sparse column matrix) + :param features: Vector of features for this point (NumPy array, + list, pyspark.mllib.linalg.SparseVector, or scipy.sparse + column matrix) + + Note: 'label' and 'features' are accessible as class attributes. """ def __init__(self, label, features): @@ -69,6 +75,7 @@ def __repr__(self): return "(weights=%s, intercept=%r)" % (self._coeff, self._intercept) +@inherit_doc class LinearRegressionModelBase(LinearModel): """A linear regression model. @@ -89,6 +96,7 @@ def predict(self, x): return self.weights.dot(x) + self.intercept +@inherit_doc class LinearRegressionModel(LinearRegressionModelBase): """A linear regression model derived from a least-squares fit. @@ -100,44 +108,88 @@ class LinearRegressionModel(LinearRegressionModelBase): ... LabeledPoint(3.0, [2.0]), ... LabeledPoint(2.0, [3.0]) ... ] - >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), initialWeights=np.array([1.0])) + >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, + ... initialWeights=np.array([1.0])) >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 True >>> abs(lrm.predict(np.array([1.0])) - 1) < 0.5 True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> lrm.save(sc, path) + >>> sameModel = LinearRegressionModel.load(sc, path) + >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5 + True + >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5 + True + >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True + >>> try: + ... os.removedirs(path) + ... except: + ... pass >>> data = [ ... LabeledPoint(0.0, SparseVector(1, {0: 0.0})), ... LabeledPoint(1.0, SparseVector(1, {0: 1.0})), ... LabeledPoint(3.0, SparseVector(1, {0: 2.0})), ... LabeledPoint(2.0, SparseVector(1, {0: 3.0})) ... ] - >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) + >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, + ... initialWeights=array([1.0])) + >>> abs(lrm.predict(array([0.0])) - 0) < 0.5 + True + >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True + >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, step=1.0, + ... miniBatchFraction=1.0, initialWeights=array([1.0]), regParam=0.1, regType="l2", + ... intercept=True, validateData=True) >>> abs(lrm.predict(array([0.0])) - 0) < 0.5 True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True """ + def save(self, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel( + _py2java(sc, self._coeff), self.intercept) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel.load( + sc._jsc.sc(), path) + weights = _java2py(sc, java_model.weights()) + intercept = java_model.intercept() + model = LinearRegressionModel(weights, intercept) + return model # train_func should take two parameters, namely data and initial_weights, and # return the result of a call to the appropriate JVM stub. # _regression_train_wrapper is responsible for setup and error checking. def _regression_train_wrapper(train_func, modelClass, data, initial_weights): + from pyspark.mllib.classification import LogisticRegressionModel first = data.first() if not isinstance(first, LabeledPoint): - raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first) - initial_weights = initial_weights or [0.0] * len(data.first().features) - weights, intercept = train_func(data, _convert_to_vector(initial_weights)) - return modelClass(weights, intercept) + raise TypeError("data should be an RDD of LabeledPoint, but got %s" % type(first)) + if initial_weights is None: + initial_weights = [0.0] * len(data.first().features) + if (modelClass == LogisticRegressionModel): + weights, intercept, numFeatures, numClasses = train_func( + data, _convert_to_vector(initial_weights)) + return modelClass(weights, intercept, numFeatures, numClasses) + else: + weights, intercept = train_func(data, _convert_to_vector(initial_weights)) + return modelClass(weights, intercept) class LinearRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, - initialWeights=None, regParam=0.0, regType=None, intercept=False): + initialWeights=None, regParam=0.0, regType=None, intercept=False, + validateData=True): """ Train a linear regression model on the given data. @@ -159,19 +211,23 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, (default: None) - @param intercept: Boolean parameter which indicates the use + :param intercept: Boolean parameter which indicates the use or not of the augmented representation for training data (i.e. whether bias features - are activated or not). + are activated or not). (default: False) + :param validateData: Boolean parameter which indicates if the + algorithm should validate data before training. + (default: True) """ def train(rdd, i): return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations), float(step), float(miniBatchFraction), i, float(regParam), - regType, bool(intercept)) + regType, bool(intercept), bool(validateData)) return _regression_train_wrapper(train, LinearRegressionModel, data, initialWeights) +@inherit_doc class LassoModel(LinearRegressionModelBase): """A linear regression model derived from a least-squares fit with an @@ -184,40 +240,78 @@ class LassoModel(LinearRegressionModelBase): ... LabeledPoint(3.0, [2.0]), ... LabeledPoint(2.0, [3.0]) ... ] - >>> lrm = LassoWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) + >>> lrm = LassoWithSGD.train(sc.parallelize(data), iterations=10, initialWeights=array([1.0])) >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 True >>> abs(lrm.predict(np.array([1.0])) - 1) < 0.5 True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> lrm.save(sc, path) + >>> sameModel = LassoModel.load(sc, path) + >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5 + True + >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5 + True + >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True + >>> try: + ... os.removedirs(path) + ... except: + ... pass >>> data = [ ... LabeledPoint(0.0, SparseVector(1, {0: 0.0})), ... LabeledPoint(1.0, SparseVector(1, {0: 1.0})), ... LabeledPoint(3.0, SparseVector(1, {0: 2.0})), ... LabeledPoint(2.0, SparseVector(1, {0: 3.0})) ... ] - >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) + >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, + ... initialWeights=array([1.0])) + >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + True + >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True + >>> lrm = LassoWithSGD.train(sc.parallelize(data), iterations=10, step=1.0, + ... regParam=0.01, miniBatchFraction=1.0, initialWeights=array([1.0]), intercept=True, + ... validateData=True) >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True """ + def save(self, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel( + _py2java(sc, self._coeff), self.intercept) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel.load( + sc._jsc.sc(), path) + weights = _java2py(sc, java_model.weights()) + intercept = java_model.intercept() + model = LassoModel(weights, intercept) + return model class LassoWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=0.01, - miniBatchFraction=1.0, initialWeights=None): + miniBatchFraction=1.0, initialWeights=None, intercept=False, + validateData=True): """Train a Lasso regression model on the given data.""" def train(rdd, i): return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step), - float(regParam), float(miniBatchFraction), i) + float(regParam), float(miniBatchFraction), i, bool(intercept), + bool(validateData)) return _regression_train_wrapper(train, LassoModel, data, initialWeights) +@inherit_doc class RidgeRegressionModel(LinearRegressionModelBase): """A linear regression model derived from a least-squares fit with an @@ -230,36 +324,74 @@ class RidgeRegressionModel(LinearRegressionModelBase): ... LabeledPoint(3.0, [2.0]), ... LabeledPoint(2.0, [3.0]) ... ] - >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) + >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), iterations=10, + ... initialWeights=array([1.0])) >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 True >>> abs(lrm.predict(np.array([1.0])) - 1) < 0.5 True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> lrm.save(sc, path) + >>> sameModel = RidgeRegressionModel.load(sc, path) + >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5 + True + >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5 + True + >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True + >>> try: + ... os.removedirs(path) + ... except: + ... pass >>> data = [ ... LabeledPoint(0.0, SparseVector(1, {0: 0.0})), ... LabeledPoint(1.0, SparseVector(1, {0: 1.0})), ... LabeledPoint(3.0, SparseVector(1, {0: 2.0})), ... LabeledPoint(2.0, SparseVector(1, {0: 3.0})) ... ] - >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) + >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, + ... initialWeights=array([1.0])) + >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + True + >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True + >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), iterations=10, step=1.0, + ... regParam=0.01, miniBatchFraction=1.0, initialWeights=array([1.0]), intercept=True, + ... validateData=True) >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True """ + def save(self, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel( + _py2java(sc, self._coeff), self.intercept) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel.load( + sc._jsc.sc(), path) + weights = _java2py(sc, java_model.weights()) + intercept = java_model.intercept() + model = RidgeRegressionModel(weights, intercept) + return model class RidgeRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=0.01, - miniBatchFraction=1.0, initialWeights=None): + miniBatchFraction=1.0, initialWeights=None, intercept=False, + validateData=True): """Train a ridge regression model on the given data.""" def train(rdd, i): return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step), - float(regParam), float(miniBatchFraction), i) + float(regParam), float(miniBatchFraction), i, bool(intercept), + bool(validateData)) return _regression_train_wrapper(train, RidgeRegressionModel, data, initialWeights) @@ -269,7 +401,7 @@ def _test(): from pyspark import SparkContext import pyspark.mllib.regression globs = pyspark.mllib.regression.__dict__.copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: diff --git a/python/pyspark/mllib/stat/__init__.py b/python/pyspark/mllib/stat/__init__.py new file mode 100644 index 000000000000..e3e128513e0d --- /dev/null +++ b/python/pyspark/mllib/stat/__init__.py @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Python package for statistical functions in MLlib. +""" + +from pyspark.mllib.stat._statistics import * +from pyspark.mllib.stat.distribution import MultivariateGaussian +from pyspark.mllib.stat.test import ChiSqTestResult + +__all__ = ["Statistics", "MultivariateStatisticalSummary", "ChiSqTestResult", + "MultivariateGaussian"] diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat/_statistics.py similarity index 84% rename from python/pyspark/mllib/stat.py rename to python/pyspark/mllib/stat/_statistics.py index c8af777a8b00..b475be4b4d95 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -15,17 +15,14 @@ # limitations under the License. # -""" -Python package for statistical functions in MLlib. -""" - -from pyspark import RDD +from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper from pyspark.mllib.linalg import Matrix, _convert_to_vector from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.stat.test import ChiSqTestResult -__all__ = ['MultivariateStatisticalSummary', 'ChiSqTestResult', 'Statistics'] +__all__ = ['MultivariateStatisticalSummary', 'Statistics'] class MultivariateStatisticalSummary(JavaModelWrapper): @@ -41,7 +38,7 @@ def variance(self): return self.call("variance").toArray() def count(self): - return self.call("count") + return int(self.call("count")) def numNonzeros(self): return self.call("numNonzeros").toArray() @@ -52,53 +49,11 @@ def max(self): def min(self): return self.call("min").toArray() + def normL1(self): + return self.call("normL1").toArray() -class ChiSqTestResult(JavaModelWrapper): - """ - .. note:: Experimental - - Object containing the test results for the chi-squared hypothesis test. - """ - @property - def method(self): - """ - Name of the test method - """ - return self._java_model.method() - - @property - def pValue(self): - """ - The probability of obtaining a test statistic result at least as - extreme as the one that was actually observed, assuming that the - null hypothesis is true. - """ - return self._java_model.pValue() - - @property - def degreesOfFreedom(self): - """ - Returns the degree(s) of freedom of the hypothesis test. - Return type should be Number(e.g. Int, Double) or tuples of Numbers. - """ - return self._java_model.degreesOfFreedom() - - @property - def statistic(self): - """ - Test statistic. - """ - return self._java_model.statistic() - - @property - def nullHypothesis(self): - """ - Null hypothesis of the test. - """ - return self._java_model.nullHypothesis() - - def __str__(self): - return self._java_model.toString() + def normL2(self): + return self.call("normL2").toArray() class Statistics(object): @@ -123,7 +78,7 @@ def colStats(rdd): >>> cStats.variance() array([ 4., 13., 0., 25.]) >>> cStats.count() - 3L + 3 >>> cStats.numNonzeros() array([ 3., 2., 0., 3.]) >>> cStats.max() @@ -169,20 +124,20 @@ def corr(x, y=None, method=None): >>> rdd = sc.parallelize([Vectors.dense([1, 0, 0, -2]), Vectors.dense([4, 5, 0, 3]), ... Vectors.dense([6, 7, 0, 8]), Vectors.dense([9, 0, 0, 1])]) >>> pearsonCorr = Statistics.corr(rdd) - >>> print str(pearsonCorr).replace('nan', 'NaN') + >>> print(str(pearsonCorr).replace('nan', 'NaN')) [[ 1. 0.05564149 NaN 0.40047142] [ 0.05564149 1. NaN 0.91359586] [ NaN NaN 1. NaN] [ 0.40047142 0.91359586 NaN 1. ]] >>> spearmanCorr = Statistics.corr(rdd, method="spearman") - >>> print str(spearmanCorr).replace('nan', 'NaN') + >>> print(str(spearmanCorr).replace('nan', 'NaN')) [[ 1. 0.10540926 NaN 0.4 ] [ 0.10540926 1. NaN 0.9486833 ] [ NaN NaN 1. NaN] [ 0.4 0.9486833 NaN 1. ]] >>> try: ... Statistics.corr(rdd, "spearman") - ... print "Method name as second argument without 'method=' shouldn't be allowed." + ... print("Method name as second argument without 'method=' shouldn't be allowed.") ... except TypeError: ... pass """ @@ -198,6 +153,7 @@ def corr(x, y=None, method=None): return callMLlibFunc("corr", x.map(float), y.map(float), method) @staticmethod + @ignore_unicode_prefix def chiSqTest(observed, expected=None): """ .. note:: Experimental @@ -233,11 +189,11 @@ def chiSqTest(observed, expected=None): >>> from pyspark.mllib.linalg import Vectors, Matrices >>> observed = Vectors.dense([4, 6, 5]) >>> pearson = Statistics.chiSqTest(observed) - >>> print pearson.statistic + >>> print(pearson.statistic) 0.4 >>> pearson.degreesOfFreedom 2 - >>> print round(pearson.pValue, 4) + >>> print(round(pearson.pValue, 4)) 0.8187 >>> pearson.method u'pearson' @@ -247,12 +203,12 @@ def chiSqTest(observed, expected=None): >>> observed = Vectors.dense([21, 38, 43, 80]) >>> expected = Vectors.dense([3, 5, 7, 20]) >>> pearson = Statistics.chiSqTest(observed, expected) - >>> print round(pearson.pValue, 4) + >>> print(round(pearson.pValue, 4)) 0.0027 >>> data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0] >>> chi = Statistics.chiSqTest(Matrices.dense(3, 4, data)) - >>> print round(chi.statistic, 4) + >>> print(round(chi.statistic, 4)) 21.9958 >>> data = [LabeledPoint(0.0, Vectors.dense([0.5, 10.0])), @@ -263,9 +219,9 @@ def chiSqTest(observed, expected=None): ... LabeledPoint(1.0, Vectors.dense([3.5, 40.0])),] >>> rdd = sc.parallelize(data, 4) >>> chi = Statistics.chiSqTest(rdd) - >>> print chi[0].statistic + >>> print(chi[0].statistic) 0.75 - >>> print chi[1].statistic + >>> print(chi[1].statistic) 1.5 """ if isinstance(observed, RDD): diff --git a/python/pyspark/mllib/stat/distribution.py b/python/pyspark/mllib/stat/distribution.py new file mode 100644 index 000000000000..46f7a1d2f277 --- /dev/null +++ b/python/pyspark/mllib/stat/distribution.py @@ -0,0 +1,32 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import namedtuple + +__all__ = ['MultivariateGaussian'] + + +class MultivariateGaussian(namedtuple('MultivariateGaussian', ['mu', 'sigma'])): + + """Represents a (mu, sigma) tuple + + >>> m = MultivariateGaussian(Vectors.dense([11,12]),DenseMatrix(2, 2, (1.0, 3.0, 5.0, 2.0))) + >>> (m.mu, m.sigma.toArray()) + (DenseVector([11.0, 12.0]), array([[ 1., 5.],[ 3., 2.]])) + >>> (m[0], m[1]) + (DenseVector([11.0, 12.0]), array([[ 1., 5.],[ 3., 2.]])) + """ diff --git a/python/pyspark/mllib/stat/test.py b/python/pyspark/mllib/stat/test.py new file mode 100644 index 000000000000..762506e952b4 --- /dev/null +++ b/python/pyspark/mllib/stat/test.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.mllib.common import JavaModelWrapper + + +__all__ = ["ChiSqTestResult"] + + +class ChiSqTestResult(JavaModelWrapper): + """ + .. note:: Experimental + + Object containing the test results for the chi-squared hypothesis test. + """ + @property + def method(self): + """ + Name of the test method + """ + return self._java_model.method() + + @property + def pValue(self): + """ + The probability of obtaining a test statistic result at least as + extreme as the one that was actually observed, assuming that the + null hypothesis is true. + """ + return self._java_model.pValue() + + @property + def degreesOfFreedom(self): + """ + Returns the degree(s) of freedom of the hypothesis test. + Return type should be Number(e.g. Int, Double) or tuples of Numbers. + """ + return self._java_model.degreesOfFreedom() + + @property + def statistic(self): + """ + Test statistic. + """ + return self._java_model.statistic() + + @property + def nullHypothesis(self): + """ + Null hypothesis of the test. + """ + return self._java_model.nullHypothesis() + + def __str__(self): + return self._java_model.toString() diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index f48e3d6dacb4..1b008b93bc13 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -19,10 +19,12 @@ Fuller unit tests for Python MLlib. """ +import os import sys +import tempfile import array as pyarray -from numpy import array, array_equal +from numpy import array, array_equal, zeros from py4j.protocol import Py4JJavaError if sys.version_info[:2] <= (2, 6): @@ -34,14 +36,18 @@ else: import unittest +from pyspark import SparkContext +from pyspark.mllib.common import _to_java_object_rdd from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ - DenseMatrix, Vectors, Matrices + DenseMatrix, SparseMatrix, Vectors, Matrices from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics +from pyspark.mllib.feature import Word2Vec +from pyspark.mllib.feature import IDF +from pyspark.mllib.feature import StandardScaler from pyspark.serializers import PickleSerializer from pyspark.sql import SQLContext -from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase _have_scipy = False try: @@ -52,6 +58,12 @@ pass ser = PickleSerializer() +sc = SparkContext('local[4]', "MLlib tests") + + +class MLlibTestCase(unittest.TestCase): + def setUp(self): + self.sc = sc def _squared_distance(a, b): @@ -61,16 +73,16 @@ def _squared_distance(a, b): return b.squared_distance(a) -class VectorTests(PySparkTestCase): +class VectorTests(MLlibTestCase): def _test_serialize(self, v): self.assertEqual(v, ser.loads(ser.dumps(v))) jvec = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(v))) - nv = ser.loads(str(self.sc._jvm.SerDe.dumps(jvec))) + nv = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvec))) self.assertEqual(v, nv) vs = [v] * 100 jvecs = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(vs))) - nvs = ser.loads(str(self.sc._jvm.SerDe.dumps(jvecs))) + nvs = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvecs))) self.assertEqual(vs, nvs) def test_serialize(self): @@ -129,11 +141,84 @@ def test_sparse_vector_indexing(self): self.assertEquals(sv[-1], 2) self.assertEquals(sv[-2], 0) self.assertEquals(sv[-4], 0) - for ind in [4, -5, 7.8]: + for ind in [4, -5]: self.assertRaises(ValueError, sv.__getitem__, ind) - - -class ListTests(PySparkTestCase): + for ind in [7.8, '1']: + self.assertRaises(TypeError, sv.__getitem__, ind) + + def test_matrix_indexing(self): + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) + expected = [[0, 6], [1, 8], [4, 10]] + for i in range(3): + for j in range(2): + self.assertEquals(mat[i, j], expected[i][j]) + + def test_sparse_matrix(self): + # Test sparse matrix creation. + sm1 = SparseMatrix( + 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) + self.assertEquals(sm1.numRows, 3) + self.assertEquals(sm1.numCols, 4) + self.assertEquals(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) + self.assertEquals(sm1.rowIndices.tolist(), [1, 2, 1, 2]) + self.assertEquals(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) + + # Test indexing + expected = [ + [0, 0, 0, 0], + [1, 0, 4, 0], + [2, 0, 5, 0]] + + for i in range(3): + for j in range(4): + self.assertEquals(expected[i][j], sm1[i, j]) + self.assertTrue(array_equal(sm1.toArray(), expected)) + + # Test conversion to dense and sparse. + smnew = sm1.toDense().toSparse() + self.assertEquals(sm1.numRows, smnew.numRows) + self.assertEquals(sm1.numCols, smnew.numCols) + self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs)) + self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices)) + self.assertTrue(array_equal(sm1.values, smnew.values)) + + sm1t = SparseMatrix( + 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], + isTransposed=True) + self.assertEquals(sm1t.numRows, 3) + self.assertEquals(sm1t.numCols, 4) + self.assertEquals(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) + self.assertEquals(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) + self.assertEquals(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) + + expected = [ + [3, 2, 0, 0], + [0, 0, 4, 0], + [9, 0, 8, 0]] + + for i in range(3): + for j in range(4): + self.assertEquals(expected[i][j], sm1t[i, j]) + self.assertTrue(array_equal(sm1t.toArray(), expected)) + + def test_dense_matrix_is_transposed(self): + mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) + mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) + self.assertEquals(mat1, mat) + + expected = [[0, 4], [1, 6], [3, 9]] + for i in range(3): + for j in range(2): + self.assertEquals(mat1[i, j], expected[i][j]) + self.assertTrue(array_equal(mat1.toArray(), expected)) + + sm = mat1.toSparse() + self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2])) + self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5])) + self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9])) + + +class ListTests(MLlibTestCase): """ Test MLlib algorithms on plain lists, to make sure they're passed through @@ -167,9 +252,36 @@ def test_kmeans_deterministic(self): # TODO: Allow small numeric difference. self.assertTrue(array_equal(c1, c2)) + def test_gmm(self): + from pyspark.mllib.clustering import GaussianMixture + data = self.sc.parallelize([ + [1, 2], + [8, 9], + [-4, -3], + [-6, -7], + ]) + clusters = GaussianMixture.train(data, 2, convergenceTol=0.001, + maxIterations=10, seed=56) + labels = clusters.predict(data).collect() + self.assertEquals(labels[0], labels[1]) + self.assertEquals(labels[2], labels[3]) + + def test_gmm_deterministic(self): + from pyspark.mllib.clustering import GaussianMixture + x = range(0, 100, 10) + y = range(0, 100, 10) + data = self.sc.parallelize([[a, b] for a, b in zip(x, y)]) + clusters1 = GaussianMixture.train(data, 5, convergenceTol=0.001, + maxIterations=10, seed=63) + clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001, + maxIterations=10, seed=63) + for c1, c2 in zip(clusters1.weights, clusters2.weights): + self.assertEquals(round(c1, 7), round(c2, 7)) + def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes - from pyspark.mllib.tree import DecisionTree + from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\ + RandomForestModel, GradientBoostedTrees, GradientBoostedTreesModel data = [ LabeledPoint(0.0, [1, 0, 0]), LabeledPoint(1.0, [0, 1, 1]), @@ -179,13 +291,15 @@ def test_classification(self): rdd = self.sc.parallelize(data) features = [p.features.tolist() for p in data] - lr_model = LogisticRegressionWithSGD.train(rdd) + temp_dir = tempfile.mkdtemp() + + lr_model = LogisticRegressionWithSGD.train(rdd, iterations=10) self.assertTrue(lr_model.predict(features[0]) <= 0) self.assertTrue(lr_model.predict(features[1]) > 0) self.assertTrue(lr_model.predict(features[2]) <= 0) self.assertTrue(lr_model.predict(features[3]) > 0) - svm_model = SVMWithSGD.train(rdd) + svm_model = SVMWithSGD.train(rdd, iterations=10) self.assertTrue(svm_model.predict(features[0]) <= 0) self.assertTrue(svm_model.predict(features[1]) > 0) self.assertTrue(svm_model.predict(features[2]) <= 0) @@ -198,18 +312,52 @@ def test_classification(self): self.assertTrue(nb_model.predict(features[3]) > 0) categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories - dt_model = \ - DecisionTree.trainClassifier(rdd, numClasses=2, - categoricalFeaturesInfo=categoricalFeaturesInfo) + dt_model = DecisionTree.trainClassifier( + rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, maxBins=4) self.assertTrue(dt_model.predict(features[0]) <= 0) self.assertTrue(dt_model.predict(features[1]) > 0) self.assertTrue(dt_model.predict(features[2]) <= 0) self.assertTrue(dt_model.predict(features[3]) > 0) + dt_model_dir = os.path.join(temp_dir, "dt") + dt_model.save(self.sc, dt_model_dir) + same_dt_model = DecisionTreeModel.load(self.sc, dt_model_dir) + self.assertEqual(same_dt_model.toDebugString(), dt_model.toDebugString()) + + rf_model = RandomForest.trainClassifier( + rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=10, + maxBins=4, seed=1) + self.assertTrue(rf_model.predict(features[0]) <= 0) + self.assertTrue(rf_model.predict(features[1]) > 0) + self.assertTrue(rf_model.predict(features[2]) <= 0) + self.assertTrue(rf_model.predict(features[3]) > 0) + + rf_model_dir = os.path.join(temp_dir, "rf") + rf_model.save(self.sc, rf_model_dir) + same_rf_model = RandomForestModel.load(self.sc, rf_model_dir) + self.assertEqual(same_rf_model.toDebugString(), rf_model.toDebugString()) + + gbt_model = GradientBoostedTrees.trainClassifier( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4) + self.assertTrue(gbt_model.predict(features[0]) <= 0) + self.assertTrue(gbt_model.predict(features[1]) > 0) + self.assertTrue(gbt_model.predict(features[2]) <= 0) + self.assertTrue(gbt_model.predict(features[3]) > 0) + + gbt_model_dir = os.path.join(temp_dir, "gbt") + gbt_model.save(self.sc, gbt_model_dir) + same_gbt_model = GradientBoostedTreesModel.load(self.sc, gbt_model_dir) + self.assertEqual(same_gbt_model.toDebugString(), gbt_model.toDebugString()) + + try: + os.removedirs(temp_dir) + except OSError: + pass + def test_regression(self): from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ RidgeRegressionWithSGD - from pyspark.mllib.tree import DecisionTree + from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees data = [ LabeledPoint(-1.0, [0, -1]), LabeledPoint(1.0, [0, 1]), @@ -219,34 +367,55 @@ def test_regression(self): rdd = self.sc.parallelize(data) features = [p.features.tolist() for p in data] - lr_model = LinearRegressionWithSGD.train(rdd) + lr_model = LinearRegressionWithSGD.train(rdd, iterations=10) self.assertTrue(lr_model.predict(features[0]) <= 0) self.assertTrue(lr_model.predict(features[1]) > 0) self.assertTrue(lr_model.predict(features[2]) <= 0) self.assertTrue(lr_model.predict(features[3]) > 0) - lasso_model = LassoWithSGD.train(rdd) + lasso_model = LassoWithSGD.train(rdd, iterations=10) self.assertTrue(lasso_model.predict(features[0]) <= 0) self.assertTrue(lasso_model.predict(features[1]) > 0) self.assertTrue(lasso_model.predict(features[2]) <= 0) self.assertTrue(lasso_model.predict(features[3]) > 0) - rr_model = RidgeRegressionWithSGD.train(rdd) + rr_model = RidgeRegressionWithSGD.train(rdd, iterations=10) self.assertTrue(rr_model.predict(features[0]) <= 0) self.assertTrue(rr_model.predict(features[1]) > 0) self.assertTrue(rr_model.predict(features[2]) <= 0) self.assertTrue(rr_model.predict(features[3]) > 0) categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories - dt_model = \ - DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + dt_model = DecisionTree.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, maxBins=4) self.assertTrue(dt_model.predict(features[0]) <= 0) self.assertTrue(dt_model.predict(features[1]) > 0) self.assertTrue(dt_model.predict(features[2]) <= 0) self.assertTrue(dt_model.predict(features[3]) > 0) - -class StatTests(PySparkTestCase): + rf_model = RandomForest.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=10, maxBins=4, seed=1) + self.assertTrue(rf_model.predict(features[0]) <= 0) + self.assertTrue(rf_model.predict(features[1]) > 0) + self.assertTrue(rf_model.predict(features[2]) <= 0) + self.assertTrue(rf_model.predict(features[3]) > 0) + + gbt_model = GradientBoostedTrees.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4) + self.assertTrue(gbt_model.predict(features[0]) <= 0) + self.assertTrue(gbt_model.predict(features[1]) > 0) + self.assertTrue(gbt_model.predict(features[2]) <= 0) + self.assertTrue(gbt_model.predict(features[3]) > 0) + + try: + LinearRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) + LassoWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) + RidgeRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) + except ValueError: + self.fail() + + +class StatTests(MLlibTestCase): # SPARK-4023 def test_col_with_different_rdds(self): # numpy @@ -262,8 +431,21 @@ def test_col_with_different_rdds(self): summary = Statistics.colStats(data) self.assertEqual(10, summary.count()) + def test_col_norms(self): + data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10) + summary = Statistics.colStats(data) + self.assertEqual(10, len(summary.normL1())) + self.assertEqual(10, len(summary.normL2())) + + data2 = self.sc.parallelize(range(10)).map(lambda x: Vectors.dense(x)) + summary2 = Statistics.colStats(data2) + self.assertEqual(array([45.0]), summary2.normL1()) + import math + expectedNormL2 = math.sqrt(sum(map(lambda x: x*x, range(10)))) + self.assertTrue(math.fabs(summary2.normL2()[0] - expectedNormL2) < 1e-14) -class VectorUDTTests(PySparkTestCase): + +class VectorUDTTests(MLlibTestCase): dv0 = DenseVector([]) dv1 = DenseVector([1.0, 2.0]) @@ -281,11 +463,11 @@ def test_serialization(self): def test_infer_schema(self): sqlCtx = SQLContext(self.sc) rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)]) - srdd = sqlCtx.inferSchema(rdd) - schema = srdd.schema() + df = rdd.toDF() + schema = df.schema field = [f for f in schema.fields if f.name == "features"][0] self.assertEqual(field.dataType, self.udt) - vectors = srdd.map(lambda p: p.features).collect() + vectors = df.map(lambda p: p.features).collect() self.assertEqual(len(vectors), 2) for v in vectors: if isinstance(v, SparseVector): @@ -293,11 +475,11 @@ def test_infer_schema(self): elif isinstance(v, DenseVector): self.assertEqual(v, self.dv1) else: - raise ValueError("expecting a vector but got %r of type %r" % (v, type(v))) + raise TypeError("expecting a vector but got %r of type %r" % (v, type(v))) @unittest.skipIf(not _have_scipy, "SciPy not installed") -class SciPyTests(PySparkTestCase): +class SciPyTests(MLlibTestCase): """ Test both vector operations and MLlib algorithms with SciPy sparse matrices, @@ -438,7 +620,7 @@ def test_regression(self): self.assertTrue(dt_model.predict(features[3]) > 0) -class ChiSqTestTests(PySparkTestCase): +class ChiSqTestTests(MLlibTestCase): def test_goodness_of_fit(self): from numpy import inf @@ -535,9 +717,82 @@ def test_right_number_of_results(self): self.assertEqual(len(chi), num_cols) self.assertIsNotNone(chi[1000]) + +class SerDeTest(MLlibTestCase): + def test_to_java_object_rdd(self): # SPARK-6660 + data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0) + self.assertEqual(_to_java_object_rdd(data).count(), 10) + + +class FeatureTest(MLlibTestCase): + def test_idf_model(self): + data = [ + Vectors.dense([1, 2, 6, 0, 2, 3, 1, 1, 0, 0, 3]), + Vectors.dense([1, 3, 0, 1, 3, 0, 0, 2, 0, 0, 1]), + Vectors.dense([1, 4, 1, 0, 0, 4, 9, 0, 1, 2, 0]), + Vectors.dense([2, 1, 0, 3, 0, 0, 5, 0, 2, 3, 9]) + ] + model = IDF().fit(self.sc.parallelize(data, 2)) + idf = model.idf() + self.assertEqual(len(idf), 11) + + +class Word2VecTests(MLlibTestCase): + def test_word2vec_setters(self): + model = Word2Vec() \ + .setVectorSize(2) \ + .setLearningRate(0.01) \ + .setNumPartitions(2) \ + .setNumIterations(10) \ + .setSeed(1024) \ + .setMinCount(3) + self.assertEquals(model.vectorSize, 2) + self.assertTrue(model.learningRate < 0.02) + self.assertEquals(model.numPartitions, 2) + self.assertEquals(model.numIterations, 10) + self.assertEquals(model.seed, 1024) + self.assertEquals(model.minCount, 3) + + def test_word2vec_get_vectors(self): + data = [ + ["a", "b", "c", "d", "e", "f", "g"], + ["a", "b", "c", "d", "e", "f"], + ["a", "b", "c", "d", "e"], + ["a", "b", "c", "d"], + ["a", "b", "c"], + ["a", "b"], + ["a"] + ] + model = Word2Vec().fit(self.sc.parallelize(data)) + self.assertEquals(len(model.getVectors()), 3) + + +class StandardScalerTests(MLlibTestCase): + def test_model_setters(self): + data = [ + [1.0, 2.0, 3.0], + [2.0, 3.0, 4.0], + [3.0, 4.0, 5.0] + ] + model = StandardScaler().fit(self.sc.parallelize(data)) + self.assertIsNotNone(model.setWithMean(True)) + self.assertIsNotNone(model.setWithStd(True)) + self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([-1.0, -1.0, -1.0])) + + def test_model_transform(self): + data = [ + [1.0, 2.0, 3.0], + [2.0, 3.0, 4.0], + [3.0, 4.0, 5.0] + ] + model = StandardScaler().fit(self.sc.parallelize(data)) + self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0])) + + if __name__ == "__main__": if not _have_scipy: - print "NOTE: Skipping SciPy tests as it does not seem to be installed" + print("NOTE: Skipping SciPy tests as it does not seem to be installed") unittest.main() if not _have_scipy: - print "NOTE: SciPy tests were skipped as it does not seem to be installed" + print("NOTE: SciPy tests were skipped as it does not seem to be installed") + sc.stop() diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 66702478474d..cfcbea573fd2 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -20,25 +20,67 @@ import random from pyspark import SparkContext, RDD -from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper +from pyspark.mllib.common import callMLlibFunc, inherit_doc, JavaModelWrapper from pyspark.mllib.linalg import _convert_to_vector from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import JavaLoader, JavaSaveable -__all__ = ['DecisionTreeModel', 'DecisionTree', 'RandomForestModel', 'RandomForest'] +__all__ = ['DecisionTreeModel', 'DecisionTree', 'RandomForestModel', + 'RandomForest', 'GradientBoostedTreesModel', 'GradientBoostedTrees'] -class DecisionTreeModel(JavaModelWrapper): +class TreeEnsembleModel(JavaModelWrapper, JavaSaveable): + def predict(self, x): + """ + Predict values for a single data point or an RDD of points using + the model trained. + + Note: In Python, predict cannot currently be used within an RDD + transformation or action. + Call predict directly on the RDD instead. + """ + if isinstance(x, RDD): + return self.call("predict", x.map(_convert_to_vector)) + + else: + return self.call("predict", _convert_to_vector(x)) + + def numTrees(self): + """ + Get number of trees in ensemble. + """ + return self.call("numTrees") + + def totalNumNodes(self): + """ + Get total number of nodes, summed over all trees in the + ensemble. + """ + return self.call("totalNumNodes") + def __repr__(self): + """ Summary of model """ + return self._java_model.toString() + + def toDebugString(self): + """ Full model """ + return self._java_model.toDebugString() + + +class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader): """ - A decision tree model for classification or regression. + .. note:: Experimental - EXPERIMENTAL: This is an experimental API. - It will probably be modified in future. + A decision tree model for classification or regression. """ def predict(self, x): """ Predict the label of one or more examples. + Note: In Python, predict cannot currently be used within an RDD + transformation or action. + Call predict directly on the RDD instead. + :param x: Data point (feature vector), or an RDD of data points (feature vectors). """ @@ -62,14 +104,17 @@ def toDebugString(self): """ full model. """ return self._java_model.toDebugString() + @classmethod + def _java_loader_class(cls): + return "org.apache.spark.mllib.tree.model.DecisionTreeModel" -class DecisionTree(object): +class DecisionTree(object): """ - Learning algorithm for a decision tree model for classification or regression. + .. note:: Experimental - EXPERIMENTAL: This is an experimental API. - It will probably be modified in future. + Learning algorithm for a decision tree model for classification or + regression. """ @classmethod @@ -118,14 +163,16 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, ... LabeledPoint(1.0, [3.0]) ... ] >>> model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {}) - >>> print model, # it already has newline + >>> print(model) DecisionTreeModel classifier of depth 1 with 3 nodes - >>> print model.toDebugString(), # it already has newline + + >>> print(model.toDebugString()) DecisionTreeModel classifier of depth 1 with 3 nodes If (feature 0 <= 0.0) Predict: 0.0 Else (feature 0 > 0.0) Predict: 1.0 + >>> model.predict(array([1.0])) 1.0 >>> model.predict(array([0.0])) @@ -146,17 +193,17 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, :param data: Training data: RDD of LabeledPoint. Labels are real numbers. - :param categoricalFeaturesInfo: Map from categorical feature index - to number of categories. - Any feature not in this map - is treated as continuous. + :param categoricalFeaturesInfo: Map from categorical feature + index to number of categories. + Any feature not in this map is treated as continuous. :param impurity: Supported values: "variance" :param maxDepth: Max depth of tree. - E.g., depth 0 means 1 leaf node. - Depth 1 means 1 internal node + 2 leaf nodes. - :param maxBins: Number of bins used for finding splits at each node. - :param minInstancesPerNode: Min number of instances required at child - nodes to create the parent split + E.g., depth 0 means 1 leaf node. + Depth 1 means 1 internal node + 2 leaf nodes. + :param maxBins: Number of bins used for finding splits at each + node. + :param minInstancesPerNode: Min number of instances required at + child nodes to create the parent split :param minInfoGain: Min info gain required to create a split :return: DecisionTreeModel @@ -186,51 +233,25 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) -class RandomForestModel(JavaModelWrapper): +@inherit_doc +class RandomForestModel(TreeEnsembleModel, JavaLoader): """ - Represents a random forest model. + .. note:: Experimental - EXPERIMENTAL: This is an experimental API. - It will probably be modified in future. + Represents a random forest model. """ - def predict(self, x): - """ - Predict values for a single data point or an RDD of points using - the model trained. - """ - if isinstance(x, RDD): - return self.call("predict", x.map(_convert_to_vector)) - - else: - return self.call("predict", _convert_to_vector(x)) - def numTrees(self): - """ - Get number of trees in forest. - """ - return self.call("numTrees") - - def totalNumNodes(self): - """ - Get total number of nodes, summed over all trees in the forest. - """ - return self.call("totalNumNodes") - - def __repr__(self): - """ Summary of model """ - return self._java_model.toString() - - def toDebugString(self): - """ Full model """ - return self._java_model.toDebugString() + @classmethod + def _java_loader_class(cls): + return "org.apache.spark.mllib.tree.model.RandomForestModel" class RandomForest(object): """ - Learning algorithm for a random forest model for classification or regression. + .. note:: Experimental - EXPERIMENTAL: This is an experimental API. - It will probably be modified in future. + Learning algorithm for a random forest model for classification or + regression. """ supportedFeatureSubsetStrategies = ("auto", "all", "sqrt", "log2", "onethird") @@ -257,26 +278,30 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, Method to train a decision tree model for binary or multiclass classification. - :param data: Training dataset: RDD of LabeledPoint. Labels should take - values {0, 1, ..., numClasses-1}. + :param data: Training dataset: RDD of LabeledPoint. Labels + should take values {0, 1, ..., numClasses-1}. :param numClasses: number of classes for classification. - :param categoricalFeaturesInfo: Map storing arity of categorical features. - E.g., an entry (n -> k) indicates that feature n is categorical - with k categories indexed from 0: {0, 1, ..., k-1}. + :param categoricalFeaturesInfo: Map storing arity of categorical + features. E.g., an entry (n -> k) indicates that + feature n is categorical with k categories indexed + from 0: {0, 1, ..., k-1}. :param numTrees: Number of trees in the random forest. - :param featureSubsetStrategy: Number of features to consider for splits at - each node. - Supported: "auto" (default), "all", "sqrt", "log2", "onethird". - If "auto" is set, this parameter is set based on numTrees: - if numTrees == 1, set to "all"; - if numTrees > 1 (forest) set to "sqrt". + :param featureSubsetStrategy: Number of features to consider for + splits at each node. + Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + If "auto" is set, this parameter is set based on numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "sqrt". :param impurity: Criterion used for information gain calculation. Supported values: "gini" (recommended) or "entropy". - :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 leaf node; - depth 1 means 1 internal node + 2 leaf nodes. (default: 4) - :param maxBins: maximum number of bins used for splitting features - (default: 100) - :param seed: Random seed for bootstrapping and choosing feature subsets. + :param maxDepth: Maximum depth of the tree. + E.g., depth 0 means 1 leaf node; depth 1 means + 1 internal node + 2 leaf nodes. (default: 4) + :param maxBins: maximum number of bins used for splitting + features + (default: 100) + :param seed: Random seed for bootstrapping and choosing feature + subsets. :return: RandomForestModel that can be used for prediction Example usage: @@ -295,9 +320,10 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, 3 >>> model.totalNumNodes() 7 - >>> print model, + >>> print(model) TreeEnsembleModel classifier with 3 trees - >>> print model.toDebugString(), + + >>> print(model.toDebugString()) TreeEnsembleModel classifier with 3 trees Tree 0: @@ -312,6 +338,7 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, Predict: 0.0 Else (feature 0 > 1.0) Predict: 1.0 + >>> model.predict([2.0]) 1.0 >>> model.predict([0.0]) @@ -338,19 +365,21 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetSt {0, 1, ..., k-1}. :param numTrees: Number of trees in the random forest. :param featureSubsetStrategy: Number of features to consider for - splits at each node. - Supported: "auto" (default), "all", "sqrt", "log2", "onethird". - If "auto" is set, this parameter is set based on numTrees: - if numTrees == 1, set to "all"; - if numTrees > 1 (forest) set to "onethird" for regression. - :param impurity: Criterion used for information gain calculation. - Supported values: "variance". - :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 - leaf node; depth 1 means 1 internal node + 2 leaf nodes. - (default: 4) - :param maxBins: maximum number of bins used for splitting features - (default: 100) - :param seed: Random seed for bootstrapping and choosing feature subsets. + splits at each node. + Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + If "auto" is set, this parameter is set based on numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "onethird" for regression. + :param impurity: Criterion used for information gain + calculation. + Supported values: "variance". + :param maxDepth: Maximum depth of the tree. E.g., depth 0 means + 1 leaf node; depth 1 means 1 internal node + 2 leaf + nodes. (default: 4) + :param maxBins: maximum number of bins used for splitting + features (default: 100) + :param seed: Random seed for bootstrapping and choosing feature + subsets. :return: RandomForestModel that can be used for prediction Example usage: @@ -383,6 +412,153 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetSt featureSubsetStrategy, impurity, maxDepth, maxBins, seed) +@inherit_doc +class GradientBoostedTreesModel(TreeEnsembleModel, JavaLoader): + """ + .. note:: Experimental + + Represents a gradient-boosted tree model. + """ + + @classmethod + def _java_loader_class(cls): + return "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel" + + +class GradientBoostedTrees(object): + """ + .. note:: Experimental + + Learning algorithm for a gradient boosted trees model for + classification or regression. + """ + + @classmethod + def _train(cls, data, algo, categoricalFeaturesInfo, + loss, numIterations, learningRate, maxDepth): + first = data.first() + assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint" + model = callMLlibFunc("trainGradientBoostedTreesModel", data, algo, categoricalFeaturesInfo, + loss, numIterations, learningRate, maxDepth) + return GradientBoostedTreesModel(model) + + @classmethod + def trainClassifier(cls, data, categoricalFeaturesInfo, + loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3): + """ + Method to train a gradient-boosted trees model for + classification. + + :param data: Training dataset: RDD of LabeledPoint. + Labels should take values {0, 1}. + :param categoricalFeaturesInfo: Map storing arity of categorical + features. E.g., an entry (n -> k) indicates that feature + n is categorical with k categories indexed from 0: + {0, 1, ..., k-1}. + :param loss: Loss function used for minimization during gradient + boosting. Supported: {"logLoss" (default), + "leastSquaresError", "leastAbsoluteError"}. + :param numIterations: Number of iterations of boosting. + (default: 100) + :param learningRate: Learning rate for shrinking the + contribution of each estimator. The learning rate + should be between in the interval (0, 1]. + (default: 0.1) + :param maxDepth: Maximum depth of the tree. E.g., depth 0 means + 1 leaf node; depth 1 means 1 internal node + 2 leaf + nodes. (default: 3) + :return: GradientBoostedTreesModel that can be used for + prediction + + Example usage: + + >>> from pyspark.mllib.regression import LabeledPoint + >>> from pyspark.mllib.tree import GradientBoostedTrees + >>> + >>> data = [ + ... LabeledPoint(0.0, [0.0]), + ... LabeledPoint(0.0, [1.0]), + ... LabeledPoint(1.0, [2.0]), + ... LabeledPoint(1.0, [3.0]) + ... ] + >>> + >>> model = GradientBoostedTrees.trainClassifier(sc.parallelize(data), {}, numIterations=10) + >>> model.numTrees() + 10 + >>> model.totalNumNodes() + 30 + >>> print(model) # it already has newline + TreeEnsembleModel classifier with 10 trees + + >>> model.predict([2.0]) + 1.0 + >>> model.predict([0.0]) + 0.0 + >>> rdd = sc.parallelize([[2.0], [0.0]]) + >>> model.predict(rdd).collect() + [1.0, 0.0] + """ + return cls._train(data, "classification", categoricalFeaturesInfo, + loss, numIterations, learningRate, maxDepth) + + @classmethod + def trainRegressor(cls, data, categoricalFeaturesInfo, + loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3): + """ + Method to train a gradient-boosted trees model for regression. + + :param data: Training dataset: RDD of LabeledPoint. Labels are + real numbers. + :param categoricalFeaturesInfo: Map storing arity of categorical + features. E.g., an entry (n -> k) indicates that feature + n is categorical with k categories indexed from 0: + {0, 1, ..., k-1}. + :param loss: Loss function used for minimization during gradient + boosting. Supported: {"logLoss" (default), + "leastSquaresError", "leastAbsoluteError"}. + :param numIterations: Number of iterations of boosting. + (default: 100) + :param learningRate: Learning rate for shrinking the + contribution of each estimator. The learning rate + should be between in the interval (0, 1]. + (default: 0.1) + :param maxDepth: Maximum depth of the tree. E.g., depth 0 means + 1 leaf node; depth 1 means 1 internal node + 2 leaf + nodes. (default: 3) + :return: GradientBoostedTreesModel that can be used for + prediction + + Example usage: + + >>> from pyspark.mllib.regression import LabeledPoint + >>> from pyspark.mllib.tree import GradientBoostedTrees + >>> from pyspark.mllib.linalg import SparseVector + >>> + >>> sparse_data = [ + ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})), + ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})), + ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})), + ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) + ... ] + >>> + >>> data = sc.parallelize(sparse_data) + >>> model = GradientBoostedTrees.trainRegressor(data, {}, numIterations=10) + >>> model.numTrees() + 10 + >>> model.totalNumNodes() + 12 + >>> model.predict(SparseVector(2, {1: 1.0})) + 1.0 + >>> model.predict(SparseVector(2, {0: 1.0})) + 0.0 + >>> rdd = sc.parallelize([[0.0, 1.0], [1.0, 0.0]]) + >>> model.predict(rdd).collect() + [1.0, 0.0] + """ + return cls._train(data, "regression", categoricalFeaturesInfo, + loss, numIterations, learningRate, maxDepth) + + def _test(): import doctest globs = globals().copy() diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 4ed978b45409..16a90db146ef 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -15,12 +15,15 @@ # limitations under the License. # +import sys import numpy as np import warnings -from pyspark.mllib.common import callMLlibFunc +if sys.version > '3': + xrange = range + +from pyspark.mllib.common import callMLlibFunc, inherit_doc from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector -from pyspark.mllib.regression import LabeledPoint class MLUtils(object): @@ -50,6 +53,7 @@ def _parse_libsvm_line(line, multiclass=None): @staticmethod def _convert_labeled_point_to_libsvm(p): """Converts a LabeledPoint to a string in LIBSVM format.""" + from pyspark.mllib.regression import LabeledPoint assert isinstance(p, LabeledPoint) items = [str(p.label)] v = _convert_to_vector(p.features) @@ -92,24 +96,20 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None >>> from tempfile import NamedTemporaryFile >>> from pyspark.mllib.util import MLUtils + >>> from pyspark.mllib.regression import LabeledPoint >>> tempFile = NamedTemporaryFile(delete=True) - >>> tempFile.write("+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0") + >>> _ = tempFile.write(b"+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0") >>> tempFile.flush() >>> examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect() >>> tempFile.close() - >>> type(examples[0]) == LabeledPoint - True - >>> print examples[0] - (1.0,(6,[0,2,4],[1.0,2.0,3.0])) - >>> type(examples[1]) == LabeledPoint - True - >>> print examples[1] - (-1.0,(6,[],[])) - >>> type(examples[2]) == LabeledPoint - True - >>> print examples[2] - (-1.0,(6,[1,3,5],[4.0,5.0,6.0])) + >>> examples[0] + LabeledPoint(1.0, (6,[0,2,4],[1.0,2.0,3.0])) + >>> examples[1] + LabeledPoint(-1.0, (6,[],[])) + >>> examples[2] + LabeledPoint(-1.0, (6,[1,3,5],[4.0,5.0,6.0])) """ + from pyspark.mllib.regression import LabeledPoint if multiclass is not None: warnings.warn("deprecated", DeprecationWarning) @@ -130,6 +130,7 @@ def saveAsLibSVMFile(data, dir): >>> from tempfile import NamedTemporaryFile >>> from fileinput import input + >>> from pyspark.mllib.regression import LabeledPoint >>> from glob import glob >>> from pyspark.mllib.util import MLUtils >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, 1.23), (2, 4.56)])), \ @@ -156,6 +157,7 @@ def loadLabeledPoints(sc, path, minPartitions=None): >>> from tempfile import NamedTemporaryFile >>> from pyspark.mllib.util import MLUtils + >>> from pyspark.mllib.regression import LabeledPoint >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, -1.23), (2, 4.56e-7)])), \ LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] >>> tempFile = NamedTemporaryFile(delete=True) @@ -168,6 +170,93 @@ def loadLabeledPoints(sc, path, minPartitions=None): return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions) +class Saveable(object): + """ + Mixin for models and transformers which may be saved as files. + """ + + def save(self, sc, path): + """ + Save this model to the given path. + + This saves: + * human-readable (JSON) model metadata to path/metadata/ + * Parquet formatted data to path/data/ + + The model may be loaded using py:meth:`Loader.load`. + + :param sc: Spark context used to save model data. + :param path: Path specifying the directory in which to save + this model. If the directory already exists, + this method throws an exception. + """ + raise NotImplementedError + + +@inherit_doc +class JavaSaveable(Saveable): + """ + Mixin for models that provide save() through their Scala + implementation. + """ + + def save(self, sc, path): + self._java_model.save(sc._jsc.sc(), path) + + +class Loader(object): + """ + Mixin for classes which can load saved models from files. + """ + + @classmethod + def load(cls, sc, path): + """ + Load a model from the given path. The model should have been + saved using py:meth:`Saveable.save`. + + :param sc: Spark context used for loading model files. + :param path: Path specifying the directory to which the model + was saved. + :return: model instance + """ + raise NotImplemented + + +@inherit_doc +class JavaLoader(Loader): + """ + Mixin for classes which can load saved models using its Scala + implementation. + """ + + @classmethod + def _java_loader_class(cls): + """ + Returns the full class name of the Java loader. The default + implementation replaces "pyspark" by "org.apache.spark" in + the Python full class name. + """ + java_package = cls.__module__.replace("pyspark", "org.apache.spark") + return ".".join([java_package, cls.__name__]) + + @classmethod + def _load_java(cls, sc, path): + """ + Load a Java model from the given path. + """ + java_class = cls._java_loader_class() + java_obj = sc._jvm + for name in java_class.split("."): + java_obj = getattr(java_obj, name) + return java_obj.load(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = cls._load_java(sc, path) + return cls(java_model) + + def _test(): import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py new file mode 100644 index 000000000000..d18daaabfcb3 --- /dev/null +++ b/python/pyspark/profiler.py @@ -0,0 +1,172 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import cProfile +import pstats +import os +import atexit + +from pyspark.accumulators import AccumulatorParam + + +class ProfilerCollector(object): + """ + This class keeps track of different profilers on a per + stage basis. Also this is used to create new profilers for + the different stages. + """ + + def __init__(self, profiler_cls, dump_path=None): + self.profiler_cls = profiler_cls + self.profile_dump_path = dump_path + self.profilers = [] + + def new_profiler(self, ctx): + """ Create a new profiler using class `profiler_cls` """ + return self.profiler_cls(ctx) + + def add_profiler(self, id, profiler): + """ Add a profiler for RDD `id` """ + if not self.profilers: + if self.profile_dump_path: + atexit.register(self.dump_profiles, self.profile_dump_path) + else: + atexit.register(self.show_profiles) + + self.profilers.append([id, profiler, False]) + + def dump_profiles(self, path): + """ Dump the profile stats into directory `path` """ + for id, profiler, _ in self.profilers: + profiler.dump(id, path) + self.profilers = [] + + def show_profiles(self): + """ Print the profile stats to stdout """ + for i, (id, profiler, showed) in enumerate(self.profilers): + if not showed and profiler: + profiler.show(id) + # mark it as showed + self.profilers[i][2] = True + + +class Profiler(object): + """ + .. note:: DeveloperApi + + PySpark supports custom profilers, this is to allow for different profilers to + be used as well as outputting to different formats than what is provided in the + BasicProfiler. + + A custom profiler has to define or inherit the following methods: + profile - will produce a system profile of some sort. + stats - return the collected stats. + dump - dumps the profiles to a path + add - adds a profile to the existing accumulated profile + + The profiler class is chosen when creating a SparkContext + + >>> from pyspark import SparkConf, SparkContext + >>> from pyspark import BasicProfiler + >>> class MyCustomProfiler(BasicProfiler): + ... def show(self, id): + ... print("My custom profiles for RDD:%s" % id) + ... + >>> conf = SparkConf().set("spark.python.profile", "true") + >>> sc = SparkContext('local', 'test', conf=conf, profiler_cls=MyCustomProfiler) + >>> sc.parallelize(range(1000)).map(lambda x: 2 * x).take(10) + [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] + >>> sc.show_profiles() + My custom profiles for RDD:1 + My custom profiles for RDD:2 + >>> sc.stop() + """ + + def __init__(self, ctx): + pass + + def profile(self, func): + """ Do profiling on the function `func`""" + raise NotImplemented + + def stats(self): + """ Return the collected profiling stats (pstats.Stats)""" + raise NotImplemented + + def show(self, id): + """ Print the profile stats to stdout, id is the RDD id """ + stats = self.stats() + if stats: + print("=" * 60) + print("Profile of RDD" % id) + print("=" * 60) + stats.sort_stats("time", "cumulative").print_stats() + + def dump(self, id, path): + """ Dump the profile into path, id is the RDD id """ + if not os.path.exists(path): + os.makedirs(path) + stats = self.stats() + if stats: + p = os.path.join(path, "rdd_%d.pstats" % id) + stats.dump_stats(p) + + +class PStatsParam(AccumulatorParam): + """PStatsParam is used to merge pstats.Stats""" + + @staticmethod + def zero(value): + return None + + @staticmethod + def addInPlace(value1, value2): + if value1 is None: + return value2 + value1.add(value2) + return value1 + + +class BasicProfiler(Profiler): + """ + BasicProfiler is the default profiler, which is implemented based on + cProfile and Accumulator + """ + def __init__(self, ctx): + Profiler.__init__(self, ctx) + # Creates a new accumulator for combining the profiles of different + # partitions of a stage + self._accumulator = ctx.accumulator(None, PStatsParam) + + def profile(self, func): + """ Runs and profiles the method to_profile passed in. A profile object is returned. """ + pr = cProfile.Profile() + pr.runcall(func) + st = pstats.Stats(pr) + st.stream = None # make it picklable + st.strip_dirs() + + # Adds a new profile to the existing accumulated value + self._accumulator.add(st) + + def stats(self): + return self._accumulator.value + + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4977400ac1c0..d254deb527d1 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -16,22 +16,29 @@ # import copy -from collections import defaultdict -from itertools import chain, ifilter, imap -import operator -import os import sys +import os +import re +import operator import shlex -from subprocess import Popen, PIPE -from tempfile import NamedTemporaryFile -from threading import Thread import warnings import heapq import bisect import random -from math import sqrt, log, isinf, isnan +import socket +from subprocess import Popen, PIPE +from tempfile import NamedTemporaryFile +from threading import Thread +from collections import defaultdict +from itertools import chain +from functools import reduce +from math import sqrt, log, isinf, isnan, pow, ceil + +if sys.version > '3': + basestring = unicode = str +else: + from itertools import imap as map, ifilter as filter -from pyspark.accumulators import PStatsParam from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ PickleSerializer, pack_long, AutoBatchedSerializer @@ -42,7 +49,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ - get_used_memory, ExternalSorter + get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync from py4j.java_collections import ListConverter, MapConverter @@ -51,20 +58,21 @@ __all__ = ["RDD"] -# TODO: for Python 3.3+, PYTHONHASHSEED should be reset to disable randomized -# hash for string def portable_hash(x): """ - This function returns consistant hash code for builtin types, especially + This function returns consistent hash code for builtin types, especially for None and tuple with None. - The algrithm is similar to that one used by CPython 2.7 + The algorithm is similar to that one used by CPython 2.7 >>> portable_hash(None) 0 >>> portable_hash((None, 1)) & 0xffffffff 219750521 """ + if sys.version >= '3.3' and 'PYTHONHASHSEED' not in os.environ: + raise Exception("Randomness of hash of string should be disabled via PYTHONHASHSEED") + if x is None: return 0 if isinstance(x, tuple): @@ -72,7 +80,7 @@ def portable_hash(x): for i in x: h ^= portable_hash(i) h *= 1000003 - h &= sys.maxint + h &= sys.maxsize h ^= len(x) if h == -1: h = -2 @@ -112,6 +120,44 @@ def _parse_memory(s): return int(float(s[:-1]) * units[s[-1].lower()]) +def _load_from_socket(port, serializer): + sock = socket.socket() + sock.settimeout(3) + try: + sock.connect(("localhost", port)) + rf = sock.makefile("rb", 65536) + for item in serializer.load_stream(rf): + yield item + finally: + sock.close() + + +def ignore_unicode_prefix(f): + """ + Ignore the 'u' prefix of string in doc tests, to make it works + in both python 2 and 3 + """ + if sys.version >= '3': + # the representation of unicode string in Python 3 does not have prefix 'u', + # so remove the prefix 'u' for doc tests + literal_re = re.compile(r"(\W|^)[uU](['])", re.UNICODE) + f.__doc__ = literal_re.sub(r'\1\2', f.__doc__) + return f + + +class Partitioner(object): + def __init__(self, numPartitions, partitionFunc): + self.numPartitions = numPartitions + self.partitionFunc = partitionFunc + + def __eq__(self, other): + return (isinstance(other, Partitioner) and self.numPartitions == other.numPartitions + and self.partitionFunc == other.partitionFunc) + + def __call__(self, k): + return self.partitionFunc(k) % self.numPartitions + + class RDD(object): """ @@ -127,7 +173,7 @@ def __init__(self, jrdd, ctx, jrdd_deserializer=AutoBatchedSerializer(PickleSeri self.ctx = ctx self._jrdd_deserializer = jrdd_deserializer self._id = jrdd.id() - self._partitionFunc = None + self.partitioner = None def _pickled(self): return self._reserialize(AutoBatchedSerializer(PickleSerializer())) @@ -141,6 +187,17 @@ def id(self): def __repr__(self): return self._jrdd.toString() + def __getnewargs__(self): + # This method is called when attempting to pickle an RDD, which is always an error: + raise Exception( + "It appears that you are attempting to broadcast an RDD or reference an RDD from an " + "action or transformation. RDD transformations and actions can only be invoked by the " + "driver, not inside of other transformations; for example, " + "rdd1.map(lambda x: rdd2.values.count() * x) is invalid because the values " + "transformation and count action cannot be performed inside of the rdd1.map " + "transformation. For more information, see SPARK-5063." + ) + @property def context(self): """ @@ -216,7 +273,7 @@ def map(self, f, preservesPartitioning=False): [('a', 1), ('b', 1), ('c', 1)] """ def func(_, iterator): - return imap(f, iterator) + return map(f, iterator) return self.mapPartitionsWithIndex(func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): @@ -231,7 +288,7 @@ def flatMap(self, f, preservesPartitioning=False): [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ def func(s, iterator): - return chain.from_iterable(imap(f, iterator)) + return chain.from_iterable(map(f, iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): @@ -294,7 +351,7 @@ def filter(self, f): [2, 4] """ def func(iterator): - return ifilter(f, iterator) + return filter(f, iterator) return self.mapPartitions(func, True) def distinct(self, numPartitions=None): @@ -306,15 +363,21 @@ def distinct(self, numPartitions=None): """ return self.map(lambda x: (x, None)) \ .reduceByKey(lambda x, _: x, numPartitions) \ - .map(lambda (x, _): x) + .map(lambda x: x[0]) def sample(self, withReplacement, fraction, seed=None): """ Return a sampled subset of this RDD. + :param withReplacement: can elements be sampled multiple times (replaced when sampled out) + :param fraction: expected size of the sample as a fraction of this RDD's size + without replacement: probability that each element is chosen; fraction must be [0, 1] + with replacement: expected number of times each element is chosen; fraction must be >= 0 + :param seed: seed for the random number generator + >>> rdd = sc.parallelize(range(100), 4) - >>> rdd.sample(False, 0.1, 81).count() - 10 + >>> 6 <= rdd.sample(False, 0.1, 81).count() <= 14 + True """ assert fraction >= 0.0, "Negative fraction value: %s" % fraction return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True) @@ -327,12 +390,14 @@ def randomSplit(self, weights, seed=None): :param seed: random seed :return: split RDDs in a list - >>> rdd = sc.parallelize(range(5), 1) + >>> rdd = sc.parallelize(range(500), 1) >>> rdd1, rdd2 = rdd.randomSplit([2, 3], 17) - >>> rdd1.collect() - [1, 3] - >>> rdd2.collect() - [0, 2, 4] + >>> len(rdd1.collect() + rdd2.collect()) + 500 + >>> 150 < rdd1.count() < 250 + True + >>> 250 < rdd2.count() < 350 + True """ s = float(sum(weights)) cweights = [0.0] @@ -375,7 +440,7 @@ def takeSample(self, withReplacement, num, seed=None): rand.shuffle(samples) return samples - maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint)) + maxSampleSize = sys.maxsize - int(numStDev * sqrt(sys.maxsize)) if num > maxSampleSize: raise ValueError( "Sample size cannot be greater than %d." % maxSampleSize) @@ -389,7 +454,7 @@ def takeSample(self, withReplacement, num, seed=None): # See: scala/spark/RDD.scala while len(samples) < num: # TODO: add log warning for when more than one iteration was run - seed = rand.randint(0, sys.maxint) + seed = rand.randint(0, sys.maxsize) samples = self.sample(withReplacement, fraction, seed).collect() rand.shuffle(samples) @@ -440,14 +505,17 @@ def union(self, other): if self._jrdd_deserializer == other._jrdd_deserializer: rdd = RDD(self._jrdd.union(other._jrdd), self.ctx, self._jrdd_deserializer) - return rdd else: # These RDDs contain data in different serialized formats, so we # must normalize them to the default serializer. self_copy = self._reserialize() other_copy = other._reserialize() - return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx, - self.ctx.serializer) + rdd = RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx, + self.ctx.serializer) + if (self.partitioner == other.partitioner and + self.getNumPartitions() == rdd.getNumPartitions()): + rdd.partitioner = self.partitioner + return rdd def intersection(self, other): """ @@ -463,7 +531,7 @@ def intersection(self, other): """ return self.map(lambda v: (v, None)) \ .cogroup(other.map(lambda v: (v, None))) \ - .filter(lambda (k, vs): all(vs)) \ + .filter(lambda k_vs: all(k_vs[1])) \ .keys() def _reserialize(self, serializer=None): @@ -505,7 +573,7 @@ def repartitionAndSortWithinPartitions(self, numPartitions=None, partitionFunc=p def sortPartition(iterator): sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted - return iter(sort(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending))) + return iter(sort(iterator, key=lambda k_v: keyfunc(k_v[0]), reverse=(not ascending))) return self.partitionBy(numPartitions, partitionFunc).mapPartitions(sortPartition, True) @@ -529,13 +597,13 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): if numPartitions is None: numPartitions = self._defaultReducePartitions() - spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == 'true') - memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) + spill = self._can_spill() + memory = self._memory_limit() serializer = self._jrdd_deserializer def sortPartition(iterator): sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted - return iter(sort(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending))) + return iter(sort(iterator, key=lambda kv: keyfunc(kv[0]), reverse=(not ascending))) if numPartitions == 1: if self.getNumPartitions() > 1: @@ -550,12 +618,12 @@ def sortPartition(iterator): return self # empty RDD maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner fraction = min(maxSampleSize / max(rddSize, 1), 1.0) - samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect() - samples = sorted(samples, reverse=(not ascending), key=keyfunc) + samples = self.sample(False, fraction, 1).map(lambda kv: kv[0]).collect() + samples = sorted(samples, key=keyfunc) # we have numPartitions many parts but one of the them has # an implicit boundary - bounds = [samples[len(samples) * (i + 1) / numPartitions] + bounds = [samples[int(len(samples) * (i + 1) / numPartitions)] for i in range(0, numPartitions - 1)] def rangePartitioner(k): @@ -618,12 +686,13 @@ def groupBy(self, f, numPartitions=None): """ return self.map(lambda x: (f(x), x)).groupByKey(numPartitions) + @ignore_unicode_prefix def pipe(self, command, env={}): """ Return an RDD created by piping elements to a forked external process. >>> sc.parallelize(['1', '2', '', '3']).pipe('cat').collect() - ['1', '2', '', '3'] + [u'1', u'2', u'', u'3'] """ def func(iterator): pipe = Popen( @@ -631,17 +700,18 @@ def func(iterator): def pipe_objs(out): for obj in iterator: - out.write(str(obj).rstrip('\n') + '\n') + s = str(obj).rstrip('\n') + '\n' + out.write(s.encode('utf-8')) out.close() Thread(target=pipe_objs, args=[pipe.stdin]).start() - return (x.rstrip('\n') for x in iter(pipe.stdout.readline, '')) + return (x.rstrip(b'\n').decode('utf-8') for x in iter(pipe.stdout.readline, b'')) return self.mapPartitions(func) def foreach(self, f): """ Applies a function to all elements of this RDD. - >>> def f(x): print x + >>> def f(x): print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) """ def processPartition(iterator): @@ -656,7 +726,7 @@ def foreachPartition(self, f): >>> def f(iterator): ... for x in iterator: - ... print x + ... print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreachPartition(f) """ def func(it): @@ -672,21 +742,8 @@ def collect(self): Return a list that contains all of the elements in this RDD. """ with SCCallSiteSync(self.context) as css: - bytesInJava = self._jrdd.collect().iterator() - return list(self._collect_iterator_through_file(bytesInJava)) - - def _collect_iterator_through_file(self, iterator): - # Transferring lots of data through Py4J can be slow because - # socket.readline() is inefficient. Instead, we'll dump the data to a - # file and read it back. - tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir) - tempFile.close() - self.ctx._writeToFile(iterator, tempFile.name) - # Read the data into Python and deserialize it: - with open(tempFile.name, 'rb') as tempFile: - for item in self._jrdd_deserializer.load_stream(tempFile): - yield item - os.unlink(tempFile.name) + port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) + return list(_load_from_socket(port, self._jrdd_deserializer)) def reduce(self, f): """ @@ -716,6 +773,43 @@ def func(iterator): return reduce(f, vals) raise ValueError("Can not reduce() empty RDD") + def treeReduce(self, f, depth=2): + """ + Reduces the elements of this RDD in a multi-level tree pattern. + + :param depth: suggested depth of the tree (default: 2) + + >>> add = lambda x, y: x + y + >>> rdd = sc.parallelize([-5, -4, -3, -2, -1, 1, 2, 3, 4], 10) + >>> rdd.treeReduce(add) + -5 + >>> rdd.treeReduce(add, 1) + -5 + >>> rdd.treeReduce(add, 2) + -5 + >>> rdd.treeReduce(add, 5) + -5 + >>> rdd.treeReduce(add, 10) + -5 + """ + if depth < 1: + raise ValueError("Depth cannot be smaller than 1 but got %d." % depth) + + zeroValue = None, True # Use the second entry to indicate whether this is a dummy value. + + def op(x, y): + if x[1]: + return y + elif y[1]: + return x + else: + return f(x[0], y[0]), False + + reduced = self.map(lambda x: (x, False)).treeAggregate(zeroValue, op, op, depth) + if reduced[1]: + raise ValueError("Cannot reduce empty RDD.") + return reduced[0] + def fold(self, zeroValue, op): """ Aggregate the elements of each partition, and then the results for all @@ -767,6 +861,58 @@ def func(iterator): return self.mapPartitions(func).fold(zeroValue, combOp) + def treeAggregate(self, zeroValue, seqOp, combOp, depth=2): + """ + Aggregates the elements of this RDD in a multi-level tree + pattern. + + :param depth: suggested depth of the tree (default: 2) + + >>> add = lambda x, y: x + y + >>> rdd = sc.parallelize([-5, -4, -3, -2, -1, 1, 2, 3, 4], 10) + >>> rdd.treeAggregate(0, add, add) + -5 + >>> rdd.treeAggregate(0, add, add, 1) + -5 + >>> rdd.treeAggregate(0, add, add, 2) + -5 + >>> rdd.treeAggregate(0, add, add, 5) + -5 + >>> rdd.treeAggregate(0, add, add, 10) + -5 + """ + if depth < 1: + raise ValueError("Depth cannot be smaller than 1 but got %d." % depth) + + if self.getNumPartitions() == 0: + return zeroValue + + def aggregatePartition(iterator): + acc = zeroValue + for obj in iterator: + acc = seqOp(acc, obj) + yield acc + + partiallyAggregated = self.mapPartitions(aggregatePartition) + numPartitions = partiallyAggregated.getNumPartitions() + scale = max(int(ceil(pow(numPartitions, 1.0 / depth))), 2) + # If creating an extra level doesn't help reduce the wall-clock time, we stop the tree + # aggregation. + while numPartitions > scale + numPartitions / scale: + numPartitions /= scale + curNumPartitions = int(numPartitions) + + def mapPartition(i, iterator): + for obj in iterator: + yield (i % curNumPartitions, obj) + + partiallyAggregated = partiallyAggregated \ + .mapPartitionsWithIndex(mapPartition) \ + .reduceByKey(combOp, curNumPartitions) \ + .values() + + return partiallyAggregated.reduce(combOp) + def max(self, key=None): """ Find the maximum item in this RDD. @@ -864,7 +1010,7 @@ def histogram(self, buckets): (('a', 'b', 'c'), [2, 2]) """ - if isinstance(buckets, (int, long)): + if isinstance(buckets, int): if buckets < 1: raise ValueError("number of buckets must be >= 1") @@ -900,6 +1046,7 @@ def minmax(a, b): raise ValueError("Can not generate buckets with infinite value") # keep them as integer if possible + inc = int(inc) if inc * buckets != maxv - minv: inc = (maxv - minv) * 1.0 / buckets @@ -1017,7 +1164,7 @@ def countPartition(iterator): yield counts def mergeMaps(m1, m2): - for k, v in m2.iteritems(): + for k, v in m2.items(): m1[k] += v return m1 return self.mapPartitions(countPartition).reduce(mergeMaps) @@ -1077,7 +1224,7 @@ def take(self, num): [91, 92, 93] """ items = [] - totalParts = self._jrdd.partitions().size() + totalParts = self.getNumPartitions() partsScanned = 0 while len(items) < num and partsScanned < totalParts: @@ -1140,7 +1287,7 @@ def isEmpty(self): >>> sc.parallelize([1]).isEmpty() False """ - return self._jrdd.partitions().size() == 0 or len(self.take(1)) == 0 + return self.getNumPartitions() == 0 or len(self.take(1)) == 0 def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None): """ @@ -1258,8 +1405,8 @@ def saveAsPickleFile(self, path, batchSize=10): >>> tmpFile = NamedTemporaryFile(delete=True) >>> tmpFile.close() >>> sc.parallelize([1, 2, 'spark', 'rdd']).saveAsPickleFile(tmpFile.name, 3) - >>> sorted(sc.pickleFile(tmpFile.name, 5).collect()) - [1, 2, 'rdd', 'spark'] + >>> sorted(sc.pickleFile(tmpFile.name, 5).map(str).collect()) + ['1', '2', 'rdd', 'spark'] """ if batchSize == 0: ser = AutoBatchedSerializer(PickleSerializer()) @@ -1267,10 +1414,15 @@ def saveAsPickleFile(self, path, batchSize=10): ser = BatchedSerializer(PickleSerializer(), batchSize) self._reserialize(ser)._jrdd.saveAsObjectFile(path) - def saveAsTextFile(self, path): + @ignore_unicode_prefix + def saveAsTextFile(self, path, compressionCodecClass=None): """ Save this RDD as a text file, using string representations of elements. + @param path: path to text file + @param compressionCodecClass: (None by default) string i.e. + "org.apache.hadoop.io.compress.GzipCodec" + >>> tempFile = NamedTemporaryFile(delete=True) >>> tempFile.close() >>> sc.parallelize(range(10)).saveAsTextFile(tempFile.name) @@ -1286,17 +1438,32 @@ def saveAsTextFile(self, path): >>> sc.parallelize(['', 'foo', '', 'bar', '']).saveAsTextFile(tempFile2.name) >>> ''.join(sorted(input(glob(tempFile2.name + "/part-0000*")))) '\\n\\n\\nbar\\nfoo\\n' + + Using compressionCodecClass + + >>> tempFile3 = NamedTemporaryFile(delete=True) + >>> tempFile3.close() + >>> codec = "org.apache.hadoop.io.compress.GzipCodec" + >>> sc.parallelize(['foo', 'bar']).saveAsTextFile(tempFile3.name, codec) + >>> from fileinput import input, hook_compressed + >>> result = sorted(input(glob(tempFile3.name + "/part*.gz"), openhook=hook_compressed)) + >>> b''.join(result).decode('utf-8') + u'bar\\nfoo\\n' """ def func(split, iterator): for x in iterator: - if not isinstance(x, basestring): + if not isinstance(x, (unicode, bytes)): x = unicode(x) if isinstance(x, unicode): x = x.encode("utf-8") yield x keyed = self.mapPartitionsWithIndex(func) keyed._bypass_serializer = True - keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path) + if compressionCodecClass: + compressionCodec = self.ctx._jvm.java.lang.Class.forName(compressionCodecClass) + keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path, compressionCodec) + else: + keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path) # Pair functions @@ -1320,7 +1487,7 @@ def keys(self): >>> m.collect() [1, 3] """ - return self.map(lambda (k, v): k) + return self.map(lambda x: x[0]) def values(self): """ @@ -1330,7 +1497,7 @@ def values(self): >>> m.collect() [2, 4] """ - return self.map(lambda (k, v): v) + return self.map(lambda x: x[1]) def reduceByKey(self, func, numPartitions=None): """ @@ -1369,7 +1536,7 @@ def reducePartition(iterator): yield m def mergeMaps(m1, m2): - for k, v in m2.iteritems(): + for k, v in m2.items(): m1[k] = func(m1[k], v) if k in m1 else v return m1 return self.mapPartitions(reducePartition).reduce(mergeMaps) @@ -1466,11 +1633,14 @@ def partitionBy(self, numPartitions, partitionFunc=portable_hash): >>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x)) >>> sets = pairs.partitionBy(2).glom().collect() - >>> set(sets[0]).intersection(set(sets[1])) - set([]) + >>> len(set(sets[0]).intersection(set(sets[1]))) + 0 """ if numPartitions is None: numPartitions = self._defaultReducePartitions() + partitioner = Partitioner(numPartitions, partitionFunc) + if self.partitioner == partitioner: + return self # Transferring O(n) objects to Java is too expensive. # Instead, we'll form the hash buckets in Python, @@ -1496,37 +1666,35 @@ def add_shuffle_key(split, iterator): if (c % 1000 == 0 and get_used_memory() > limit or c > batch): n, size = len(buckets), 0 - for split in buckets.keys(): + for split in list(buckets.keys()): yield pack_long(split) d = outputSerializer.dumps(buckets[split]) del buckets[split] yield d size += len(d) - avg = (size / n) >> 20 + avg = int(size / n) >> 20 # let 1M < avg < 10M if avg < 1: batch *= 1.5 elif avg > 10: - batch = max(batch / 1.5, 1) + batch = max(int(batch / 1.5), 1) c = 0 - for split, items in buckets.iteritems(): + for split, items in buckets.items(): yield pack_long(split) yield outputSerializer.dumps(items) - keyed = self.mapPartitionsWithIndex(add_shuffle_key) + keyed = self.mapPartitionsWithIndex(add_shuffle_key, preservesPartitioning=True) keyed._bypass_serializer = True with SCCallSiteSync(self.context) as css: pairRDD = self.ctx._jvm.PairwiseRDD( keyed._jrdd.rdd()).asJavaPairRDD() - partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, - id(partitionFunc)) - jrdd = pairRDD.partitionBy(partitioner).values() + jpartitioner = self.ctx._jvm.PythonPartitioner(numPartitions, + id(partitionFunc)) + jrdd = self.ctx._jvm.PythonRDD.valueOfPair(pairRDD.partitionBy(jpartitioner)) rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer)) - # This is required so that id(partitionFunc) remains unique, - # even if partitionFunc is a lambda: - rdd._partitionFunc = partitionFunc + rdd.partitioner = partitioner return rdd # TODO: add control over map-side aggregation @@ -1560,28 +1728,26 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numPartitions = self._defaultReducePartitions() serializer = self.ctx.serializer - spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() - == 'true') - memory = _parse_memory(self.ctx._conf.get( - "spark.python.worker.memory", "512m")) + spill = self._can_spill() + memory = self._memory_limit() agg = Aggregator(createCombiner, mergeValue, mergeCombiners) def combineLocally(iterator): merger = ExternalMerger(agg, memory * 0.9, serializer) \ if spill else InMemoryMerger(agg) merger.mergeValues(iterator) - return merger.iteritems() + return merger.items() - locally_combined = self.mapPartitions(combineLocally) + locally_combined = self.mapPartitions(combineLocally, preservesPartitioning=True) shuffled = locally_combined.partitionBy(numPartitions) def _mergeCombiners(iterator): merger = ExternalMerger(agg, memory, serializer) \ if spill else InMemoryMerger(agg) merger.mergeCombiners(iterator) - return merger.iteritems() + return merger.items() - return shuffled.mapPartitions(_mergeCombiners, True) + return shuffled.mapPartitions(_mergeCombiners, preservesPartitioning=True) def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): """ @@ -1608,7 +1774,7 @@ def foldByKey(self, zeroValue, func, numPartitions=None): >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> from operator import add - >>> rdd.foldByKey(0, add).collect() + >>> sorted(rdd.foldByKey(0, add).collect()) [('a', 2), ('b', 1)] """ def createZero(): @@ -1616,21 +1782,28 @@ def createZero(): return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions) + def _can_spill(self): + return self.ctx._conf.get("spark.shuffle.spill", "True").lower() == "true" + + def _memory_limit(self): + return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) + # TODO: support variant with custom partitioner def groupByKey(self, numPartitions=None): """ Group the values for each key in the RDD into a single sequence. - Hash-partitions the resulting RDD with into numPartitions partitions. + Hash-partitions the resulting RDD with numPartitions partitions. Note: If you are grouping in order to perform an aggregation (such as a - sum or average) over each key, using reduceByKey will provide much - better performance. + sum or average) over each key, using reduceByKey or aggregateByKey will + provide much better performance. - >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect())) + >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(rdd.groupByKey().mapValues(len).collect()) + [('a', 2), ('b', 1)] + >>> sorted(rdd.groupByKey().mapValues(list).collect()) [('a', [1, 1]), ('b', [1])] """ - def createCombiner(x): return [x] @@ -1642,8 +1815,27 @@ def mergeCombiners(a, b): a.extend(b) return a - return self.combineByKey(createCombiner, mergeValue, mergeCombiners, - numPartitions).mapValues(lambda x: ResultIterable(x)) + spill = self._can_spill() + memory = self._memory_limit() + serializer = self._jrdd_deserializer + agg = Aggregator(createCombiner, mergeValue, mergeCombiners) + + def combine(iterator): + merger = ExternalMerger(agg, memory * 0.9, serializer) \ + if spill else InMemoryMerger(agg) + merger.mergeValues(iterator) + return merger.items() + + locally_combined = self.mapPartitions(combine, preservesPartitioning=True) + shuffled = locally_combined.partitionBy(numPartitions) + + def groupByKey(it): + merger = ExternalGroupBy(agg, memory, serializer)\ + if spill else InMemoryMerger(agg) + merger.mergeCombiners(it) + return merger.items() + + return shuffled.mapPartitions(groupByKey, True).mapValues(ResultIterable) def flatMapValues(self, f): """ @@ -1656,7 +1848,7 @@ def flatMapValues(self, f): >>> x.flatMapValues(f).collect() [('a', 'x'), ('a', 'y'), ('a', 'z'), ('b', 'p'), ('b', 'r')] """ - flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) + flat_map_fn = lambda kv: ((kv[0], x) for x in f(kv[1])) return self.flatMap(flat_map_fn, preservesPartitioning=True) def mapValues(self, f): @@ -1670,7 +1862,7 @@ def mapValues(self, f): >>> x.mapValues(f).collect() [('a', 3), ('b', 1)] """ - map_values_fn = lambda (k, v): (k, f(v)) + map_values_fn = lambda kv: (kv[0], f(kv[1])) return self.map(map_values_fn, preservesPartitioning=True) def groupWith(self, other, *others): @@ -1681,8 +1873,7 @@ def groupWith(self, other, *others): >>> x = sc.parallelize([("a", 1), ("b", 4)]) >>> y = sc.parallelize([("a", 2)]) >>> z = sc.parallelize([("b", 42)]) - >>> map((lambda (x,y): (x, (list(y[0]), list(y[1]), list(y[2]), list(y[3])))), \ - sorted(list(w.groupWith(x, y, z).collect()))) + >>> [(x, tuple(map(list, y))) for x, y in sorted(list(w.groupWith(x, y, z).collect()))] [('a', ([5], [1], [2], [])), ('b', ([6], [4], [], [42]))] """ @@ -1697,7 +1888,7 @@ def cogroup(self, other, numPartitions=None): >>> x = sc.parallelize([("a", 1), ("b", 4)]) >>> y = sc.parallelize([("a", 2)]) - >>> map((lambda (x,y): (x, (list(y[0]), list(y[1])))), sorted(list(x.cogroup(y).collect()))) + >>> [(x, tuple(map(list, y))) for x, y in sorted(list(x.cogroup(y).collect()))] [('a', ([1], [2])), ('b', ([4], []))] """ return python_cogroup((self, other), numPartitions) @@ -1733,8 +1924,9 @@ def subtractByKey(self, other, numPartitions=None): >>> sorted(x.subtractByKey(y).collect()) [('b', 4), ('b', 5)] """ - def filter_func((key, vals)): - return vals[0] and not vals[1] + def filter_func(pair): + key, (val1, val2) = pair + return val1 and not val2 return self.cogroup(other, numPartitions).filter(filter_func).flatMapValues(lambda x: x[0]) def subtract(self, other, numPartitions=None): @@ -1756,8 +1948,8 @@ def keyBy(self, f): >>> x = sc.parallelize(range(0,3)).keyBy(lambda x: x*x) >>> y = sc.parallelize(zip(range(0,5), range(0,5))) - >>> map((lambda (x,y): (x, (list(y[0]), (list(y[1]))))), sorted(x.cogroup(y).collect())) - [(0, ([0], [0])), (1, ([1], [1])), (2, ([], [2])), (3, ([], [3])), (4, ([2], [4]))] + >>> [(x, list(map(list, y))) for x, y in sorted(x.cogroup(y).collect())] + [(0, [[0], [0]]), (1, [[1], [1]]), (2, [[], [2]]), (3, [[], [3]]), (4, [[2], [4]])] """ return self.map(lambda x: (f(x), x)) @@ -1816,7 +2008,7 @@ def batch_as(rdd, batchSize): my_batch = get_batch_size(self._jrdd_deserializer) other_batch = get_batch_size(other._jrdd_deserializer) - if my_batch != other_batch: + if my_batch != other_batch or not my_batch: # use the smallest batchSize for both of them batchSize = min(my_batch, other_batch) if batchSize <= 0: @@ -1886,17 +2078,18 @@ def name(self): """ Return the name of this RDD. """ - name_ = self._jrdd.name() - if name_: - return name_.encode('utf-8') + n = self._jrdd.name() + if n: + return n + @ignore_unicode_prefix def setName(self, name): """ Assign a name to this RDD. - >>> rdd1 = sc.parallelize([1,2]) + >>> rdd1 = sc.parallelize([1, 2]) >>> rdd1.setName('RDD1').name() - 'RDD1' + u'RDD1' """ self._jrdd.setName(name) return self @@ -1958,10 +2151,10 @@ def lookup(self, key): >>> sorted.lookup(1024) [] """ - values = self.filter(lambda (k, v): k == key).values() + values = self.filter(lambda kv: kv[0] == key).values() - if self._partitionFunc is not None: - return self.ctx.runJob(values, lambda x: x, [self._partitionFunc(key)], False) + if self.partitioner is not None: + return self.ctx.runJob(values, lambda x: x, [self.partitioner(key)], False) return values.collect() @@ -1977,6 +2170,7 @@ def _to_java_object_rdd(self): def countApprox(self, timeout, confidence=0.95): """ .. note:: Experimental + Approximate version of count() that returns a potentially incomplete result within a timeout, even if not all tasks have finished. @@ -1990,11 +2184,12 @@ def countApprox(self, timeout, confidence=0.95): def sumApprox(self, timeout, confidence=0.95): """ .. note:: Experimental + Approximate operation to return the sum within a timeout or meet the confidence. >>> rdd = sc.parallelize(range(1000), 10) - >>> r = sum(xrange(1000)) + >>> r = sum(range(1000)) >>> (rdd.sumApprox(1000) - r) / r < 0.05 True """ @@ -2006,11 +2201,12 @@ def sumApprox(self, timeout, confidence=0.95): def meanApprox(self, timeout, confidence=0.95): """ .. note:: Experimental + Approximate operation to return the mean within a timeout or meet the confidence. >>> rdd = sc.parallelize(range(1000), 10) - >>> r = sum(xrange(1000)) / 1000.0 + >>> r = sum(range(1000)) / 1000.0 >>> (rdd.meanApprox(1000) - r) / r < 0.05 True """ @@ -2022,6 +2218,7 @@ def meanApprox(self, timeout, confidence=0.95): def countApproxDistinct(self, relativeSD=0.05): """ .. note:: Experimental + Return approximate number of distinct elements in the RDD. The algorithm used is based on streamlib's implementation of @@ -2034,10 +2231,10 @@ def countApproxDistinct(self, relativeSD=0.05): It must be greater than 0.000017. >>> n = sc.parallelize(range(1000)).map(str).countApproxDistinct() - >>> 950 < n < 1050 + >>> 900 < n < 1100 True >>> n = sc.parallelize([i % 20 for i in range(1000)]).countApproxDistinct() - >>> 18 < n < 22 + >>> 16 < n < 24 True """ if relativeSD < 0.000017: @@ -2048,6 +2245,39 @@ def countApproxDistinct(self, relativeSD=0.05): hashRDD = self.map(lambda x: portable_hash(x) & 0xFFFFFFFF) return hashRDD._to_java_object_rdd().countApproxDistinct(relativeSD) + def toLocalIterator(self): + """ + Return an iterator that contains all of the elements in this RDD. + The iterator will consume as much memory as the largest partition in this RDD. + >>> rdd = sc.parallelize(range(10)) + >>> [x for x in rdd.toLocalIterator()] + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + """ + for partition in range(self.getNumPartitions()): + rows = self.context.runJob(self, lambda x: x, [partition]) + for row in rows: + yield row + + +def _prepare_for_python_RDD(sc, command, obj=None): + # the serialized command will be compressed by broadcast + ser = CloudPickleSerializer() + pickled_command = ser.dumps((command, sys.version_info[:2])) + if len(pickled_command) > (1 << 20): # 1M + # The broadcast will have same life cycle as created PythonRDD + broadcast = sc.broadcast(pickled_command) + pickled_command = ser.dumps(broadcast) + # There is a bug in py4j.java_gateway.JavaClass with auto_convert + # https://github.com/bartdag/py4j/issues/161 + # TODO: use auto_convert once py4j fix the bug + broadcast_vars = ListConverter().convert( + [x._jbroadcast for x in sc._pickled_broadcast_vars], + sc._gateway._gateway_client) + sc._pickled_broadcast_vars.clear() + env = MapConverter().convert(sc.environment, sc._gateway._gateway_client) + includes = ListConverter().convert(sc._python_includes, sc._gateway._gateway_client) + return pickled_command, broadcast_vars, env, includes + class PipelinedRDD(RDD): @@ -2093,13 +2323,10 @@ def pipeline_func(split, iterator): self._id = None self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False - self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None - self._broadcast = None + self.partitioner = prev.partitioner if self.preservesPartitioning else None - def __del__(self): - if self._broadcast: - self._broadcast.unpersist() - self._broadcast = None + def getNumPartitions(self): + return self._prev_jrdd.partitions().size() @property def _jrdd(self): @@ -2107,34 +2334,25 @@ def _jrdd(self): return self._jrdd_val if self._bypass_serializer: self._jrdd_deserializer = NoOpSerializer() - enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true" - profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None - command = (self.func, profileStats, self._prev_jrdd_deserializer, + + if self.ctx.profiler_collector: + profiler = self.ctx.profiler_collector.new_profiler(self.ctx) + else: + profiler = None + + command = (self.func, profiler, self._prev_jrdd_deserializer, self._jrdd_deserializer) - # the serialized command will be compressed by broadcast - ser = CloudPickleSerializer() - pickled_command = ser.dumps(command) - if len(pickled_command) > (1 << 20): # 1M - self._broadcast = self.ctx.broadcast(pickled_command) - pickled_command = ser.dumps(self._broadcast) - broadcast_vars = ListConverter().convert( - [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], - self.ctx._gateway._gateway_client) - self.ctx._pickled_broadcast_vars.clear() - env = MapConverter().convert(self.ctx.environment, - self.ctx._gateway._gateway_client) - includes = ListConverter().convert(self.ctx._python_includes, - self.ctx._gateway._gateway_client) + pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self) python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), - bytearray(pickled_command), + bytearray(pickled_cmd), env, includes, self.preservesPartitioning, self.ctx.pythonExec, - broadcast_vars, self.ctx._javaAccumulator) + bvars, self.ctx._javaAccumulator) self._jrdd_val = python_rdd.asJavaRDD() - if enable_profile: + if profiler: self._id = self._jrdd_val.id() - self.ctx._add_profile(self._id, profileStats) + self.ctx.profiler_collector.add_profiler(self._id, profiler) return self._jrdd_val def id(self): diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index 459e1427803c..fe8f87324804 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -23,7 +23,7 @@ class RDDSamplerBase(object): def __init__(self, withReplacement, seed=None): - self._seed = seed if seed is not None else random.randint(0, sys.maxint) + self._seed = seed if seed is not None else random.randint(0, sys.maxsize) self._withReplacement = withReplacement self._random = None @@ -31,7 +31,7 @@ def initRandomGenerator(self, split): self._random = random.Random(self._seed ^ split) # mixing because the initial seeds are close to each other - for _ in xrange(10): + for _ in range(10): self._random.randint(0, 1) def getUniformSample(self): diff --git a/python/pyspark/resultiterable.py b/python/pyspark/resultiterable.py index ef04c82866e6..1ab5ce14c353 100644 --- a/python/pyspark/resultiterable.py +++ b/python/pyspark/resultiterable.py @@ -15,15 +15,16 @@ # limitations under the License. # -__all__ = ["ResultIterable"] - import collections +__all__ = ["ResultIterable"] + class ResultIterable(collections.Iterable): """ - A special result iterable. This is used because the standard iterator can not be pickled + A special result iterable. This is used because the standard + iterator can not be pickled """ def __init__(self, data): diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index b8bda835174b..d8cdcda3a378 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -49,16 +49,24 @@ >>> sc.stop() """ -import cPickle -from itertools import chain, izip, product +import sys +from itertools import chain, product import marshal import struct -import sys import types import collections import zlib import itertools +if sys.version < '3': + import cPickle as pickle + protocol = 2 + from itertools import izip as zip +else: + import pickle + protocol = 3 + xrange = range + from pyspark import cloudpickle @@ -70,6 +78,7 @@ class SpecialLengths(object): PYTHON_EXCEPTION_THROWN = -2 TIMING_DATA = -3 END_OF_STREAM = -4 + NULL = -5 class Serializer(object): @@ -96,7 +105,7 @@ def _load_stream_without_unbatching(self, stream): # subclasses should override __eq__ as appropriate. def __eq__(self, other): - return isinstance(other, self.__class__) + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ def __ne__(self, other): return not self.__eq__(other) @@ -133,6 +142,8 @@ def load_stream(self, stream): def _write_with_length(self, obj, stream): serialized = self.dumps(obj) + if serialized is None: + raise ValueError("serialized value should not be None") if len(serialized) > (1 << 31): raise ValueError("can not serialize object larger than 2G") write_int(len(serialized), stream) @@ -145,8 +156,10 @@ def _read_with_length(self, stream): length = read_int(stream) if length == SpecialLengths.END_OF_DATA_SECTION: raise EOFError + elif length == SpecialLengths.NULL: + return None obj = stream.read(length) - if obj == "": + if len(obj) < length: raise EOFError return self.loads(obj) @@ -207,14 +220,33 @@ def load_stream(self, stream): def _load_stream_without_unbatching(self, stream): return self.serializer.load_stream(stream) - def __eq__(self, other): - return (isinstance(other, BatchedSerializer) and - other.serializer == self.serializer and other.batchSize == self.batchSize) - def __repr__(self): return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize) +class FlattenedValuesSerializer(BatchedSerializer): + + """ + Serializes a stream of list of pairs, split the list of values + which contain more than a certain number of objects to make them + have similar sizes. + """ + def __init__(self, serializer, batchSize=10): + BatchedSerializer.__init__(self, serializer, batchSize) + + def _batched(self, iterator): + n = self.batchSize + for key, values in iterator: + for i in range(0, len(values), n): + yield key, values[i:i + n] + + def load_stream(self, stream): + return self.serializer.load_stream(stream) + + def __repr__(self): + return "FlattenedValuesSerializer(%s, %d)" % (self.serializer, self.batchSize) + + class AutoBatchedSerializer(BatchedSerializer): """ Choose the size of batch automatically based on the size of object @@ -242,12 +274,8 @@ def dump_stream(self, iterator, stream): elif size > best * 10 and batch > 1: batch /= 2 - def __eq__(self, other): - return (isinstance(other, AutoBatchedSerializer) and - other.serializer == self.serializer and other.bestSize == self.bestSize) - - def __str__(self): - return "AutoBatchedSerializer(%s)" % str(self.serializer) + def __repr__(self): + return "AutoBatchedSerializer(%s)" % self.serializer class CartesianDeserializer(FramedSerializer): @@ -257,6 +285,7 @@ class CartesianDeserializer(FramedSerializer): """ def __init__(self, key_ser, val_ser): + FramedSerializer.__init__(self) self.key_ser = key_ser self.val_ser = val_ser @@ -265,7 +294,7 @@ def prepare_keys_values(self, stream): val_stream = self.val_ser._load_stream_without_unbatching(stream) key_is_batched = isinstance(self.key_ser, BatchedSerializer) val_is_batched = isinstance(self.val_ser, BatchedSerializer) - for (keys, vals) in izip(key_stream, val_stream): + for (keys, vals) in zip(key_stream, val_stream): keys = keys if key_is_batched else [keys] vals = vals if val_is_batched else [vals] yield (keys, vals) @@ -275,10 +304,6 @@ def load_stream(self, stream): for pair in product(keys, vals): yield pair - def __eq__(self, other): - return (isinstance(other, CartesianDeserializer) and - self.key_ser == other.key_ser and self.val_ser == other.val_ser) - def __repr__(self): return "CartesianDeserializer(%s, %s)" % \ (str(self.key_ser), str(self.val_ser)) @@ -290,22 +315,14 @@ class PairDeserializer(CartesianDeserializer): Deserializes the JavaRDD zip() of two PythonRDDs. """ - def __init__(self, key_ser, val_ser): - self.key_ser = key_ser - self.val_ser = val_ser - def load_stream(self, stream): for (keys, vals) in self.prepare_keys_values(stream): if len(keys) != len(vals): raise ValueError("Can not deserialize RDD with different number of items" " in pair: (%d, %d)" % (len(keys), len(vals))) - for pair in izip(keys, vals): + for pair in zip(keys, vals): yield pair - def __eq__(self, other): - return (isinstance(other, PairDeserializer) and - self.key_ser == other.key_ser and self.val_ser == other.val_ser) - def __repr__(self): return "PairDeserializer(%s, %s)" % (str(self.key_ser), str(self.val_ser)) @@ -354,8 +371,8 @@ def _hijack_namedtuple(): global _old_namedtuple # or it will put in closure def _copy_func(f): - return types.FunctionType(f.func_code, f.func_globals, f.func_name, - f.func_defaults, f.func_closure) + return types.FunctionType(f.__code__, f.__globals__, f.__name__, + f.__defaults__, f.__closure__) _old_namedtuple = _copy_func(collections.namedtuple) @@ -364,15 +381,15 @@ def namedtuple(*args, **kwargs): return _hack_namedtuple(cls) # replace namedtuple with new one - collections.namedtuple.func_globals["_old_namedtuple"] = _old_namedtuple - collections.namedtuple.func_globals["_hack_namedtuple"] = _hack_namedtuple - collections.namedtuple.func_code = namedtuple.func_code + collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple + collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple + collections.namedtuple.__code__ = namedtuple.__code__ collections.namedtuple.__hijack = 1 # hack the cls already generated by namedtuple # those created in other module can be pickled as normal, # so only hack those in __main__ module - for n, o in sys.modules["__main__"].__dict__.iteritems(): + for n, o in sys.modules["__main__"].__dict__.items(): if (type(o) is type and o.__base__ is tuple and hasattr(o, "_fields") and "__reduce__" not in o.__dict__): @@ -385,7 +402,7 @@ def namedtuple(*args, **kwargs): class PickleSerializer(FramedSerializer): """ - Serializes objects using Python's cPickle serializer: + Serializes objects using Python's pickle serializer: http://docs.python.org/2/library/pickle.html @@ -394,10 +411,14 @@ class PickleSerializer(FramedSerializer): """ def dumps(self, obj): - return cPickle.dumps(obj, 2) + return pickle.dumps(obj, protocol) - def loads(self, obj): - return cPickle.loads(obj) + if sys.version >= '3': + def loads(self, obj, encoding="bytes"): + return pickle.loads(obj, encoding=encoding) + else: + def loads(self, obj, encoding=None): + return pickle.loads(obj) class CloudPickleSerializer(PickleSerializer): @@ -426,7 +447,7 @@ def loads(self, obj): class AutoSerializer(FramedSerializer): """ - Choose marshal or cPickle as serialization protocol automatically + Choose marshal or pickle as serialization protocol automatically """ def __init__(self): @@ -435,19 +456,19 @@ def __init__(self): def dumps(self, obj): if self._type is not None: - return 'P' + cPickle.dumps(obj, -1) + return b'P' + pickle.dumps(obj, -1) try: - return 'M' + marshal.dumps(obj) + return b'M' + marshal.dumps(obj) except Exception: - self._type = 'P' - return 'P' + cPickle.dumps(obj, -1) + self._type = b'P' + return b'P' + pickle.dumps(obj, -1) def loads(self, obj): _type = obj[0] - if _type == 'M': + if _type == b'M': return marshal.loads(obj[1:]) - elif _type == 'P': - return cPickle.loads(obj[1:]) + elif _type == b'P': + return pickle.loads(obj[1:]) else: raise ValueError("invalid sevialization type: %s" % _type) @@ -467,8 +488,8 @@ def dumps(self, obj): def loads(self, obj): return self.serializer.loads(zlib.decompress(obj)) - def __eq__(self, other): - return isinstance(other, CompressedSerializer) and self.serializer == other.serializer + def __repr__(self): + return "CompressedSerializer(%s)" % self.serializer class UTF8Deserializer(Serializer): @@ -477,13 +498,15 @@ class UTF8Deserializer(Serializer): Deserializes streams written by String.getBytes. """ - def __init__(self, use_unicode=False): + def __init__(self, use_unicode=True): self.use_unicode = use_unicode def loads(self, stream): length = read_int(stream) if length == SpecialLengths.END_OF_DATA_SECTION: raise EOFError + elif length == SpecialLengths.NULL: + return None s = stream.read(length) return s.decode("utf-8") if self.use_unicode else s @@ -496,13 +519,13 @@ def load_stream(self, stream): except EOFError: return - def __eq__(self, other): - return isinstance(other, UTF8Deserializer) and self.use_unicode == other.use_unicode + def __repr__(self): + return "UTF8Deserializer(%s)" % self.use_unicode def read_long(stream): length = stream.read(8) - if length == "": + if not length: raise EOFError return struct.unpack("!q", length)[0] @@ -517,7 +540,7 @@ def pack_long(value): def read_int(stream): length = stream.read(4) - if length == "": + if not length: raise EOFError return struct.unpack("!i", length)[0] diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 89cf76920e35..144cdf0b0cdd 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -21,23 +21,21 @@ This file is designed to be launched as a PYTHONSTARTUP script. """ -import sys -if sys.version_info[0] != 2: - print("Error: Default Python used is Python%s" % sys.version_info.major) - print("\tSet env variable PYSPARK_PYTHON to Python2 binary and re-run it.") - sys.exit(1) - - import atexit import os import platform + +import py4j + import pyspark from pyspark.context import SparkContext +from pyspark.sql import SQLContext, HiveContext from pyspark.storagelevel import StorageLevel -# this is the equivalent of ADD_JARS -add_files = (os.environ.get("ADD_FILES").split(',') - if os.environ.get("ADD_FILES") is not None else None) +# this is the deprecated equivalent of ADD_JARS +add_files = None +if os.environ.get("ADD_FILES") is not None: + add_files = os.environ.get("ADD_FILES").split(',') if os.environ.get("SPARK_EXECUTOR_URI"): SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"]) @@ -45,6 +43,18 @@ sc = SparkContext(appName="PySparkShell", pyFiles=add_files) atexit.register(lambda: sc.stop()) +try: + # Try to access HiveConf, it will raise exception if Hive is not added + sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + sqlContext = HiveContext(sc) +except py4j.protocol.Py4JError: + sqlContext = SQLContext(sc) +except TypeError: + sqlContext = SQLContext(sc) + +# for compatibility +sqlCtx = sqlContext + print("""Welcome to ____ __ / __/__ ___ _____/ /__ @@ -56,9 +66,10 @@ platform.python_version(), platform.python_build()[0], platform.python_build()[1])) -print("SparkContext available as sc.") +print("SparkContext available as sc, %s available as sqlContext." % sqlContext.__class__.__name__) if add_files is not None: + print("Warning: ADD_FILES environment variable is deprecated, use --py-files argument instead") print("Adding files: [%s]" % ", ".join(add_files)) # The ./bin/pyspark script stores the old PYTHONSTARTUP value in OLD_PYTHONSTARTUP, diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 10a7ccd50200..1d0b16cade8b 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -16,28 +16,35 @@ # import os -import sys import platform import shutil import warnings import gc import itertools +import operator import random import pyspark.heapq3 as heapq -from pyspark.serializers import AutoBatchedSerializer, PickleSerializer +from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ + CompressedSerializer, AutoBatchedSerializer + try: import psutil + process = None + def get_used_memory(): """ Return the used memory in MB """ - process = psutil.Process(os.getpid()) + global process + if process is None or process._pid != os.getpid(): + process = psutil.Process(os.getpid()) if hasattr(process, "memory_info"): info = process.memory_info() else: info = process.get_memory_info() return info.rss >> 20 + except ImportError: def get_used_memory(): @@ -46,6 +53,7 @@ def get_used_memory(): for line in open('/proc/self/status'): if line.startswith('VmRSS:'): return int(line.split()[1]) >> 10 + else: warnings.warn("Please install psutil to have better " "support with spilling") @@ -54,6 +62,7 @@ def get_used_memory(): rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss return rss >> 20 # TODO: support windows + return 0 @@ -69,8 +78,8 @@ def _get_local_dirs(sub): # global stats -MemoryBytesSpilled = 0L -DiskBytesSpilled = 0L +MemoryBytesSpilled = 0 +DiskBytesSpilled = 0 class Aggregator(object): @@ -117,7 +126,7 @@ def mergeCombiners(self, iterator): """ Merge the combined items by mergeCombiner """ raise NotImplementedError - def iteritems(self): + def items(self): """ Return the merged items ad iterator """ raise NotImplementedError @@ -147,9 +156,15 @@ def mergeCombiners(self, iterator): for k, v in iterator: d[k] = comb(d[k], v) if k in d else v - def iteritems(self): + def items(self): """ Return the merged items ad iterator """ - return self.data.iteritems() + return iter(self.data.items()) + + +def _compressed_serializer(self, serializer=None): + # always use PickleSerializer to simplify implementation + ser = PickleSerializer() + return AutoBatchedSerializer(CompressedSerializer(ser)) class ExternalMerger(Merger): @@ -173,7 +188,7 @@ class ExternalMerger(Merger): dict. Repeat this again until combine all the items. - Before return any items, it will load each partition and - combine them seperately. Yield them before loading next + combine them separately. Yield them before loading next partition. - During loading a partition, if the memory goes over limit, @@ -182,7 +197,7 @@ class ExternalMerger(Merger): `data` and `pdata` are used to hold the merged items in memory. At first, all the data are merged into `data`. Once the used - memory goes over limit, the items in `data` are dumped indo + memory goes over limit, the items in `data` are dumped into disks, `data` will be cleared, all rest of items will be merged into `pdata` and then dumped into disks. Before returning, all the items in `pdata` will be dumped into disks. @@ -193,16 +208,16 @@ class ExternalMerger(Merger): >>> agg = SimpleAggregator(lambda x, y: x + y) >>> merger = ExternalMerger(agg, 10) >>> N = 10000 - >>> merger.mergeValues(zip(xrange(N), xrange(N)) * 10) + >>> merger.mergeValues(zip(range(N), range(N))) >>> assert merger.spills > 0 - >>> sum(v for k,v in merger.iteritems()) - 499950000 + >>> sum(v for k,v in merger.items()) + 49995000 >>> merger = ExternalMerger(agg, 10) - >>> merger.mergeCombiners(zip(xrange(N), xrange(N)) * 10) + >>> merger.mergeCombiners(zip(range(N), range(N))) >>> assert merger.spills > 0 - >>> sum(v for k,v in merger.iteritems()) - 499950000 + >>> sum(v for k,v in merger.items()) + 49995000 """ # the max total partitions created recursively @@ -212,8 +227,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None, localdirs=None, scale=1, partitions=59, batch=1000): Merger.__init__(self, aggregator) self.memory_limit = memory_limit - # default serializer is only used for tests - self.serializer = serializer or AutoBatchedSerializer(PickleSerializer()) + self.serializer = _compressed_serializer(serializer) self.localdirs = localdirs or _get_local_dirs(str(id(self))) # number of partitions when spill data into disks self.partitions = partitions @@ -221,7 +235,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None, self.batch = batch # scale is used to scale down the hash of key for recursive hash map self.scale = scale - # unpartitioned merged data + # un-partitioned merged data self.data = {} # partitioned merged data, list of dicts self.pdata = [] @@ -244,72 +258,63 @@ def _next_limit(self): def mergeValues(self, iterator): """ Combine the items by creator and combiner """ - iterator = iter(iterator) # speedup attribute lookup creator, comb = self.agg.createCombiner, self.agg.mergeValue - d, c, batch = self.data, 0, self.batch + c, data, pdata, hfun, batch = 0, self.data, self.pdata, self._partition, self.batch + limit = self.memory_limit for k, v in iterator: + d = pdata[hfun(k)] if pdata else data d[k] = comb(d[k], v) if k in d else creator(v) c += 1 - if c % batch == 0 and get_used_memory() > self.memory_limit: - self._spill() - self._partitioned_mergeValues(iterator, self._next_limit()) - break + if c >= batch: + if get_used_memory() >= limit: + self._spill() + limit = self._next_limit() + batch /= 2 + c = 0 + else: + batch *= 1.5 + + if get_used_memory() >= limit: + self._spill() def _partition(self, key): """ Return the partition for key """ return hash((key, self._seed)) % self.partitions - def _partitioned_mergeValues(self, iterator, limit=0): - """ Partition the items by key, then combine them """ - # speedup attribute lookup - creator, comb = self.agg.createCombiner, self.agg.mergeValue - c, pdata, hfun, batch = 0, self.pdata, self._partition, self.batch - - for k, v in iterator: - d = pdata[hfun(k)] - d[k] = comb(d[k], v) if k in d else creator(v) - if not limit: - continue - - c += 1 - if c % batch == 0 and get_used_memory() > limit: - self._spill() - limit = self._next_limit() + def _object_size(self, obj): + """ How much of memory for this obj, assume that all the objects + consume similar bytes of memory + """ + return 1 - def mergeCombiners(self, iterator, check=True): + def mergeCombiners(self, iterator, limit=None): """ Merge (K,V) pair by mergeCombiner """ - iterator = iter(iterator) + if limit is None: + limit = self.memory_limit # speedup attribute lookup - d, comb, batch = self.data, self.agg.mergeCombiners, self.batch - c = 0 - for k, v in iterator: - d[k] = comb(d[k], v) if k in d else v - if not check: - continue - - c += 1 - if c % batch == 0 and get_used_memory() > self.memory_limit: - self._spill() - self._partitioned_mergeCombiners(iterator, self._next_limit()) - break - - def _partitioned_mergeCombiners(self, iterator, limit=0): - """ Partition the items by key, then merge them """ - comb, pdata = self.agg.mergeCombiners, self.pdata - c, hfun = 0, self._partition + comb, hfun, objsize = self.agg.mergeCombiners, self._partition, self._object_size + c, data, pdata, batch = 0, self.data, self.pdata, self.batch for k, v in iterator: - d = pdata[hfun(k)] + d = pdata[hfun(k)] if pdata else data d[k] = comb(d[k], v) if k in d else v if not limit: continue - c += 1 - if c % self.batch == 0 and get_used_memory() > limit: - self._spill() - limit = self._next_limit() + c += objsize(v) + if c > batch: + if get_used_memory() > limit: + self._spill() + limit = self._next_limit() + batch /= 2 + c = 0 + else: + batch *= 1.5 + + if limit and get_used_memory() >= limit: + self._spill() def _spill(self): """ @@ -330,12 +335,12 @@ def _spill(self): # above limit at the first time. # open all the files for writing - streams = [open(os.path.join(path, str(i)), 'w') + streams = [open(os.path.join(path, str(i)), 'wb') for i in range(self.partitions)] - for k, v in self.data.iteritems(): + for k, v in self.data.items(): h = self._partition(k) - # put one item in batch, make it compatitable with load_stream + # put one item in batch, make it compatible with load_stream # it will increase the memory if dump them in batch self.serializer.dump_stream([(k, v)], streams[h]) @@ -344,14 +349,14 @@ def _spill(self): s.close() self.data.clear() - self.pdata = [{} for i in range(self.partitions)] + self.pdata.extend([{} for i in range(self.partitions)]) else: for i in range(self.partitions): p = os.path.join(path, str(i)) - with open(p, "w") as f: + with open(p, "wb") as f: # dump items in batch - self.serializer.dump_stream(self.pdata[i].iteritems(), f) + self.serializer.dump_stream(iter(self.pdata[i].items()), f) self.pdata[i].clear() DiskBytesSpilled += os.path.getsize(p) @@ -359,10 +364,10 @@ def _spill(self): gc.collect() # release the memory as much as possible MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 - def iteritems(self): + def items(self): """ Return all merged items as iterator """ if not self.pdata and not self.spills: - return self.data.iteritems() + return iter(self.data.items()) return self._external_items() def _external_items(self): @@ -370,29 +375,12 @@ def _external_items(self): assert not self.data if any(self.pdata): self._spill() - hard_limit = self._next_limit() + # disable partitioning and spilling when merge combiners from disk + self.pdata = [] try: for i in range(self.partitions): - self.data = {} - for j in range(self.spills): - path = self._get_spill_dir(j) - p = os.path.join(path, str(i)) - # do not check memory during merging - self.mergeCombiners(self.serializer.load_stream(open(p)), - False) - - # limit the total partitions - if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS - and j < self.spills - 1 - and get_used_memory() > hard_limit): - self.data.clear() # will read from disk again - gc.collect() # release the memory as much as possible - for v in self._recursive_merged_items(i): - yield v - return - - for v in self.data.iteritems(): + for v in self._merged_items(i): yield v self.data.clear() @@ -400,53 +388,58 @@ def _external_items(self): for j in range(self.spills): path = self._get_spill_dir(j) os.remove(os.path.join(path, str(i))) - finally: self._cleanup() - def _cleanup(self): - """ Clean up all the files in disks """ - for d in self.localdirs: - shutil.rmtree(d, True) - - def _recursive_merged_items(self, start): + def _merged_items(self, index): + self.data = {} + limit = self._next_limit() + for j in range(self.spills): + path = self._get_spill_dir(j) + p = os.path.join(path, str(index)) + # do not check memory during merging + with open(p, "rb") as f: + self.mergeCombiners(self.serializer.load_stream(f), 0) + + # limit the total partitions + if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS + and j < self.spills - 1 + and get_used_memory() > limit): + self.data.clear() # will read from disk again + gc.collect() # release the memory as much as possible + return self._recursive_merged_items(index) + + return self.data.items() + + def _recursive_merged_items(self, index): """ merge the partitioned items and return the as iterator If one partition can not be fit in memory, then them will be partitioned and merged recursively. """ - # make sure all the data are dumps into disks. - assert not self.data - if any(self.pdata): - self._spill() - assert self.spills > 0 - - for i in range(start, self.partitions): - subdirs = [os.path.join(d, "parts", str(i)) - for d in self.localdirs] - m = ExternalMerger(self.agg, self.memory_limit, self.serializer, - subdirs, self.scale * self.partitions, self.partitions) - m.pdata = [{} for _ in range(self.partitions)] - limit = self._next_limit() - - for j in range(self.spills): - path = self._get_spill_dir(j) - p = os.path.join(path, str(i)) - m._partitioned_mergeCombiners( - self.serializer.load_stream(open(p))) - - if get_used_memory() > limit: - m._spill() - limit = self._next_limit() + subdirs = [os.path.join(d, "parts", str(index)) for d in self.localdirs] + m = ExternalMerger(self.agg, self.memory_limit, self.serializer, subdirs, + self.scale * self.partitions, self.partitions, self.batch) + m.pdata = [{} for _ in range(self.partitions)] + limit = self._next_limit() + + for j in range(self.spills): + path = self._get_spill_dir(j) + p = os.path.join(path, str(index)) + with open(p, 'rb') as f: + m.mergeCombiners(self.serializer.load_stream(f), 0) + + if get_used_memory() > limit: + m._spill() + limit = self._next_limit() - for v in m._external_items(): - yield v + return m._external_items() - # remove the merged partition - for j in range(self.spills): - path = self._get_spill_dir(j) - os.remove(os.path.join(path, str(i))) + def _cleanup(self): + """ Clean up all the files in disks """ + for d in self.localdirs: + shutil.rmtree(d, True) class ExternalSorter(object): @@ -457,9 +450,10 @@ class ExternalSorter(object): The spilling will only happen when the used memory goes above the limit. + >>> sorter = ExternalSorter(1) # 1M >>> import random - >>> l = range(1024) + >>> l = list(range(1024)) >>> random.shuffle(l) >>> sorted(l) == list(sorter.sorted(l)) True @@ -469,7 +463,7 @@ class ExternalSorter(object): def __init__(self, memory_limit, serializer=None): self.memory_limit = memory_limit self.local_dirs = _get_local_dirs("sort") - self.serializer = serializer or AutoBatchedSerializer(PickleSerializer()) + self.serializer = _compressed_serializer(serializer) def _get_path(self, n): """ Choose one directory for spill by number n """ @@ -492,7 +486,7 @@ def sorted(self, iterator, key=None, reverse=False): goes above the limit. """ global MemoryBytesSpilled, DiskBytesSpilled - batch, limit = 100, self._next_limit() + batch, limit = 100, self.memory_limit chunks, current_chunk = [], [] iterator = iter(iterator) while True: @@ -503,21 +497,30 @@ def sorted(self, iterator, key=None, reverse=False): break used_memory = get_used_memory() - if used_memory > self.memory_limit: + if used_memory > limit: # sort them inplace will save memory current_chunk.sort(key=key, reverse=reverse) path = self._get_path(len(chunks)) - with open(path, 'w') as f: + with open(path, 'wb') as f: self.serializer.dump_stream(current_chunk, f) - chunks.append(self.serializer.load_stream(open(path))) + + def load(f): + for v in self.serializer.load_stream(f): + yield v + # close the file explicit once we consume all the items + # to avoid ResourceWarning in Python3 + f.close() + chunks.append(load(open(path, 'rb'))) current_chunk = [] gc.collect() + batch //= 2 limit = self._next_limit() MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 DiskBytesSpilled += os.path.getsize(path) + os.unlink(path) # data will be deleted after close elif not chunks: - batch = min(batch * 2, 10000) + batch = min(int(batch * 1.5), 10000) current_chunk.sort(key=key, reverse=reverse) if not chunks: @@ -529,6 +532,313 @@ def sorted(self, iterator, key=None, reverse=False): return heapq.merge(chunks, key=key, reverse=reverse) +class ExternalList(object): + """ + ExternalList can have many items which cannot be hold in memory in + the same time. + + >>> l = ExternalList(list(range(100))) + >>> len(l) + 100 + >>> l.append(10) + >>> len(l) + 101 + >>> for i in range(20240): + ... l.append(i) + >>> len(l) + 20341 + >>> import pickle + >>> l2 = pickle.loads(pickle.dumps(l)) + >>> len(l2) + 20341 + >>> list(l2)[100] + 10 + """ + LIMIT = 10240 + + def __init__(self, values): + self.values = values + self.count = len(values) + self._file = None + self._ser = None + + def __getstate__(self): + if self._file is not None: + self._file.flush() + with os.fdopen(os.dup(self._file.fileno()), "rb") as f: + f.seek(0) + serialized = f.read() + else: + serialized = b'' + return self.values, self.count, serialized + + def __setstate__(self, item): + self.values, self.count, serialized = item + if serialized: + self._open_file() + self._file.write(serialized) + else: + self._file = None + self._ser = None + + def __iter__(self): + if self._file is not None: + self._file.flush() + # read all items from disks first + with os.fdopen(os.dup(self._file.fileno()), 'rb') as f: + f.seek(0) + for v in self._ser.load_stream(f): + yield v + + for v in self.values: + yield v + + def __len__(self): + return self.count + + def append(self, value): + self.values.append(value) + self.count += 1 + # dump them into disk if the key is huge + if len(self.values) >= self.LIMIT: + self._spill() + + def _open_file(self): + dirs = _get_local_dirs("objects") + d = dirs[id(self) % len(dirs)] + if not os.path.exists(d): + os.makedirs(d) + p = os.path.join(d, str(id(self))) + self._file = open(p, "wb+", 65536) + self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024) + os.unlink(p) + + def __del__(self): + if self._file: + self._file.close() + self._file = None + + def _spill(self): + """ dump the values into disk """ + global MemoryBytesSpilled, DiskBytesSpilled + if self._file is None: + self._open_file() + + used_memory = get_used_memory() + pos = self._file.tell() + self._ser.dump_stream(self.values, self._file) + self.values = [] + gc.collect() + DiskBytesSpilled += self._file.tell() - pos + MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + + +class ExternalListOfList(ExternalList): + """ + An external list for list. + + >>> l = ExternalListOfList([[i, i] for i in range(100)]) + >>> len(l) + 200 + >>> l.append(range(10)) + >>> len(l) + 210 + >>> len(list(l)) + 210 + """ + + def __init__(self, values): + ExternalList.__init__(self, values) + self.count = sum(len(i) for i in values) + + def append(self, value): + ExternalList.append(self, value) + # already counted 1 in ExternalList.append + self.count += len(value) - 1 + + def __iter__(self): + for values in ExternalList.__iter__(self): + for v in values: + yield v + + +class GroupByKey(object): + """ + Group a sorted iterator as [(k1, it1), (k2, it2), ...] + + >>> k = [i // 3 for i in range(6)] + >>> v = [[i] for i in range(6)] + >>> g = GroupByKey(zip(k, v)) + >>> [(k, list(it)) for k, it in g] + [(0, [0, 1, 2]), (1, [3, 4, 5])] + """ + + def __init__(self, iterator): + self.iterator = iterator + + def __iter__(self): + key, values = None, None + for k, v in self.iterator: + if values is not None and k == key: + values.append(v) + else: + if values is not None: + yield (key, values) + key = k + values = ExternalListOfList([v]) + if values is not None: + yield (key, values) + + +class ExternalGroupBy(ExternalMerger): + + """ + Group by the items by key. If any partition of them can not been + hold in memory, it will do sort based group by. + + This class works as follows: + + - It repeatedly group the items by key and save them in one dict in + memory. + + - When the used memory goes above memory limit, it will split + the combined data into partitions by hash code, dump them + into disk, one file per partition. If the number of keys + in one partitions is smaller than 1000, it will sort them + by key before dumping into disk. + + - Then it goes through the rest of the iterator, group items + by key into different dict by hash. Until the used memory goes over + memory limit, it dump all the dicts into disks, one file per + dict. Repeat this again until combine all the items. It + also will try to sort the items by key in each partition + before dumping into disks. + + - It will yield the grouped items partitions by partitions. + If the data in one partitions can be hold in memory, then it + will load and combine them in memory and yield. + + - If the dataset in one partition cannot be hold in memory, + it will sort them first. If all the files are already sorted, + it merge them by heap.merge(), so it will do external sort + for all the files. + + - After sorting, `GroupByKey` class will put all the continuous + items with the same key as a group, yield the values as + an iterator. + """ + SORT_KEY_LIMIT = 1000 + + def flattened_serializer(self): + assert isinstance(self.serializer, BatchedSerializer) + ser = self.serializer + return FlattenedValuesSerializer(ser, 20) + + def _object_size(self, obj): + return len(obj) + + def _spill(self): + """ + dump already partitioned data into disks. + """ + global MemoryBytesSpilled, DiskBytesSpilled + path = self._get_spill_dir(self.spills) + if not os.path.exists(path): + os.makedirs(path) + + used_memory = get_used_memory() + if not self.pdata: + # The data has not been partitioned, it will iterator the + # data once, write them into different files, has no + # additional memory. It only called when the memory goes + # above limit at the first time. + + # open all the files for writing + streams = [open(os.path.join(path, str(i)), 'wb') + for i in range(self.partitions)] + + # If the number of keys is small, then the overhead of sort is small + # sort them before dumping into disks + self._sorted = len(self.data) < self.SORT_KEY_LIMIT + if self._sorted: + self.serializer = self.flattened_serializer() + for k in sorted(self.data.keys()): + h = self._partition(k) + self.serializer.dump_stream([(k, self.data[k])], streams[h]) + else: + for k, v in self.data.items(): + h = self._partition(k) + self.serializer.dump_stream([(k, v)], streams[h]) + + for s in streams: + DiskBytesSpilled += s.tell() + s.close() + + self.data.clear() + # self.pdata is cached in `mergeValues` and `mergeCombiners` + self.pdata.extend([{} for i in range(self.partitions)]) + + else: + for i in range(self.partitions): + p = os.path.join(path, str(i)) + with open(p, "wb") as f: + # dump items in batch + if self._sorted: + # sort by key only (stable) + sorted_items = sorted(self.pdata[i].items(), key=operator.itemgetter(0)) + self.serializer.dump_stream(sorted_items, f) + else: + self.serializer.dump_stream(self.pdata[i].items(), f) + self.pdata[i].clear() + DiskBytesSpilled += os.path.getsize(p) + + self.spills += 1 + gc.collect() # release the memory as much as possible + MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + + def _merged_items(self, index): + size = sum(os.path.getsize(os.path.join(self._get_spill_dir(j), str(index))) + for j in range(self.spills)) + # if the memory can not hold all the partition, + # then use sort based merge. Because of compression, + # the data on disks will be much smaller than needed memory + if size >= self.memory_limit << 17: # * 1M / 8 + return self._merge_sorted_items(index) + + self.data = {} + for j in range(self.spills): + path = self._get_spill_dir(j) + p = os.path.join(path, str(index)) + # do not check memory during merging + with open(p, "rb") as f: + self.mergeCombiners(self.serializer.load_stream(f), 0) + return self.data.items() + + def _merge_sorted_items(self, index): + """ load a partition from disk, then sort and group by key """ + def load_partition(j): + path = self._get_spill_dir(j) + p = os.path.join(path, str(index)) + with open(p, 'rb', 65536) as f: + for v in self.serializer.load_stream(f): + yield v + + disk_items = [load_partition(j) for j in range(self.spills)] + + if self._sorted: + # all the partitions are already sorted + sorted_items = heapq.merge(disk_items, key=operator.itemgetter(0)) + + else: + # Flatten the combined values, so it will not consume huge + # memory during merging sort. + ser = self.flattened_serializer() + sorter = ExternalSorter(self.memory_limit, ser) + sorted_items = sorter.sorted(itertools.chain(*disk_items), + key=operator.itemgetter(0)) + return ((k, vs) for k, vs in GroupByKey(sorted_items)) + + if __name__ == "__main__": import doctest doctest.testmod() diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py deleted file mode 100644 index 1990323249cf..000000000000 --- a/python/pyspark/sql.py +++ /dev/null @@ -1,2138 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -public classes of Spark SQL: - - - L{SQLContext} - Main entry point for SQL functionality. - - L{SchemaRDD} - A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In - addition to normal RDD operations, SchemaRDDs also support SQL. - - L{Row} - A Row of data returned by a Spark SQL query. - - L{HiveContext} - Main entry point for accessing data stored in Apache Hive.. -""" - -import itertools -import decimal -import datetime -import keyword -import warnings -import json -import re -from array import array -from operator import itemgetter -from itertools import imap - -from py4j.protocol import Py4JError -from py4j.java_collections import ListConverter, MapConverter - -from pyspark.rdd import RDD -from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \ - CloudPickleSerializer, UTF8Deserializer -from pyspark.storagelevel import StorageLevel -from pyspark.traceback_utils import SCCallSiteSync - - -__all__ = [ - "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", - "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", - "ShortType", "ArrayType", "MapType", "StructField", "StructType", - "SQLContext", "HiveContext", "SchemaRDD", "Row"] - - -class DataType(object): - - """Spark SQL DataType""" - - def __repr__(self): - return self.__class__.__name__ - - def __hash__(self): - return hash(str(self)) - - def __eq__(self, other): - return (isinstance(other, self.__class__) and - self.__dict__ == other.__dict__) - - def __ne__(self, other): - return not self.__eq__(other) - - @classmethod - def typeName(cls): - return cls.__name__[:-4].lower() - - def jsonValue(self): - return self.typeName() - - def json(self): - return json.dumps(self.jsonValue(), - separators=(',', ':'), - sort_keys=True) - - -class PrimitiveTypeSingleton(type): - - """Metaclass for PrimitiveType""" - - _instances = {} - - def __call__(cls): - if cls not in cls._instances: - cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__() - return cls._instances[cls] - - -class PrimitiveType(DataType): - - """Spark SQL PrimitiveType""" - - __metaclass__ = PrimitiveTypeSingleton - - def __eq__(self, other): - # because they should be the same object - return self is other - - -class NullType(PrimitiveType): - - """Spark SQL NullType - - The data type representing None, used for the types which has not - been inferred. - """ - - -class StringType(PrimitiveType): - - """Spark SQL StringType - - The data type representing string values. - """ - - -class BinaryType(PrimitiveType): - - """Spark SQL BinaryType - - The data type representing bytearray values. - """ - - -class BooleanType(PrimitiveType): - - """Spark SQL BooleanType - - The data type representing bool values. - """ - - -class DateType(PrimitiveType): - - """Spark SQL DateType - - The data type representing datetime.date values. - """ - - -class TimestampType(PrimitiveType): - - """Spark SQL TimestampType - - The data type representing datetime.datetime values. - """ - - -class DecimalType(DataType): - - """Spark SQL DecimalType - - The data type representing decimal.Decimal values. - """ - - def __init__(self, precision=None, scale=None): - self.precision = precision - self.scale = scale - self.hasPrecisionInfo = precision is not None - - def jsonValue(self): - if self.hasPrecisionInfo: - return "decimal(%d,%d)" % (self.precision, self.scale) - else: - return "decimal" - - def __repr__(self): - if self.hasPrecisionInfo: - return "DecimalType(%d,%d)" % (self.precision, self.scale) - else: - return "DecimalType()" - - -class DoubleType(PrimitiveType): - - """Spark SQL DoubleType - - The data type representing float values. - """ - - -class FloatType(PrimitiveType): - - """Spark SQL FloatType - - The data type representing single precision floating-point values. - """ - - -class ByteType(PrimitiveType): - - """Spark SQL ByteType - - The data type representing int values with 1 singed byte. - """ - - -class IntegerType(PrimitiveType): - - """Spark SQL IntegerType - - The data type representing int values. - """ - - -class LongType(PrimitiveType): - - """Spark SQL LongType - - The data type representing long values. If the any value is - beyond the range of [-9223372036854775808, 9223372036854775807], - please use DecimalType. - """ - - -class ShortType(PrimitiveType): - - """Spark SQL ShortType - - The data type representing int values with 2 signed bytes. - """ - - -class ArrayType(DataType): - - """Spark SQL ArrayType - - The data type representing list values. An ArrayType object - comprises two fields, elementType (a DataType) and containsNull (a bool). - The field of elementType is used to specify the type of array elements. - The field of containsNull is used to specify if the array has None values. - - """ - - def __init__(self, elementType, containsNull=True): - """Creates an ArrayType - - :param elementType: the data type of elements. - :param containsNull: indicates whether the list contains None values. - - >>> ArrayType(StringType) == ArrayType(StringType, True) - True - >>> ArrayType(StringType, False) == ArrayType(StringType) - False - """ - self.elementType = elementType - self.containsNull = containsNull - - def __repr__(self): - return "ArrayType(%s,%s)" % (self.elementType, - str(self.containsNull).lower()) - - def jsonValue(self): - return {"type": self.typeName(), - "elementType": self.elementType.jsonValue(), - "containsNull": self.containsNull} - - @classmethod - def fromJson(cls, json): - return ArrayType(_parse_datatype_json_value(json["elementType"]), - json["containsNull"]) - - -class MapType(DataType): - - """Spark SQL MapType - - The data type representing dict values. A MapType object comprises - three fields, keyType (a DataType), valueType (a DataType) and - valueContainsNull (a bool). - - The field of keyType is used to specify the type of keys in the map. - The field of valueType is used to specify the type of values in the map. - The field of valueContainsNull is used to specify if values of this - map has None values. - - For values of a MapType column, keys are not allowed to have None values. - - """ - - def __init__(self, keyType, valueType, valueContainsNull=True): - """Creates a MapType - :param keyType: the data type of keys. - :param valueType: the data type of values. - :param valueContainsNull: indicates whether values contains - null values. - - >>> (MapType(StringType, IntegerType) - ... == MapType(StringType, IntegerType, True)) - True - >>> (MapType(StringType, IntegerType, False) - ... == MapType(StringType, FloatType)) - False - """ - self.keyType = keyType - self.valueType = valueType - self.valueContainsNull = valueContainsNull - - def __repr__(self): - return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, - str(self.valueContainsNull).lower()) - - def jsonValue(self): - return {"type": self.typeName(), - "keyType": self.keyType.jsonValue(), - "valueType": self.valueType.jsonValue(), - "valueContainsNull": self.valueContainsNull} - - @classmethod - def fromJson(cls, json): - return MapType(_parse_datatype_json_value(json["keyType"]), - _parse_datatype_json_value(json["valueType"]), - json["valueContainsNull"]) - - -class StructField(DataType): - - """Spark SQL StructField - - Represents a field in a StructType. - A StructField object comprises three fields, name (a string), - dataType (a DataType) and nullable (a bool). The field of name - is the name of a StructField. The field of dataType specifies - the data type of a StructField. - - The field of nullable specifies if values of a StructField can - contain None values. - - """ - - def __init__(self, name, dataType, nullable=True, metadata=None): - """Creates a StructField - :param name: the name of this field. - :param dataType: the data type of this field. - :param nullable: indicates whether values of this field - can be null. - :param metadata: metadata of this field, which is a map from string - to simple type that can be serialized to JSON - automatically - - >>> (StructField("f1", StringType, True) - ... == StructField("f1", StringType, True)) - True - >>> (StructField("f1", StringType, True) - ... == StructField("f2", StringType, True)) - False - """ - self.name = name - self.dataType = dataType - self.nullable = nullable - self.metadata = metadata or {} - - def __repr__(self): - return "StructField(%s,%s,%s)" % (self.name, self.dataType, - str(self.nullable).lower()) - - def jsonValue(self): - return {"name": self.name, - "type": self.dataType.jsonValue(), - "nullable": self.nullable, - "metadata": self.metadata} - - @classmethod - def fromJson(cls, json): - return StructField(json["name"], - _parse_datatype_json_value(json["type"]), - json["nullable"], - json["metadata"]) - - -class StructType(DataType): - - """Spark SQL StructType - - The data type representing rows. - A StructType object comprises a list of L{StructField}. - - """ - - def __init__(self, fields): - """Creates a StructType - - >>> struct1 = StructType([StructField("f1", StringType, True)]) - >>> struct2 = StructType([StructField("f1", StringType, True)]) - >>> struct1 == struct2 - True - >>> struct1 = StructType([StructField("f1", StringType, True)]) - >>> struct2 = StructType([StructField("f1", StringType, True), - ... [StructField("f2", IntegerType, False)]]) - >>> struct1 == struct2 - False - """ - self.fields = fields - - def __repr__(self): - return ("StructType(List(%s))" % - ",".join(str(field) for field in self.fields)) - - def jsonValue(self): - return {"type": self.typeName(), - "fields": [f.jsonValue() for f in self.fields]} - - @classmethod - def fromJson(cls, json): - return StructType([StructField.fromJson(f) for f in json["fields"]]) - - -class UserDefinedType(DataType): - """ - .. note:: WARN: Spark Internal Use Only - SQL User-Defined Type (UDT). - """ - - @classmethod - def typeName(cls): - return cls.__name__.lower() - - @classmethod - def sqlType(cls): - """ - Underlying SQL storage type for this UDT. - """ - raise NotImplementedError("UDT must implement sqlType().") - - @classmethod - def module(cls): - """ - The Python module of the UDT. - """ - raise NotImplementedError("UDT must implement module().") - - @classmethod - def scalaUDT(cls): - """ - The class name of the paired Scala UDT. - """ - raise NotImplementedError("UDT must have a paired Scala UDT.") - - def serialize(self, obj): - """ - Converts the a user-type object into a SQL datum. - """ - raise NotImplementedError("UDT must implement serialize().") - - def deserialize(self, datum): - """ - Converts a SQL datum into a user-type object. - """ - raise NotImplementedError("UDT must implement deserialize().") - - def json(self): - return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) - - def jsonValue(self): - schema = { - "type": "udt", - "class": self.scalaUDT(), - "pyClass": "%s.%s" % (self.module(), type(self).__name__), - "sqlType": self.sqlType().jsonValue() - } - return schema - - @classmethod - def fromJson(cls, json): - pyUDT = json["pyClass"] - split = pyUDT.rfind(".") - pyModule = pyUDT[:split] - pyClass = pyUDT[split+1:] - m = __import__(pyModule, globals(), locals(), [pyClass], -1) - UDT = getattr(m, pyClass) - return UDT() - - def __eq__(self, other): - return type(self) == type(other) - - -_all_primitive_types = dict((v.typeName(), v) - for v in globals().itervalues() - if type(v) is PrimitiveTypeSingleton and - v.__base__ == PrimitiveType) - - -_all_complex_types = dict((v.typeName(), v) - for v in [ArrayType, MapType, StructType]) - - -def _parse_datatype_json_string(json_string): - """Parses the given data type JSON string. - >>> def check_datatype(datatype): - ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json()) - ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) - ... return datatype == python_datatype - >>> all(check_datatype(cls()) for cls in _all_primitive_types.values()) - True - >>> # Simple ArrayType. - >>> simple_arraytype = ArrayType(StringType(), True) - >>> check_datatype(simple_arraytype) - True - >>> # Simple MapType. - >>> simple_maptype = MapType(StringType(), LongType()) - >>> check_datatype(simple_maptype) - True - >>> # Simple StructType. - >>> simple_structtype = StructType([ - ... StructField("a", DecimalType(), False), - ... StructField("b", BooleanType(), True), - ... StructField("c", LongType(), True), - ... StructField("d", BinaryType(), False)]) - >>> check_datatype(simple_structtype) - True - >>> # Complex StructType. - >>> complex_structtype = StructType([ - ... StructField("simpleArray", simple_arraytype, True), - ... StructField("simpleMap", simple_maptype, True), - ... StructField("simpleStruct", simple_structtype, True), - ... StructField("boolean", BooleanType(), False), - ... StructField("withMeta", DoubleType(), False, {"name": "age"})]) - >>> check_datatype(complex_structtype) - True - >>> # Complex ArrayType. - >>> complex_arraytype = ArrayType(complex_structtype, True) - >>> check_datatype(complex_arraytype) - True - >>> # Complex MapType. - >>> complex_maptype = MapType(complex_structtype, - ... complex_arraytype, False) - >>> check_datatype(complex_maptype) - True - >>> check_datatype(ExamplePointUDT()) - True - >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> check_datatype(structtype_with_udt) - True - """ - return _parse_datatype_json_value(json.loads(json_string)) - - -_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)") - - -def _parse_datatype_json_value(json_value): - if type(json_value) is unicode: - if json_value in _all_primitive_types.keys(): - return _all_primitive_types[json_value]() - elif json_value == u'decimal': - return DecimalType() - elif _FIXED_DECIMAL.match(json_value): - m = _FIXED_DECIMAL.match(json_value) - return DecimalType(int(m.group(1)), int(m.group(2))) - else: - raise ValueError("Could not parse datatype: %s" % json_value) - else: - tpe = json_value["type"] - if tpe in _all_complex_types: - return _all_complex_types[tpe].fromJson(json_value) - elif tpe == 'udt': - return UserDefinedType.fromJson(json_value) - else: - raise ValueError("not supported type: %s" % tpe) - - -# Mapping Python types to Spark SQL DataType -_type_mappings = { - type(None): NullType, - bool: BooleanType, - int: IntegerType, - long: LongType, - float: DoubleType, - str: StringType, - unicode: StringType, - bytearray: BinaryType, - decimal.Decimal: DecimalType, - datetime.date: DateType, - datetime.datetime: TimestampType, - datetime.time: TimestampType, -} - - -def _infer_type(obj): - """Infer the DataType from obj - - >>> p = ExamplePoint(1.0, 2.0) - >>> _infer_type(p) - ExamplePointUDT - """ - if obj is None: - raise ValueError("Can not infer type for None") - - if hasattr(obj, '__UDT__'): - return obj.__UDT__ - - dataType = _type_mappings.get(type(obj)) - if dataType is not None: - return dataType() - - if isinstance(obj, dict): - for key, value in obj.iteritems(): - if key is not None and value is not None: - return MapType(_infer_type(key), _infer_type(value), True) - else: - return MapType(NullType(), NullType(), True) - elif isinstance(obj, (list, array)): - for v in obj: - if v is not None: - return ArrayType(_infer_type(obj[0]), True) - else: - return ArrayType(NullType(), True) - else: - try: - return _infer_schema(obj) - except ValueError: - raise ValueError("not supported type: %s" % type(obj)) - - -def _infer_schema(row): - """Infer the schema from dict/namedtuple/object""" - if isinstance(row, dict): - items = sorted(row.items()) - - elif isinstance(row, tuple): - if hasattr(row, "_fields"): # namedtuple - items = zip(row._fields, tuple(row)) - elif hasattr(row, "__FIELDS__"): # Row - items = zip(row.__FIELDS__, tuple(row)) - elif all(isinstance(x, tuple) and len(x) == 2 for x in row): - items = row - else: - raise ValueError("Can't infer schema from tuple") - - elif hasattr(row, "__dict__"): # object - items = sorted(row.__dict__.items()) - - else: - raise ValueError("Can not infer schema for type: %s" % type(row)) - - fields = [StructField(k, _infer_type(v), True) for k, v in items] - return StructType(fields) - - -def _need_python_to_sql_conversion(dataType): - """ - Checks whether we need python to sql conversion for the given type. - For now, only UDTs need this conversion. - - >>> _need_python_to_sql_conversion(DoubleType()) - False - >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), - ... StructField("values", ArrayType(DoubleType(), False), False)]) - >>> _need_python_to_sql_conversion(schema0) - False - >>> _need_python_to_sql_conversion(ExamplePointUDT()) - True - >>> schema1 = ArrayType(ExamplePointUDT(), False) - >>> _need_python_to_sql_conversion(schema1) - True - >>> schema2 = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> _need_python_to_sql_conversion(schema2) - True - """ - if isinstance(dataType, StructType): - return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields]) - elif isinstance(dataType, ArrayType): - return _need_python_to_sql_conversion(dataType.elementType) - elif isinstance(dataType, MapType): - return _need_python_to_sql_conversion(dataType.keyType) or \ - _need_python_to_sql_conversion(dataType.valueType) - elif isinstance(dataType, UserDefinedType): - return True - else: - return False - - -def _python_to_sql_converter(dataType): - """ - Returns a converter that converts a Python object into a SQL datum for the given type. - - >>> conv = _python_to_sql_converter(DoubleType()) - >>> conv(1.0) - 1.0 - >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False)) - >>> conv([1.0, 2.0]) - [1.0, 2.0] - >>> conv = _python_to_sql_converter(ExamplePointUDT()) - >>> conv(ExamplePoint(1.0, 2.0)) - [1.0, 2.0] - >>> schema = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> conv = _python_to_sql_converter(schema) - >>> conv((1.0, ExamplePoint(1.0, 2.0))) - (1.0, [1.0, 2.0]) - """ - if not _need_python_to_sql_conversion(dataType): - return lambda x: x - - if isinstance(dataType, StructType): - names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) - converters = map(_python_to_sql_converter, types) - - def converter(obj): - if isinstance(obj, dict): - return tuple(c(obj.get(n)) for n, c in zip(names, converters)) - elif isinstance(obj, tuple): - if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"): - return tuple(c(v) for c, v in zip(converters, obj)) - elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs - d = dict(obj) - return tuple(c(d.get(n)) for n, c in zip(names, converters)) - else: - return tuple(c(v) for c, v in zip(converters, obj)) - else: - raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) - return converter - elif isinstance(dataType, ArrayType): - element_converter = _python_to_sql_converter(dataType.elementType) - return lambda a: [element_converter(v) for v in a] - elif isinstance(dataType, MapType): - key_converter = _python_to_sql_converter(dataType.keyType) - value_converter = _python_to_sql_converter(dataType.valueType) - return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) - elif isinstance(dataType, UserDefinedType): - return lambda obj: dataType.serialize(obj) - else: - raise ValueError("Unexpected type %r" % dataType) - - -def _has_nulltype(dt): - """ Return whether there is NullType in `dt` or not """ - if isinstance(dt, StructType): - return any(_has_nulltype(f.dataType) for f in dt.fields) - elif isinstance(dt, ArrayType): - return _has_nulltype((dt.elementType)) - elif isinstance(dt, MapType): - return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType) - else: - return isinstance(dt, NullType) - - -def _merge_type(a, b): - if isinstance(a, NullType): - return b - elif isinstance(b, NullType): - return a - elif type(a) is not type(b): - # TODO: type cast (such as int -> long) - raise TypeError("Can not merge type %s and %s" % (a, b)) - - # same type - if isinstance(a, StructType): - nfs = dict((f.name, f.dataType) for f in b.fields) - fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()))) - for f in a.fields] - names = set([f.name for f in fields]) - for n in nfs: - if n not in names: - fields.append(StructField(n, nfs[n])) - return StructType(fields) - - elif isinstance(a, ArrayType): - return ArrayType(_merge_type(a.elementType, b.elementType), True) - - elif isinstance(a, MapType): - return MapType(_merge_type(a.keyType, b.keyType), - _merge_type(a.valueType, b.valueType), - True) - else: - return a - - -def _create_converter(dataType): - """Create an converter to drop the names of fields in obj """ - if isinstance(dataType, ArrayType): - conv = _create_converter(dataType.elementType) - return lambda row: map(conv, row) - - elif isinstance(dataType, MapType): - kconv = _create_converter(dataType.keyType) - vconv = _create_converter(dataType.valueType) - return lambda row: dict((kconv(k), vconv(v)) for k, v in row.iteritems()) - - elif isinstance(dataType, NullType): - return lambda x: None - - elif not isinstance(dataType, StructType): - return lambda x: x - - # dataType must be StructType - names = [f.name for f in dataType.fields] - converters = [_create_converter(f.dataType) for f in dataType.fields] - - def convert_struct(obj): - if obj is None: - return - - if isinstance(obj, tuple): - if hasattr(obj, "_fields"): - d = dict(zip(obj._fields, obj)) - elif hasattr(obj, "__FIELDS__"): - d = dict(zip(obj.__FIELDS__, obj)) - elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): - d = dict(obj) - else: - raise ValueError("unexpected tuple: %s" % str(obj)) - - elif isinstance(obj, dict): - d = obj - elif hasattr(obj, "__dict__"): # object - d = obj.__dict__ - else: - raise ValueError("Unexpected obj: %s" % obj) - - return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) - - return convert_struct - - -_BRACKETS = {'(': ')', '[': ']', '{': '}'} - - -def _split_schema_abstract(s): - """ - split the schema abstract into fields - - >>> _split_schema_abstract("a b c") - ['a', 'b', 'c'] - >>> _split_schema_abstract("a(a b)") - ['a(a b)'] - >>> _split_schema_abstract("a b[] c{a b}") - ['a', 'b[]', 'c{a b}'] - >>> _split_schema_abstract(" ") - [] - """ - - r = [] - w = '' - brackets = [] - for c in s: - if c == ' ' and not brackets: - if w: - r.append(w) - w = '' - else: - w += c - if c in _BRACKETS: - brackets.append(c) - elif c in _BRACKETS.values(): - if not brackets or c != _BRACKETS[brackets.pop()]: - raise ValueError("unexpected " + c) - - if brackets: - raise ValueError("brackets not closed: %s" % brackets) - if w: - r.append(w) - return r - - -def _parse_field_abstract(s): - """ - Parse a field in schema abstract - - >>> _parse_field_abstract("a") - StructField(a,None,true) - >>> _parse_field_abstract("b(c d)") - StructField(b,StructType(...c,None,true),StructField(d... - >>> _parse_field_abstract("a[]") - StructField(a,ArrayType(None,true),true) - >>> _parse_field_abstract("a{[]}") - StructField(a,MapType(None,ArrayType(None,true),true),true) - """ - if set(_BRACKETS.keys()) & set(s): - idx = min((s.index(c) for c in _BRACKETS if c in s)) - name = s[:idx] - return StructField(name, _parse_schema_abstract(s[idx:]), True) - else: - return StructField(s, None, True) - - -def _parse_schema_abstract(s): - """ - parse abstract into schema - - >>> _parse_schema_abstract("a b c") - StructType...a...b...c... - >>> _parse_schema_abstract("a[b c] b{}") - StructType...a,ArrayType...b...c...b,MapType... - >>> _parse_schema_abstract("c{} d{a b}") - StructType...c,MapType...d,MapType...a...b... - >>> _parse_schema_abstract("a b(t)").fields[1] - StructField(b,StructType(List(StructField(t,None,true))),true) - """ - s = s.strip() - if not s: - return - - elif s.startswith('('): - return _parse_schema_abstract(s[1:-1]) - - elif s.startswith('['): - return ArrayType(_parse_schema_abstract(s[1:-1]), True) - - elif s.startswith('{'): - return MapType(None, _parse_schema_abstract(s[1:-1])) - - parts = _split_schema_abstract(s) - fields = [_parse_field_abstract(p) for p in parts] - return StructType(fields) - - -def _infer_schema_type(obj, dataType): - """ - Fill the dataType with types infered from obj - - >>> schema = _parse_schema_abstract("a b c d") - >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10)) - >>> _infer_schema_type(row, schema) - StructType...IntegerType...DoubleType...StringType...DateType... - >>> row = [[1], {"key": (1, 2.0)}] - >>> schema = _parse_schema_abstract("a[] b{c d}") - >>> _infer_schema_type(row, schema) - StructType...a,ArrayType...b,MapType(StringType,...c,IntegerType... - """ - if dataType is None: - return _infer_type(obj) - - if not obj: - return NullType() - - if isinstance(dataType, ArrayType): - eType = _infer_schema_type(obj[0], dataType.elementType) - return ArrayType(eType, True) - - elif isinstance(dataType, MapType): - k, v = obj.iteritems().next() - return MapType(_infer_schema_type(k, dataType.keyType), - _infer_schema_type(v, dataType.valueType)) - - elif isinstance(dataType, StructType): - fs = dataType.fields - assert len(fs) == len(obj), \ - "Obj(%s) have different length with fields(%s)" % (obj, fs) - fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True) - for o, f in zip(obj, fs)] - return StructType(fields) - - else: - raise ValueError("Unexpected dataType: %s" % dataType) - - -_acceptable_types = { - BooleanType: (bool,), - ByteType: (int, long), - ShortType: (int, long), - IntegerType: (int, long), - LongType: (int, long), - FloatType: (float,), - DoubleType: (float,), - DecimalType: (decimal.Decimal,), - StringType: (str, unicode), - BinaryType: (bytearray,), - DateType: (datetime.date,), - TimestampType: (datetime.datetime,), - ArrayType: (list, tuple, array), - MapType: (dict,), - StructType: (tuple, list), -} - - -def _verify_type(obj, dataType): - """ - Verify the type of obj against dataType, raise an exception if - they do not match. - - >>> _verify_type(None, StructType([])) - >>> _verify_type("", StringType()) - >>> _verify_type(0, IntegerType()) - >>> _verify_type(range(3), ArrayType(ShortType())) - >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - TypeError:... - >>> _verify_type({}, MapType(StringType(), IntegerType())) - >>> _verify_type((), StructType([])) - >>> _verify_type([], StructType([])) - >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError:... - >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) - >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError:... - """ - # all objects are nullable - if obj is None: - return - - if isinstance(dataType, UserDefinedType): - if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): - raise ValueError("%r is not an instance of type %r" % (obj, dataType)) - _verify_type(dataType.serialize(obj), dataType.sqlType()) - return - - _type = type(dataType) - assert _type in _acceptable_types, "unkown datatype: %s" % dataType - - # subclass of them can not be deserialized in JVM - if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s can not accept object in type %s" - % (dataType, type(obj))) - - if isinstance(dataType, ArrayType): - for i in obj: - _verify_type(i, dataType.elementType) - - elif isinstance(dataType, MapType): - for k, v in obj.iteritems(): - _verify_type(k, dataType.keyType) - _verify_type(v, dataType.valueType) - - elif isinstance(dataType, StructType): - if len(obj) != len(dataType.fields): - raise ValueError("Length of object (%d) does not match with" - "length of fields (%d)" % (len(obj), len(dataType.fields))) - for v, f in zip(obj, dataType.fields): - _verify_type(v, f.dataType) - - -_cached_cls = {} - - -def _restore_object(dataType, obj): - """ Restore object during unpickling. """ - # use id(dataType) as key to speed up lookup in dict - # Because of batched pickling, dataType will be the - # same object in most cases. - k = id(dataType) - cls = _cached_cls.get(k) - if cls is None: - # use dataType as key to avoid create multiple class - cls = _cached_cls.get(dataType) - if cls is None: - cls = _create_cls(dataType) - _cached_cls[dataType] = cls - _cached_cls[k] = cls - return cls(obj) - - -def _create_object(cls, v): - """ Create an customized object with class `cls`. """ - # datetime.date would be deserialized as datetime.datetime - # from java type, so we need to set it back. - if cls is datetime.date and isinstance(v, datetime.datetime): - return v.date() - return cls(v) if v is not None else v - - -def _create_getter(dt, i): - """ Create a getter for item `i` with schema """ - cls = _create_cls(dt) - - def getter(self): - return _create_object(cls, self[i]) - - return getter - - -def _has_struct_or_date(dt): - """Return whether `dt` is or has StructType/DateType in it""" - if isinstance(dt, StructType): - return True - elif isinstance(dt, ArrayType): - return _has_struct_or_date(dt.elementType) - elif isinstance(dt, MapType): - return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType) - elif isinstance(dt, DateType): - return True - elif isinstance(dt, UserDefinedType): - return True - return False - - -def _create_properties(fields): - """Create properties according to fields""" - ps = {} - for i, f in enumerate(fields): - name = f.name - if (name.startswith("__") and name.endswith("__") - or keyword.iskeyword(name)): - warnings.warn("field name %s can not be accessed in Python," - "use position to access it instead" % name) - if _has_struct_or_date(f.dataType): - # delay creating object until accessing it - getter = _create_getter(f.dataType, i) - else: - getter = itemgetter(i) - ps[name] = property(getter) - return ps - - -def _create_cls(dataType): - """ - Create an class by dataType - - The created class is similar to namedtuple, but can have nested schema. - - >>> schema = _parse_schema_abstract("a b c") - >>> row = (1, 1.0, "str") - >>> schema = _infer_schema_type(row, schema) - >>> obj = _create_cls(schema)(row) - >>> import pickle - >>> pickle.loads(pickle.dumps(obj)) - Row(a=1, b=1.0, c='str') - - >>> row = [[1], {"key": (1, 2.0)}] - >>> schema = _parse_schema_abstract("a[] b{c d}") - >>> schema = _infer_schema_type(row, schema) - >>> obj = _create_cls(schema)(row) - >>> pickle.loads(pickle.dumps(obj)) - Row(a=[1], b={'key': Row(c=1, d=2.0)}) - >>> pickle.loads(pickle.dumps(obj.a)) - [1] - >>> pickle.loads(pickle.dumps(obj.b)) - {'key': Row(c=1, d=2.0)} - """ - - if isinstance(dataType, ArrayType): - cls = _create_cls(dataType.elementType) - - def List(l): - if l is None: - return - return [_create_object(cls, v) for v in l] - - return List - - elif isinstance(dataType, MapType): - kcls = _create_cls(dataType.keyType) - vcls = _create_cls(dataType.valueType) - - def Dict(d): - if d is None: - return - return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items()) - - return Dict - - elif isinstance(dataType, DateType): - return datetime.date - - elif isinstance(dataType, UserDefinedType): - return lambda datum: dataType.deserialize(datum) - - elif not isinstance(dataType, StructType): - # no wrapper for primitive types - return lambda x: x - - class Row(tuple): - - """ Row in SchemaRDD """ - __DATATYPE__ = dataType - __FIELDS__ = tuple(f.name for f in dataType.fields) - __slots__ = () - - # create property for fast access - locals().update(_create_properties(dataType.fields)) - - def asDict(self): - """ Return as a dict """ - return dict((n, getattr(self, n)) for n in self.__FIELDS__) - - def __repr__(self): - # call collect __repr__ for nested objects - return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) - for n in self.__FIELDS__)) - - def __reduce__(self): - return (_restore_object, (self.__DATATYPE__, tuple(self))) - - return Row - - -class SQLContext(object): - - """Main entry point for Spark SQL functionality. - - A SQLContext can be used create L{SchemaRDD}, register L{SchemaRDD} as - tables, execute SQL over tables, cache tables, and read parquet files. - """ - - def __init__(self, sparkContext, sqlContext=None): - """Create a new SQLContext. - - :param sparkContext: The SparkContext to wrap. - :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new - SQLContext in the JVM, instead we make all calls to this object. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - TypeError:... - - >>> bad_rdd = sc.parallelize([1,2,3]) - >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError:... - - >>> from datetime import datetime - >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, - ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), - ... time=datetime(2014, 8, 1, 14, 1, 5))]) - >>> srdd = sqlCtx.inferSchema(allTypes) - >>> srdd.registerTempTable("allTypes") - >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' - ... 'from allTypes where b and i > 0').collect() - [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] - >>> srdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, - ... x.row.a, x.list)).collect() - [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] - """ - self._sc = sparkContext - self._jsc = self._sc._jsc - self._jvm = self._sc._jvm - self._scala_SQLContext = sqlContext - - @property - def _ssql_ctx(self): - """Accessor for the JVM Spark SQL context. - - Subclasses can override this property to provide their own - JVM Contexts. - """ - if self._scala_SQLContext is None: - self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) - return self._scala_SQLContext - - def registerFunction(self, name, f, returnType=StringType()): - """Registers a lambda function as a UDF so it can be used in SQL statements. - - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not given it default to a string and conversion will automatically - be done. For any other return type, the produced object must match the specified type. - - >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x)) - >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() - [Row(c0=u'4')] - >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) - >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() - [Row(c0=4)] - """ - func = lambda _, it: imap(lambda x: f(*x), it) - command = (func, None, - AutoBatchedSerializer(PickleSerializer()), - AutoBatchedSerializer(PickleSerializer())) - ser = CloudPickleSerializer() - pickled_command = ser.dumps(command) - if len(pickled_command) > (1 << 20): # 1M - broadcast = self._sc.broadcast(pickled_command) - pickled_command = ser.dumps(broadcast) - broadcast_vars = ListConverter().convert( - [x._jbroadcast for x in self._sc._pickled_broadcast_vars], - self._sc._gateway._gateway_client) - self._sc._pickled_broadcast_vars.clear() - env = MapConverter().convert(self._sc.environment, - self._sc._gateway._gateway_client) - includes = ListConverter().convert(self._sc._python_includes, - self._sc._gateway._gateway_client) - self._ssql_ctx.udf().registerPython(name, - bytearray(pickled_command), - env, - includes, - self._sc.pythonExec, - broadcast_vars, - self._sc._javaAccumulator, - returnType.json()) - - def inferSchema(self, rdd, samplingRatio=None): - """Infer and apply a schema to an RDD of L{Row}. - - When samplingRatio is specified, the schema is inferred by looking - at the types of each row in the sampled dataset. Otherwise, the - first 100 rows of the RDD are inspected. Nested collections are - supported, which can include array, dict, list, Row, tuple, - namedtuple, or object. - - Each row could be L{pyspark.sql.Row} object or namedtuple or objects. - Using top level dicts is deprecated, as dict is used to represent Maps. - - If a single column has multiple distinct inferred types, it may cause - runtime exceptions. - - >>> rdd = sc.parallelize( - ... [Row(field1=1, field2="row1"), - ... Row(field1=2, field2="row2"), - ... Row(field1=3, field2="row3")]) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.collect()[0] - Row(field1=1, field2=u'row1') - - >>> NestedRow = Row("f1", "f2") - >>> nestedRdd1 = sc.parallelize([ - ... NestedRow(array('i', [1, 2]), {"row1": 1.0}), - ... NestedRow(array('i', [2, 3]), {"row2": 2.0})]) - >>> srdd = sqlCtx.inferSchema(nestedRdd1) - >>> srdd.collect() - [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})] - - >>> nestedRdd2 = sc.parallelize([ - ... NestedRow([[1, 2], [2, 3]], [1, 2]), - ... NestedRow([[2, 3], [3, 4]], [2, 3])]) - >>> srdd = sqlCtx.inferSchema(nestedRdd2) - >>> srdd.collect() - [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])] - - >>> from collections import namedtuple - >>> CustomRow = namedtuple('CustomRow', 'field1 field2') - >>> rdd = sc.parallelize( - ... [CustomRow(field1=1, field2="row1"), - ... CustomRow(field1=2, field2="row2"), - ... CustomRow(field1=3, field2="row3")]) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.collect()[0] - Row(field1=1, field2=u'row1') - """ - - if isinstance(rdd, SchemaRDD): - raise TypeError("Cannot apply schema to SchemaRDD") - - first = rdd.first() - if not first: - raise ValueError("The first row in RDD is empty, " - "can not infer schema") - if type(first) is dict: - warnings.warn("Using RDD of dict to inferSchema is deprecated," - "please use pyspark.sql.Row instead") - - if samplingRatio is None: - schema = _infer_schema(first) - if _has_nulltype(schema): - for row in rdd.take(100)[1:]: - schema = _merge_type(schema, _infer_schema(row)) - if not _has_nulltype(schema): - break - else: - warnings.warn("Some of types cannot be determined by the " - "first 100 rows, please try again with sampling") - else: - if samplingRatio > 0.99: - rdd = rdd.sample(False, float(samplingRatio)) - schema = rdd.map(_infer_schema).reduce(_merge_type) - - converter = _create_converter(schema) - rdd = rdd.map(converter) - return self.applySchema(rdd, schema) - - def applySchema(self, rdd, schema): - """ - Applies the given schema to the given RDD of L{tuple} or L{list}. - - These tuples or lists can contain complex nested structures like - lists, maps or nested rows. - - The schema should be a StructType. - - It is important that the schema matches the types of the objects - in each row or exceptions could be thrown at runtime. - - >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) - >>> schema = StructType([StructField("field1", IntegerType(), False), - ... StructField("field2", StringType(), False)]) - >>> srdd = sqlCtx.applySchema(rdd2, schema) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") - >>> srdd2 = sqlCtx.sql("SELECT * from table1") - >>> srdd2.collect() - [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] - - >>> from datetime import date, datetime - >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, - ... date(2010, 1, 1), - ... datetime(2010, 1, 1, 1, 1, 1), - ... {"a": 1}, (2,), [1, 2, 3], None)]) - >>> schema = StructType([ - ... StructField("byte1", ByteType(), False), - ... StructField("byte2", ByteType(), False), - ... StructField("short1", ShortType(), False), - ... StructField("short2", ShortType(), False), - ... StructField("int", IntegerType(), False), - ... StructField("float", FloatType(), False), - ... StructField("date", DateType(), False), - ... StructField("time", TimestampType(), False), - ... StructField("map", - ... MapType(StringType(), IntegerType(), False), False), - ... StructField("struct", - ... StructType([StructField("b", ShortType(), False)]), False), - ... StructField("list", ArrayType(ByteType(), False), False), - ... StructField("null", DoubleType(), True)]) - >>> srdd = sqlCtx.applySchema(rdd, schema) - >>> results = srdd.map( - ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date, - ... x.time, x.map["a"], x.struct.b, x.list, x.null)) - >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE - (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1), - datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) - - >>> srdd.registerTempTable("table2") - >>> sqlCtx.sql( - ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + - ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + - ... "float + 1.5 as float FROM table2").collect() - [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.5)] - - >>> rdd = sc.parallelize([(127, -32768, 1.0, - ... datetime(2010, 1, 1, 1, 1, 1), - ... {"a": 1}, (2,), [1, 2, 3])]) - >>> abstract = "byte short float time map{} struct(b) list[]" - >>> schema = _parse_schema_abstract(abstract) - >>> typedSchema = _infer_schema_type(rdd.first(), schema) - >>> srdd = sqlCtx.applySchema(rdd, typedSchema) - >>> srdd.collect() - [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])] - """ - - if isinstance(rdd, SchemaRDD): - raise TypeError("Cannot apply schema to SchemaRDD") - - if not isinstance(schema, StructType): - raise TypeError("schema should be StructType") - - # take the first few rows to verify schema - rows = rdd.take(10) - # Row() cannot been deserialized by Pyrolite - if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': - rdd = rdd.map(tuple) - rows = rdd.take(10) - - for row in rows: - _verify_type(row, schema) - - # convert python objects to sql data - converter = _python_to_sql_converter(schema) - rdd = rdd.map(converter) - - jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) - srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) - return SchemaRDD(srdd, self) - - def registerRDDAsTable(self, rdd, tableName): - """Registers the given RDD as a temporary table in the catalog. - - Temporary tables exist only during the lifetime of this instance of - SQLContext. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") - """ - if (rdd.__class__ is SchemaRDD): - srdd = rdd._jschema_rdd.baseSchemaRDD() - self._ssql_ctx.registerRDDAsTable(srdd, tableName) - else: - raise ValueError("Can only register SchemaRDD as table") - - def parquetFile(self, path): - """Loads a Parquet file, returning the result as a L{SchemaRDD}. - - >>> import tempfile, shutil - >>> parquetFile = tempfile.mkdtemp() - >>> shutil.rmtree(parquetFile) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.saveAsParquetFile(parquetFile) - >>> srdd2 = sqlCtx.parquetFile(parquetFile) - >>> sorted(srdd.collect()) == sorted(srdd2.collect()) - True - """ - jschema_rdd = self._ssql_ctx.parquetFile(path) - return SchemaRDD(jschema_rdd, self) - - def jsonFile(self, path, schema=None, samplingRatio=1.0): - """ - Loads a text file storing one JSON object per line as a - L{SchemaRDD}. - - If the schema is provided, applies the given schema to this - JSON dataset. - - Otherwise, it samples the dataset with ratio `samplingRatio` to - determine the schema. - - >>> import tempfile, shutil - >>> jsonFile = tempfile.mkdtemp() - >>> shutil.rmtree(jsonFile) - >>> ofn = open(jsonFile, 'w') - >>> for json in jsonStrings: - ... print>>ofn, json - >>> ofn.close() - >>> srdd1 = sqlCtx.jsonFile(jsonFile) - >>> sqlCtx.registerRDDAsTable(srdd1, "table1") - >>> srdd2 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table1") - >>> for r in srdd2.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema()) - >>> sqlCtx.registerRDDAsTable(srdd3, "table2") - >>> srdd4 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table2") - >>> for r in srdd4.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> schema = StructType([ - ... StructField("field2", StringType(), True), - ... StructField("field3", - ... StructType([ - ... StructField("field5", - ... ArrayType(IntegerType(), False), True)]), False)]) - >>> srdd5 = sqlCtx.jsonFile(jsonFile, schema) - >>> sqlCtx.registerRDDAsTable(srdd5, "table3") - >>> srdd6 = sqlCtx.sql( - ... "SELECT field2 AS f1, field3.field5 as f2, " - ... "field3.field5[0] as f3 from table3") - >>> srdd6.collect() - [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] - """ - if schema is None: - srdd = self._ssql_ctx.jsonFile(path, samplingRatio) - else: - scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - srdd = self._ssql_ctx.jsonFile(path, scala_datatype) - return SchemaRDD(srdd, self) - - def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): - """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. - - If the schema is provided, applies the given schema to this - JSON dataset. - - Otherwise, it samples the dataset with ratio `samplingRatio` to - determine the schema. - - >>> srdd1 = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(srdd1, "table1") - >>> srdd2 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table1") - >>> for r in srdd2.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema()) - >>> sqlCtx.registerRDDAsTable(srdd3, "table2") - >>> srdd4 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table2") - >>> for r in srdd4.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> schema = StructType([ - ... StructField("field2", StringType(), True), - ... StructField("field3", - ... StructType([ - ... StructField("field5", - ... ArrayType(IntegerType(), False), True)]), False)]) - >>> srdd5 = sqlCtx.jsonRDD(json, schema) - >>> sqlCtx.registerRDDAsTable(srdd5, "table3") - >>> srdd6 = sqlCtx.sql( - ... "SELECT field2 AS f1, field3.field5 as f2, " - ... "field3.field5[0] as f3 from table3") - >>> srdd6.collect() - [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)] - - >>> sqlCtx.jsonRDD(sc.parallelize(['{}', - ... '{"key0": {"key1": "value1"}}'])).collect() - [Row(key0=None), Row(key0=Row(key1=u'value1'))] - >>> sqlCtx.jsonRDD(sc.parallelize(['{"key0": null}', - ... '{"key0": {"key1": "value1"}}'])).collect() - [Row(key0=None), Row(key0=Row(key1=u'value1'))] - """ - - def func(iterator): - for x in iterator: - if not isinstance(x, basestring): - x = unicode(x) - if isinstance(x, unicode): - x = x.encode("utf-8") - yield x - keyed = rdd.mapPartitions(func) - keyed._bypass_serializer = True - jrdd = keyed._jrdd.map(self._jvm.BytesToString()) - if schema is None: - srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) - else: - scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) - return SchemaRDD(srdd, self) - - def sql(self, sqlQuery): - """Return a L{SchemaRDD} representing the result of the given query. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") - >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") - >>> srdd2.collect() - [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] - """ - return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self) - - def table(self, tableName): - """Returns the specified table as a L{SchemaRDD}. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") - >>> srdd2 = sqlCtx.table("table1") - >>> sorted(srdd.collect()) == sorted(srdd2.collect()) - True - """ - return SchemaRDD(self._ssql_ctx.table(tableName), self) - - def cacheTable(self, tableName): - """Caches the specified table in-memory.""" - self._ssql_ctx.cacheTable(tableName) - - def uncacheTable(self, tableName): - """Removes the specified table from the in-memory cache.""" - self._ssql_ctx.uncacheTable(tableName) - - -class HiveContext(SQLContext): - - """A variant of Spark SQL that integrates with data stored in Hive. - - Configuration for Hive is read from hive-site.xml on the classpath. - It supports running both SQL and HiveQL commands. - """ - - def __init__(self, sparkContext, hiveContext=None): - """Create a new HiveContext. - - :param sparkContext: The SparkContext to wrap. - :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new - HiveContext in the JVM, instead we make all calls to this object. - """ - SQLContext.__init__(self, sparkContext) - - if hiveContext: - self._scala_HiveContext = hiveContext - - @property - def _ssql_ctx(self): - try: - if not hasattr(self, '_scala_HiveContext'): - self._scala_HiveContext = self._get_hive_ctx() - return self._scala_HiveContext - except Py4JError as e: - raise Exception("You must build Spark with Hive. " - "Export 'SPARK_HIVE=true' and run " - "build/sbt assembly", e) - - def _get_hive_ctx(self): - return self._jvm.HiveContext(self._jsc.sc()) - - -class LocalHiveContext(HiveContext): - - def __init__(self, sparkContext, sqlContext=None): - HiveContext.__init__(self, sparkContext, sqlContext) - warnings.warn("LocalHiveContext is deprecated. " - "Use HiveContext instead.", DeprecationWarning) - - def _get_hive_ctx(self): - return self._jvm.LocalHiveContext(self._jsc.sc()) - - -def _create_row(fields, values): - row = Row(*values) - row.__FIELDS__ = fields - return row - - -class Row(tuple): - - """ - A row in L{SchemaRDD}. The fields in it can be accessed like attributes. - - Row can be used to create a row object by using named arguments, - the fields will be sorted by names. - - >>> row = Row(name="Alice", age=11) - >>> row - Row(age=11, name='Alice') - >>> row.name, row.age - ('Alice', 11) - - Row also can be used to create another Row like class, then it - could be used to create Row objects, such as - - >>> Person = Row("name", "age") - >>> Person - - >>> Person("Alice", 11) - Row(name='Alice', age=11) - """ - - def __new__(self, *args, **kwargs): - if args and kwargs: - raise ValueError("Can not use both args " - "and kwargs to create Row") - if args: - # create row class or objects - return tuple.__new__(self, args) - - elif kwargs: - # create row objects - names = sorted(kwargs.keys()) - values = tuple(kwargs[n] for n in names) - row = tuple.__new__(self, values) - row.__FIELDS__ = names - return row - - else: - raise ValueError("No args or kwargs") - - def asDict(self): - """ - Return as an dict - """ - if not hasattr(self, "__FIELDS__"): - raise TypeError("Cannot convert a Row class into dict") - return dict(zip(self.__FIELDS__, self)) - - # let obect acs like class - def __call__(self, *args): - """create new Row object""" - return _create_row(self, args) - - def __getattr__(self, item): - if item.startswith("__"): - raise AttributeError(item) - try: - # it will be slow when it has many fields, - # but this will not be used in normal cases - idx = self.__FIELDS__.index(item) - return self[idx] - except IndexError: - raise AttributeError(item) - - def __reduce__(self): - if hasattr(self, "__FIELDS__"): - return (_create_row, (self.__FIELDS__, tuple(self))) - else: - return tuple.__reduce__(self) - - def __repr__(self): - if hasattr(self, "__FIELDS__"): - return "Row(%s)" % ", ".join("%s=%r" % (k, v) - for k, v in zip(self.__FIELDS__, self)) - else: - return "" % ", ".join(self) - - -def inherit_doc(cls): - for name, func in vars(cls).items(): - # only inherit docstring for public functions - if name.startswith("_"): - continue - if not func.__doc__: - for parent in cls.__bases__: - parent_func = getattr(parent, name, None) - if parent_func and getattr(parent_func, "__doc__", None): - func.__doc__ = parent_func.__doc__ - break - return cls - - -@inherit_doc -class SchemaRDD(RDD): - - """An RDD of L{Row} objects that has an associated schema. - - The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can - utilize the relational query api exposed by Spark SQL. - - For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the - L{SchemaRDD} is not operated on directly, as it's underlying - implementation is an RDD composed of Java objects. Instead it is - converted to a PythonRDD in the JVM, on which Python operations can - be done. - - This class receives raw tuples from Java but assigns a class to it in - all its data-collection methods (mapPartitionsWithIndex, collect, take, - etc) so that PySpark sees them as Row objects with named fields. - """ - - def __init__(self, jschema_rdd, sql_ctx): - self.sql_ctx = sql_ctx - self._sc = sql_ctx._sc - clsName = jschema_rdd.getClass().getName() - assert clsName.endswith("SchemaRDD"), "jschema_rdd must be SchemaRDD" - self._jschema_rdd = jschema_rdd - self._id = None - self.is_cached = False - self.is_checkpointed = False - self.ctx = self.sql_ctx._sc - # the _jrdd is created by javaToPython(), serialized by pickle - self._jrdd_deserializer = AutoBatchedSerializer(PickleSerializer()) - - @property - def _jrdd(self): - """Lazy evaluation of PythonRDD object. - - Only done when a user calls methods defined by the - L{pyspark.rdd.RDD} super class (map, filter, etc.). - """ - if not hasattr(self, '_lazy_jrdd'): - self._lazy_jrdd = self._jschema_rdd.baseSchemaRDD().javaToPython() - return self._lazy_jrdd - - def id(self): - if self._id is None: - self._id = self._jrdd.id() - return self._id - - def limit(self, num): - """Limit the result count to the number specified. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.limit(2).collect() - [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] - >>> srdd.limit(0).collect() - [] - """ - rdd = self._jschema_rdd.baseSchemaRDD().limit(num) - return SchemaRDD(rdd, self.sql_ctx) - - def toJSON(self, use_unicode=False): - """Convert a SchemaRDD into a MappedRDD of JSON documents; one document per row. - - >>> srdd1 = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(srdd1, "table1") - >>> srdd2 = sqlCtx.sql( "SELECT * from table1") - >>> srdd2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}' - True - >>> srdd3 = sqlCtx.sql( "SELECT field3.field4 from table1") - >>> srdd3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}'] - True - """ - rdd = self._jschema_rdd.baseSchemaRDD().toJSON() - return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) - - def saveAsParquetFile(self, path): - """Save the contents as a Parquet file, preserving the schema. - - Files that are written out using this method can be read back in as - a SchemaRDD using the L{SQLContext.parquetFile} method. - - >>> import tempfile, shutil - >>> parquetFile = tempfile.mkdtemp() - >>> shutil.rmtree(parquetFile) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.saveAsParquetFile(parquetFile) - >>> srdd2 = sqlCtx.parquetFile(parquetFile) - >>> sorted(srdd2.collect()) == sorted(srdd.collect()) - True - """ - self._jschema_rdd.saveAsParquetFile(path) - - def registerTempTable(self, name): - """Registers this RDD as a temporary table using the given name. - - The lifetime of this temporary table is tied to the L{SQLContext} - that was used to create this SchemaRDD. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.registerTempTable("test") - >>> srdd2 = sqlCtx.sql("select * from test") - >>> sorted(srdd.collect()) == sorted(srdd2.collect()) - True - """ - self._jschema_rdd.registerTempTable(name) - - def registerAsTable(self, name): - """DEPRECATED: use registerTempTable() instead""" - warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) - self.registerTempTable(name) - - def insertInto(self, tableName, overwrite=False): - """Inserts the contents of this SchemaRDD into the specified table. - - Optionally overwriting any existing data. - """ - self._jschema_rdd.insertInto(tableName, overwrite) - - def saveAsTable(self, tableName): - """Creates a new table with the contents of this SchemaRDD.""" - self._jschema_rdd.saveAsTable(tableName) - - def schema(self): - """Returns the schema of this SchemaRDD (represented by - a L{StructType}).""" - return _parse_datatype_json_string(self._jschema_rdd.baseSchemaRDD().schema().json()) - - def schemaString(self): - """Returns the output schema in the tree format.""" - return self._jschema_rdd.schemaString() - - def printSchema(self): - """Prints out the schema in the tree format.""" - print self.schemaString() - - def count(self): - """Return the number of elements in this RDD. - - Unlike the base RDD implementation of count, this implementation - leverages the query optimizer to compute the count on the SchemaRDD, - which supports features such as filter pushdown. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.count() - 3L - >>> srdd.count() == srdd.map(lambda x: x).count() - True - """ - return self._jschema_rdd.count() - - def collect(self): - """Return a list that contains all of the rows in this RDD. - - Each object in the list is a Row, the fields can be accessed as - attributes. - - Unlike the base RDD implementation of collect, this implementation - leverages the query optimizer to perform a collect on the SchemaRDD, - which supports features such as filter pushdown. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.collect() - [Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')] - """ - with SCCallSiteSync(self.context) as css: - bytesInJava = self._jschema_rdd.baseSchemaRDD().collectToPython().iterator() - cls = _create_cls(self.schema()) - return map(cls, self._collect_iterator_through_file(bytesInJava)) - - def take(self, num): - """Take the first num rows of the RDD. - - Each object in the list is a Row, the fields can be accessed as - attributes. - - Unlike the base RDD implementation of take, this implementation - leverages the query optimizer to perform a collect on a SchemaRDD, - which supports features such as filter pushdown. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.take(2) - [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] - """ - return self.limit(num).collect() - - # Convert each object in the RDD to a Row with the right class - # for this SchemaRDD, so that fields can be accessed as attributes. - def mapPartitionsWithIndex(self, f, preservesPartitioning=False): - """ - Return a new RDD by applying a function to each partition of this RDD, - while tracking the index of the original partition. - - >>> rdd = sc.parallelize([1, 2, 3, 4], 4) - >>> def f(splitIndex, iterator): yield splitIndex - >>> rdd.mapPartitionsWithIndex(f).sum() - 6 - """ - rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer) - - schema = self.schema() - - def applySchema(_, it): - cls = _create_cls(schema) - return itertools.imap(cls, it) - - objrdd = rdd.mapPartitionsWithIndex(applySchema, preservesPartitioning) - return objrdd.mapPartitionsWithIndex(f, preservesPartitioning) - - # We override the default cache/persist/checkpoint behavior - # as we want to cache the underlying SchemaRDD object in the JVM, - # not the PythonRDD checkpointed by the super class - def cache(self): - self.is_cached = True - self._jschema_rdd.cache() - return self - - def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): - self.is_cached = True - javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) - self._jschema_rdd.persist(javaStorageLevel) - return self - - def unpersist(self, blocking=True): - self.is_cached = False - self._jschema_rdd.unpersist(blocking) - return self - - def checkpoint(self): - self.is_checkpointed = True - self._jschema_rdd.checkpoint() - - def isCheckpointed(self): - return self._jschema_rdd.isCheckpointed() - - def getCheckpointFile(self): - checkpointFile = self._jschema_rdd.getCheckpointFile() - if checkpointFile.isDefined(): - return checkpointFile.get() - - def coalesce(self, numPartitions, shuffle=False): - rdd = self._jschema_rdd.coalesce(numPartitions, shuffle, None) - return SchemaRDD(rdd, self.sql_ctx) - - def distinct(self, numPartitions=None): - if numPartitions is None: - rdd = self._jschema_rdd.distinct() - else: - rdd = self._jschema_rdd.distinct(numPartitions, None) - return SchemaRDD(rdd, self.sql_ctx) - - def intersection(self, other): - if (other.__class__ is SchemaRDD): - rdd = self._jschema_rdd.intersection(other._jschema_rdd) - return SchemaRDD(rdd, self.sql_ctx) - else: - raise ValueError("Can only intersect with another SchemaRDD") - - def repartition(self, numPartitions): - rdd = self._jschema_rdd.repartition(numPartitions, None) - return SchemaRDD(rdd, self.sql_ctx) - - def subtract(self, other, numPartitions=None): - if (other.__class__ is SchemaRDD): - if numPartitions is None: - rdd = self._jschema_rdd.subtract(other._jschema_rdd) - else: - rdd = self._jschema_rdd.subtract(other._jschema_rdd, - numPartitions) - return SchemaRDD(rdd, self.sql_ctx) - else: - raise ValueError("Can only subtract another SchemaRDD") - - def sample(self, withReplacement, fraction, seed=None): - """ - Return a sampled subset of this SchemaRDD. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.sample(False, 0.5, 97).count() - 2L - """ - assert fraction >= 0.0, "Negative fraction value: %s" % fraction - seed = seed if seed is not None else random.randint(0, sys.maxint) - rdd = self._jschema_rdd.sample(withReplacement, fraction, long(seed)) - return SchemaRDD(rdd, self.sql_ctx) - - def takeSample(self, withReplacement, num, seed=None): - """Return a fixed-size sampled subset of this SchemaRDD. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.takeSample(False, 2, 97) - [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')] - """ - seed = seed if seed is not None else random.randint(0, sys.maxint) - with SCCallSiteSync(self.context) as css: - bytesInJava = self._jschema_rdd.baseSchemaRDD() \ - .takeSampleToPython(withReplacement, num, long(seed)) \ - .iterator() - cls = _create_cls(self.schema()) - return map(cls, self._collect_iterator_through_file(bytesInJava)) - - -def _test(): - import doctest - from pyspark.context import SparkContext - # let doctest run in pyspark.sql, so DataTypes can be picklable - import pyspark.sql - from pyspark.sql import Row, SQLContext - from pyspark.tests import ExamplePoint, ExamplePointUDT - globs = pyspark.sql.__dict__.copy() - sc = SparkContext('local[4]', 'PythonTest') - globs['sc'] = sc - globs['sqlCtx'] = SQLContext(sc) - globs['rdd'] = sc.parallelize( - [Row(field1=1, field2="row1"), - Row(field1=2, field2="row2"), - Row(field1=3, field2="row3")] - ) - globs['ExamplePoint'] = ExamplePoint - globs['ExamplePointUDT'] = ExamplePointUDT - jsonStrings = [ - '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', - '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' - '"field6":[{"field7": "row2"}]}', - '{"field1" : null, "field2": "row3", ' - '"field3":{"field4":33, "field5": []}}' - ] - globs['jsonStrings'] = jsonStrings - globs['json'] = sc.parallelize(jsonStrings) - (failure_count, test_count) = doctest.testmod( - pyspark.sql, globs=globs, optionflags=doctest.ELLIPSIS) - globs['sc'].stop() - if failure_count: - exit(-1) - - -if __name__ == "__main__": - _test() diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py new file mode 100644 index 000000000000..6d54b9e49ed1 --- /dev/null +++ b/python/pyspark/sql/__init__.py @@ -0,0 +1,60 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Important classes of Spark SQL and DataFrames: + + - L{SQLContext} + Main entry point for :class:`DataFrame` and SQL functionality. + - L{DataFrame} + A distributed collection of data grouped into named columns. + - L{Column} + A column expression in a :class:`DataFrame`. + - L{Row} + A row of data in a :class:`DataFrame`. + - L{HiveContext} + Main entry point for accessing data stored in Apache Hive. + - L{GroupedData} + Aggregation methods, returned by :func:`DataFrame.groupBy`. + - L{DataFrameNaFunctions} + Methods for handling missing data (null values). + - L{functions} + List of built-in functions available for :class:`DataFrame`. + - L{types} + List of data types available. +""" +from __future__ import absolute_import + +# fix the module name conflict for Python 3+ +import sys +from . import _types as types +modname = __name__ + '.types' +types.__name__ = modname +# update the __module__ for all objects, make them picklable +for v in types.__dict__.values(): + if hasattr(v, "__module__") and v.__module__.endswith('._types'): + v.__module__ = modname +sys.modules[modname] = types +del modname, sys + +from pyspark.sql.types import Row +from pyspark.sql.context import SQLContext, HiveContext +from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD, DataFrameNaFunctions + +__all__ = [ + 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', 'DataFrameNaFunctions' +] diff --git a/python/pyspark/sql/_types.py b/python/pyspark/sql/_types.py new file mode 100644 index 000000000000..95fb91ad4345 --- /dev/null +++ b/python/pyspark/sql/_types.py @@ -0,0 +1,1288 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import decimal +import time +import datetime +import keyword +import warnings +import json +import re +import weakref +from array import array +from operator import itemgetter + +if sys.version >= "3": + long = int + unicode = str + +from py4j.protocol import register_input_converter +from py4j.java_gateway import JavaClass + +__all__ = [ + "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", + "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", + "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType"] + + +class DataType(object): + """Base class for data types.""" + + def __repr__(self): + return self.__class__.__name__ + + def __hash__(self): + return hash(str(self)) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not self.__eq__(other) + + @classmethod + def typeName(cls): + return cls.__name__[:-4].lower() + + def simpleString(self): + return self.typeName() + + def jsonValue(self): + return self.typeName() + + def json(self): + return json.dumps(self.jsonValue(), + separators=(',', ':'), + sort_keys=True) + + +# This singleton pattern does not work with pickle, you will get +# another object after pickle and unpickle +class PrimitiveTypeSingleton(type): + """Metaclass for PrimitiveType""" + + _instances = {} + + def __call__(cls): + if cls not in cls._instances: + cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__() + return cls._instances[cls] + + +class PrimitiveType(DataType): + """Spark SQL PrimitiveType""" + + __metaclass__ = PrimitiveTypeSingleton + + +class NullType(PrimitiveType): + """Null type. + + The data type representing None, used for the types that cannot be inferred. + """ + + +class StringType(PrimitiveType): + """String data type. + """ + + +class BinaryType(PrimitiveType): + """Binary (byte array) data type. + """ + + +class BooleanType(PrimitiveType): + """Boolean data type. + """ + + +class DateType(PrimitiveType): + """Date (datetime.date) data type. + """ + + +class TimestampType(PrimitiveType): + """Timestamp (datetime.datetime) data type. + """ + + +class DecimalType(DataType): + """Decimal (decimal.Decimal) data type. + """ + + def __init__(self, precision=None, scale=None): + self.precision = precision + self.scale = scale + self.hasPrecisionInfo = precision is not None + + def simpleString(self): + if self.hasPrecisionInfo: + return "decimal(%d,%d)" % (self.precision, self.scale) + else: + return "decimal(10,0)" + + def jsonValue(self): + if self.hasPrecisionInfo: + return "decimal(%d,%d)" % (self.precision, self.scale) + else: + return "decimal" + + def __repr__(self): + if self.hasPrecisionInfo: + return "DecimalType(%d,%d)" % (self.precision, self.scale) + else: + return "DecimalType()" + + +class DoubleType(PrimitiveType): + """Double data type, representing double precision floats. + """ + + +class FloatType(PrimitiveType): + """Float data type, representing single precision floats. + """ + + +class ByteType(PrimitiveType): + """Byte data type, i.e. a signed integer in a single byte. + """ + def simpleString(self): + return 'tinyint' + + +class IntegerType(PrimitiveType): + """Int data type, i.e. a signed 32-bit integer. + """ + def simpleString(self): + return 'int' + + +class LongType(PrimitiveType): + """Long data type, i.e. a signed 64-bit integer. + + If the values are beyond the range of [-9223372036854775808, 9223372036854775807], + please use :class:`DecimalType`. + """ + def simpleString(self): + return 'bigint' + + +class ShortType(PrimitiveType): + """Short data type, i.e. a signed 16-bit integer. + """ + def simpleString(self): + return 'smallint' + + +class ArrayType(DataType): + """Array data type. + + :param elementType: :class:`DataType` of each element in the array. + :param containsNull: boolean, whether the array can contain null (None) values. + """ + + def __init__(self, elementType, containsNull=True): + """ + >>> ArrayType(StringType()) == ArrayType(StringType(), True) + True + >>> ArrayType(StringType(), False) == ArrayType(StringType()) + False + """ + assert isinstance(elementType, DataType), "elementType should be DataType" + self.elementType = elementType + self.containsNull = containsNull + + def simpleString(self): + return 'array<%s>' % self.elementType.simpleString() + + def __repr__(self): + return "ArrayType(%s,%s)" % (self.elementType, + str(self.containsNull).lower()) + + def jsonValue(self): + return {"type": self.typeName(), + "elementType": self.elementType.jsonValue(), + "containsNull": self.containsNull} + + @classmethod + def fromJson(cls, json): + return ArrayType(_parse_datatype_json_value(json["elementType"]), + json["containsNull"]) + + +class MapType(DataType): + """Map data type. + + :param keyType: :class:`DataType` of the keys in the map. + :param valueType: :class:`DataType` of the values in the map. + :param valueContainsNull: indicates whether values can contain null (None) values. + + Keys in a map data type are not allowed to be null (None). + """ + + def __init__(self, keyType, valueType, valueContainsNull=True): + """ + >>> (MapType(StringType(), IntegerType()) + ... == MapType(StringType(), IntegerType(), True)) + True + >>> (MapType(StringType(), IntegerType(), False) + ... == MapType(StringType(), FloatType())) + False + """ + assert isinstance(keyType, DataType), "keyType should be DataType" + assert isinstance(valueType, DataType), "valueType should be DataType" + self.keyType = keyType + self.valueType = valueType + self.valueContainsNull = valueContainsNull + + def simpleString(self): + return 'map<%s,%s>' % (self.keyType.simpleString(), self.valueType.simpleString()) + + def __repr__(self): + return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, + str(self.valueContainsNull).lower()) + + def jsonValue(self): + return {"type": self.typeName(), + "keyType": self.keyType.jsonValue(), + "valueType": self.valueType.jsonValue(), + "valueContainsNull": self.valueContainsNull} + + @classmethod + def fromJson(cls, json): + return MapType(_parse_datatype_json_value(json["keyType"]), + _parse_datatype_json_value(json["valueType"]), + json["valueContainsNull"]) + + +class StructField(DataType): + """A field in :class:`StructType`. + + :param name: string, name of the field. + :param dataType: :class:`DataType` of the field. + :param nullable: boolean, whether the field can be null (None) or not. + :param metadata: a dict from string to simple type that can be serialized to JSON automatically + """ + + def __init__(self, name, dataType, nullable=True, metadata=None): + """ + >>> (StructField("f1", StringType(), True) + ... == StructField("f1", StringType(), True)) + True + >>> (StructField("f1", StringType(), True) + ... == StructField("f2", StringType(), True)) + False + """ + assert isinstance(dataType, DataType), "dataType should be DataType" + self.name = name + self.dataType = dataType + self.nullable = nullable + self.metadata = metadata or {} + + def simpleString(self): + return '%s:%s' % (self.name, self.dataType.simpleString()) + + def __repr__(self): + return "StructField(%s,%s,%s)" % (self.name, self.dataType, + str(self.nullable).lower()) + + def jsonValue(self): + return {"name": self.name, + "type": self.dataType.jsonValue(), + "nullable": self.nullable, + "metadata": self.metadata} + + @classmethod + def fromJson(cls, json): + return StructField(json["name"], + _parse_datatype_json_value(json["type"]), + json["nullable"], + json["metadata"]) + + +class StructType(DataType): + """Struct type, consisting of a list of :class:`StructField`. + + This is the data type representing a :class:`Row`. + """ + + def __init__(self, fields): + """ + >>> struct1 = StructType([StructField("f1", StringType(), True)]) + >>> struct2 = StructType([StructField("f1", StringType(), True)]) + >>> struct1 == struct2 + True + >>> struct1 = StructType([StructField("f1", StringType(), True)]) + >>> struct2 = StructType([StructField("f1", StringType(), True), + ... StructField("f2", IntegerType(), False)]) + >>> struct1 == struct2 + False + """ + assert all(isinstance(f, DataType) for f in fields), "fields should be a list of DataType" + self.fields = fields + + def simpleString(self): + return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields)) + + def __repr__(self): + return ("StructType(List(%s))" % + ",".join(str(field) for field in self.fields)) + + def jsonValue(self): + return {"type": self.typeName(), + "fields": [f.jsonValue() for f in self.fields]} + + @classmethod + def fromJson(cls, json): + return StructType([StructField.fromJson(f) for f in json["fields"]]) + + +class UserDefinedType(DataType): + """User-defined type (UDT). + + .. note:: WARN: Spark Internal Use Only + """ + + @classmethod + def typeName(cls): + return cls.__name__.lower() + + @classmethod + def sqlType(cls): + """ + Underlying SQL storage type for this UDT. + """ + raise NotImplementedError("UDT must implement sqlType().") + + @classmethod + def module(cls): + """ + The Python module of the UDT. + """ + raise NotImplementedError("UDT must implement module().") + + @classmethod + def scalaUDT(cls): + """ + The class name of the paired Scala UDT. + """ + raise NotImplementedError("UDT must have a paired Scala UDT.") + + def serialize(self, obj): + """ + Converts the a user-type object into a SQL datum. + """ + raise NotImplementedError("UDT must implement serialize().") + + def deserialize(self, datum): + """ + Converts a SQL datum into a user-type object. + """ + raise NotImplementedError("UDT must implement deserialize().") + + def simpleString(self): + return 'udt' + + def json(self): + return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) + + def jsonValue(self): + schema = { + "type": "udt", + "class": self.scalaUDT(), + "pyClass": "%s.%s" % (self.module(), type(self).__name__), + "sqlType": self.sqlType().jsonValue() + } + return schema + + @classmethod + def fromJson(cls, json): + pyUDT = json["pyClass"] + split = pyUDT.rfind(".") + pyModule = pyUDT[:split] + pyClass = pyUDT[split+1:] + m = __import__(pyModule, globals(), locals(), [pyClass]) + UDT = getattr(m, pyClass) + return UDT() + + def __eq__(self, other): + return type(self) == type(other) + + +_all_primitive_types = dict((v.typeName(), v) + for v in list(globals().values()) + if (type(v) is type or type(v) is PrimitiveTypeSingleton) + and v.__base__ == PrimitiveType) + +_all_complex_types = dict((v.typeName(), v) + for v in [ArrayType, MapType, StructType]) + + +def _parse_datatype_json_string(json_string): + """Parses the given data type JSON string. + >>> import pickle + >>> def check_datatype(datatype): + ... pickled = pickle.loads(pickle.dumps(datatype)) + ... assert datatype == pickled + ... scala_datatype = sqlContext._ssql_ctx.parseDataType(datatype.json()) + ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) + ... assert datatype == python_datatype + >>> for cls in _all_primitive_types.values(): + ... check_datatype(cls()) + + >>> # Simple ArrayType. + >>> simple_arraytype = ArrayType(StringType(), True) + >>> check_datatype(simple_arraytype) + + >>> # Simple MapType. + >>> simple_maptype = MapType(StringType(), LongType()) + >>> check_datatype(simple_maptype) + + >>> # Simple StructType. + >>> simple_structtype = StructType([ + ... StructField("a", DecimalType(), False), + ... StructField("b", BooleanType(), True), + ... StructField("c", LongType(), True), + ... StructField("d", BinaryType(), False)]) + >>> check_datatype(simple_structtype) + + >>> # Complex StructType. + >>> complex_structtype = StructType([ + ... StructField("simpleArray", simple_arraytype, True), + ... StructField("simpleMap", simple_maptype, True), + ... StructField("simpleStruct", simple_structtype, True), + ... StructField("boolean", BooleanType(), False), + ... StructField("withMeta", DoubleType(), False, {"name": "age"})]) + >>> check_datatype(complex_structtype) + + >>> # Complex ArrayType. + >>> complex_arraytype = ArrayType(complex_structtype, True) + >>> check_datatype(complex_arraytype) + + >>> # Complex MapType. + >>> complex_maptype = MapType(complex_structtype, + ... complex_arraytype, False) + >>> check_datatype(complex_maptype) + + >>> check_datatype(ExamplePointUDT()) + >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> check_datatype(structtype_with_udt) + """ + return _parse_datatype_json_value(json.loads(json_string)) + + +_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)") + + +def _parse_datatype_json_value(json_value): + if not isinstance(json_value, dict): + if json_value in _all_primitive_types.keys(): + return _all_primitive_types[json_value]() + elif json_value == 'decimal': + return DecimalType() + elif _FIXED_DECIMAL.match(json_value): + m = _FIXED_DECIMAL.match(json_value) + return DecimalType(int(m.group(1)), int(m.group(2))) + else: + raise ValueError("Could not parse datatype: %s" % json_value) + else: + tpe = json_value["type"] + if tpe in _all_complex_types: + return _all_complex_types[tpe].fromJson(json_value) + elif tpe == 'udt': + return UserDefinedType.fromJson(json_value) + else: + raise ValueError("not supported type: %s" % tpe) + + +# Mapping Python types to Spark SQL DataType +_type_mappings = { + type(None): NullType, + bool: BooleanType, + int: LongType, + float: DoubleType, + str: StringType, + bytearray: BinaryType, + decimal.Decimal: DecimalType, + datetime.date: DateType, + datetime.datetime: TimestampType, + datetime.time: TimestampType, +} + +if sys.version < "3": + _type_mappings.update({ + unicode: StringType, + long: LongType, + }) + + +def _infer_type(obj): + """Infer the DataType from obj + + >>> p = ExamplePoint(1.0, 2.0) + >>> _infer_type(p) + ExamplePointUDT + """ + if obj is None: + return NullType() + + if hasattr(obj, '__UDT__'): + return obj.__UDT__ + + dataType = _type_mappings.get(type(obj)) + if dataType is not None: + return dataType() + + if isinstance(obj, dict): + for key, value in obj.items(): + if key is not None and value is not None: + return MapType(_infer_type(key), _infer_type(value), True) + else: + return MapType(NullType(), NullType(), True) + elif isinstance(obj, (list, array)): + for v in obj: + if v is not None: + return ArrayType(_infer_type(obj[0]), True) + else: + return ArrayType(NullType(), True) + else: + try: + return _infer_schema(obj) + except TypeError: + raise TypeError("not supported type: %s" % type(obj)) + + +def _infer_schema(row): + """Infer the schema from dict/namedtuple/object""" + if isinstance(row, dict): + items = sorted(row.items()) + + elif isinstance(row, (tuple, list)): + if hasattr(row, "__fields__"): # Row + items = zip(row.__fields__, tuple(row)) + elif hasattr(row, "_fields"): # namedtuple + items = zip(row._fields, tuple(row)) + else: + names = ['_%d' % i for i in range(1, len(row) + 1)] + items = zip(names, row) + + elif hasattr(row, "__dict__"): # object + items = sorted(row.__dict__.items()) + + else: + raise TypeError("Can not infer schema for type: %s" % type(row)) + + fields = [StructField(k, _infer_type(v), True) for k, v in items] + return StructType(fields) + + +def _need_python_to_sql_conversion(dataType): + """ + Checks whether we need python to sql conversion for the given type. + For now, only UDTs need this conversion. + + >>> _need_python_to_sql_conversion(DoubleType()) + False + >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), + ... StructField("values", ArrayType(DoubleType(), False), False)]) + >>> _need_python_to_sql_conversion(schema0) + False + >>> _need_python_to_sql_conversion(ExamplePointUDT()) + True + >>> schema1 = ArrayType(ExamplePointUDT(), False) + >>> _need_python_to_sql_conversion(schema1) + True + >>> schema2 = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> _need_python_to_sql_conversion(schema2) + True + """ + if isinstance(dataType, StructType): + return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields]) + elif isinstance(dataType, ArrayType): + return _need_python_to_sql_conversion(dataType.elementType) + elif isinstance(dataType, MapType): + return _need_python_to_sql_conversion(dataType.keyType) or \ + _need_python_to_sql_conversion(dataType.valueType) + elif isinstance(dataType, UserDefinedType): + return True + else: + return False + + +def _python_to_sql_converter(dataType): + """ + Returns a converter that converts a Python object into a SQL datum for the given type. + + >>> conv = _python_to_sql_converter(DoubleType()) + >>> conv(1.0) + 1.0 + >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False)) + >>> conv([1.0, 2.0]) + [1.0, 2.0] + >>> conv = _python_to_sql_converter(ExamplePointUDT()) + >>> conv(ExamplePoint(1.0, 2.0)) + [1.0, 2.0] + >>> schema = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> conv = _python_to_sql_converter(schema) + >>> conv((1.0, ExamplePoint(1.0, 2.0))) + (1.0, [1.0, 2.0]) + """ + if not _need_python_to_sql_conversion(dataType): + return lambda x: x + + if isinstance(dataType, StructType): + names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) + converters = map(_python_to_sql_converter, types) + + def converter(obj): + if isinstance(obj, dict): + return tuple(c(obj.get(n)) for n, c in zip(names, converters)) + elif isinstance(obj, tuple): + if hasattr(obj, "__fields__") or hasattr(obj, "_fields"): + return tuple(c(v) for c, v in zip(converters, obj)) + elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs + d = dict(obj) + return tuple(c(d.get(n)) for n, c in zip(names, converters)) + else: + return tuple(c(v) for c, v in zip(converters, obj)) + else: + raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) + return converter + elif isinstance(dataType, ArrayType): + element_converter = _python_to_sql_converter(dataType.elementType) + return lambda a: [element_converter(v) for v in a] + elif isinstance(dataType, MapType): + key_converter = _python_to_sql_converter(dataType.keyType) + value_converter = _python_to_sql_converter(dataType.valueType) + return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) + elif isinstance(dataType, UserDefinedType): + return lambda obj: dataType.serialize(obj) + else: + raise ValueError("Unexpected type %r" % dataType) + + +def _has_nulltype(dt): + """ Return whether there is NullType in `dt` or not """ + if isinstance(dt, StructType): + return any(_has_nulltype(f.dataType) for f in dt.fields) + elif isinstance(dt, ArrayType): + return _has_nulltype((dt.elementType)) + elif isinstance(dt, MapType): + return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType) + else: + return isinstance(dt, NullType) + + +def _merge_type(a, b): + if isinstance(a, NullType): + return b + elif isinstance(b, NullType): + return a + elif type(a) is not type(b): + # TODO: type cast (such as int -> long) + raise TypeError("Can not merge type %s and %s" % (type(a), type(b))) + + # same type + if isinstance(a, StructType): + nfs = dict((f.name, f.dataType) for f in b.fields) + fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()))) + for f in a.fields] + names = set([f.name for f in fields]) + for n in nfs: + if n not in names: + fields.append(StructField(n, nfs[n])) + return StructType(fields) + + elif isinstance(a, ArrayType): + return ArrayType(_merge_type(a.elementType, b.elementType), True) + + elif isinstance(a, MapType): + return MapType(_merge_type(a.keyType, b.keyType), + _merge_type(a.valueType, b.valueType), + True) + else: + return a + + +def _need_converter(dataType): + if isinstance(dataType, StructType): + return True + elif isinstance(dataType, ArrayType): + return _need_converter(dataType.elementType) + elif isinstance(dataType, MapType): + return _need_converter(dataType.keyType) or _need_converter(dataType.valueType) + elif isinstance(dataType, NullType): + return True + else: + return False + + +def _create_converter(dataType): + """Create an converter to drop the names of fields in obj """ + if not _need_converter(dataType): + return lambda x: x + + if isinstance(dataType, ArrayType): + conv = _create_converter(dataType.elementType) + return lambda row: [conv(v) for v in row] + + elif isinstance(dataType, MapType): + kconv = _create_converter(dataType.keyType) + vconv = _create_converter(dataType.valueType) + return lambda row: dict((kconv(k), vconv(v)) for k, v in row.items()) + + elif isinstance(dataType, NullType): + return lambda x: None + + elif not isinstance(dataType, StructType): + return lambda x: x + + # dataType must be StructType + names = [f.name for f in dataType.fields] + converters = [_create_converter(f.dataType) for f in dataType.fields] + convert_fields = any(_need_converter(f.dataType) for f in dataType.fields) + + def convert_struct(obj): + if obj is None: + return + + if isinstance(obj, (tuple, list)): + if convert_fields: + return tuple(conv(v) for v, conv in zip(obj, converters)) + else: + return tuple(obj) + + if isinstance(obj, dict): + d = obj + elif hasattr(obj, "__dict__"): # object + d = obj.__dict__ + else: + raise TypeError("Unexpected obj type: %s" % type(obj)) + + if convert_fields: + return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) + else: + return tuple([d.get(name) for name in names]) + + return convert_struct + + +_BRACKETS = {'(': ')', '[': ']', '{': '}'} + + +def _split_schema_abstract(s): + """ + split the schema abstract into fields + + >>> _split_schema_abstract("a b c") + ['a', 'b', 'c'] + >>> _split_schema_abstract("a(a b)") + ['a(a b)'] + >>> _split_schema_abstract("a b[] c{a b}") + ['a', 'b[]', 'c{a b}'] + >>> _split_schema_abstract(" ") + [] + """ + + r = [] + w = '' + brackets = [] + for c in s: + if c == ' ' and not brackets: + if w: + r.append(w) + w = '' + else: + w += c + if c in _BRACKETS: + brackets.append(c) + elif c in _BRACKETS.values(): + if not brackets or c != _BRACKETS[brackets.pop()]: + raise ValueError("unexpected " + c) + + if brackets: + raise ValueError("brackets not closed: %s" % brackets) + if w: + r.append(w) + return r + + +def _parse_field_abstract(s): + """ + Parse a field in schema abstract + + >>> _parse_field_abstract("a") + StructField(a,NullType,true) + >>> _parse_field_abstract("b(c d)") + StructField(b,StructType(...c,NullType,true),StructField(d... + >>> _parse_field_abstract("a[]") + StructField(a,ArrayType(NullType,true),true) + >>> _parse_field_abstract("a{[]}") + StructField(a,MapType(NullType,ArrayType(NullType,true),true),true) + """ + if set(_BRACKETS.keys()) & set(s): + idx = min((s.index(c) for c in _BRACKETS if c in s)) + name = s[:idx] + return StructField(name, _parse_schema_abstract(s[idx:]), True) + else: + return StructField(s, NullType(), True) + + +def _parse_schema_abstract(s): + """ + parse abstract into schema + + >>> _parse_schema_abstract("a b c") + StructType...a...b...c... + >>> _parse_schema_abstract("a[b c] b{}") + StructType...a,ArrayType...b...c...b,MapType... + >>> _parse_schema_abstract("c{} d{a b}") + StructType...c,MapType...d,MapType...a...b... + >>> _parse_schema_abstract("a b(t)").fields[1] + StructField(b,StructType(List(StructField(t,NullType,true))),true) + """ + s = s.strip() + if not s: + return NullType() + + elif s.startswith('('): + return _parse_schema_abstract(s[1:-1]) + + elif s.startswith('['): + return ArrayType(_parse_schema_abstract(s[1:-1]), True) + + elif s.startswith('{'): + return MapType(NullType(), _parse_schema_abstract(s[1:-1])) + + parts = _split_schema_abstract(s) + fields = [_parse_field_abstract(p) for p in parts] + return StructType(fields) + + +def _infer_schema_type(obj, dataType): + """ + Fill the dataType with types inferred from obj + + >>> schema = _parse_schema_abstract("a b c d") + >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10)) + >>> _infer_schema_type(row, schema) + StructType...LongType...DoubleType...StringType...DateType... + >>> row = [[1], {"key": (1, 2.0)}] + >>> schema = _parse_schema_abstract("a[] b{c d}") + >>> _infer_schema_type(row, schema) + StructType...a,ArrayType...b,MapType(StringType,...c,LongType... + """ + if isinstance(dataType, NullType): + return _infer_type(obj) + + if not obj: + return NullType() + + if isinstance(dataType, ArrayType): + eType = _infer_schema_type(obj[0], dataType.elementType) + return ArrayType(eType, True) + + elif isinstance(dataType, MapType): + k, v = next(iter(obj.items())) + return MapType(_infer_schema_type(k, dataType.keyType), + _infer_schema_type(v, dataType.valueType)) + + elif isinstance(dataType, StructType): + fs = dataType.fields + assert len(fs) == len(obj), \ + "Obj(%s) have different length with fields(%s)" % (obj, fs) + fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True) + for o, f in zip(obj, fs)] + return StructType(fields) + + else: + raise TypeError("Unexpected dataType: %s" % type(dataType)) + + +_acceptable_types = { + BooleanType: (bool,), + ByteType: (int, long), + ShortType: (int, long), + IntegerType: (int, long), + LongType: (int, long), + FloatType: (float,), + DoubleType: (float,), + DecimalType: (decimal.Decimal,), + StringType: (str, unicode), + BinaryType: (bytearray,), + DateType: (datetime.date,), + TimestampType: (datetime.datetime,), + ArrayType: (list, tuple, array), + MapType: (dict,), + StructType: (tuple, list), +} + + +def _verify_type(obj, dataType): + """ + Verify the type of obj against dataType, raise an exception if + they do not match. + + >>> _verify_type(None, StructType([])) + >>> _verify_type("", StringType()) + >>> _verify_type(0, LongType()) + >>> _verify_type(list(range(3)), ArrayType(ShortType())) + >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + TypeError:... + >>> _verify_type({}, MapType(StringType(), IntegerType())) + >>> _verify_type((), StructType([])) + >>> _verify_type([], StructType([])) + >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) + >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + """ + # all objects are nullable + if obj is None: + return + + if isinstance(dataType, UserDefinedType): + if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): + raise ValueError("%r is not an instance of type %r" % (obj, dataType)) + _verify_type(dataType.serialize(obj), dataType.sqlType()) + return + + _type = type(dataType) + assert _type in _acceptable_types, "unknown datatype: %s" % dataType + + # subclass of them can not be deserialized in JVM + if type(obj) not in _acceptable_types[_type]: + raise TypeError("%s can not accept object in type %s" + % (dataType, type(obj))) + + if isinstance(dataType, ArrayType): + for i in obj: + _verify_type(i, dataType.elementType) + + elif isinstance(dataType, MapType): + for k, v in obj.items(): + _verify_type(k, dataType.keyType) + _verify_type(v, dataType.valueType) + + elif isinstance(dataType, StructType): + if len(obj) != len(dataType.fields): + raise ValueError("Length of object (%d) does not match with " + "length of fields (%d)" % (len(obj), len(dataType.fields))) + for v, f in zip(obj, dataType.fields): + _verify_type(v, f.dataType) + +_cached_cls = weakref.WeakValueDictionary() + + +def _restore_object(dataType, obj): + """ Restore object during unpickling. """ + # use id(dataType) as key to speed up lookup in dict + # Because of batched pickling, dataType will be the + # same object in most cases. + k = id(dataType) + cls = _cached_cls.get(k) + if cls is None or cls.__datatype is not dataType: + # use dataType as key to avoid create multiple class + cls = _cached_cls.get(dataType) + if cls is None: + cls = _create_cls(dataType) + _cached_cls[dataType] = cls + cls.__datatype = dataType + _cached_cls[k] = cls + return cls(obj) + + +def _create_object(cls, v): + """ Create an customized object with class `cls`. """ + # datetime.date would be deserialized as datetime.datetime + # from java type, so we need to set it back. + if cls is datetime.date and isinstance(v, datetime.datetime): + return v.date() + return cls(v) if v is not None else v + + +def _create_getter(dt, i): + """ Create a getter for item `i` with schema """ + cls = _create_cls(dt) + + def getter(self): + return _create_object(cls, self[i]) + + return getter + + +def _has_struct_or_date(dt): + """Return whether `dt` is or has StructType/DateType in it""" + if isinstance(dt, StructType): + return True + elif isinstance(dt, ArrayType): + return _has_struct_or_date(dt.elementType) + elif isinstance(dt, MapType): + return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType) + elif isinstance(dt, DateType): + return True + elif isinstance(dt, UserDefinedType): + return True + return False + + +def _create_properties(fields): + """Create properties according to fields""" + ps = {} + for i, f in enumerate(fields): + name = f.name + if (name.startswith("__") and name.endswith("__") + or keyword.iskeyword(name)): + warnings.warn("field name %s can not be accessed in Python," + "use position to access it instead" % name) + if _has_struct_or_date(f.dataType): + # delay creating object until accessing it + getter = _create_getter(f.dataType, i) + else: + getter = itemgetter(i) + ps[name] = property(getter) + return ps + + +def _create_cls(dataType): + """ + Create an class by dataType + + The created class is similar to namedtuple, but can have nested schema. + + >>> schema = _parse_schema_abstract("a b c") + >>> row = (1, 1.0, "str") + >>> schema = _infer_schema_type(row, schema) + >>> obj = _create_cls(schema)(row) + >>> import pickle + >>> pickle.loads(pickle.dumps(obj)) + Row(a=1, b=1.0, c='str') + + >>> row = [[1], {"key": (1, 2.0)}] + >>> schema = _parse_schema_abstract("a[] b{c d}") + >>> schema = _infer_schema_type(row, schema) + >>> obj = _create_cls(schema)(row) + >>> pickle.loads(pickle.dumps(obj)) + Row(a=[1], b={'key': Row(c=1, d=2.0)}) + >>> pickle.loads(pickle.dumps(obj.a)) + [1] + >>> pickle.loads(pickle.dumps(obj.b)) + {'key': Row(c=1, d=2.0)} + """ + + if isinstance(dataType, ArrayType): + cls = _create_cls(dataType.elementType) + + def List(l): + if l is None: + return + return [_create_object(cls, v) for v in l] + + return List + + elif isinstance(dataType, MapType): + kcls = _create_cls(dataType.keyType) + vcls = _create_cls(dataType.valueType) + + def Dict(d): + if d is None: + return + return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items()) + + return Dict + + elif isinstance(dataType, DateType): + return datetime.date + + elif isinstance(dataType, UserDefinedType): + return lambda datum: dataType.deserialize(datum) + + elif not isinstance(dataType, StructType): + # no wrapper for primitive types + return lambda x: x + + class Row(tuple): + + """ Row in DataFrame """ + __datatype = dataType + __fields__ = tuple(f.name for f in dataType.fields) + __slots__ = () + + # create property for fast access + locals().update(_create_properties(dataType.fields)) + + def asDict(self): + """ Return as a dict """ + return dict((n, getattr(self, n)) for n in self.__fields__) + + def __repr__(self): + # call collect __repr__ for nested objects + return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) + for n in self.__fields__)) + + def __reduce__(self): + return (_restore_object, (self.__datatype, tuple(self))) + + return Row + + +def _create_row(fields, values): + row = Row(*values) + row.__fields__ = fields + return row + + +class Row(tuple): + + """ + A row in L{DataFrame}. The fields in it can be accessed like attributes. + + Row can be used to create a row object by using named arguments, + the fields will be sorted by names. + + >>> row = Row(name="Alice", age=11) + >>> row + Row(age=11, name='Alice') + >>> row.name, row.age + ('Alice', 11) + + Row also can be used to create another Row like class, then it + could be used to create Row objects, such as + + >>> Person = Row("name", "age") + >>> Person + + >>> Person("Alice", 11) + Row(name='Alice', age=11) + """ + + def __new__(self, *args, **kwargs): + if args and kwargs: + raise ValueError("Can not use both args " + "and kwargs to create Row") + if args: + # create row class or objects + return tuple.__new__(self, args) + + elif kwargs: + # create row objects + names = sorted(kwargs.keys()) + row = tuple.__new__(self, [kwargs[n] for n in names]) + row.__fields__ = names + return row + + else: + raise ValueError("No args or kwargs") + + def asDict(self): + """ + Return as an dict + """ + if not hasattr(self, "__fields__"): + raise TypeError("Cannot convert a Row class into dict") + return dict(zip(self.__fields__, self)) + + # let object acts like class + def __call__(self, *args): + """create new Row object""" + return _create_row(self, args) + + def __getattr__(self, item): + if item.startswith("__"): + raise AttributeError(item) + try: + # it will be slow when it has many fields, + # but this will not be used in normal cases + idx = self.__fields__.index(item) + return self[idx] + except IndexError: + raise AttributeError(item) + except ValueError: + raise AttributeError(item) + + def __reduce__(self): + if hasattr(self, "__fields__"): + return (_create_row, (self.__fields__, tuple(self))) + else: + return tuple.__reduce__(self) + + def __repr__(self): + if hasattr(self, "__fields__"): + return "Row(%s)" % ", ".join("%s=%r" % (k, v) + for k, v in zip(self.__fields__, tuple(self))) + else: + return "" % ", ".join(self) + + +class DateConverter(object): + def can_convert(self, obj): + return isinstance(obj, datetime.date) + + def convert(self, obj, gateway_client): + Date = JavaClass("java.sql.Date", gateway_client) + return Date.valueOf(obj.strftime("%Y-%m-%d")) + + +class DatetimeConverter(object): + def can_convert(self, obj): + return isinstance(obj, datetime.datetime) + + def convert(self, obj, gateway_client): + Timestamp = JavaClass("java.sql.Timestamp", gateway_client) + return Timestamp(int(time.mktime(obj.timetuple())) * 1000 + obj.microsecond // 1000) + + +# datetime is a subclass of date, we should register DatetimeConverter first +register_input_converter(DatetimeConverter()) +register_input_converter(DateConverter()) + + +def _test(): + import doctest + from pyspark.context import SparkContext + # let doctest run in pyspark.sql.types, so DataTypes can be picklable + import pyspark.sql.types + from pyspark.sql import Row, SQLContext + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + globs = pyspark.sql.types.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) + globs['ExamplePoint'] = ExamplePoint + globs['ExamplePointUDT'] = ExamplePointUDT + (failure_count, test_count) = doctest.testmod( + pyspark.sql.types, globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py new file mode 100644 index 000000000000..f6f107ca32d2 --- /dev/null +++ b/python/pyspark/sql/context.py @@ -0,0 +1,635 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import warnings +import json + +if sys.version >= '3': + basestring = unicode = str +else: + from itertools import imap as map + +from py4j.protocol import Py4JError + +from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix +from pyspark.serializers import AutoBatchedSerializer, PickleSerializer +from pyspark.sql.types import Row, StringType, StructType, _verify_type, \ + _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter +from pyspark.sql.dataframe import DataFrame + +try: + import pandas + has_pandas = True +except ImportError: + has_pandas = False + +__all__ = ["SQLContext", "HiveContext", "UDFRegistration"] + + +def _monkey_patch_RDD(sqlContext): + def toDF(self, schema=None, sampleRatio=None): + """ + Converts current :class:`RDD` into a :class:`DataFrame` + + This is a shorthand for ``sqlContext.createDataFrame(rdd, schema, sampleRatio)`` + + :param schema: a StructType or list of names of columns + :param samplingRatio: the sample ratio of rows used for inferring + :return: a DataFrame + + >>> rdd.toDF().collect() + [Row(name=u'Alice', age=1)] + """ + return sqlContext.createDataFrame(self, schema, sampleRatio) + + RDD.toDF = toDF + + +class SQLContext(object): + """Main entry point for Spark SQL functionality. + + A SQLContext can be used create :class:`DataFrame`, register :class:`DataFrame` as + tables, execute SQL over tables, cache tables, and read parquet files. + + :param sparkContext: The :class:`SparkContext` backing this SQLContext. + :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instantiate a new + SQLContext in the JVM, instead we make all calls to this object. + """ + + @ignore_unicode_prefix + def __init__(self, sparkContext, sqlContext=None): + """Creates a new SQLContext. + + >>> from datetime import datetime + >>> sqlContext = SQLContext(sc) + >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1, + ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), + ... time=datetime(2014, 8, 1, 14, 1, 5))]) + >>> df = allTypes.toDF() + >>> df.registerTempTable("allTypes") + >>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' + ... 'from allTypes where b and i > 0').collect() + [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)] + >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect() + [(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] + """ + self._sc = sparkContext + self._jsc = self._sc._jsc + self._jvm = self._sc._jvm + self._scala_SQLContext = sqlContext + _monkey_patch_RDD(self) + + @property + def _ssql_ctx(self): + """Accessor for the JVM Spark SQL context. + + Subclasses can override this property to provide their own + JVM Contexts. + """ + if self._scala_SQLContext is None: + self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) + return self._scala_SQLContext + + def setConf(self, key, value): + """Sets the given Spark SQL configuration property. + """ + self._ssql_ctx.setConf(key, value) + + def getConf(self, key, defaultValue): + """Returns the value of Spark SQL configuration property for the given key. + + If the key is not set, returns defaultValue. + """ + return self._ssql_ctx.getConf(key, defaultValue) + + @property + def udf(self): + """Returns a :class:`UDFRegistration` for UDF registration.""" + return UDFRegistration(self) + + @ignore_unicode_prefix + def registerFunction(self, name, f, returnType=StringType()): + """Registers a lambda function as a UDF so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not given it default to a string and conversion will automatically + be done. For any other return type, the produced object must match the specified type. + + :param name: name of the UDF + :param samplingRatio: lambda function + :param returnType: a :class:`DataType` object + + >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) + >>> sqlContext.sql("SELECT stringLengthString('test')").collect() + [Row(c0=u'4')] + + >>> from pyspark.sql.types import IntegerType + >>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() + [Row(c0=4)] + + >>> from pyspark.sql.types import IntegerType + >>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() + [Row(c0=4)] + """ + func = lambda _, it: map(lambda x: f(*x), it) + ser = AutoBatchedSerializer(PickleSerializer()) + command = (func, None, ser, ser) + pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self) + self._ssql_ctx.udf().registerPython(name, + bytearray(pickled_cmd), + env, + includes, + self._sc.pythonExec, + bvars, + self._sc._javaAccumulator, + returnType.json()) + + def _inferSchema(self, rdd, samplingRatio=None): + first = rdd.first() + if not first: + raise ValueError("The first row in RDD is empty, " + "can not infer schema") + if type(first) is dict: + warnings.warn("Using RDD of dict to inferSchema is deprecated," + "please use pyspark.sql.Row instead") + + if samplingRatio is None: + schema = _infer_schema(first) + if _has_nulltype(schema): + for row in rdd.take(100)[1:]: + schema = _merge_type(schema, _infer_schema(row)) + if not _has_nulltype(schema): + break + else: + raise ValueError("Some of types cannot be determined by the " + "first 100 rows, please try again with sampling") + else: + if samplingRatio < 0.99: + rdd = rdd.sample(False, float(samplingRatio)) + schema = rdd.map(_infer_schema).reduce(_merge_type) + return schema + + @ignore_unicode_prefix + def inferSchema(self, rdd, samplingRatio=None): + """::note: Deprecated in 1.3, use :func:`createDataFrame` instead. + """ + warnings.warn("inferSchema is deprecated, please use createDataFrame instead") + + if isinstance(rdd, DataFrame): + raise TypeError("Cannot apply schema to DataFrame") + + return self.createDataFrame(rdd, None, samplingRatio) + + @ignore_unicode_prefix + def applySchema(self, rdd, schema): + """::note: Deprecated in 1.3, use :func:`createDataFrame` instead. + """ + warnings.warn("applySchema is deprecated, please use createDataFrame instead") + + if isinstance(rdd, DataFrame): + raise TypeError("Cannot apply schema to DataFrame") + + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType, but got %s" % type(schema)) + + return self.createDataFrame(rdd, schema) + + @ignore_unicode_prefix + def createDataFrame(self, data, schema=None, samplingRatio=None): + """ + Creates a :class:`DataFrame` from an :class:`RDD` of :class:`tuple`/:class:`list`, + list or :class:`pandas.DataFrame`. + + When ``schema`` is a list of column names, the type of each column + will be inferred from ``data``. + + When ``schema`` is ``None``, it will try to infer the schema (column names and types) + from ``data``, which should be an RDD of :class:`Row`, + or :class:`namedtuple`, or :class:`dict`. + + If schema inference is needed, ``samplingRatio`` is used to determined the ratio of + rows used for schema inference. The first row will be used if ``samplingRatio`` is ``None``. + + :param data: an RDD of :class:`Row`/:class:`tuple`/:class:`list`/:class:`dict`, + :class:`list`, or :class:`pandas.DataFrame`. + :param schema: a :class:`StructType` or list of column names. default None. + :param samplingRatio: the sample ratio of rows used for inferring + + >>> l = [('Alice', 1)] + >>> sqlContext.createDataFrame(l).collect() + [Row(_1=u'Alice', _2=1)] + >>> sqlContext.createDataFrame(l, ['name', 'age']).collect() + [Row(name=u'Alice', age=1)] + + >>> d = [{'name': 'Alice', 'age': 1}] + >>> sqlContext.createDataFrame(d).collect() + [Row(age=1, name=u'Alice')] + + >>> rdd = sc.parallelize(l) + >>> sqlContext.createDataFrame(rdd).collect() + [Row(_1=u'Alice', _2=1)] + >>> df = sqlContext.createDataFrame(rdd, ['name', 'age']) + >>> df.collect() + [Row(name=u'Alice', age=1)] + + >>> from pyspark.sql import Row + >>> Person = Row('name', 'age') + >>> person = rdd.map(lambda r: Person(*r)) + >>> df2 = sqlContext.createDataFrame(person) + >>> df2.collect() + [Row(name=u'Alice', age=1)] + + >>> from pyspark.sql.types import * + >>> schema = StructType([ + ... StructField("name", StringType(), True), + ... StructField("age", IntegerType(), True)]) + >>> df3 = sqlContext.createDataFrame(rdd, schema) + >>> df3.collect() + [Row(name=u'Alice', age=1)] + + >>> sqlContext.createDataFrame(df.toPandas()).collect() # doctest: +SKIP + [Row(name=u'Alice', age=1)] + """ + if isinstance(data, DataFrame): + raise TypeError("data is already a DataFrame") + + if has_pandas and isinstance(data, pandas.DataFrame): + if schema is None: + schema = list(data.columns) + data = [r.tolist() for r in data.to_records(index=False)] + + if not isinstance(data, RDD): + try: + # data could be list, tuple, generator ... + rdd = self._sc.parallelize(data) + except Exception: + raise TypeError("cannot create an RDD from type: %s" % type(data)) + else: + rdd = data + + if schema is None: + schema = self._inferSchema(rdd, samplingRatio) + converter = _create_converter(schema) + rdd = rdd.map(converter) + + if isinstance(schema, (list, tuple)): + first = rdd.first() + if not isinstance(first, (list, tuple)): + raise TypeError("each row in `rdd` should be list or tuple, " + "but got %r" % type(first)) + row_cls = Row(*schema) + schema = self._inferSchema(rdd.map(lambda r: row_cls(*r)), samplingRatio) + + # take the first few rows to verify schema + rows = rdd.take(10) + # Row() cannot been deserialized by Pyrolite + if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': + rdd = rdd.map(tuple) + rows = rdd.take(10) + + for row in rows: + _verify_type(row, schema) + + # convert python objects to sql data + converter = _python_to_sql_converter(schema) + rdd = rdd.map(converter) + + jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) + df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) + return DataFrame(df, self) + + def registerDataFrameAsTable(self, df, tableName): + """Registers the given :class:`DataFrame` as a temporary table in the catalog. + + Temporary tables exist only during the lifetime of this instance of :class:`SQLContext`. + + >>> sqlContext.registerDataFrameAsTable(df, "table1") + """ + if (df.__class__ is DataFrame): + self._ssql_ctx.registerDataFrameAsTable(df._jdf, tableName) + else: + raise ValueError("Can only register DataFrame as table") + + def parquetFile(self, *paths): + """Loads a Parquet file, returning the result as a :class:`DataFrame`. + + >>> import tempfile, shutil + >>> parquetFile = tempfile.mkdtemp() + >>> shutil.rmtree(parquetFile) + >>> df.saveAsParquetFile(parquetFile) + >>> df2 = sqlContext.parquetFile(parquetFile) + >>> sorted(df.collect()) == sorted(df2.collect()) + True + """ + gateway = self._sc._gateway + jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths)) + for i in range(0, len(paths)): + jpaths[i] = paths[i] + jdf = self._ssql_ctx.parquetFile(jpaths) + return DataFrame(jdf, self) + + def jsonFile(self, path, schema=None, samplingRatio=1.0): + """Loads a text file storing one JSON object per line as a :class:`DataFrame`. + + If the schema is provided, applies the given schema to this JSON dataset. + Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema. + + >>> import tempfile, shutil + >>> jsonFile = tempfile.mkdtemp() + >>> shutil.rmtree(jsonFile) + >>> with open(jsonFile, 'w') as f: + ... f.writelines(jsonStrings) + >>> df1 = sqlContext.jsonFile(jsonFile) + >>> df1.printSchema() + root + |-- field1: long (nullable = true) + |-- field2: string (nullable = true) + |-- field3: struct (nullable = true) + | |-- field4: long (nullable = true) + + >>> from pyspark.sql.types import * + >>> schema = StructType([ + ... StructField("field2", StringType()), + ... StructField("field3", + ... StructType([StructField("field5", ArrayType(IntegerType()))]))]) + >>> df2 = sqlContext.jsonFile(jsonFile, schema) + >>> df2.printSchema() + root + |-- field2: string (nullable = true) + |-- field3: struct (nullable = true) + | |-- field5: array (nullable = true) + | | |-- element: integer (containsNull = true) + """ + if schema is None: + df = self._ssql_ctx.jsonFile(path, samplingRatio) + else: + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) + df = self._ssql_ctx.jsonFile(path, scala_datatype) + return DataFrame(df, self) + + @ignore_unicode_prefix + def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): + """Loads an RDD storing one JSON object per string as a :class:`DataFrame`. + + If the schema is provided, applies the given schema to this JSON dataset. + Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema. + + >>> df1 = sqlContext.jsonRDD(json) + >>> df1.first() + Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) + + >>> df2 = sqlContext.jsonRDD(json, df1.schema) + >>> df2.first() + Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) + + >>> from pyspark.sql.types import * + >>> schema = StructType([ + ... StructField("field2", StringType()), + ... StructField("field3", + ... StructType([StructField("field5", ArrayType(IntegerType()))])) + ... ]) + >>> df3 = sqlContext.jsonRDD(json, schema) + >>> df3.first() + Row(field2=u'row1', field3=Row(field5=None)) + """ + + def func(iterator): + for x in iterator: + if not isinstance(x, basestring): + x = unicode(x) + if isinstance(x, unicode): + x = x.encode("utf-8") + yield x + keyed = rdd.mapPartitions(func) + keyed._bypass_serializer = True + jrdd = keyed._jrdd.map(self._jvm.BytesToString()) + if schema is None: + df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) + else: + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) + df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) + return DataFrame(df, self) + + def load(self, path=None, source=None, schema=None, **options): + """Returns the dataset in a data source as a :class:`DataFrame`. + + The data source is specified by the ``source`` and a set of ``options``. + If ``source`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. + + Optionally, a schema can be provided as the schema of the returned DataFrame. + """ + if path is not None: + options["path"] = path + if source is None: + source = self.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + if schema is None: + df = self._ssql_ctx.load(source, options) + else: + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) + df = self._ssql_ctx.load(source, scala_datatype, options) + return DataFrame(df, self) + + def createExternalTable(self, tableName, path=None, source=None, + schema=None, **options): + """Creates an external table based on the dataset in a data source. + + It returns the DataFrame associated with the external table. + + The data source is specified by the ``source`` and a set of ``options``. + If ``source`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. + + Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and + created external table. + """ + if path is not None: + options["path"] = path + if source is None: + source = self.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + if schema is None: + df = self._ssql_ctx.createExternalTable(tableName, source, options) + else: + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) + df = self._ssql_ctx.createExternalTable(tableName, source, scala_datatype, + options) + return DataFrame(df, self) + + @ignore_unicode_prefix + def sql(self, sqlQuery): + """Returns a :class:`DataFrame` representing the result of the given query. + + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> df2 = sqlContext.sql("SELECT field1 AS f1, field2 as f2 from table1") + >>> df2.collect() + [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] + """ + return DataFrame(self._ssql_ctx.sql(sqlQuery), self) + + def table(self, tableName): + """Returns the specified table as a :class:`DataFrame`. + + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> df2 = sqlContext.table("table1") + >>> sorted(df.collect()) == sorted(df2.collect()) + True + """ + return DataFrame(self._ssql_ctx.table(tableName), self) + + @ignore_unicode_prefix + def tables(self, dbName=None): + """Returns a :class:`DataFrame` containing names of tables in the given database. + + If ``dbName`` is not specified, the current database will be used. + + The returned DataFrame has two columns: ``tableName`` and ``isTemporary`` + (a column with :class:`BooleanType` indicating if a table is a temporary one or not). + + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> df2 = sqlContext.tables() + >>> df2.filter("tableName = 'table1'").first() + Row(tableName=u'table1', isTemporary=True) + """ + if dbName is None: + return DataFrame(self._ssql_ctx.tables(), self) + else: + return DataFrame(self._ssql_ctx.tables(dbName), self) + + def tableNames(self, dbName=None): + """Returns a list of names of tables in the database ``dbName``. + + If ``dbName`` is not specified, the current database will be used. + + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> "table1" in sqlContext.tableNames() + True + >>> "table1" in sqlContext.tableNames("db") + True + """ + if dbName is None: + return [name for name in self._ssql_ctx.tableNames()] + else: + return [name for name in self._ssql_ctx.tableNames(dbName)] + + def cacheTable(self, tableName): + """Caches the specified table in-memory.""" + self._ssql_ctx.cacheTable(tableName) + + def uncacheTable(self, tableName): + """Removes the specified table from the in-memory cache.""" + self._ssql_ctx.uncacheTable(tableName) + + def clearCache(self): + """Removes all cached tables from the in-memory cache. """ + self._ssql_ctx.clearCache() + + +class HiveContext(SQLContext): + """A variant of Spark SQL that integrates with data stored in Hive. + + Configuration for Hive is read from ``hive-site.xml`` on the classpath. + It supports running both SQL and HiveQL commands. + + :param sparkContext: The SparkContext to wrap. + :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instantiate a new + :class:`HiveContext` in the JVM, instead we make all calls to this object. + """ + + def __init__(self, sparkContext, hiveContext=None): + SQLContext.__init__(self, sparkContext) + if hiveContext: + self._scala_HiveContext = hiveContext + + @property + def _ssql_ctx(self): + try: + if not hasattr(self, '_scala_HiveContext'): + self._scala_HiveContext = self._get_hive_ctx() + return self._scala_HiveContext + except Py4JError as e: + raise Exception("You must build Spark with Hive. " + "Export 'SPARK_HIVE=true' and run " + "build/sbt assembly", e) + + def _get_hive_ctx(self): + return self._jvm.HiveContext(self._jsc.sc()) + + def refreshTable(self, tableName): + """Invalidate and refresh all the cached the metadata of the given + table. For performance reasons, Spark SQL or the external data source + library it uses might cache certain metadata about a table, such as the + location of blocks. When those change outside of Spark SQL, users should + call this function to invalidate the cache. + """ + self._ssql_ctx.refreshTable(tableName) + + +class UDFRegistration(object): + """Wrapper for user-defined function registration.""" + + def __init__(self, sqlContext): + self.sqlContext = sqlContext + + def register(self, name, f, returnType=StringType()): + return self.sqlContext.registerFunction(name, f, returnType) + + register.__doc__ = SQLContext.registerFunction.__doc__ + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext + import pyspark.sql.context + globs = pyspark.sql.context.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) + globs['rdd'] = rdd = sc.parallelize( + [Row(field1=1, field2="row1"), + Row(field1=2, field2="row2"), + Row(field1=3, field2="row3")] + ) + globs['df'] = rdd.toDF() + jsonStrings = [ + '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', + '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' + '"field6":[{"field7": "row2"}]}', + '{"field1" : null, "field2": "row3", ' + '"field3":{"field4":33, "field5": []}}' + ] + globs['jsonStrings'] = jsonStrings + globs['json'] = sc.parallelize(jsonStrings) + (failure_count, test_count) = doctest.testmod( + pyspark.sql.context, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py new file mode 100644 index 000000000000..4759f5fe783a --- /dev/null +++ b/python/pyspark/sql/dataframe.py @@ -0,0 +1,1321 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import warnings +import random + +if sys.version >= '3': + basestring = unicode = str + long = int +else: + from itertools import imap as map + +from pyspark.context import SparkContext +from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix +from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer +from pyspark.storagelevel import StorageLevel +from pyspark.traceback_utils import SCCallSiteSync +from pyspark.sql.types import * +from pyspark.sql.types import _create_cls, _parse_datatype_json_string + + +__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFunctions"] + + +class DataFrame(object): + """A distributed collection of data grouped into named columns. + + A :class:`DataFrame` is equivalent to a relational table in Spark SQL, + and can be created using various functions in :class:`SQLContext`:: + + people = sqlContext.parquetFile("...") + + Once created, it can be manipulated using the various domain-specific-language + (DSL) functions defined in: :class:`DataFrame`, :class:`Column`. + + To select a column from the data frame, use the apply method:: + + ageCol = people.age + + A more concrete example:: + + # To create DataFrame using SQLContext + people = sqlContext.parquetFile("...") + department = sqlContext.parquetFile("...") + + people.filter(people.age > 30).join(department, people.deptId == department.id)) \ + .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) + """ + + def __init__(self, jdf, sql_ctx): + self._jdf = jdf + self.sql_ctx = sql_ctx + self._sc = sql_ctx and sql_ctx._sc + self.is_cached = False + self._schema = None # initialized lazily + self._lazy_rdd = None + + @property + def rdd(self): + """Returns the content as an :class:`pyspark.RDD` of :class:`Row`. + """ + if self._lazy_rdd is None: + jrdd = self._jdf.javaToPython() + rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) + schema = self.schema + + def applySchema(it): + cls = _create_cls(schema) + return map(cls, it) + + self._lazy_rdd = rdd.mapPartitions(applySchema) + + return self._lazy_rdd + + @property + def na(self): + """Returns a :class:`DataFrameNaFunctions` for handling missing values. + """ + return DataFrameNaFunctions(self) + + @ignore_unicode_prefix + def toJSON(self, use_unicode=True): + """Converts a :class:`DataFrame` into a :class:`RDD` of string. + + Each row is turned into a JSON document as one element in the returned RDD. + + >>> df.toJSON().first() + u'{"age":2,"name":"Alice"}' + """ + rdd = self._jdf.toJSON() + return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) + + def saveAsParquetFile(self, path): + """Saves the contents as a Parquet file, preserving the schema. + + Files that are written out using this method can be read back in as + a :class:`DataFrame` using :func:`SQLContext.parquetFile`. + + >>> import tempfile, shutil + >>> parquetFile = tempfile.mkdtemp() + >>> shutil.rmtree(parquetFile) + >>> df.saveAsParquetFile(parquetFile) + >>> df2 = sqlContext.parquetFile(parquetFile) + >>> sorted(df2.collect()) == sorted(df.collect()) + True + """ + self._jdf.saveAsParquetFile(path) + + def registerTempTable(self, name): + """Registers this RDD as a temporary table using the given name. + + The lifetime of this temporary table is tied to the :class:`SQLContext` + that was used to create this :class:`DataFrame`. + + >>> df.registerTempTable("people") + >>> df2 = sqlContext.sql("select * from people") + >>> sorted(df.collect()) == sorted(df2.collect()) + True + """ + self._jdf.registerTempTable(name) + + def registerAsTable(self, name): + """DEPRECATED: use :func:`registerTempTable` instead""" + warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) + self.registerTempTable(name) + + def insertInto(self, tableName, overwrite=False): + """Inserts the contents of this :class:`DataFrame` into the specified table. + + Optionally overwriting any existing data. + """ + self._jdf.insertInto(tableName, overwrite) + + def _java_save_mode(self, mode): + """Returns the Java save mode based on the Python save mode represented by a string. + """ + jSaveMode = self._sc._jvm.org.apache.spark.sql.SaveMode + jmode = jSaveMode.ErrorIfExists + mode = mode.lower() + if mode == "append": + jmode = jSaveMode.Append + elif mode == "overwrite": + jmode = jSaveMode.Overwrite + elif mode == "ignore": + jmode = jSaveMode.Ignore + elif mode == "error": + pass + else: + raise ValueError( + "Only 'append', 'overwrite', 'ignore', and 'error' are acceptable save mode.") + return jmode + + def saveAsTable(self, tableName, source=None, mode="error", **options): + """Saves the contents of this :class:`DataFrame` to a data source as a table. + + The data source is specified by the ``source`` and a set of ``options``. + If ``source`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. + + Additionally, mode is used to specify the behavior of the saveAsTable operation when + table already exists in the data source. There are four modes: + + * `append`: Append contents of this :class:`DataFrame` to existing data. + * `overwrite`: Overwrite existing data. + * `error`: Throw an exception if data already exists. + * `ignore`: Silently ignore this operation if data already exists. + """ + if source is None: + source = self.sql_ctx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + jmode = self._java_save_mode(mode) + self._jdf.saveAsTable(tableName, source, jmode, options) + + def save(self, path=None, source=None, mode="error", **options): + """Saves the contents of the :class:`DataFrame` to a data source. + + The data source is specified by the ``source`` and a set of ``options``. + If ``source`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. + + Additionally, mode is used to specify the behavior of the save operation when + data already exists in the data source. There are four modes: + + * `append`: Append contents of this :class:`DataFrame` to existing data. + * `overwrite`: Overwrite existing data. + * `error`: Throw an exception if data already exists. + * `ignore`: Silently ignore this operation if data already exists. + """ + if path is not None: + options["path"] = path + if source is None: + source = self.sql_ctx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + jmode = self._java_save_mode(mode) + self._jdf.save(source, jmode, options) + + @property + def schema(self): + """Returns the schema of this :class:`DataFrame` as a :class:`types.StructType`. + + >>> df.schema + StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true))) + """ + if self._schema is None: + self._schema = _parse_datatype_json_string(self._jdf.schema().json()) + return self._schema + + def printSchema(self): + """Prints out the schema in the tree format. + + >>> df.printSchema() + root + |-- age: integer (nullable = true) + |-- name: string (nullable = true) + + """ + print(self._jdf.schema().treeString()) + + def explain(self, extended=False): + """Prints the (logical and physical) plans to the console for debugging purpose. + + :param extended: boolean, default ``False``. If ``False``, prints only the physical plan. + + >>> df.explain() + PhysicalRDD [age#0,name#1], MapPartitionsRDD[...] at mapPartitions at SQLContext.scala:... + + >>> df.explain(True) + == Parsed Logical Plan == + ... + == Analyzed Logical Plan == + ... + == Optimized Logical Plan == + ... + == Physical Plan == + ... + == RDD == + """ + if extended: + print(self._jdf.queryExecution().toString()) + else: + print(self._jdf.queryExecution().executedPlan().toString()) + + def isLocal(self): + """Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally + (without any Spark executors). + """ + return self._jdf.isLocal() + + def show(self, n=20): + """Prints the first ``n`` rows to the console. + + >>> df + DataFrame[age: int, name: string] + >>> df.show() + age name + 2 Alice + 5 Bob + """ + print(self._jdf.showString(n)) + + def __repr__(self): + return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) + + def count(self): + """Returns the number of rows in this :class:`DataFrame`. + + >>> df.count() + 2 + """ + return int(self._jdf.count()) + + @ignore_unicode_prefix + def collect(self): + """Returns all the records as a list of :class:`Row`. + + >>> df.collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + """ + with SCCallSiteSync(self._sc) as css: + port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd()) + rs = list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) + cls = _create_cls(self.schema) + return [cls(r) for r in rs] + + @ignore_unicode_prefix + def limit(self, num): + """Limits the result count to the number specified. + + >>> df.limit(1).collect() + [Row(age=2, name=u'Alice')] + >>> df.limit(0).collect() + [] + """ + jdf = self._jdf.limit(num) + return DataFrame(jdf, self.sql_ctx) + + @ignore_unicode_prefix + def take(self, num): + """Returns the first ``num`` rows as a :class:`list` of :class:`Row`. + + >>> df.take(2) + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + """ + return self.limit(num).collect() + + @ignore_unicode_prefix + def map(self, f): + """ Returns a new :class:`RDD` by applying a the ``f`` function to each :class:`Row`. + + This is a shorthand for ``df.rdd.map()``. + + >>> df.map(lambda p: p.name).collect() + [u'Alice', u'Bob'] + """ + return self.rdd.map(f) + + @ignore_unicode_prefix + def flatMap(self, f): + """ Returns a new :class:`RDD` by first applying the ``f`` function to each :class:`Row`, + and then flattening the results. + + This is a shorthand for ``df.rdd.flatMap()``. + + >>> df.flatMap(lambda p: p.name).collect() + [u'A', u'l', u'i', u'c', u'e', u'B', u'o', u'b'] + """ + return self.rdd.flatMap(f) + + def mapPartitions(self, f, preservesPartitioning=False): + """Returns a new :class:`RDD` by applying the ``f`` function to each partition. + + This is a shorthand for ``df.rdd.mapPartitions()``. + + >>> rdd = sc.parallelize([1, 2, 3, 4], 4) + >>> def f(iterator): yield 1 + >>> rdd.mapPartitions(f).sum() + 4 + """ + return self.rdd.mapPartitions(f, preservesPartitioning) + + def foreach(self, f): + """Applies the ``f`` function to all :class:`Row` of this :class:`DataFrame`. + + This is a shorthand for ``df.rdd.foreach()``. + + >>> def f(person): + ... print(person.name) + >>> df.foreach(f) + """ + return self.rdd.foreach(f) + + def foreachPartition(self, f): + """Applies the ``f`` function to each partition of this :class:`DataFrame`. + + This a shorthand for ``df.rdd.foreachPartition()``. + + >>> def f(people): + ... for person in people: + ... print(person.name) + >>> df.foreachPartition(f) + """ + return self.rdd.foreachPartition(f) + + def cache(self): + """ Persists with the default storage level (C{MEMORY_ONLY_SER}). + """ + self.is_cached = True + self._jdf.cache() + return self + + def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): + """Sets the storage level to persist its values across operations + after the first time it is computed. This can only be used to assign + a new storage level if the RDD does not have a storage level set yet. + If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). + """ + self.is_cached = True + javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) + self._jdf.persist(javaStorageLevel) + return self + + def unpersist(self, blocking=True): + """Marks the :class:`DataFrame` as non-persistent, and remove all blocks for it from + memory and disk. + """ + self.is_cached = False + self._jdf.unpersist(blocking) + return self + + # def coalesce(self, numPartitions, shuffle=False): + # rdd = self._jdf.coalesce(numPartitions, shuffle, None) + # return DataFrame(rdd, self.sql_ctx) + + def repartition(self, numPartitions): + """Returns a new :class:`DataFrame` that has exactly ``numPartitions`` partitions. + + >>> df.repartition(10).rdd.getNumPartitions() + 10 + """ + return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx) + + def distinct(self): + """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. + + >>> df.distinct().count() + 2 + """ + return DataFrame(self._jdf.distinct(), self.sql_ctx) + + def sample(self, withReplacement, fraction, seed=None): + """Returns a sampled subset of this :class:`DataFrame`. + + >>> df.sample(False, 0.5, 97).count() + 1 + """ + assert fraction >= 0.0, "Negative fraction value: %s" % fraction + seed = seed if seed is not None else random.randint(0, sys.maxsize) + rdd = self._jdf.sample(withReplacement, fraction, long(seed)) + return DataFrame(rdd, self.sql_ctx) + + @property + def dtypes(self): + """Returns all column names and their data types as a list. + + >>> df.dtypes + [('age', 'int'), ('name', 'string')] + """ + return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields] + + @property + @ignore_unicode_prefix + def columns(self): + """Returns all column names as a list. + + >>> df.columns + [u'age', u'name'] + """ + return [f.name for f in self.schema.fields] + + @ignore_unicode_prefix + def alias(self, alias): + """Returns a new :class:`DataFrame` with an alias set. + + >>> from pyspark.sql.functions import * + >>> df_as1 = df.alias("df_as1") + >>> df_as2 = df.alias("df_as2") + >>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), 'inner') + >>> joined_df.select(col("df_as1.name"), col("df_as2.name"), col("df_as2.age")).collect() + [Row(name=u'Alice', name=u'Alice', age=2), Row(name=u'Bob', name=u'Bob', age=5)] + """ + assert isinstance(alias, basestring), "alias should be a string" + return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx) + + @ignore_unicode_prefix + def join(self, other, joinExprs=None, joinType=None): + """Joins with another :class:`DataFrame`, using the given join expression. + + The following performs a full outer join between ``df1`` and ``df2``. + + :param other: Right side of the join + :param joinExprs: a string for join column name, or a join expression (Column). + If joinExprs is a string indicating the name of the join column, + the column must exist on both sides, and this performs an inner equi-join. + :param joinType: str, default 'inner'. + One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. + + >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() + [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] + + >>> df.join(df2, 'name').select(df.name, df2.height).collect() + [Row(name=u'Bob', height=85)] + """ + + if joinExprs is None: + jdf = self._jdf.join(other._jdf) + elif isinstance(joinExprs, basestring): + jdf = self._jdf.join(other._jdf, joinExprs) + else: + assert isinstance(joinExprs, Column), "joinExprs should be Column" + if joinType is None: + jdf = self._jdf.join(other._jdf, joinExprs._jc) + else: + assert isinstance(joinType, basestring), "joinType should be basestring" + jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType) + return DataFrame(jdf, self.sql_ctx) + + @ignore_unicode_prefix + def sort(self, *cols, **kwargs): + """Returns a new :class:`DataFrame` sorted by the specified column(s). + + :param cols: list of :class:`Column` or column names to sort by. + :param ascending: boolean or list of boolean (default True). + Sort ascending vs. descending. Specify list for multiple sort orders. + If a list is specified, length of the list must equal length of the `cols`. + + >>> df.sort(df.age.desc()).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + >>> df.sort("age", ascending=False).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + >>> df.orderBy(df.age.desc()).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + >>> from pyspark.sql.functions import * + >>> df.sort(asc("age")).collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + >>> df.orderBy(desc("age"), "name").collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + >>> df.orderBy(["age", "name"], ascending=[0, 1]).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + """ + if not cols: + raise ValueError("should sort by at least one column") + if len(cols) == 1 and isinstance(cols[0], list): + cols = cols[0] + jcols = [_to_java_column(c) for c in cols] + ascending = kwargs.get('ascending', True) + if isinstance(ascending, (bool, int)): + if not ascending: + jcols = [jc.desc() for jc in jcols] + elif isinstance(ascending, list): + jcols = [jc if asc else jc.desc() + for asc, jc in zip(ascending, jcols)] + else: + raise TypeError("ascending can only be boolean or list, but got %s" % type(ascending)) + + jdf = self._jdf.sort(self._jseq(jcols)) + return DataFrame(jdf, self.sql_ctx) + + orderBy = sort + + def _jseq(self, cols, converter=None): + """Return a JVM Seq of Columns from a list of Column or names""" + return _to_seq(self.sql_ctx._sc, cols, converter) + + def _jcols(self, *cols): + """Return a JVM Seq of Columns from a list of Column or column names + + If `cols` has only one list in it, cols[0] will be used as the list. + """ + if len(cols) == 1 and isinstance(cols[0], list): + cols = cols[0] + return self._jseq(cols, _to_java_column) + + def describe(self, *cols): + """Computes statistics for numeric columns. + + This include count, mean, stddev, min, and max. If no columns are + given, this function computes statistics for all numerical columns. + + >>> df.describe().show() + summary age + count 2 + mean 3.5 + stddev 1.5 + min 2 + max 5 + """ + jdf = self._jdf.describe(self._jseq(cols)) + return DataFrame(jdf, self.sql_ctx) + + @ignore_unicode_prefix + def head(self, n=None): + """ + Returns the first ``n`` rows as a list of :class:`Row`, + or the first :class:`Row` if ``n`` is ``None.`` + + >>> df.head() + Row(age=2, name=u'Alice') + >>> df.head(1) + [Row(age=2, name=u'Alice')] + """ + if n is None: + rs = self.head(1) + return rs[0] if rs else None + return self.take(n) + + @ignore_unicode_prefix + def first(self): + """Returns the first row as a :class:`Row`. + + >>> df.first() + Row(age=2, name=u'Alice') + """ + return self.head() + + @ignore_unicode_prefix + def __getitem__(self, item): + """Returns the column as a :class:`Column`. + + >>> df.select(df['age']).collect() + [Row(age=2), Row(age=5)] + >>> df[ ["name", "age"]].collect() + [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] + >>> df[ df.age > 3 ].collect() + [Row(age=5, name=u'Bob')] + >>> df[df[0] > 3].collect() + [Row(age=5, name=u'Bob')] + """ + if isinstance(item, basestring): + if item not in self.columns: + raise IndexError("no such column: %s" % item) + jc = self._jdf.apply(item) + return Column(jc) + elif isinstance(item, Column): + return self.filter(item) + elif isinstance(item, (list, tuple)): + return self.select(*item) + elif isinstance(item, int): + jc = self._jdf.apply(self.columns[item]) + return Column(jc) + else: + raise TypeError("unexpected item type: %s" % type(item)) + + def __getattr__(self, name): + """Returns the :class:`Column` denoted by ``name``. + + >>> df.select(df.age).collect() + [Row(age=2), Row(age=5)] + """ + if name not in self.columns: + raise AttributeError("No such column: %s" % name) + jc = self._jdf.apply(name) + return Column(jc) + + @ignore_unicode_prefix + def select(self, *cols): + """Projects a set of expressions and returns a new :class:`DataFrame`. + + :param cols: list of column names (string) or expressions (:class:`Column`). + If one of the column names is '*', that column is expanded to include all columns + in the current DataFrame. + + >>> df.select('*').collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + >>> df.select('name', 'age').collect() + [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] + >>> df.select(df.name, (df.age + 10).alias('age')).collect() + [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)] + """ + jdf = self._jdf.select(self._jcols(*cols)) + return DataFrame(jdf, self.sql_ctx) + + def selectExpr(self, *expr): + """Projects a set of SQL expressions and returns a new :class:`DataFrame`. + + This is a variant of :func:`select` that accepts SQL expressions. + + >>> df.selectExpr("age * 2", "abs(age)").collect() + [Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)] + """ + if len(expr) == 1 and isinstance(expr[0], list): + expr = expr[0] + jdf = self._jdf.selectExpr(self._jseq(expr)) + return DataFrame(jdf, self.sql_ctx) + + @ignore_unicode_prefix + def filter(self, condition): + """Filters rows using the given condition. + + :func:`where` is an alias for :func:`filter`. + + :param condition: a :class:`Column` of :class:`types.BooleanType` + or a string of SQL expression. + + >>> df.filter(df.age > 3).collect() + [Row(age=5, name=u'Bob')] + >>> df.where(df.age == 2).collect() + [Row(age=2, name=u'Alice')] + + >>> df.filter("age > 3").collect() + [Row(age=5, name=u'Bob')] + >>> df.where("age = 2").collect() + [Row(age=2, name=u'Alice')] + """ + if isinstance(condition, basestring): + jdf = self._jdf.filter(condition) + elif isinstance(condition, Column): + jdf = self._jdf.filter(condition._jc) + else: + raise TypeError("condition should be string or Column") + return DataFrame(jdf, self.sql_ctx) + + where = filter + + @ignore_unicode_prefix + def groupBy(self, *cols): + """Groups the :class:`DataFrame` using the specified columns, + so we can run aggregation on them. See :class:`GroupedData` + for all the available aggregate functions. + + :func:`groupby` is an alias for :func:`groupBy`. + + :param cols: list of columns to group by. + Each element should be a column name (string) or an expression (:class:`Column`). + + >>> df.groupBy().avg().collect() + [Row(AVG(age)=3.5)] + >>> df.groupBy('name').agg({'age': 'mean'}).collect() + [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] + >>> df.groupBy(df.name).avg().collect() + [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] + >>> df.groupBy(['name', df.age]).count().collect() + [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)] + """ + jdf = self._jdf.groupBy(self._jcols(*cols)) + return GroupedData(jdf, self.sql_ctx) + + groupby = groupBy + + def agg(self, *exprs): + """ Aggregate on the entire :class:`DataFrame` without groups + (shorthand for ``df.groupBy.agg()``). + + >>> df.agg({"age": "max"}).collect() + [Row(MAX(age)=5)] + >>> from pyspark.sql import functions as F + >>> df.agg(F.min(df.age)).collect() + [Row(MIN(age)=2)] + """ + return self.groupBy().agg(*exprs) + + def unionAll(self, other): + """ Return a new :class:`DataFrame` containing union of rows in this + frame and another frame. + + This is equivalent to `UNION ALL` in SQL. + """ + return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx) + + def intersect(self, other): + """ Return a new :class:`DataFrame` containing rows only in + both this frame and another frame. + + This is equivalent to `INTERSECT` in SQL. + """ + return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) + + def subtract(self, other): + """ Return a new :class:`DataFrame` containing rows in this frame + but not in another frame. + + This is equivalent to `EXCEPT` in SQL. + """ + return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) + + def dropna(self, how='any', thresh=None, subset=None): + """Returns a new :class:`DataFrame` omitting rows with null values. + + This is an alias for ``na.drop()``. + + :param how: 'any' or 'all'. + If 'any', drop a row if it contains any nulls. + If 'all', drop a row only if all its values are null. + :param thresh: int, default None + If specified, drop rows that have less than `thresh` non-null values. + This overwrites the `how` parameter. + :param subset: optional list of column names to consider. + + >>> df4.dropna().show() + age height name + 10 80 Alice + + >>> df4.na.drop().show() + age height name + 10 80 Alice + """ + if how is not None and how not in ['any', 'all']: + raise ValueError("how ('" + how + "') should be 'any' or 'all'") + + if subset is None: + subset = self.columns + elif isinstance(subset, basestring): + subset = [subset] + elif not isinstance(subset, (list, tuple)): + raise ValueError("subset should be a list or tuple of column names") + + if thresh is None: + thresh = len(subset) if how == 'any' else 1 + + return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), self.sql_ctx) + + def fillna(self, value, subset=None): + """Replace null values, alias for ``na.fill()``. + + :param value: int, long, float, string, or dict. + Value to replace null values with. + If the value is a dict, then `subset` is ignored and `value` must be a mapping + from column name (string) to replacement value. The replacement value must be + an int, long, float, or string. + :param subset: optional list of column names to consider. + Columns specified in subset that do not have matching data type are ignored. + For example, if `value` is a string, and subset contains a non-string column, + then the non-string column is simply ignored. + + >>> df4.fillna(50).show() + age height name + 10 80 Alice + 5 50 Bob + 50 50 Tom + 50 50 null + + >>> df4.fillna({'age': 50, 'name': 'unknown'}).show() + age height name + 10 80 Alice + 5 null Bob + 50 null Tom + 50 null unknown + + >>> df4.na.fill({'age': 50, 'name': 'unknown'}).show() + age height name + 10 80 Alice + 5 null Bob + 50 null Tom + 50 null unknown + """ + if not isinstance(value, (float, int, long, basestring, dict)): + raise ValueError("value should be a float, int, long, string, or dict") + + if isinstance(value, (int, long)): + value = float(value) + + if isinstance(value, dict): + return DataFrame(self._jdf.na().fill(value), self.sql_ctx) + elif subset is None: + return DataFrame(self._jdf.na().fill(value), self.sql_ctx) + else: + if isinstance(subset, basestring): + subset = [subset] + elif not isinstance(subset, (list, tuple)): + raise ValueError("subset should be a list or tuple of column names") + + return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) + + @ignore_unicode_prefix + def withColumn(self, colName, col): + """Returns a new :class:`DataFrame` by adding a column. + + :param colName: string, name of the new column. + :param col: a :class:`Column` expression for the new column. + + >>> df.withColumn('age2', df.age + 2).collect() + [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] + """ + return self.select('*', col.alias(colName)) + + @ignore_unicode_prefix + def withColumnRenamed(self, existing, new): + """REturns a new :class:`DataFrame` by renaming an existing column. + + :param existing: string, name of the existing column to rename. + :param col: string, new name of the column. + + >>> df.withColumnRenamed('age', 'age2').collect() + [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')] + """ + cols = [Column(_to_java_column(c)).alias(new) + if c == existing else c + for c in self.columns] + return self.select(*cols) + + def toPandas(self): + """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. + + This is only available if Pandas is installed and available. + + >>> df.toPandas() # doctest: +SKIP + age name + 0 2 Alice + 1 5 Bob + """ + import pandas as pd + return pd.DataFrame.from_records(self.collect(), columns=self.columns) + + +# Having SchemaRDD for backward compatibility (for docs) +class SchemaRDD(DataFrame): + """SchemaRDD is deprecated, please use :class:`DataFrame`. + """ + + +def dfapi(f): + def _api(self): + name = f.__name__ + jdf = getattr(self._jdf, name)() + return DataFrame(jdf, self.sql_ctx) + _api.__name__ = f.__name__ + _api.__doc__ = f.__doc__ + return _api + + +def df_varargs_api(f): + def _api(self, *args): + name = f.__name__ + jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args)) + return DataFrame(jdf, self.sql_ctx) + _api.__name__ = f.__name__ + _api.__doc__ = f.__doc__ + return _api + + +class GroupedData(object): + """ + A set of methods for aggregations on a :class:`DataFrame`, + created by :func:`DataFrame.groupBy`. + """ + + def __init__(self, jdf, sql_ctx): + self._jdf = jdf + self.sql_ctx = sql_ctx + + @ignore_unicode_prefix + def agg(self, *exprs): + """Compute aggregates and returns the result as a :class:`DataFrame`. + + The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`. + + If ``exprs`` is a single :class:`dict` mapping from string to string, then the key + is the column to perform aggregation on, and the value is the aggregate function. + + Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions. + + :param exprs: a dict mapping from column name (string) to aggregate functions (string), + or a list of :class:`Column`. + + >>> gdf = df.groupBy(df.name) + >>> gdf.agg({"*": "count"}).collect() + [Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)] + + >>> from pyspark.sql import functions as F + >>> gdf.agg(F.min(df.age)).collect() + [Row(MIN(age)=2), Row(MIN(age)=5)] + """ + assert exprs, "exprs should not be empty" + if len(exprs) == 1 and isinstance(exprs[0], dict): + jdf = self._jdf.agg(exprs[0]) + else: + # Columns + assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" + jdf = self._jdf.agg(exprs[0]._jc, + _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) + return DataFrame(jdf, self.sql_ctx) + + @dfapi + def count(self): + """Counts the number of records for each group. + + >>> df.groupBy(df.age).count().collect() + [Row(age=2, count=1), Row(age=5, count=1)] + """ + + @df_varargs_api + def mean(self, *cols): + """Computes average values for each numeric columns for each group. + + :func:`mean` is an alias for :func:`avg`. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().mean('age').collect() + [Row(AVG(age)=3.5)] + >>> df3.groupBy().mean('age', 'height').collect() + [Row(AVG(age)=3.5, AVG(height)=82.5)] + """ + + @df_varargs_api + def avg(self, *cols): + """Computes average values for each numeric columns for each group. + + :func:`mean` is an alias for :func:`avg`. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().avg('age').collect() + [Row(AVG(age)=3.5)] + >>> df3.groupBy().avg('age', 'height').collect() + [Row(AVG(age)=3.5, AVG(height)=82.5)] + """ + + @df_varargs_api + def max(self, *cols): + """Computes the max value for each numeric columns for each group. + + >>> df.groupBy().max('age').collect() + [Row(MAX(age)=5)] + >>> df3.groupBy().max('age', 'height').collect() + [Row(MAX(age)=5, MAX(height)=85)] + """ + + @df_varargs_api + def min(self, *cols): + """Computes the min value for each numeric column for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().min('age').collect() + [Row(MIN(age)=2)] + >>> df3.groupBy().min('age', 'height').collect() + [Row(MIN(age)=2, MIN(height)=80)] + """ + + @df_varargs_api + def sum(self, *cols): + """Compute the sum for each numeric columns for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().sum('age').collect() + [Row(SUM(age)=7)] + >>> df3.groupBy().sum('age', 'height').collect() + [Row(SUM(age)=7, SUM(height)=165)] + """ + + +def _create_column_from_literal(literal): + sc = SparkContext._active_spark_context + return sc._jvm.functions.lit(literal) + + +def _create_column_from_name(name): + sc = SparkContext._active_spark_context + return sc._jvm.functions.col(name) + + +def _to_java_column(col): + if isinstance(col, Column): + jcol = col._jc + else: + jcol = _create_column_from_name(col) + return jcol + + +def _to_seq(sc, cols, converter=None): + """ + Convert a list of Column (or names) into a JVM Seq of Column. + + An optional `converter` could be used to convert items in `cols` + into JVM Column objects. + """ + if converter: + cols = [converter(c) for c in cols] + return sc._jvm.PythonUtils.toSeq(cols) + + +def _unary_op(name, doc="unary operator"): + """ Create a method for given unary operator """ + def _(self): + jc = getattr(self._jc, name)() + return Column(jc) + _.__doc__ = doc + return _ + + +def _func_op(name, doc=''): + def _(self): + sc = SparkContext._active_spark_context + jc = getattr(sc._jvm.functions, name)(self._jc) + return Column(jc) + _.__doc__ = doc + return _ + + +def _bin_op(name, doc="binary operator"): + """ Create a method for given binary operator + """ + def _(self, other): + jc = other._jc if isinstance(other, Column) else other + njc = getattr(self._jc, name)(jc) + return Column(njc) + _.__doc__ = doc + return _ + + +def _reverse_op(name, doc="binary operator"): + """ Create a method for binary operator (this object is on right side) + """ + def _(self, other): + jother = _create_column_from_literal(other) + jc = getattr(jother, name)(self._jc) + return Column(jc) + _.__doc__ = doc + return _ + + +class Column(object): + + """ + A column in a DataFrame. + + :class:`Column` instances can be created by:: + + # 1. Select a column out of a DataFrame + + df.colName + df["colName"] + + # 2. Create from an expression + df.colName + 1 + 1 / df.colName + """ + + def __init__(self, jc): + self._jc = jc + + # arithmetic operators + __neg__ = _func_op("negate") + __add__ = _bin_op("plus") + __sub__ = _bin_op("minus") + __mul__ = _bin_op("multiply") + __div__ = _bin_op("divide") + __truediv__ = _bin_op("divide") + __mod__ = _bin_op("mod") + __radd__ = _bin_op("plus") + __rsub__ = _reverse_op("minus") + __rmul__ = _bin_op("multiply") + __rdiv__ = _reverse_op("divide") + __rtruediv__ = _reverse_op("divide") + __rmod__ = _reverse_op("mod") + + # logistic operators + __eq__ = _bin_op("equalTo") + __ne__ = _bin_op("notEqual") + __lt__ = _bin_op("lt") + __le__ = _bin_op("leq") + __ge__ = _bin_op("geq") + __gt__ = _bin_op("gt") + + # `and`, `or`, `not` cannot be overloaded in Python, + # so use bitwise operators as boolean operators + __and__ = _bin_op('and') + __or__ = _bin_op('or') + __invert__ = _func_op('not') + __rand__ = _bin_op("and") + __ror__ = _bin_op("or") + + # container operators + __contains__ = _bin_op("contains") + __getitem__ = _bin_op("getItem") + + def getItem(self, key): + """An expression that gets an item at position `ordinal` out of a list, + or gets an item by key out of a dict. + + >>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"]) + >>> df.select(df.l.getItem(0), df.d.getItem("key")).show() + l[0] d[key] + 1 value + >>> df.select(df.l[0], df.d["key"]).show() + l[0] d[key] + 1 value + """ + return self[key] + + def getField(self, name): + """An expression that gets a field by name in a StructField. + + >>> from pyspark.sql import Row + >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF() + >>> df.select(df.r.getField("b")).show() + r.b + b + >>> df.select(df.r.a).show() + r.a + 1 + """ + return Column(self._jc.getField(name)) + + def __getattr__(self, item): + if item.startswith("__"): + raise AttributeError(item) + return self.getField(item) + + # string methods + rlike = _bin_op("rlike") + like = _bin_op("like") + startswith = _bin_op("startsWith") + endswith = _bin_op("endsWith") + + @ignore_unicode_prefix + def substr(self, startPos, length): + """ + Return a :class:`Column` which is a substring of the column + + :param startPos: start position (int or Column) + :param length: length of the substring (int or Column) + + >>> df.select(df.name.substr(1, 3).alias("col")).collect() + [Row(col=u'Ali'), Row(col=u'Bob')] + """ + if type(startPos) != type(length): + raise TypeError("Can not mix the type") + if isinstance(startPos, (int, long)): + jc = self._jc.substr(startPos, length) + elif isinstance(startPos, Column): + jc = self._jc.substr(startPos._jc, length._jc) + else: + raise TypeError("Unexpected type: %s" % type(startPos)) + return Column(jc) + + __getslice__ = substr + + @ignore_unicode_prefix + def inSet(self, *cols): + """ A boolean expression that is evaluated to true if the value of this + expression is contained by the evaluated values of the arguments. + + >>> df[df.name.inSet("Bob", "Mike")].collect() + [Row(age=5, name=u'Bob')] + >>> df[df.age.inSet([1, 2, 3])].collect() + [Row(age=2, name=u'Alice')] + """ + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols] + sc = SparkContext._active_spark_context + jc = getattr(self._jc, "in")(_to_seq(sc, cols)) + return Column(jc) + + # order + asc = _unary_op("asc", "Returns a sort expression based on the" + " ascending order of the given column name.") + desc = _unary_op("desc", "Returns a sort expression based on the" + " descending order of the given column name.") + + isNull = _unary_op("isNull", "True if the current expression is null.") + isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") + + def alias(self, alias): + """Return a alias for this column + + >>> df.select(df.age.alias("age2")).collect() + [Row(age2=2), Row(age2=5)] + """ + return Column(getattr(self._jc, "as")(alias)) + + @ignore_unicode_prefix + def cast(self, dataType): + """ Convert the column into type `dataType` + + >>> df.select(df.age.cast("string").alias('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + >>> df.select(df.age.cast(StringType()).alias('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + """ + if isinstance(dataType, basestring): + jc = self._jc.cast(dataType) + elif isinstance(dataType, DataType): + sc = SparkContext._active_spark_context + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + jdt = ssql_ctx.parseDataType(dataType.json()) + jc = self._jc.cast(jdt) + else: + raise TypeError("unexpected type: %s" % type(dataType)) + return Column(jc) + + def __repr__(self): + return 'Column<%s>' % self._jc.toString().encode('utf8') + + +class DataFrameNaFunctions(object): + """Functionality for working with missing data in :class:`DataFrame`. + """ + + def __init__(self, df): + self.df = df + + def drop(self, how='any', thresh=None, subset=None): + return self.df.dropna(how=how, thresh=thresh, subset=subset) + + drop.__doc__ = DataFrame.dropna.__doc__ + + def fill(self, value, subset=None): + return self.df.fillna(value=value, subset=subset) + + fill.__doc__ = DataFrame.fillna.__doc__ + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext + import pyspark.sql.dataframe + globs = pyspark.sql.dataframe.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) + globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')])\ + .toDF(StructType([StructField('age', IntegerType()), + StructField('name', StringType())])) + globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF() + globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80), + Row(name='Bob', age=5, height=85)]).toDF() + + globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80), + Row(name='Bob', age=5, height=None), + Row(name='Tom', age=None, height=None), + Row(name=None, age=None, height=None)]).toDF() + + (failure_count, test_count) = doctest.testmod( + pyspark.sql.dataframe, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py new file mode 100644 index 000000000000..bb47923f24b8 --- /dev/null +++ b/python/pyspark/sql/functions.py @@ -0,0 +1,172 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A collections of builtin functions +""" +import sys + +if sys.version < "3": + from itertools import imap as map + +from pyspark import SparkContext +from pyspark.rdd import _prepare_for_python_RDD +from pyspark.serializers import PickleSerializer, AutoBatchedSerializer +from pyspark.sql.types import StringType +from pyspark.sql.dataframe import Column, _to_java_column, _to_seq + + +__all__ = ['countDistinct', 'approxCountDistinct', 'udf'] + + +def _create_function(name, doc=""): + """ Create a function for aggregator by name""" + def _(col): + sc = SparkContext._active_spark_context + jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col) + return Column(jc) + _.__name__ = name + _.__doc__ = doc + return _ + + +_functions = { + 'lit': 'Creates a :class:`Column` of literal value.', + 'col': 'Returns a :class:`Column` based on the given column name.', + 'column': 'Returns a :class:`Column` based on the given column name.', + 'asc': 'Returns a sort expression based on the ascending order of the given column name.', + 'desc': 'Returns a sort expression based on the descending order of the given column name.', + + 'upper': 'Converts a string expression to upper case.', + 'lower': 'Converts a string expression to upper case.', + 'sqrt': 'Computes the square root of the specified float value.', + 'abs': 'Computes the absolutle value.', + + 'max': 'Aggregate function: returns the maximum value of the expression in a group.', + 'min': 'Aggregate function: returns the minimum value of the expression in a group.', + 'first': 'Aggregate function: returns the first value in a group.', + 'last': 'Aggregate function: returns the last value in a group.', + 'count': 'Aggregate function: returns the number of items in a group.', + 'sum': 'Aggregate function: returns the sum of all values in the expression.', + 'avg': 'Aggregate function: returns the average of the values in a group.', + 'mean': 'Aggregate function: returns the average of the values in a group.', + 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.', +} + + +for _name, _doc in _functions.items(): + globals()[_name] = _create_function(_name, _doc) +del _name, _doc +__all__ += _functions.keys() +__all__.sort() + + +def countDistinct(col, *cols): + """Returns a new :class:`Column` for distinct count of ``col`` or ``cols``. + + >>> df.agg(countDistinct(df.age, df.name).alias('c')).collect() + [Row(c=2)] + + >>> df.agg(countDistinct("age", "name").alias('c')).collect() + [Row(c=2)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.countDistinct(_to_java_column(col), _to_seq(sc, cols, _to_java_column)) + return Column(jc) + + +def approxCountDistinct(col, rsd=None): + """Returns a new :class:`Column` for approximate distinct count of ``col``. + + >>> df.agg(approxCountDistinct(df.age).alias('c')).collect() + [Row(c=2)] + """ + sc = SparkContext._active_spark_context + if rsd is None: + jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col)) + else: + jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd) + return Column(jc) + + +class UserDefinedFunction(object): + """ + User defined function in Python + """ + def __init__(self, func, returnType): + self.func = func + self.returnType = returnType + self._broadcast = None + self._judf = self._create_judf() + + def _create_judf(self): + f = self.func # put it in closure `func` + func = lambda _, it: map(lambda x: f(*x), it) + ser = AutoBatchedSerializer(PickleSerializer()) + command = (func, None, ser, ser) + sc = SparkContext._active_spark_context + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + jdt = ssql_ctx.parseDataType(self.returnType.json()) + fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ + judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, + includes, sc.pythonExec, broadcast_vars, + sc._javaAccumulator, jdt) + return judf + + def __del__(self): + if self._broadcast is not None: + self._broadcast.unpersist() + self._broadcast = None + + def __call__(self, *cols): + sc = SparkContext._active_spark_context + jc = self._judf.apply(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + +def udf(f, returnType=StringType()): + """Creates a :class:`Column` expression representing a user defined function (UDF). + + >>> from pyspark.sql.types import IntegerType + >>> slen = udf(lambda s: len(s), IntegerType()) + >>> df.select(slen(df.name).alias('slen')).collect() + [Row(slen=5), Row(slen=3)] + """ + return UserDefinedFunction(f, returnType) + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext + import pyspark.sql.functions + globs = pyspark.sql.functions.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) + globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF() + (failure_count, test_count) = doctest.testmod( + pyspark.sql.functions, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py new file mode 100644 index 000000000000..fe43c374f1cb --- /dev/null +++ b/python/pyspark/sql/tests.py @@ -0,0 +1,655 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Unit tests for pyspark.sql; additional tests are implemented as doctests in +individual modules. +""" +import os +import sys +import pydoc +import shutil +import tempfile +import pickle +import functools +import datetime + +import py4j + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.sql import SQLContext, HiveContext, Column, Row +from pyspark.sql.types import * +from pyspark.sql.types import UserDefinedType, _infer_type +from pyspark.tests import ReusedPySparkTestCase +from pyspark.sql.functions import UserDefinedFunction + + +class ExamplePointUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return 'pyspark.tests' + + @classmethod + def scalaUDT(cls): + return 'org.apache.spark.sql.test.ExamplePointUDT' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return ExamplePoint(datum[0], datum[1]) + + +class ExamplePoint: + """ + An example class to demonstrate UDT in Scala, Java, and Python. + """ + + __UDT__ = ExamplePointUDT() + + def __init__(self, x, y): + self.x = x + self.y = y + + def __repr__(self): + return "ExamplePoint(%s,%s)" % (self.x, self.y) + + def __str__(self): + return "(%s,%s)" % (self.x, self.y) + + def __eq__(self, other): + return isinstance(other, ExamplePoint) and \ + other.x == self.x and other.y == self.y + + +class DataTypeTests(unittest.TestCase): + # regression test for SPARK-6055 + def test_data_type_eq(self): + lt = LongType() + lt2 = pickle.loads(pickle.dumps(LongType())) + self.assertEquals(lt, lt2) + + +class SQLTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + cls.sqlCtx = SQLContext(cls.sc) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + rdd = cls.sc.parallelize(cls.testData, 2) + cls.df = rdd.toDF() + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name, ignore_errors=True) + + def test_udf_with_callable(self): + d = [Row(number=i, squared=i**2) for i in range(10)] + rdd = self.sc.parallelize(d) + data = self.sqlCtx.createDataFrame(rdd) + + class PlusFour: + def __call__(self, col): + if col is not None: + return col + 4 + + call = PlusFour() + pudf = UserDefinedFunction(call, LongType()) + res = data.select(pudf(data['number']).alias('plus_four')) + self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) + + def test_udf_with_partial_function(self): + d = [Row(number=i, squared=i**2) for i in range(10)] + rdd = self.sc.parallelize(d) + data = self.sqlCtx.createDataFrame(rdd) + + def some_func(col, param): + if col is not None: + return col + param + + pfunc = functools.partial(some_func, param=4) + pudf = UserDefinedFunction(pfunc, LongType()) + res = data.select(pudf(data['number']).alias('plus_four')) + self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) + + def test_udf(self): + self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) + [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], 5) + + def test_udf2(self): + self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType()) + self.sqlCtx.createDataFrame(self.sc.parallelize([Row(a="test")])).registerTempTable("test") + [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() + self.assertEqual(4, res[0]) + + def test_udf_with_array_type(self): + d = [Row(l=list(range(3)), d={"key": list(range(5))})] + rdd = self.sc.parallelize(d) + self.sqlCtx.createDataFrame(rdd).registerTempTable("test") + self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) + self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) + [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() + self.assertEqual(list(range(3)), l1) + self.assertEqual(1, l2) + + def test_broadcast_in_udf(self): + bar = {"a": "aa", "b": "bb", "c": "abc"} + foo = self.sc.broadcast(bar) + self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') + [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect() + self.assertEqual("abc", res[0]) + [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() + self.assertEqual("", res[0]) + + def test_basic_functions(self): + rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) + df = self.sqlCtx.jsonRDD(rdd) + df.count() + df.collect() + df.schema + + # cache and checkpoint + self.assertFalse(df.is_cached) + df.persist() + df.unpersist() + df.cache() + self.assertTrue(df.is_cached) + self.assertEqual(2, df.count()) + + df.registerTempTable("temp") + df = self.sqlCtx.sql("select foo from temp") + df.count() + df.collect() + + def test_apply_schema_to_row(self): + df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) + df2 = self.sqlCtx.createDataFrame(df.map(lambda x: x), df.schema) + self.assertEqual(df.collect(), df2.collect()) + + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) + df3 = self.sqlCtx.createDataFrame(rdd, df.schema) + self.assertEqual(10, df3.count()) + + def test_serialize_nested_array_and_map(self): + d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] + rdd = self.sc.parallelize(d) + df = self.sqlCtx.createDataFrame(rdd) + row = df.head() + self.assertEqual(1, len(row.l)) + self.assertEqual(1, row.l[0].a) + self.assertEqual("2", row.d["key"].d) + + l = df.map(lambda x: x.l).first() + self.assertEqual(1, len(l)) + self.assertEqual('s', l[0].b) + + d = df.map(lambda x: x.d).first() + self.assertEqual(1, len(d)) + self.assertEqual(1.0, d["key"].c) + + row = df.map(lambda x: x.d["key"]).first() + self.assertEqual(1.0, row.c) + self.assertEqual("2", row.d) + + def test_infer_schema(self): + d = [Row(l=[], d={}, s=None), + Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] + rdd = self.sc.parallelize(d) + df = self.sqlCtx.createDataFrame(rdd) + self.assertEqual([], df.map(lambda r: r.l).first()) + self.assertEqual([None, ""], df.map(lambda r: r.s).collect()) + df.registerTempTable("test") + result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") + self.assertEqual(1, result.head()[0]) + + df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0) + self.assertEqual(df.schema, df2.schema) + self.assertEqual({}, df2.map(lambda r: r.d).first()) + self.assertEqual([None, ""], df2.map(lambda r: r.s).collect()) + df2.registerTempTable("test2") + result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") + self.assertEqual(1, result.head()[0]) + + def test_infer_nested_schema(self): + NestedRow = Row("f1", "f2") + nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}), + NestedRow([2, 3], {"row2": 2.0})]) + df = self.sqlCtx.inferSchema(nestedRdd1) + self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0]) + + nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]), + NestedRow([[2, 3], [3, 4]], [2, 3])]) + df = self.sqlCtx.inferSchema(nestedRdd2) + self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0]) + + from collections import namedtuple + CustomRow = namedtuple('CustomRow', 'field1 field2') + rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"), + CustomRow(field1=2, field2="row2"), + CustomRow(field1=3, field2="row3")]) + df = self.sqlCtx.inferSchema(rdd) + self.assertEquals(Row(field1=1, field2=u'row1'), df.first()) + + def test_apply_schema(self): + from datetime import date, datetime + rdd = self.sc.parallelize([(127, -128, -32768, 32767, 2147483647, 1.0, + date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1), + {"a": 1}, (2,), [1, 2, 3], None)]) + schema = StructType([ + StructField("byte1", ByteType(), False), + StructField("byte2", ByteType(), False), + StructField("short1", ShortType(), False), + StructField("short2", ShortType(), False), + StructField("int1", IntegerType(), False), + StructField("float1", FloatType(), False), + StructField("date1", DateType(), False), + StructField("time1", TimestampType(), False), + StructField("map1", MapType(StringType(), IntegerType(), False), False), + StructField("struct1", StructType([StructField("b", ShortType(), False)]), False), + StructField("list1", ArrayType(ByteType(), False), False), + StructField("null1", DoubleType(), True)]) + df = self.sqlCtx.createDataFrame(rdd, schema) + results = df.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1, + x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1)) + r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1), + datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + self.assertEqual(r, results.first()) + + df.registerTempTable("table2") + r = self.sqlCtx.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + + "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " + + "float1 + 1.5 as float1 FROM table2").first() + + self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r)) + + from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type + rdd = self.sc.parallelize([(127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), + {"a": 1}, (2,), [1, 2, 3])]) + abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]" + schema = _parse_schema_abstract(abstract) + typedSchema = _infer_schema_type(rdd.first(), schema) + df = self.sqlCtx.createDataFrame(rdd, typedSchema) + r = (127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), {"a": 1}, Row(b=2), [1, 2, 3]) + self.assertEqual(r, tuple(df.first())) + + def test_struct_in_map(self): + d = [Row(m={Row(i=1): Row(s="")})] + df = self.sc.parallelize(d).toDF() + k, v = list(df.head().m.items())[0] + self.assertEqual(1, k.i) + self.assertEqual("", v.s) + + def test_convert_row_to_dict(self): + row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) + self.assertEqual(1, row.asDict()['l'][0].a) + df = self.sc.parallelize([row]).toDF() + df.registerTempTable("test") + row = self.sqlCtx.sql("select l, d from test").head() + self.assertEqual(1, row.asDict()["l"][0].a) + self.assertEqual(1.0, row.asDict()['d']['key'].c) + + def test_infer_schema_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + df = self.sc.parallelize([row]).toDF() + schema = df.schema + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), ExamplePointUDT) + df.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + def test_apply_schema_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + row = (1.0, ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + df = rdd.toDF(schema) + point = df.head().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + + def test_parquet_with_udt(self): + from pyspark.sql.tests import ExamplePoint + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + df0 = self.sc.parallelize([row]).toDF() + output_dir = os.path.join(self.tempdir.name, "labeled_point") + df0.saveAsParquetFile(output_dir) + df1 = self.sqlCtx.parquetFile(output_dir) + point = df1.head().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + + def test_column_operators(self): + ci = self.df.key + cs = self.df.value + c = ci == cs + self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) + rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) + self.assertTrue(all(isinstance(c, Column) for c in rcc)) + cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs] + self.assertTrue(all(isinstance(c, Column) for c in cb)) + cbool = (ci & ci), (ci | ci), (~ci) + self.assertTrue(all(isinstance(c, Column) for c in cbool)) + css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a') + self.assertTrue(all(isinstance(c, Column) for c in css)) + self.assertTrue(isinstance(ci.cast(LongType()), Column)) + + def test_column_select(self): + df = self.df + self.assertEqual(self.testData, df.select("*").collect()) + self.assertEqual(self.testData, df.select(df.key, df.value).collect()) + self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) + + def test_aggregator(self): + df = self.df + g = df.groupBy() + self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) + self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) + + from pyspark.sql import functions + self.assertEqual((0, u'99'), + tuple(g.agg(functions.first(df.key), functions.last(df.value)).first())) + self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0]) + self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) + + def test_save_and_load(self): + df = self.df + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.save(tmpPath, "org.apache.spark.sql.json", "error") + actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json") + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + + schema = StructType([StructField("value", StringType(), True)]) + actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json", schema) + self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect())) + + df.save(tmpPath, "org.apache.spark.sql.json", "overwrite") + actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json") + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + + df.save(source="org.apache.spark.sql.json", mode="overwrite", path=tmpPath, + noUse="this options will not be used in save.") + actual = self.sqlCtx.load(source="org.apache.spark.sql.json", path=tmpPath, + noUse="this options will not be used in load.") + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + + defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + actual = self.sqlCtx.load(path=tmpPath) + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + + shutil.rmtree(tmpPath) + + def test_help_command(self): + # Regression test for SPARK-5464 + rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) + df = self.sqlCtx.jsonRDD(rdd) + # render_doc() reproduces the help() exception without printing output + pydoc.render_doc(df) + pydoc.render_doc(df.foo) + pydoc.render_doc(df.take(1)) + + def test_access_column(self): + df = self.df + self.assertTrue(isinstance(df.key, Column)) + self.assertTrue(isinstance(df['key'], Column)) + self.assertTrue(isinstance(df[0], Column)) + self.assertRaises(IndexError, lambda: df[2]) + self.assertRaises(IndexError, lambda: df["bad_key"]) + self.assertRaises(TypeError, lambda: df[{}]) + + def test_access_nested_types(self): + df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF() + self.assertEqual(1, df.select(df.l[0]).first()[0]) + self.assertEqual(1, df.select(df.l.getItem(0)).first()[0]) + self.assertEqual(1, df.select(df.r.a).first()[0]) + self.assertEqual("b", df.select(df.r.getField("b")).first()[0]) + self.assertEqual("v", df.select(df.d["k"]).first()[0]) + self.assertEqual("v", df.select(df.d.getItem("k")).first()[0]) + + def test_infer_long_type(self): + longrow = [Row(f1='a', f2=100000000000000)] + df = self.sc.parallelize(longrow).toDF() + self.assertEqual(df.schema.fields[1].dataType, LongType()) + + # this saving as Parquet caused issues as well. + output_dir = os.path.join(self.tempdir.name, "infer_long_type") + df.saveAsParquetFile(output_dir) + df1 = self.sqlCtx.parquetFile(output_dir) + self.assertEquals('a', df1.first().f1) + self.assertEquals(100000000000000, df1.first().f2) + + self.assertEqual(_infer_type(1), LongType()) + self.assertEqual(_infer_type(2**10), LongType()) + self.assertEqual(_infer_type(2**20), LongType()) + self.assertEqual(_infer_type(2**31 - 1), LongType()) + self.assertEqual(_infer_type(2**31), LongType()) + self.assertEqual(_infer_type(2**61), LongType()) + self.assertEqual(_infer_type(2**71), LongType()) + + def test_filter_with_datetime(self): + time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000) + date = time.date() + row = Row(date=date, time=time) + df = self.sqlCtx.createDataFrame([row]) + self.assertEqual(1, df.filter(df.date == date).count()) + self.assertEqual(1, df.filter(df.time == time).count()) + self.assertEqual(0, df.filter(df.date > date).count()) + self.assertEqual(0, df.filter(df.time > time).count()) + + def test_dropna(self): + schema = StructType([ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", DoubleType(), True)]) + + # shouldn't drop a non-null row + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', 50, 80.1)], schema).dropna().count(), + 1) + + # dropping rows with a single null value + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna().count(), + 0) + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna(how='any').count(), + 0) + + # if how = 'all', only drop rows if all values are null + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna(how='all').count(), + 1) + self.assertEqual(self.sqlCtx.createDataFrame( + [(None, None, None)], schema).dropna(how='all').count(), + 0) + + # how and subset + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', 50, None)], schema).dropna(how='any', subset=['name', 'age']).count(), + 1) + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, None)], schema).dropna(how='any', subset=['name', 'age']).count(), + 0) + + # threshold + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna(thresh=2).count(), + 1) + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, None)], schema).dropna(thresh=2).count(), + 0) + + # threshold and subset + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', 50, None)], schema).dropna(thresh=2, subset=['name', 'age']).count(), + 1) + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', None, 180.9)], schema).dropna(thresh=2, subset=['name', 'age']).count(), + 0) + + # thresh should take precedence over how + self.assertEqual(self.sqlCtx.createDataFrame( + [(u'Alice', 50, None)], schema).dropna( + how='any', thresh=2, subset=['name', 'age']).count(), + 1) + + def test_fillna(self): + schema = StructType([ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", DoubleType(), True)]) + + # fillna shouldn't change non-null values + row = self.sqlCtx.createDataFrame([(u'Alice', 10, 80.1)], schema).fillna(50).first() + self.assertEqual(row.age, 10) + + # fillna with int + row = self.sqlCtx.createDataFrame([(u'Alice', None, None)], schema).fillna(50).first() + self.assertEqual(row.age, 50) + self.assertEqual(row.height, 50.0) + + # fillna with double + row = self.sqlCtx.createDataFrame([(u'Alice', None, None)], schema).fillna(50.1).first() + self.assertEqual(row.age, 50) + self.assertEqual(row.height, 50.1) + + # fillna with string + row = self.sqlCtx.createDataFrame([(None, None, None)], schema).fillna("hello").first() + self.assertEqual(row.name, u"hello") + self.assertEqual(row.age, None) + + # fillna with subset specified for numeric cols + row = self.sqlCtx.createDataFrame( + [(None, None, None)], schema).fillna(50, subset=['name', 'age']).first() + self.assertEqual(row.name, None) + self.assertEqual(row.age, 50) + self.assertEqual(row.height, None) + + # fillna with subset specified for numeric cols + row = self.sqlCtx.createDataFrame( + [(None, None, None)], schema).fillna("haha", subset=['name', 'age']).first() + self.assertEqual(row.name, "haha") + self.assertEqual(row.age, None) + self.assertEqual(row.height, None) + + +class HiveContextSQLTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + try: + cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + except py4j.protocol.Py4JError: + cls.sqlCtx = None + return + except TypeError: + cls.sqlCtx = None + return + os.unlink(cls.tempdir.name) + _scala_HiveContext =\ + cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc()) + cls.sqlCtx = HiveContext(cls.sc, _scala_HiveContext) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + cls.df = cls.sc.parallelize(cls.testData).toDF() + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name, ignore_errors=True) + + def test_save_and_load_table(self): + if self.sqlCtx is None: + return # no hive available, skipped + + df = self.df + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "append", path=tmpPath) + actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath, + "org.apache.spark.sql.json") + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + self.sqlCtx.sql("DROP TABLE externalJsonTable") + + df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "overwrite", path=tmpPath) + schema = StructType([StructField("value", StringType(), True)]) + actual = self.sqlCtx.createExternalTable("externalJsonTable", + source="org.apache.spark.sql.json", + schema=schema, path=tmpPath, + noUse="this options will not be used") + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertTrue( + sorted(df.select("value").collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect())) + self.sqlCtx.sql("DROP TABLE savedJsonTable") + self.sqlCtx.sql("DROP TABLE externalJsonTable") + + defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + df.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite") + actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath) + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + self.sqlCtx.sql("DROP TABLE savedJsonTable") + self.sqlCtx.sql("DROP TABLE externalJsonTable") + self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + + shutil.rmtree(tmpPath) + +if __name__ == "__main__": + unittest.main() diff --git a/python/pyspark/statcounter.py b/python/pyspark/statcounter.py index 1e597d64e03f..944fa414b0c0 100644 --- a/python/pyspark/statcounter.py +++ b/python/pyspark/statcounter.py @@ -31,7 +31,7 @@ class StatCounter(object): def __init__(self, values=[]): - self.n = 0L # Running count of our values + self.n = 0 # Running count of our values self.mu = 0.0 # Running mean of our values self.m2 = 0.0 # Running variance numerator (sum of (x - mean)^2) self.maxValue = float("-inf") @@ -87,7 +87,7 @@ def copy(self): return copy.deepcopy(self) def count(self): - return self.n + return int(self.n) def mean(self): return self.mu diff --git a/python/pyspark/status.py b/python/pyspark/status.py new file mode 100644 index 000000000000..a6fa7dd3144d --- /dev/null +++ b/python/pyspark/status.py @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import namedtuple + +__all__ = ["SparkJobInfo", "SparkStageInfo", "StatusTracker"] + + +class SparkJobInfo(namedtuple("SparkJobInfo", "jobId stageIds status")): + """ + Exposes information about Spark Jobs. + """ + + +class SparkStageInfo(namedtuple("SparkStageInfo", + "stageId currentAttemptId name numTasks numActiveTasks " + "numCompletedTasks numFailedTasks")): + """ + Exposes information about Spark Stages. + """ + + +class StatusTracker(object): + """ + Low-level status reporting APIs for monitoring job and stage progress. + + These APIs intentionally provide very weak consistency semantics; + consumers of these APIs should be prepared to handle empty / missing + information. For example, a job's stage ids may be known but the status + API may not have any information about the details of those stages, so + `getStageInfo` could potentially return `None` for a valid stage id. + + To limit memory usage, these APIs only provide information on recent + jobs / stages. These APIs will provide information for the last + `spark.ui.retainedStages` stages and `spark.ui.retainedJobs` jobs. + """ + def __init__(self, jtracker): + self._jtracker = jtracker + + def getJobIdsForGroup(self, jobGroup=None): + """ + Return a list of all known jobs in a particular job group. If + `jobGroup` is None, then returns all known jobs that are not + associated with a job group. + + The returned list may contain running, failed, and completed jobs, + and may vary across invocations of this method. This method does + not guarantee the order of the elements in its result. + """ + return list(self._jtracker.getJobIdsForGroup(jobGroup)) + + def getActiveStageIds(self): + """ + Returns an array containing the ids of all active stages. + """ + return sorted(list(self._jtracker.getActiveStageIds())) + + def getActiveJobsIds(self): + """ + Returns an array containing the ids of all active jobs. + """ + return sorted((list(self._jtracker.getActiveJobIds()))) + + def getJobInfo(self, jobId): + """ + Returns a :class:`SparkJobInfo` object, or None if the job info + could not be found or was garbage collected. + """ + job = self._jtracker.getJobInfo(jobId) + if job is not None: + return SparkJobInfo(jobId, job.stageIds(), str(job.status())) + + def getStageInfo(self, stageId): + """ + Returns a :class:`SparkStageInfo` object, or None if the stage + info could not be found or was garbage collected. + """ + stage = self._jtracker.getStageInfo(stageId) + if stage is not None: + # TODO: fetch them in batch for better performance + attrs = [getattr(stage, f)() for f in SparkStageInfo._fields[1:]] + return SparkStageInfo(stageId, *attrs) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index d48f3598e33b..ac5ba69e8dbb 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -14,14 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +from __future__ import print_function + import os import sys -from py4j.java_collections import ListConverter from py4j.java_gateway import java_import, JavaObject from pyspark import RDD, SparkConf -from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer +from pyspark.serializers import NoOpSerializer, UTF8Deserializer, CloudPickleSerializer from pyspark.context import SparkContext from pyspark.storagelevel import StorageLevel from pyspark.streaming.dstream import DStream @@ -157,7 +159,7 @@ def getOrCreate(cls, checkpointPath, setupFunc): try: jssc = gw.jvm.JavaStreamingContext(checkpointPath) except Exception: - print >>sys.stderr, "failed to load StreamingContext from checkpoint" + print("failed to load StreamingContext from checkpoint", file=sys.stderr) raise jsc = jssc.sparkContext() @@ -189,7 +191,16 @@ def awaitTermination(self, timeout=None): if timeout is None: self._jssc.awaitTermination() else: - self._jssc.awaitTermination(int(timeout * 1000)) + self._jssc.awaitTerminationOrTimeout(int(timeout * 1000)) + + def awaitTerminationOrTimeout(self, timeout): + """ + Wait for the execution to stop. Return `true` if it's stopped; or + throw the reported error during the execution; or `false` if the + waiting time elapsed before returning from the method. + @param timeout: time to wait in seconds + """ + self._jssc.awaitTerminationOrTimeout(int(timeout * 1000)) def stop(self, stopSparkContext=True, stopGraceFully=False): """ @@ -251,6 +262,20 @@ def textFileStream(self, directory): """ return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer()) + def binaryRecordsStream(self, directory, recordLength): + """ + Create an input stream that monitors a Hadoop-compatible file system + for new files and reads them as flat binary files with records of + fixed length. Files must be written to the monitored directory by "moving" + them from another location within the same file system. + File names starting with . are ignored. + + @param directory: Directory to load data from + @param recordLength: Length of each record in bytes + """ + return DStream(self._jssc.binaryRecordsStream(directory, recordLength), self, + NoOpSerializer()) + def _check_serializers(self, rdds): # make sure they have same serializer if len(set(rdd._jrdd_deserializer for rdd in rdds)) > 1: @@ -279,9 +304,7 @@ def queueStream(self, rdds, oneAtATime=True, default=None): rdds = [self._sc.parallelize(input) for input in rdds] self._check_serializers(rdds) - jrdds = ListConverter().convert([r._jrdd for r in rdds], - SparkContext._gateway._gateway_client) - queue = self._jvm.PythonDStream.toRDDQueue(jrdds) + queue = self._jvm.PythonDStream.toRDDQueue([r._jrdd for r in rdds]) if default: default = default._reserialize(rdds[0]._jrdd_deserializer) jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd) @@ -296,8 +319,7 @@ def transform(self, dstreams, transformFunc): the transform function parameter will be the same as the order of corresponding DStreams in the list. """ - jdstreams = ListConverter().convert([d._jdstream for d in dstreams], - SparkContext._gateway._gateway_client) + jdstreams = [d._jdstream for d in dstreams] # change the final serializer to sc.serializer func = TransformFunction(self._sc, lambda t, *rdds: transformFunc(rdds).map(lambda x: x), @@ -320,6 +342,5 @@ def union(self, *dstreams): if len(set(s._slideDuration for s in dstreams)) > 1: raise ValueError("All DStreams should have same slide duration") first = dstreams[0] - jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]], - SparkContext._gateway._gateway_client) + jrest = [d._jdstream for d in dstreams[1:]] return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 2fe39392ff08..ff097985fae3 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -15,11 +15,15 @@ # limitations under the License. # -from itertools import chain, ifilter, imap +import sys import operator import time +from itertools import chain from datetime import datetime +if sys.version < "3": + from itertools import imap as map, ifilter as filter + from py4j.protocol import Py4JJavaError from pyspark import RDD @@ -76,7 +80,7 @@ def filter(self, f): Return a new DStream containing only the elements that satisfy predicate. """ def func(iterator): - return ifilter(f, iterator) + return filter(f, iterator) return self.mapPartitions(func, True) def flatMap(self, f, preservesPartitioning=False): @@ -85,7 +89,7 @@ def flatMap(self, f, preservesPartitioning=False): this DStream, and then flattening the results """ def func(s, iterator): - return chain.from_iterable(imap(f, iterator)) + return chain.from_iterable(map(f, iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) def map(self, f, preservesPartitioning=False): @@ -93,7 +97,7 @@ def map(self, f, preservesPartitioning=False): Return a new DStream by applying a function to each element of DStream. """ def func(iterator): - return imap(f, iterator) + return map(f, iterator) return self.mapPartitions(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): @@ -150,7 +154,7 @@ def foreachRDD(self, func): """ Apply a function to each RDD in this DStream. """ - if func.func_code.co_argcount == 1: + if func.__code__.co_argcount == 1: old_func = func func = lambda t, rdd: old_func(rdd) jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer) @@ -165,14 +169,14 @@ def pprint(self, num=10): """ def takeAndPrint(time, rdd): taken = rdd.take(num + 1) - print "-------------------------------------------" - print "Time: %s" % time - print "-------------------------------------------" + print("-------------------------------------------") + print("Time: %s" % time) + print("-------------------------------------------") for record in taken[:num]: - print record + print(record) if len(taken) > num: - print "..." - print + print("...") + print() self.foreachRDD(takeAndPrint) @@ -181,7 +185,7 @@ def mapValues(self, f): Return a new DStream by applying a map function to the value of each key-value pairs in this DStream without changing the key. """ - map_values_fn = lambda (k, v): (k, f(v)) + map_values_fn = lambda kv: (kv[0], f(kv[1])) return self.map(map_values_fn, preservesPartitioning=True) def flatMapValues(self, f): @@ -189,7 +193,7 @@ def flatMapValues(self, f): Return a new DStream by applying a flatmap function to the value of each key-value pairs in this DStream without changing the key. """ - flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) + flat_map_fn = lambda kv: ((kv[0], x) for x in f(kv[1])) return self.flatMap(flat_map_fn, preservesPartitioning=True) def glom(self): @@ -286,10 +290,10 @@ def transform(self, func): `func` can have one argument of `rdd`, or have two arguments of (`time`, `rdd`) """ - if func.func_code.co_argcount == 1: + if func.__code__.co_argcount == 1: oldfunc = func func = lambda t, rdd: oldfunc(rdd) - assert func.func_code.co_argcount == 2, "func should take one or two arguments" + assert func.__code__.co_argcount == 2, "func should take one or two arguments" return TransformedDStream(self, func) def transformWith(self, func, other, keepSerializer=False): @@ -300,10 +304,10 @@ def transformWith(self, func, other, keepSerializer=False): `func` can have two arguments of (`rdd_a`, `rdd_b`) or have three arguments of (`time`, `rdd_a`, `rdd_b`) """ - if func.func_code.co_argcount == 2: + if func.__code__.co_argcount == 2: oldfunc = func func = lambda t, a, b: oldfunc(a, b) - assert func.func_code.co_argcount == 3, "func should take two or three arguments" + assert func.__code__.co_argcount == 3, "func should take two or three arguments" jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer, other._jrdd_deserializer) dstream = self._sc._jvm.PythonTransformed2DStream(self._jdstream.dstream(), other._jdstream.dstream(), jfunc) @@ -460,7 +464,7 @@ def reduceByWindow(self, reduceFunc, invReduceFunc, windowDuration, slideDuratio keyed = self.map(lambda x: (1, x)) reduced = keyed.reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration, 1) - return reduced.map(lambda (k, v): v) + return reduced.map(lambda kv: kv[1]) def countByWindow(self, windowDuration, slideDuration): """ @@ -489,7 +493,7 @@ def countByValueAndWindow(self, windowDuration, slideDuration, numPartitions=Non keyed = self.map(lambda x: (x, 1)) counted = keyed.reduceByKeyAndWindow(operator.add, operator.sub, windowDuration, slideDuration, numPartitions) - return counted.filter(lambda (k, v): v > 0).count() + return counted.filter(lambda kv: kv[1] > 0).count() def groupByKeyAndWindow(self, windowDuration, slideDuration, numPartitions=None): """ @@ -548,7 +552,8 @@ def reduceFunc(t, a, b): def invReduceFunc(t, a, b): b = b.reduceByKey(func, numPartitions) joined = a.leftOuterJoin(b, numPartitions) - return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1) + return joined.mapValues(lambda kv: invFunc(kv[0], kv[1]) + if kv[1] is not None else kv[0]) jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer) if invReduceFunc: @@ -578,10 +583,10 @@ def reduceFunc(t, a, b): if a is None: g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None)) else: - g = a.cogroup(b, numPartitions) - g = g.mapValues(lambda (va, vb): (list(vb), list(va)[0] if len(va) else None)) - state = g.mapValues(lambda (vs, s): updateFunc(vs, s)) - return state.filter(lambda (k, v): v is not None) + g = a.cogroup(b.partitionBy(numPartitions), numPartitions) + g = g.mapValues(lambda ab: (list(ab[1]), list(ab[0])[0] if len(ab[0]) else None)) + state = g.mapValues(lambda vs_s: updateFunc(vs_s[0], vs_s[1])) + return state.filter(lambda k_v: k_v[1] is not None) jreduceFunc = TransformFunction(self._sc, reduceFunc, self._sc.serializer, self._jrdd_deserializer) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py new file mode 100644 index 000000000000..8d610d6569b4 --- /dev/null +++ b/python/pyspark/streaming/kafka.py @@ -0,0 +1,92 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from py4j.java_gateway import Py4JJavaError + +from pyspark.storagelevel import StorageLevel +from pyspark.serializers import PairDeserializer, NoOpSerializer +from pyspark.streaming import DStream + +__all__ = ['KafkaUtils', 'utf8_decoder'] + + +def utf8_decoder(s): + """ Decode the unicode as UTF-8 """ + return s and s.decode('utf-8') + + +class KafkaUtils(object): + + @staticmethod + def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, + storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, + keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): + """ + Create an input stream that pulls messages from a Kafka Broker. + + :param ssc: StreamingContext object + :param zkQuorum: Zookeeper quorum (hostname:port,hostname:port,..). + :param groupId: The group id for this consumer. + :param topics: Dict of (topic_name -> numPartitions) to consume. + Each partition is consumed in its own thread. + :param kafkaParams: Additional params for Kafka + :param storageLevel: RDD storage level. + :param keyDecoder: A function used to decode key (default is utf8_decoder) + :param valueDecoder: A function used to decode value (default is utf8_decoder) + :return: A DStream object + """ + kafkaParams.update({ + "zookeeper.connect": zkQuorum, + "group.id": groupId, + "zookeeper.connection.timeout.ms": "10000", + }) + if not isinstance(topics, dict): + raise TypeError("topics should be dict") + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + + try: + # Use KafkaUtilsPythonHelper to access Scala's KafkaUtils (see SPARK-6027) + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ + .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") + helper = helperClass.newInstance() + jstream = helper.createStream(ssc._jssc, kafkaParams, topics, jlevel) + except Py4JJavaError as e: + # TODO: use --jar once it also work on driver + if 'ClassNotFoundException' in str(e.java_exception): + print(""" +________________________________________________________________________________________________ + + Spark Streaming's Kafka libraries not found in class path. Try one of the following. + + 1. Include the Kafka library and its dependencies with in the + spark-submit command as + + $ bin/spark-submit --packages org.apache.spark:spark-streaming-kafka:%s ... + + 2. Download the JAR of the artifact from Maven Central http://search.maven.org/, + Group Id = org.apache.spark, Artifact Id = spark-streaming-kafka-assembly, Version = %s. + Then, include the jar in the spark-submit command as + + $ bin/spark-submit --jars ... + +________________________________________________________________________________________________ + +""" % (ssc.sparkContext.version, ssc.sparkContext.version)) + raise e + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + stream = DStream(jstream, ssc, ser) + return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index a8d876d0fa3b..5fa1e5ef081a 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -16,38 +16,56 @@ # import os +import sys from itertools import chain import time import operator -import unittest import tempfile +import struct +from functools import reduce + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest from pyspark.context import SparkConf, SparkContext, RDD from pyspark.streaming.context import StreamingContext +from pyspark.streaming.kafka import KafkaUtils class PySparkStreamingTestCase(unittest.TestCase): - timeout = 10 # seconds - duration = 1 + timeout = 4 # seconds + duration = .2 - def setUp(self): - class_name = self.__class__.__name__ + @classmethod + def setUpClass(cls): + class_name = cls.__name__ conf = SparkConf().set("spark.default.parallelism", 1) - self.sc = SparkContext(appName=class_name, conf=conf) - self.sc.setCheckpointDir("/tmp") - # TODO: decrease duration to speed up tests + cls.sc = SparkContext(appName=class_name, conf=conf) + cls.sc.setCheckpointDir("/tmp") + + @classmethod + def tearDownClass(cls): + cls.sc.stop() + + def setUp(self): self.ssc = StreamingContext(self.sc, self.duration) def tearDown(self): - self.ssc.stop() + self.ssc.stop(False) def wait_for(self, result, n): start_time = time.time() while len(result) < n and time.time() - start_time < self.timeout: time.sleep(0.01) if len(result) < n: - print "timeout after", self.timeout + print("timeout after", self.timeout) def _take(self, dstream, n): """ @@ -127,7 +145,7 @@ def test_map(self): def func(dstream): return dstream.map(str) - expected = map(lambda x: map(str, x), input) + expected = [list(map(str, x)) for x in input] self._test_func(input, func, expected) def test_flatMap(self): @@ -136,8 +154,8 @@ def test_flatMap(self): def func(dstream): return dstream.flatMap(lambda x: (x, x * 2)) - expected = map(lambda x: list(chain.from_iterable((map(lambda y: [y, y * 2], x)))), - input) + expected = [list(chain.from_iterable((map(lambda y: [y, y * 2], x)))) + for x in input] self._test_func(input, func, expected) def test_filter(self): @@ -146,7 +164,7 @@ def test_filter(self): def func(dstream): return dstream.filter(lambda x: x % 2 == 0) - expected = map(lambda x: filter(lambda y: y % 2 == 0, x), input) + expected = [[y for y in x if y % 2 == 0] for x in input] self._test_func(input, func, expected) def test_count(self): @@ -155,7 +173,7 @@ def test_count(self): def func(dstream): return dstream.count() - expected = map(lambda x: [len(x)], input) + expected = [[len(x)] for x in input] self._test_func(input, func, expected) def test_reduce(self): @@ -164,7 +182,7 @@ def test_reduce(self): def func(dstream): return dstream.reduce(operator.add) - expected = map(lambda x: [reduce(operator.add, x)], input) + expected = [[reduce(operator.add, x)] for x in input] self._test_func(input, func, expected) def test_reduceByKey(self): @@ -181,27 +199,27 @@ def func(dstream): def test_mapValues(self): """Basic operation test for DStream.mapValues.""" input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)], - [("", 4), (1, 1), (2, 2), (3, 3)], + [(0, 4), (1, 1), (2, 2), (3, 3)], [(1, 1), (2, 1), (3, 1), (4, 1)]] def func(dstream): return dstream.mapValues(lambda x: x + 10) expected = [[("a", 12), ("b", 12), ("c", 11), ("d", 11)], - [("", 14), (1, 11), (2, 12), (3, 13)], + [(0, 14), (1, 11), (2, 12), (3, 13)], [(1, 11), (2, 11), (3, 11), (4, 11)]] self._test_func(input, func, expected, sort=True) def test_flatMapValues(self): """Basic operation test for DStream.flatMapValues.""" input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)], - [("", 4), (1, 1), (2, 1), (3, 1)], + [(0, 4), (1, 1), (2, 1), (3, 1)], [(1, 1), (2, 1), (3, 1), (4, 1)]] def func(dstream): return dstream.flatMapValues(lambda x: (x, x + 10)) expected = [[("a", 2), ("a", 12), ("b", 2), ("b", 12), ("c", 1), ("c", 11), ("d", 1), ("d", 11)], - [("", 4), ("", 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)], + [(0, 4), (0, 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)], [(1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11), (4, 1), (4, 11)]] self._test_func(input, func, expected) @@ -229,7 +247,7 @@ def f(iterator): def test_countByValue(self): """Basic operation test for DStream.countByValue.""" - input = [range(1, 5) * 2, range(5, 7) + range(5, 9), ["a", "a", "b", ""]] + input = [list(range(1, 5)) * 2, list(range(5, 7)) + list(range(5, 9)), ["a", "a", "b", ""]] def func(dstream): return dstream.countByValue() @@ -281,7 +299,7 @@ def test_union(self): def func(d1, d2): return d1.union(d2) - expected = [range(6), range(6), range(6)] + expected = [list(range(6)), list(range(6)), list(range(6))] self._test_func(input1, func, expected, input2=input2) def test_cogroup(self): @@ -360,13 +378,13 @@ def func(dstream): class WindowFunctionTests(PySparkStreamingTestCase): - timeout = 20 + timeout = 5 def test_window(self): input = [range(1), range(2), range(3), range(4), range(5)] def func(dstream): - return dstream.window(3, 1).count() + return dstream.window(.6, .2).count() expected = [[1], [3], [6], [9], [12], [9], [5]] self._test_func(input, func, expected) @@ -375,7 +393,7 @@ def test_count_by_window(self): input = [range(1), range(2), range(3), range(4), range(5)] def func(dstream): - return dstream.countByWindow(3, 1) + return dstream.countByWindow(.6, .2) expected = [[1], [3], [6], [9], [12], [9], [5]] self._test_func(input, func, expected) @@ -384,7 +402,7 @@ def test_count_by_window_large(self): input = [range(1), range(2), range(3), range(4), range(5), range(6)] def func(dstream): - return dstream.countByWindow(5, 1) + return dstream.countByWindow(1, .2) expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]] self._test_func(input, func, expected) @@ -393,7 +411,7 @@ def test_count_by_value_and_window(self): input = [range(1), range(2), range(3), range(4), range(5), range(6)] def func(dstream): - return dstream.countByValueAndWindow(5, 1) + return dstream.countByValueAndWindow(1, .2) expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]] self._test_func(input, func, expected) @@ -402,7 +420,7 @@ def test_group_by_key_and_window(self): input = [[('a', i)] for i in range(5)] def func(dstream): - return dstream.groupByKeyAndWindow(3, 1).mapValues(list) + return dstream.groupByKeyAndWindow(.6, .2).mapValues(list) expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])], [('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]] @@ -420,7 +438,7 @@ class StreamingContextTests(PySparkStreamingTestCase): duration = 0.1 def _add_input_stream(self): - inputs = map(lambda x: range(1, x), range(101)) + inputs = [range(1, x) for x in range(101)] stream = self.ssc.queueStream(inputs) self._collect(stream, 1, block=False) @@ -433,11 +451,11 @@ def test_stop_only_streaming_context(self): def test_stop_multiple_times(self): self._add_input_stream() self.ssc.start() - self.ssc.stop() - self.ssc.stop() + self.ssc.stop(False) + self.ssc.stop(False) def test_queue_stream(self): - input = [range(i + 1) for i in range(3)] + input = [list(range(i + 1)) for i in range(3)] dstream = self.ssc.queueStream(input) result = self._collect(dstream, 3) self.assertEqual(input, result) @@ -453,10 +471,24 @@ def test_text_file_stream(self): with open(os.path.join(d, name), "w") as f: f.writelines(["%d\n" % i for i in range(10)]) self.wait_for(result, 2) - self.assertEqual([range(10), range(10)], result) + self.assertEqual([list(range(10)), list(range(10))], result) + + def test_binary_records_stream(self): + d = tempfile.mkdtemp() + self.ssc = StreamingContext(self.sc, self.duration) + dstream = self.ssc.binaryRecordsStream(d, 10).map( + lambda v: struct.unpack("10b", bytes(v))) + result = self._collect(dstream, 2, block=False) + self.ssc.start() + for name in ('a', 'b'): + time.sleep(1) + with open(os.path.join(d, name), "wb") as f: + f.write(bytearray(range(10))) + self.wait_for(result, 2) + self.assertEqual([list(range(10)), list(range(10))], [list(v[0]) for v in result]) def test_union(self): - input = [range(i + 1) for i in range(3)] + input = [list(range(i + 1)) for i in range(3)] dstream = self.ssc.queueStream(input) dstream2 = self.ssc.queueStream(input) dstream3 = self.ssc.union(dstream, dstream2) @@ -478,10 +510,7 @@ def func(rdds): self.assertEqual([2, 3, 1], self._take(dstream, 3)) -class CheckpointTests(PySparkStreamingTestCase): - - def setUp(self): - pass +class CheckpointTests(unittest.TestCase): def test_get_or_create(self): inputd = tempfile.mkdtemp() @@ -501,12 +530,12 @@ def setup(): return ssc cpd = tempfile.mkdtemp("test_streaming_cps") - self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup) + ssc = StreamingContext.getOrCreate(cpd, setup) ssc.start() def check_output(n): while not os.listdir(outputd): - time.sleep(0.1) + time.sleep(0.01) time.sleep(1) # make sure mtime is larger than the previous one with open(os.path.join(inputd, str(n)), 'w') as f: f.writelines(["%d\n" % i for i in range(10)]) @@ -536,10 +565,49 @@ def check_output(n): ssc.stop(True, True) time.sleep(1) - self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup) + ssc = StreamingContext.getOrCreate(cpd, setup) ssc.start() check_output(3) + ssc.stop(True, True) + + +class KafkaStreamTests(PySparkStreamingTestCase): + timeout = 20 # seconds + duration = 1 + + def setUp(self): + super(KafkaStreamTests, self).setUp() + + kafkaTestUtilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ + .loadClass("org.apache.spark.streaming.kafka.KafkaTestUtils") + self._kafkaTestUtils = kafkaTestUtilsClz.newInstance() + self._kafkaTestUtils.setup() + + def tearDown(self): + if self._kafkaTestUtils is not None: + self._kafkaTestUtils.teardown() + self._kafkaTestUtils = None + + super(KafkaStreamTests, self).tearDown() + + def test_kafka_stream(self): + """Test the Python Kafka stream API.""" + topic = "topic1" + sendData = {"a": 3, "b": 5, "c": 10} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(), + "test-streaming-consumer", {topic: 1}, + {"auto.offset.reset": "smallest"}) + + result = {} + for i in chain.from_iterable(self._collect(stream.map(lambda x: x[1]), + sum(sendData.values()))): + result[i] = result.get(i, 0) + 1 + self.assertEqual(sendData, result) if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index 86ee5aa04f25..34291f30a565 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -91,9 +91,9 @@ def dumps(self, id): except Exception: traceback.print_exc() - def loads(self, bytes): + def loads(self, data): try: - f, deserializers = self.serializer.loads(str(bytes)) + f, deserializers = self.serializer.loads(bytes(data)) return TransformFunction(self.ctx, f, *deserializers) except Exception: traceback.print_exc() @@ -116,7 +116,7 @@ def rddToFileName(prefix, suffix, timestamp): """ if isinstance(timestamp, datetime): seconds = time.mktime(timestamp.timetuple()) - timestamp = long(seconds * 1000) + timestamp.microsecond / 1000 + timestamp = int(seconds * 1000) + timestamp.microsecond // 1000 if suffix is None: return prefix + "-" + str(timestamp) else: diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index b474fcf5bfb7..ea63a396da5b 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -19,8 +19,8 @@ Unit tests for PySpark; additional tests are implemented as doctests in individual modules. """ + from array import array -from fileinput import input from glob import glob import os import re @@ -34,6 +34,8 @@ import threading import hashlib +from py4j.protocol import Py4JJavaError + if sys.version_info[:2] <= (2, 6): try: import unittest2 as unittest @@ -42,17 +44,27 @@ sys.exit(1) else: import unittest + if sys.version_info[0] >= 3: + xrange = range + basestring = str + +if sys.version >= "3": + from io import StringIO +else: + from StringIO import StringIO from pyspark.conf import SparkConf from pyspark.context import SparkContext +from pyspark.rdd import RDD from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ - CloudPickleSerializer, CompressedSerializer + CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer, \ + PairDeserializer, CartesianDeserializer, AutoBatchedSerializer, AutoSerializer, \ + FlattenedValuesSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter -from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ - UserDefinedType, DoubleType from pyspark import shuffle +from pyspark.profiler import BasicProfiler _have_scipy = False _have_numpy = False @@ -76,9 +88,9 @@ class MergerTests(unittest.TestCase): def setUp(self): - self.N = 1 << 14 + self.N = 1 << 12 self.l = [i for i in xrange(self.N)] - self.data = zip(self.l, self.l) + self.data = list(zip(self.l, self.l)) self.agg = Aggregator(lambda x: [x], lambda x, y: x.append(y) or x, lambda x, y: x.extend(y) or x) @@ -86,84 +98,110 @@ def setUp(self): def test_in_memory(self): m = InMemoryMerger(self.agg) m.mergeValues(self.data) - self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + self.assertEqual(sum(sum(v) for k, v in m.items()), sum(xrange(self.N))) m = InMemoryMerger(self.agg) - m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data)) - self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + m.mergeCombiners(map(lambda x_y: (x_y[0], [x_y[1]]), self.data)) + self.assertEqual(sum(sum(v) for k, v in m.items()), sum(xrange(self.N))) def test_small_dataset(self): m = ExternalMerger(self.agg, 1000) m.mergeValues(self.data) self.assertEqual(m.spills, 0) - self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + self.assertEqual(sum(sum(v) for k, v in m.items()), sum(xrange(self.N))) m = ExternalMerger(self.agg, 1000) - m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data)) + m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), self.data)) self.assertEqual(m.spills, 0) - self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + self.assertEqual(sum(sum(v) for k, v in m.items()), sum(xrange(self.N))) def test_medium_dataset(self): - m = ExternalMerger(self.agg, 10) + m = ExternalMerger(self.agg, 20) m.mergeValues(self.data) self.assertTrue(m.spills >= 1) - self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + self.assertEqual(sum(sum(v) for k, v in m.items()), sum(xrange(self.N))) m = ExternalMerger(self.agg, 10) - m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data * 3)) + m.mergeCombiners(map(lambda x_y2: (x_y2[0], [x_y2[1]]), self.data * 3)) self.assertTrue(m.spills >= 1) - self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + self.assertEqual(sum(sum(v) for k, v in m.items()), sum(xrange(self.N)) * 3) def test_huge_dataset(self): - m = ExternalMerger(self.agg, 10, partitions=3) - m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10)) + m = ExternalMerger(self.agg, 5, partitions=3) + m.mergeCombiners(map(lambda k_v: (k_v[0], [str(k_v[1])]), self.data * 10)) self.assertTrue(m.spills >= 1) - self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)), + self.assertEqual(sum(len(v) for k, v in m.items()), self.N * 10) m._cleanup() + def test_group_by_key(self): + + def gen_data(N, step): + for i in range(1, N + 1, step): + for j in range(i): + yield (i, [j]) + + def gen_gs(N, step=1): + return shuffle.GroupByKey(gen_data(N, step)) + + self.assertEqual(1, len(list(gen_gs(1)))) + self.assertEqual(2, len(list(gen_gs(2)))) + self.assertEqual(100, len(list(gen_gs(100)))) + self.assertEqual(list(range(1, 101)), [k for k, _ in gen_gs(100)]) + self.assertTrue(all(list(range(k)) == list(vs) for k, vs in gen_gs(100))) + + for k, vs in gen_gs(50002, 10000): + self.assertEqual(k, len(vs)) + self.assertEqual(list(range(k)), list(vs)) + + ser = PickleSerializer() + l = ser.loads(ser.dumps(list(gen_gs(50002, 30000)))) + for k, vs in l: + self.assertEqual(k, len(vs)) + self.assertEqual(list(range(k)), list(vs)) + class SorterTests(unittest.TestCase): def test_in_memory_sort(self): - l = range(1024) + l = list(range(1024)) random.shuffle(l) sorter = ExternalSorter(1024) - self.assertEquals(sorted(l), list(sorter.sorted(l))) - self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) - self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) - self.assertEquals(sorted(l, key=lambda x: -x, reverse=True), - list(sorter.sorted(l, key=lambda x: -x, reverse=True))) + self.assertEqual(sorted(l), list(sorter.sorted(l))) + self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) + self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) + self.assertEqual(sorted(l, key=lambda x: -x, reverse=True), + list(sorter.sorted(l, key=lambda x: -x, reverse=True))) def test_external_sort(self): - l = range(1024) + l = list(range(1024)) random.shuffle(l) sorter = ExternalSorter(1) - self.assertEquals(sorted(l), list(sorter.sorted(l))) + self.assertEqual(sorted(l), list(sorter.sorted(l))) self.assertGreater(shuffle.DiskBytesSpilled, 0) last = shuffle.DiskBytesSpilled - self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) + self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) self.assertGreater(shuffle.DiskBytesSpilled, last) last = shuffle.DiskBytesSpilled - self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) + self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) self.assertGreater(shuffle.DiskBytesSpilled, last) last = shuffle.DiskBytesSpilled - self.assertEquals(sorted(l, key=lambda x: -x, reverse=True), - list(sorter.sorted(l, key=lambda x: -x, reverse=True))) + self.assertEqual(sorted(l, key=lambda x: -x, reverse=True), + list(sorter.sorted(l, key=lambda x: -x, reverse=True))) self.assertGreater(shuffle.DiskBytesSpilled, last) def test_external_sort_in_rdd(self): conf = SparkConf().set("spark.python.worker.memory", "1m") sc = SparkContext(conf=conf) - l = range(10240) + l = list(range(10240)) random.shuffle(l) - rdd = sc.parallelize(l, 10) - self.assertEquals(sorted(l), rdd.sortBy(lambda x: x).collect()) + rdd = sc.parallelize(l, 4) + self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) sc.stop() @@ -171,11 +209,11 @@ class SerializationTestCase(unittest.TestCase): def test_namedtuple(self): from collections import namedtuple - from cPickle import dumps, loads + from pickle import dumps, loads P = namedtuple("P", "x y") p1 = P(1, 3) p2 = loads(dumps(p1, 2)) - self.assertEquals(p1, p2) + self.assertEqual(p1, p2) def test_itemgetter(self): from operator import itemgetter @@ -217,7 +255,7 @@ def test_pickling_file_handles(self): ser = CloudPickleSerializer() out1 = sys.stderr out2 = ser.loads(ser.dumps(out1)) - self.assertEquals(out1, out2) + self.assertEqual(out1, out2) def test_func_globals(self): @@ -234,19 +272,48 @@ def __reduce__(self): def foo(): sys.exit(0) - self.assertTrue("exit" in foo.func_code.co_names) + self.assertTrue("exit" in foo.__code__.co_names) ser.dumps(foo) def test_compressed_serializer(self): ser = CompressedSerializer(PickleSerializer()) - from StringIO import StringIO + try: + from StringIO import StringIO + except ImportError: + from io import BytesIO as StringIO io = StringIO() ser.dump_stream(["abc", u"123", range(5)], io) io.seek(0) self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io))) ser.dump_stream(range(1000), io) io.seek(0) - self.assertEqual(["abc", u"123", range(5)] + range(1000), list(ser.load_stream(io))) + self.assertEqual(["abc", u"123", range(5)] + list(range(1000)), list(ser.load_stream(io))) + io.close() + + def test_hash_serializer(self): + hash(NoOpSerializer()) + hash(UTF8Deserializer()) + hash(PickleSerializer()) + hash(MarshalSerializer()) + hash(AutoSerializer()) + hash(BatchedSerializer(PickleSerializer())) + hash(AutoBatchedSerializer(MarshalSerializer())) + hash(PairDeserializer(NoOpSerializer(), UTF8Deserializer())) + hash(CartesianDeserializer(NoOpSerializer(), UTF8Deserializer())) + hash(CompressedSerializer(PickleSerializer())) + hash(FlattenedValuesSerializer(PickleSerializer())) + + +class QuietTest(object): + def __init__(self, sc): + self.log4j = sc._jvm.org.apache.log4j + + def __enter__(self): + self.old_level = self.log4j.LogManager.getRootLogger().getLevel() + self.log4j.LogManager.getRootLogger().setLevel(self.log4j.Level.FATAL) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.log4j.LogManager.getRootLogger().setLevel(self.old_level) class PySparkTestCase(unittest.TestCase): @@ -311,7 +378,7 @@ def test_checkpoint_and_restore(self): self.assertTrue(flatMappedRDD.getCheckpointFile() is not None) recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(), flatMappedRDD._jrdd_deserializer) - self.assertEquals([1, 2, 3, 4], recovered.collect()) + self.assertEqual([1, 2, 3, 4], recovered.collect()) class AddFileTests(PySparkTestCase): @@ -320,16 +387,11 @@ def test_add_py_file(self): # To ensure that we're actually testing addPyFile's effects, check that # this job fails due to `userlibrary` not being on the Python path: # disable logging in log4j temporarily - log4j = self.sc._jvm.org.apache.log4j - old_level = log4j.LogManager.getRootLogger().getLevel() - log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL) - def func(x): from userlibrary import UserClass return UserClass().hello() - self.assertRaises(Exception, - self.sc.parallelize(range(2)).map(func).first) - log4j.LogManager.getRootLogger().setLevel(old_level) + with QuietTest(self.sc): + self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first) # Add the file, so the job should now succeed: path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") @@ -343,7 +405,7 @@ def test_add_file_locally(self): download_path = SparkFiles.get("hello.txt") self.assertNotEqual(path, download_path) with open(download_path) as test_file: - self.assertEquals("Hello World!\n", test_file.readline()) + self.assertEqual("Hello World!\n", test_file.readline()) def test_add_py_file_locally(self): # To ensure that we're actually testing addPyFile's effects, check that @@ -352,7 +414,7 @@ def func(): from userlibrary import UserClass self.assertRaises(ImportError, func) path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") - self.sc.addFile(path) + self.sc.addPyFile(path) from userlibrary import UserClass self.assertEqual("Hello World!", UserClass().hello()) @@ -362,7 +424,7 @@ def test_add_egg_file_locally(self): def func(): from userlib import UserClass self.assertRaises(ImportError, func) - path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1-py2.7.egg") + path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1.zip") self.sc.addPyFile(path) from userlib import UserClass self.assertEqual("Hello World from inside a package!", UserClass().hello()) @@ -398,8 +460,9 @@ def test_save_as_textfile_with_unicode(self): tempFile = tempfile.NamedTemporaryFile(delete=True) tempFile.close() data.saveAsTextFile(tempFile.name) - raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*"))) - self.assertEqual(x, unicode(raw_contents.strip(), "utf-8")) + raw_contents = b''.join(open(p, 'rb').read() + for p in glob(tempFile.name + "/part-0000*")) + self.assertEqual(x, raw_contents.strip().decode("utf-8")) def test_save_as_textfile_with_utf8(self): x = u"\u00A1Hola, mundo!" @@ -407,19 +470,20 @@ def test_save_as_textfile_with_utf8(self): tempFile = tempfile.NamedTemporaryFile(delete=True) tempFile.close() data.saveAsTextFile(tempFile.name) - raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*"))) - self.assertEqual(x, unicode(raw_contents.strip(), "utf-8")) + raw_contents = b''.join(open(p, 'rb').read() + for p in glob(tempFile.name + "/part-0000*")) + self.assertEqual(x, raw_contents.strip().decode('utf8')) def test_transforming_cartesian_result(self): # Regression test for SPARK-1034 rdd1 = self.sc.parallelize([1, 2]) rdd2 = self.sc.parallelize([3, 4]) cart = rdd1.cartesian(rdd2) - result = cart.map(lambda (x, y): x + y).collect() + result = cart.map(lambda x_y3: x_y3[0] + x_y3[1]).collect() def test_transforming_pickle_file(self): # Regression test for SPARK-2601 - data = self.sc.parallelize(["Hello", "World!"]) + data = self.sc.parallelize([u"Hello", u"World!"]) tempFile = tempfile.NamedTemporaryFile(delete=True) tempFile.close() data.saveAsPickleFile(tempFile.name) @@ -432,19 +496,20 @@ def test_cartesian_on_textfile(self): a = self.sc.textFile(path) result = a.cartesian(a).collect() (x, y) = result[0] - self.assertEqual("Hello World!", x.strip()) - self.assertEqual("Hello World!", y.strip()) + self.assertEqual(u"Hello World!", x.strip()) + self.assertEqual(u"Hello World!", y.strip()) def test_deleting_input_files(self): # Regression test for SPARK-1025 tempFile = tempfile.NamedTemporaryFile(delete=False) - tempFile.write("Hello World!") + tempFile.write(b"Hello World!") tempFile.close() data = self.sc.textFile(tempFile.name) filtered_data = data.filter(lambda x: True) self.assertEqual(1, filtered_data.count()) os.unlink(tempFile.name) - self.assertRaises(Exception, lambda: filtered_data.count()) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: filtered_data.count()) def test_sampling_default_seed(self): # Test for SPARK-3995 (default seed setting) @@ -481,21 +546,21 @@ def test_namedtuple_in_rdd(self): jon = Person(1, "Jon", "Doe") jane = Person(2, "Jane", "Doe") theDoes = self.sc.parallelize([jon, jane]) - self.assertEquals([jon, jane], theDoes.collect()) + self.assertEqual([jon, jane], theDoes.collect()) def test_large_broadcast(self): - N = 100000 + N = 10000 data = [[float(i) for i in range(300)] for i in range(N)] - bdata = self.sc.broadcast(data) # 270MB + bdata = self.sc.broadcast(data) # 27MB m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() - self.assertEquals(N, m) + self.assertEqual(N, m) def test_multiple_broadcasts(self): N = 1 << 21 b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM - r = range(1 << 15) + r = list(range(1 << 15)) random.shuffle(r) - s = str(r) + s = str(r).encode() checksum = hashlib.md5(s).hexdigest() b2 = self.sc.broadcast(s) r = list(set(self.sc.parallelize(range(10), 10).map( @@ -506,7 +571,7 @@ def test_multiple_broadcasts(self): self.assertEqual(checksum, csum) random.shuffle(r) - s = str(r) + s = str(r).encode() checksum = hashlib.md5(s).hexdigest() b2 = self.sc.broadcast(s) r = list(set(self.sc.parallelize(range(10), 10).map( @@ -517,14 +582,12 @@ def test_multiple_broadcasts(self): self.assertEqual(checksum, csum) def test_large_closure(self): - N = 1000000 + N = 200000 data = [float(i) for i in xrange(N)] rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data)) - self.assertEquals(N, rdd.first()) - self.assertTrue(rdd._broadcast is not None) - rdd = self.sc.parallelize(range(1), 1).map(lambda x: 1) - self.assertEqual(1, rdd.first()) - self.assertTrue(rdd._broadcast is None) + self.assertEqual(N, rdd.first()) + # regression test for SPARK-6886 + self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count()) def test_zip_with_different_serializers(self): a = self.sc.parallelize(range(5)) @@ -543,29 +606,36 @@ def test_zip_with_different_serializers(self): # regression test for bug in _reserializer() self.assertEqual(cnt, t.zip(rdd).count()) + def test_zip_with_different_object_sizes(self): + # regress test for SPARK-5973 + a = self.sc.parallelize(range(10000)).map(lambda i: '*' * i) + b = self.sc.parallelize(range(10000, 20000)).map(lambda i: '*' * i) + self.assertEqual(10000, a.zip(b).count()) + def test_zip_with_different_number_of_items(self): a = self.sc.parallelize(range(5), 2) # different number of partitions b = self.sc.parallelize(range(100, 106), 3) self.assertRaises(ValueError, lambda: a.zip(b)) - # different number of batched items in JVM - b = self.sc.parallelize(range(100, 104), 2) - self.assertRaises(Exception, lambda: a.zip(b).count()) - # different number of items in one pair - b = self.sc.parallelize(range(100, 106), 2) - self.assertRaises(Exception, lambda: a.zip(b).count()) - # same total number of items, but different distributions - a = self.sc.parallelize([2, 3], 2).flatMap(range) - b = self.sc.parallelize([3, 2], 2).flatMap(range) - self.assertEquals(a.count(), b.count()) - self.assertRaises(Exception, lambda: a.zip(b).count()) + with QuietTest(self.sc): + # different number of batched items in JVM + b = self.sc.parallelize(range(100, 104), 2) + self.assertRaises(Exception, lambda: a.zip(b).count()) + # different number of items in one pair + b = self.sc.parallelize(range(100, 106), 2) + self.assertRaises(Exception, lambda: a.zip(b).count()) + # same total number of items, but different distributions + a = self.sc.parallelize([2, 3], 2).flatMap(range) + b = self.sc.parallelize([3, 2], 2).flatMap(range) + self.assertEqual(a.count(), b.count()) + self.assertRaises(Exception, lambda: a.zip(b).count()) def test_count_approx_distinct(self): rdd = self.sc.parallelize(range(1000)) - self.assertTrue(950 < rdd.countApproxDistinct(0.04) < 1050) - self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.04) < 1050) - self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.04) < 1050) - self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.04) < 1050) + self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050) + self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050) + self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050) + self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.03) < 1050) rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7) self.assertTrue(18 < rdd.countApproxDistinct() < 22) @@ -579,59 +649,59 @@ def test_count_approx_distinct(self): def test_histogram(self): # empty rdd = self.sc.parallelize([]) - self.assertEquals([0], rdd.histogram([0, 10])[1]) - self.assertEquals([0, 0], rdd.histogram([0, 4, 10])[1]) + self.assertEqual([0], rdd.histogram([0, 10])[1]) + self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1]) self.assertRaises(ValueError, lambda: rdd.histogram(1)) # out of range rdd = self.sc.parallelize([10.01, -0.01]) - self.assertEquals([0], rdd.histogram([0, 10])[1]) - self.assertEquals([0, 0], rdd.histogram((0, 4, 10))[1]) + self.assertEqual([0], rdd.histogram([0, 10])[1]) + self.assertEqual([0, 0], rdd.histogram((0, 4, 10))[1]) # in range with one bucket rdd = self.sc.parallelize(range(1, 5)) - self.assertEquals([4], rdd.histogram([0, 10])[1]) - self.assertEquals([3, 1], rdd.histogram([0, 4, 10])[1]) + self.assertEqual([4], rdd.histogram([0, 10])[1]) + self.assertEqual([3, 1], rdd.histogram([0, 4, 10])[1]) # in range with one bucket exact match - self.assertEquals([4], rdd.histogram([1, 4])[1]) + self.assertEqual([4], rdd.histogram([1, 4])[1]) # out of range with two buckets rdd = self.sc.parallelize([10.01, -0.01]) - self.assertEquals([0, 0], rdd.histogram([0, 5, 10])[1]) + self.assertEqual([0, 0], rdd.histogram([0, 5, 10])[1]) # out of range with two uneven buckets rdd = self.sc.parallelize([10.01, -0.01]) - self.assertEquals([0, 0], rdd.histogram([0, 4, 10])[1]) + self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1]) # in range with two buckets rdd = self.sc.parallelize([1, 2, 3, 5, 6]) - self.assertEquals([3, 2], rdd.histogram([0, 5, 10])[1]) + self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1]) # in range with two bucket and None rdd = self.sc.parallelize([1, 2, 3, 5, 6, None, float('nan')]) - self.assertEquals([3, 2], rdd.histogram([0, 5, 10])[1]) + self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1]) # in range with two uneven buckets rdd = self.sc.parallelize([1, 2, 3, 5, 6]) - self.assertEquals([3, 2], rdd.histogram([0, 5, 11])[1]) + self.assertEqual([3, 2], rdd.histogram([0, 5, 11])[1]) # mixed range with two uneven buckets rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01]) - self.assertEquals([4, 3], rdd.histogram([0, 5, 11])[1]) + self.assertEqual([4, 3], rdd.histogram([0, 5, 11])[1]) # mixed range with four uneven buckets rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1]) - self.assertEquals([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) + self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) # mixed range with uneven buckets and NaN rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1, None, float('nan')]) - self.assertEquals([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) + self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) # out of range with infinite buckets rdd = self.sc.parallelize([10.01, -0.01, float('nan'), float("inf")]) - self.assertEquals([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1]) + self.assertEqual([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1]) # invalid buckets self.assertRaises(ValueError, lambda: rdd.histogram([])) @@ -641,25 +711,25 @@ def test_histogram(self): # without buckets rdd = self.sc.parallelize(range(1, 5)) - self.assertEquals(([1, 4], [4]), rdd.histogram(1)) + self.assertEqual(([1, 4], [4]), rdd.histogram(1)) # without buckets single element rdd = self.sc.parallelize([1]) - self.assertEquals(([1, 1], [1]), rdd.histogram(1)) + self.assertEqual(([1, 1], [1]), rdd.histogram(1)) # without bucket no range rdd = self.sc.parallelize([1] * 4) - self.assertEquals(([1, 1], [4]), rdd.histogram(1)) + self.assertEqual(([1, 1], [4]), rdd.histogram(1)) # without buckets basic two rdd = self.sc.parallelize(range(1, 5)) - self.assertEquals(([1, 2.5, 4], [2, 2]), rdd.histogram(2)) + self.assertEqual(([1, 2.5, 4], [2, 2]), rdd.histogram(2)) # without buckets with more requested than elements rdd = self.sc.parallelize([1, 2]) buckets = [1 + 0.2 * i for i in range(6)] hist = [1, 0, 0, 0, 1] - self.assertEquals((buckets, hist), rdd.histogram(5)) + self.assertEqual((buckets, hist), rdd.histogram(5)) # invalid RDDs rdd = self.sc.parallelize([1, float('inf')]) @@ -669,15 +739,8 @@ def test_histogram(self): # string rdd = self.sc.parallelize(["ab", "ac", "b", "bd", "ef"], 2) - self.assertEquals([2, 2], rdd.histogram(["a", "b", "c"])[1]) - self.assertEquals((["ab", "ef"], [5]), rdd.histogram(1)) - self.assertRaises(TypeError, lambda: rdd.histogram(2)) - - # mixed RDD - rdd = self.sc.parallelize([1, 4, "ab", "ac", "b"], 2) - self.assertEquals([1, 1], rdd.histogram([0, 4, 10])[1]) - self.assertEquals([2, 1], rdd.histogram(["a", "b", "c"])[1]) - self.assertEquals(([1, "b"], [5]), rdd.histogram(1)) + self.assertEqual([2, 2], rdd.histogram(["a", "b", "c"])[1]) + self.assertEqual((["ab", "ef"], [5]), rdd.histogram(1)) self.assertRaises(TypeError, lambda: rdd.histogram(2)) def test_repartitionAndSortWithinPartitions(self): @@ -685,16 +748,31 @@ def test_repartitionAndSortWithinPartitions(self): repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2) partitions = repartitioned.glom().collect() - self.assertEquals(partitions[0], [(0, 5), (0, 8), (2, 6)]) - self.assertEquals(partitions[1], [(1, 3), (3, 8), (3, 8)]) + self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)]) + self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)]) def test_distinct(self): rdd = self.sc.parallelize((1, 2, 3)*10, 10) - self.assertEquals(rdd.getNumPartitions(), 10) - self.assertEquals(rdd.distinct().count(), 3) + self.assertEqual(rdd.getNumPartitions(), 10) + self.assertEqual(rdd.distinct().count(), 3) result = rdd.distinct(5) - self.assertEquals(result.getNumPartitions(), 5) - self.assertEquals(result.count(), 3) + self.assertEqual(result.getNumPartitions(), 5) + self.assertEqual(result.count(), 3) + + def test_external_group_by_key(self): + self.sc._conf.set("spark.python.worker.memory", "1m") + N = 200001 + kv = self.sc.parallelize(range(N)).map(lambda x: (x % 3, x)) + gkv = kv.groupByKey().cache() + self.assertEqual(3, gkv.count()) + filtered = gkv.filter(lambda kv: kv[0] == 1) + self.assertEqual(1, filtered.count()) + self.assertEqual([(1, N // 3)], filtered.mapValues(len).collect()) + self.assertEqual([(N // 3, N // 3)], + filtered.values().map(lambda x: (len(x), len(list(x)))).collect()) + result = filtered.collect()[0][1] + self.assertEqual(N // 3, len(result)) + self.assertTrue(isinstance(result.data, shuffle.ExternalListOfList)) def test_sort_on_empty_rdd(self): self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect()) @@ -714,6 +792,84 @@ def test_sample(self): wr_s21 = rdd.sample(True, 0.4, 21).collect() self.assertNotEqual(set(wr_s11), set(wr_s21)) + def test_null_in_rdd(self): + jrdd = self.sc._jvm.PythonUtils.generateRDDWithNull(self.sc._jsc) + rdd = RDD(jrdd, self.sc, UTF8Deserializer()) + self.assertEqual([u"a", None, u"b"], rdd.collect()) + rdd = RDD(jrdd, self.sc, NoOpSerializer()) + self.assertEqual([b"a", None, b"b"], rdd.collect()) + + def test_multiple_python_java_RDD_conversions(self): + # Regression test for SPARK-5361 + data = [ + (u'1', {u'director': u'David Lean'}), + (u'2', {u'director': u'Andrew Dominik'}) + ] + data_rdd = self.sc.parallelize(data) + data_java_rdd = data_rdd._to_java_object_rdd() + data_python_rdd = self.sc._jvm.SerDe.javaToPython(data_java_rdd) + converted_rdd = RDD(data_python_rdd, self.sc) + self.assertEqual(2, converted_rdd.count()) + + # conversion between python and java RDD threw exceptions + data_java_rdd = converted_rdd._to_java_object_rdd() + data_python_rdd = self.sc._jvm.SerDe.javaToPython(data_java_rdd) + converted_rdd = RDD(data_python_rdd, self.sc) + self.assertEqual(2, converted_rdd.count()) + + def test_narrow_dependency_in_join(self): + rdd = self.sc.parallelize(range(10)).map(lambda x: (x, x)) + parted = rdd.partitionBy(2) + self.assertEqual(2, parted.union(parted).getNumPartitions()) + self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions()) + self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions()) + + tracker = self.sc.statusTracker() + + self.sc.setJobGroup("test1", "test", True) + d = sorted(parted.join(parted).collect()) + self.assertEqual(10, len(d)) + self.assertEqual((0, (0, 0)), d[0]) + jobId = tracker.getJobIdsForGroup("test1")[0] + self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) + + self.sc.setJobGroup("test2", "test", True) + d = sorted(parted.join(rdd).collect()) + self.assertEqual(10, len(d)) + self.assertEqual((0, (0, 0)), d[0]) + jobId = tracker.getJobIdsForGroup("test2")[0] + self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) + + self.sc.setJobGroup("test3", "test", True) + d = sorted(parted.cogroup(parted).collect()) + self.assertEqual(10, len(d)) + self.assertEqual([[0], [0]], list(map(list, d[0][1]))) + jobId = tracker.getJobIdsForGroup("test3")[0] + self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) + + self.sc.setJobGroup("test4", "test", True) + d = sorted(parted.cogroup(rdd).collect()) + self.assertEqual(10, len(d)) + self.assertEqual([[0], [0]], list(map(list, d[0][1]))) + jobId = tracker.getJobIdsForGroup("test4")[0] + self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) + + # Regression test for SPARK-6294 + def test_take_on_jrdd(self): + rdd = self.sc.parallelize(range(1 << 20)).map(lambda x: str(x)) + rdd._jrdd.first() + + def test_sortByKey_uses_all_partitions_not_only_first_and_last(self): + # Regression test for SPARK-5969 + seq = [(i * 59 % 101, i) for i in range(101)] # unsorted sequence + rdd = self.sc.parallelize(seq) + for ascending in [True, False]: + sort = rdd.sortByKey(ascending=ascending, numPartitions=5) + self.assertEqual(sort.collect(), sorted(seq, reverse=not ascending)) + sizes = sort.glom().map(len).collect() + for size in sizes: + self.assertGreater(size, 0) + class ProfilerTests(PySparkTestCase): @@ -724,255 +880,51 @@ def setUp(self): self.sc = SparkContext('local[4]', class_name, conf=conf) def test_profiler(self): + self.do_computation() - def heavy_foo(x): - for i in range(1 << 20): - x = 1 - rdd = self.sc.parallelize(range(100)) - rdd.foreach(heavy_foo) - profiles = self.sc._profile_stats - self.assertEqual(1, len(profiles)) - id, acc, _ = profiles[0] - stats = acc.value + profilers = self.sc.profiler_collector.profilers + self.assertEqual(1, len(profilers)) + id, profiler, _ = profilers[0] + stats = profiler.stats() self.assertTrue(stats is not None) width, stat_list = stats.get_print_list([]) func_names = [func_name for fname, n, func_name in stat_list] self.assertTrue("heavy_foo" in func_names) + old_stdout = sys.stdout + sys.stdout = io = StringIO() self.sc.show_profiles() + self.assertTrue("heavy_foo" in io.getvalue()) + sys.stdout = old_stdout + d = tempfile.gettempdir() self.sc.dump_profiles(d) self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) + def test_custom_profiler(self): + class TestCustomProfiler(BasicProfiler): + def show(self, id): + self.result = "Custom formatting" -class ExamplePointUDT(UserDefinedType): - """ - User-defined type (UDT) for ExamplePoint. - """ - - @classmethod - def sqlType(self): - return ArrayType(DoubleType(), False) - - @classmethod - def module(cls): - return 'pyspark.tests' - - @classmethod - def scalaUDT(cls): - return 'org.apache.spark.sql.test.ExamplePointUDT' - - def serialize(self, obj): - return [obj.x, obj.y] - - def deserialize(self, datum): - return ExamplePoint(datum[0], datum[1]) - + self.sc.profiler_collector.profiler_cls = TestCustomProfiler -class ExamplePoint: - """ - An example class to demonstrate UDT in Scala, Java, and Python. - """ + self.do_computation() - __UDT__ = ExamplePointUDT() - - def __init__(self, x, y): - self.x = x - self.y = y - - def __repr__(self): - return "ExamplePoint(%s,%s)" % (self.x, self.y) - - def __str__(self): - return "(%s,%s)" % (self.x, self.y) - - def __eq__(self, other): - return isinstance(other, ExamplePoint) and \ - other.x == self.x and other.y == self.y - - -class SQLTests(ReusedPySparkTestCase): - - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.tempdir = tempfile.NamedTemporaryFile(delete=False) - os.unlink(cls.tempdir.name) + profilers = self.sc.profiler_collector.profilers + self.assertEqual(1, len(profilers)) + _, profiler, _ = profilers[0] + self.assertTrue(isinstance(profiler, TestCustomProfiler)) - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - shutil.rmtree(cls.tempdir.name, ignore_errors=True) + self.sc.show_profiles() + self.assertEqual("Custom formatting", profiler.result) - def setUp(self): - self.sqlCtx = SQLContext(self.sc) - - def test_udf(self): - self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) - [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() - self.assertEqual(row[0], 5) - - def test_udf2(self): - self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType()) - self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test") - [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() - self.assertEqual(4, res[0]) - - def test_udf_with_array_type(self): - d = [Row(l=range(3), d={"key": range(5)})] - rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd).registerTempTable("test") - self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) - self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) - [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() - self.assertEqual(range(3), l1) - self.assertEqual(1, l2) - - def test_broadcast_in_udf(self): - bar = {"a": "aa", "b": "bb", "c": "abc"} - foo = self.sc.broadcast(bar) - self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') - [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect() - self.assertEqual("abc", res[0]) - [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() - self.assertEqual("", res[0]) - - def test_basic_functions(self): - rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - srdd = self.sqlCtx.jsonRDD(rdd) - srdd.count() - srdd.collect() - srdd.schemaString() - srdd.schema() - - # cache and checkpoint - self.assertFalse(srdd.is_cached) - srdd.persist() - srdd.unpersist() - srdd.cache() - self.assertTrue(srdd.is_cached) - self.assertFalse(srdd.isCheckpointed()) - self.assertEqual(None, srdd.getCheckpointFile()) - - srdd = srdd.coalesce(2, True) - srdd = srdd.repartition(3) - srdd = srdd.distinct() - srdd.intersection(srdd) - self.assertEqual(2, srdd.count()) - - srdd.registerTempTable("temp") - srdd = self.sqlCtx.sql("select foo from temp") - srdd.count() - srdd.collect() + def do_computation(self): + def heavy_foo(x): + for i in range(1 << 18): + x = 1 - def test_distinct(self): - rdd = self.sc.parallelize(['{"a": 1}', '{"b": 2}', '{"c": 3}']*10, 10) - srdd = self.sqlCtx.jsonRDD(rdd) - self.assertEquals(srdd.getNumPartitions(), 10) - self.assertEquals(srdd.distinct().count(), 3) - result = srdd.distinct(5) - self.assertEquals(result.getNumPartitions(), 5) - self.assertEquals(result.count(), 3) - - def test_apply_schema_to_row(self): - srdd = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) - srdd2 = self.sqlCtx.applySchema(srdd.map(lambda x: x), srdd.schema()) - self.assertEqual(srdd.collect(), srdd2.collect()) - - rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) - srdd3 = self.sqlCtx.applySchema(rdd, srdd.schema()) - self.assertEqual(10, srdd3.count()) - - def test_serialize_nested_array_and_map(self): - d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] - rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd) - row = srdd.first() - self.assertEqual(1, len(row.l)) - self.assertEqual(1, row.l[0].a) - self.assertEqual("2", row.d["key"].d) - - l = srdd.map(lambda x: x.l).first() - self.assertEqual(1, len(l)) - self.assertEqual('s', l[0].b) - - d = srdd.map(lambda x: x.d).first() - self.assertEqual(1, len(d)) - self.assertEqual(1.0, d["key"].c) - - row = srdd.map(lambda x: x.d["key"]).first() - self.assertEqual(1.0, row.c) - self.assertEqual("2", row.d) - - def test_infer_schema(self): - d = [Row(l=[], d={}), - Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] - rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd) - self.assertEqual([], srdd.map(lambda r: r.l).first()) - self.assertEqual([None, ""], srdd.map(lambda r: r.s).collect()) - srdd.registerTempTable("test") - result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") - self.assertEqual(1, result.first()[0]) - - srdd2 = self.sqlCtx.inferSchema(rdd, 1.0) - self.assertEqual(srdd.schema(), srdd2.schema()) - self.assertEqual({}, srdd2.map(lambda r: r.d).first()) - self.assertEqual([None, ""], srdd2.map(lambda r: r.s).collect()) - srdd2.registerTempTable("test2") - result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") - self.assertEqual(1, result.first()[0]) - - def test_struct_in_map(self): - d = [Row(m={Row(i=1): Row(s="")})] - rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd) - k, v = srdd.first().m.items()[0] - self.assertEqual(1, k.i) - self.assertEqual("", v.s) - - def test_convert_row_to_dict(self): - row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) - self.assertEqual(1, row.asDict()['l'][0].a) - rdd = self.sc.parallelize([row]) - srdd = self.sqlCtx.inferSchema(rdd) - srdd.registerTempTable("test") - row = self.sqlCtx.sql("select l, d from test").first() - self.assertEqual(1, row.asDict()["l"][0].a) - self.assertEqual(1.0, row.asDict()['d']['key'].c) - - def test_infer_schema_with_udt(self): - from pyspark.tests import ExamplePoint, ExamplePointUDT - row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) - srdd = self.sqlCtx.inferSchema(rdd) - schema = srdd.schema() - field = [f for f in schema.fields if f.name == "point"][0] - self.assertEqual(type(field.dataType), ExamplePointUDT) - srdd.registerTempTable("labeled_point") - point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point - self.assertEqual(point, ExamplePoint(1.0, 2.0)) - - def test_apply_schema_with_udt(self): - from pyspark.tests import ExamplePoint, ExamplePointUDT - row = (1.0, ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) - schema = StructType([StructField("label", DoubleType(), False), - StructField("point", ExamplePointUDT(), False)]) - srdd = self.sqlCtx.applySchema(rdd, schema) - point = srdd.first().point - self.assertEquals(point, ExamplePoint(1.0, 2.0)) - - def test_parquet_with_udt(self): - from pyspark.tests import ExamplePoint - row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) - srdd0 = self.sqlCtx.inferSchema(rdd) - output_dir = os.path.join(self.tempdir.name, "labeled_point") - srdd0.saveAsParquetFile(output_dir) - srdd1 = self.sqlCtx.parquetFile(output_dir) - point = srdd1.first().point - self.assertEquals(point, ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize(range(100)) + rdd.foreach(heavy_foo) class InputFormatTests(ReusedPySparkTestCase): @@ -989,6 +941,7 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name) + @unittest.skipIf(sys.version >= "3", "serialize array of byte") def test_sequencefiles(self): basepath = self.tempdir.name ints = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfint/", @@ -1037,15 +990,16 @@ def test_sequencefiles(self): en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] self.assertEqual(nulls, en) - maps = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfmap/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable").collect()) + maps = self.sc.sequenceFile(basepath + "/sftestdata/sfmap/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable").collect() em = [(1, {}), (1, {3.0: u'bb'}), (2, {1.0: u'aa'}), (2, {1.0: u'cc'}), (3, {2.0: u'dd'})] - self.assertEqual(maps, em) + for v in maps: + self.assertTrue(v in em) # arrays get pickled to tuples by default tuples = sorted(self.sc.sequenceFile( @@ -1172,8 +1126,8 @@ def test_converters(self): def test_binary_files(self): path = os.path.join(self.tempdir.name, "binaryfiles") os.mkdir(path) - data = "short binary data" - with open(os.path.join(path, "part-0000"), 'w') as f: + data = b"short binary data" + with open(os.path.join(path, "part-0000"), 'wb') as f: f.write(data) [(p, d)] = self.sc.binaryFiles(path).collect() self.assertTrue(p.endswith("part-0000")) @@ -1186,7 +1140,7 @@ def test_binary_records(self): for i in range(100): f.write('%04d' % i) result = self.sc.binaryRecords(path, 4).map(int).collect() - self.assertEqual(range(100), result) + self.assertEqual(list(range(100)), result) class OutputFormatTests(ReusedPySparkTestCase): @@ -1198,6 +1152,7 @@ def setUp(self): def tearDown(self): shutil.rmtree(self.tempdir.name, ignore_errors=True) + @unittest.skipIf(sys.version >= "3", "serialize array of byte") def test_sequencefiles(self): basepath = self.tempdir.name ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] @@ -1238,8 +1193,9 @@ def test_sequencefiles(self): (2, {1.0: u'cc'}), (3, {2.0: u'dd'})] self.sc.parallelize(em).saveAsSequenceFile(basepath + "/sfmap/") - maps = sorted(self.sc.sequenceFile(basepath + "/sfmap/").collect()) - self.assertEqual(maps, em) + maps = self.sc.sequenceFile(basepath + "/sfmap/").collect() + for v in maps: + self.assertTrue(v, em) def test_oldhadoop(self): basepath = self.tempdir.name @@ -1251,12 +1207,13 @@ def test_oldhadoop(self): "org.apache.hadoop.mapred.SequenceFileOutputFormat", "org.apache.hadoop.io.IntWritable", "org.apache.hadoop.io.MapWritable") - result = sorted(self.sc.hadoopFile( + result = self.sc.hadoopFile( basepath + "/oldhadoop/", "org.apache.hadoop.mapred.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable").collect()) - self.assertEqual(result, dict_data) + "org.apache.hadoop.io.MapWritable").collect() + for v in result: + self.assertTrue(v, dict_data) conf = { "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", @@ -1266,12 +1223,13 @@ def test_oldhadoop(self): } self.sc.parallelize(dict_data).saveAsHadoopDataset(conf) input_conf = {"mapred.input.dir": basepath + "/olddataset/"} - old_dataset = sorted(self.sc.hadoopRDD( + result = self.sc.hadoopRDD( "org.apache.hadoop.mapred.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", "org.apache.hadoop.io.MapWritable", - conf=input_conf).collect()) - self.assertEqual(old_dataset, dict_data) + conf=input_conf).collect() + for v in result: + self.assertTrue(v, dict_data) def test_newhadoop(self): basepath = self.tempdir.name @@ -1306,6 +1264,7 @@ def test_newhadoop(self): conf=input_conf).collect()) self.assertEqual(new_dataset, data) + @unittest.skipIf(sys.version >= "3", "serialize of array") def test_newhadoop_with_array(self): basepath = self.tempdir.name # use custom ArrayWritable types and converters to handle arrays @@ -1386,7 +1345,7 @@ def test_reserialization(self): basepath = self.tempdir.name x = range(1, 5) y = range(1001, 1005) - data = zip(x, y) + data = list(zip(x, y)) rdd = self.sc.parallelize(x).zip(self.sc.parallelize(y)) rdd.saveAsSequenceFile(basepath + "/reserialize/sequence") result1 = sorted(self.sc.sequenceFile(basepath + "/reserialize/sequence").collect()) @@ -1437,7 +1396,7 @@ def connect(self, port): sock = socket(AF_INET, SOCK_STREAM) sock.connect(('127.0.0.1', port)) # send a split index of -1 to shutdown the worker - sock.send("\xFF\xFF\xFF\xFF") + sock.send(b"\xFF\xFF\xFF\xFF") sock.close() return True @@ -1477,8 +1436,7 @@ def test_termination_sigterm(self): self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) -class WorkerTests(PySparkTestCase): - +class WorkerTests(ReusedPySparkTestCase): def test_cancel_task(self): temp = tempfile.NamedTemporaryFile(delete=True) temp.close() @@ -1493,7 +1451,10 @@ def sleep(x): # start job in background thread def run(): - self.sc.parallelize(range(1)).foreach(sleep) + try: + self.sc.parallelize(range(1), 1).foreach(sleep) + except Exception: + pass import threading t = threading.Thread(target=run) t.daemon = True @@ -1502,7 +1463,8 @@ def run(): daemon_pid, worker_pid = 0, 0 while True: if os.path.exists(path): - data = open(path).read().split(' ') + with open(path) as f: + data = f.read().split(' ') daemon_pid, worker_pid = map(int, data) break time.sleep(0.1) @@ -1533,18 +1495,20 @@ def test_after_exception(self): def raise_exception(_): raise Exception() rdd = self.sc.parallelize(range(100), 1) - self.assertRaises(Exception, lambda: rdd.foreach(raise_exception)) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: rdd.foreach(raise_exception)) self.assertEqual(100, rdd.map(str).count()) def test_after_jvm_exception(self): tempFile = tempfile.NamedTemporaryFile(delete=False) - tempFile.write("Hello World!") + tempFile.write(b"Hello World!") tempFile.close() data = self.sc.textFile(tempFile.name, 1) filtered_data = data.filter(lambda x: True) self.assertEqual(1, filtered_data.count()) os.unlink(tempFile.name) - self.assertRaises(Exception, lambda: filtered_data.count()) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: filtered_data.count()) rdd = self.sc.parallelize(range(100), 1) self.assertEqual(100, rdd.map(str).count()) @@ -1577,6 +1541,17 @@ def count(): self.assertTrue(not t.isAlive()) self.assertEqual(100000, rdd.count()) + def test_with_different_versions_of_python(self): + rdd = self.sc.parallelize(range(10)) + rdd.count() + version = sys.version_info + sys.version_info = (2, 0, 0) + try: + with QuietTest(self.sc): + self.assertRaises(Py4JJavaError, lambda: rdd.count()) + finally: + sys.version_info = version + class SparkSubmitTests(unittest.TestCase): @@ -1587,43 +1562,71 @@ def setUp(self): def tearDown(self): shutil.rmtree(self.programDir) - def createTempFile(self, name, content): + def createTempFile(self, name, content, dir=None): """ Create a temp file with the given name and content and return its path. Strips leading spaces from content up to the first '|' in each line. """ pattern = re.compile(r'^ *\|', re.MULTILINE) content = re.sub(pattern, '', content.strip()) - path = os.path.join(self.programDir, name) + if dir is None: + path = os.path.join(self.programDir, name) + else: + os.makedirs(os.path.join(self.programDir, dir)) + path = os.path.join(self.programDir, dir, name) with open(path, "w") as f: f.write(content) return path - def createFileInZip(self, name, content): + def createFileInZip(self, name, content, ext=".zip", dir=None, zip_name=None): """ Create a zip archive containing a file with the given content and return its path. Strips leading spaces from content up to the first '|' in each line. """ pattern = re.compile(r'^ *\|', re.MULTILINE) content = re.sub(pattern, '', content.strip()) - path = os.path.join(self.programDir, name + ".zip") + if dir is None: + path = os.path.join(self.programDir, name + ext) + else: + path = os.path.join(self.programDir, dir, zip_name + ext) zip = zipfile.ZipFile(path, 'w') zip.writestr(name, content) zip.close() return path + def create_spark_package(self, artifact_name): + group_id, artifact_id, version = artifact_name.split(":") + self.createTempFile("%s-%s.pom" % (artifact_id, version), (""" + | + | + | 4.0.0 + | %s + | %s + | %s + | + """ % (group_id, artifact_id, version)).lstrip(), + os.path.join(group_id, artifact_id, version)) + self.createFileInZip("%s.py" % artifact_id, """ + |def myfunc(x): + | return x + 1 + """, ".jar", os.path.join(group_id, artifact_id, version), + "%s-%s" % (artifact_id, version)) + def test_single_script(self): """Submit and test a single script file""" script = self.createTempFile("test.py", """ |from pyspark import SparkContext | |sc = SparkContext() - |print sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect() + |print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect()) """) proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) - self.assertIn("[2, 4, 6]", out) + self.assertIn("[2, 4, 6]", out.decode('utf-8')) def test_script_with_local_functions(self): """Submit and test a single script file calling a global function""" @@ -1634,12 +1637,12 @@ def test_script_with_local_functions(self): | return x * 3 | |sc = SparkContext() - |print sc.parallelize([1, 2, 3]).map(foo).collect() + |print(sc.parallelize([1, 2, 3]).map(foo).collect()) """) proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) - self.assertIn("[3, 6, 9]", out) + self.assertIn("[3, 6, 9]", out.decode('utf-8')) def test_module_dependency(self): """Submit and test a script with a dependency on another module""" @@ -1648,7 +1651,7 @@ def test_module_dependency(self): |from mylib import myfunc | |sc = SparkContext() - |print sc.parallelize([1, 2, 3]).map(myfunc).collect() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) """) zip = self.createFileInZip("mylib.py", """ |def myfunc(x): @@ -1658,7 +1661,7 @@ def test_module_dependency(self): stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) def test_module_dependency_on_cluster(self): """Submit and test a script with a dependency on another module on a cluster""" @@ -1667,7 +1670,7 @@ def test_module_dependency_on_cluster(self): |from mylib import myfunc | |sc = SparkContext() - |print sc.parallelize([1, 2, 3]).map(myfunc).collect() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) """) zip = self.createFileInZip("mylib.py", """ |def myfunc(x): @@ -1678,7 +1681,40 @@ def test_module_dependency_on_cluster(self): stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) + + def test_package_dependency(self): + """Submit and test a script with a dependency on a Spark Package""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + |from mylib import myfunc + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) + """) + self.create_spark_package("a:mylib:0.1") + proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, script], stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) + + def test_package_dependency_on_cluster(self): + """Submit and test a script with a dependency on a Spark Package on a cluster""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + |from mylib import myfunc + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) + """) + self.create_spark_package("a:mylib:0.1") + proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, "--master", + "local-cluster[1,1,512]", script], stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) def test_single_script_on_cluster(self): """Submit and test a single script on a cluster""" @@ -1689,7 +1725,7 @@ def test_single_script_on_cluster(self): | return x * 2 | |sc = SparkContext() - |print sc.parallelize([1, 2, 3]).map(foo).collect() + |print(sc.parallelize([1, 2, 3]).map(foo).collect()) """) # this will fail if you have different spark.executor.memory # in conf/spark-defaults.conf @@ -1698,7 +1734,7 @@ def test_single_script_on_cluster(self): stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) - self.assertIn("[2, 4, 6]", out) + self.assertIn("[2, 4, 6]", out.decode('utf-8')) class ContextTests(unittest.TestCase): @@ -1733,6 +1769,42 @@ def test_with_stop(self): sc.stop() self.assertEqual(SparkContext._active_spark_context, None) + def test_progress_api(self): + with SparkContext() as sc: + sc.setJobGroup('test_progress_api', '', True) + rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100)) + + def run(): + try: + rdd.count() + except Exception: + pass + t = threading.Thread(target=run) + t.daemon = True + t.start() + # wait for scheduler to start + time.sleep(1) + + tracker = sc.statusTracker() + jobIds = tracker.getJobIdsForGroup('test_progress_api') + self.assertEqual(1, len(jobIds)) + job = tracker.getJobInfo(jobIds[0]) + self.assertEqual(1, len(job.stageIds)) + stage = tracker.getStageInfo(job.stageIds[0]) + self.assertEqual(rdd.getNumPartitions(), stage.numTasks) + + sc.cancelAllJobs() + t.join() + # wait for event listener to update the status + time.sleep(1) + + job = tracker.getJobInfo(jobIds[0]) + self.assertEqual('FAILED', job.status) + self.assertEqual([], tracker.getActiveJobsIds()) + self.assertEqual([], tracker.getActiveStageIds()) + + sc.stop() + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): @@ -1742,7 +1814,7 @@ class SciPyTests(PySparkTestCase): def test_serialize(self): from scipy.special import gammaln x = range(1, 5) - expected = map(gammaln, x) + expected = list(map(gammaln, x)) observed = self.sc.parallelize(x).map(gammaln).collect() self.assertEqual(expected, observed) @@ -1763,11 +1835,11 @@ def test_statcounter_array(self): if __name__ == "__main__": if not _have_scipy: - print "NOTE: Skipping SciPy tests as it does not seem to be installed" + print("NOTE: Skipping SciPy tests as it does not seem to be installed") if not _have_numpy: - print "NOTE: Skipping NumPy tests as it does not seem to be installed" + print("NOTE: Skipping NumPy tests as it does not seem to be installed") unittest.main() if not _have_scipy: - print "NOTE: SciPy tests were skipped as it does not seem to be installed" + print("NOTE: SciPy tests were skipped as it does not seem to be installed") if not _have_numpy: - print "NOTE: NumPy tests were skipped as it does not seem to be installed" + print("NOTE: NumPy tests were skipped as it does not seem to be installed") diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 7e5343c973dc..fbdaf3a5814c 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -18,13 +18,12 @@ """ Worker that receives input from Piped RDD. """ +from __future__ import print_function import os import sys import time import socket import traceback -import cProfile -import pstats from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry @@ -39,9 +38,9 @@ def report_times(outfile, boot, init, finish): write_int(SpecialLengths.TIMING_DATA, outfile) - write_long(1000 * boot, outfile) - write_long(1000 * init, outfile) - write_long(1000 * finish, outfile) + write_long(int(1000 * boot), outfile) + write_long(int(1000 * init), outfile) + write_long(int(1000 * finish), outfile) def add_path(path): @@ -74,6 +73,9 @@ def main(infile, outfile): for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) add_path(os.path.join(spark_files_dir, filename)) + if sys.version > '3': + import importlib + importlib.invalidate_caches() # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) @@ -90,32 +92,32 @@ def main(infile, outfile): command = pickleSer._read_with_length(infile) if isinstance(command, Broadcast): command = pickleSer.loads(command.value) - (func, stats, deserializer, serializer) = command + (func, profiler, deserializer, serializer), version = command + if version != sys.version_info[:2]: + raise Exception(("Python in worker has different version %s than that in " + + "driver %s, PySpark cannot run with different minor versions") % + (sys.version_info[:2], version)) init_time = time.time() def process(): iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) - if stats: - p = cProfile.Profile() - p.runcall(process) - st = pstats.Stats(p) - st.stream = None # make it picklable - stats.add(st.strip_dirs()) + if profiler: + profiler.profile(process) else: process() except Exception: try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) - write_with_length(traceback.format_exc(), outfile) + write_with_length(traceback.format_exc().encode("utf-8"), outfile) except IOError: # JVM close the socket pass except Exception: # Write the error to stderr if it happened while serializing - print >> sys.stderr, "PySpark worker failed with exception:" - print >> sys.stderr, traceback.format_exc() + print("PySpark worker failed with exception:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) diff --git a/python/run-tests b/python/run-tests index 9ee19ed6e6b2..88b63b84fdc2 100755 --- a/python/run-tests +++ b/python/run-tests @@ -21,11 +21,14 @@ # Figure out where the Spark framework is installed FWDIR="$(cd "`dirname "$0"`"; cd ../; pwd)" +. "$FWDIR"/bin/load-spark-env.sh + # CD into the python directory to find things on the right path cd "$FWDIR/python" FAILED=0 LOG_FILE=unit-tests.log +START=$(date +"%s") rm -f $LOG_FILE @@ -33,9 +36,9 @@ rm -f $LOG_FILE rm -rf metastore warehouse function run_test() { - echo "Running test: $1" | tee -a $LOG_FILE - - SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 2>&1 | tee -a $LOG_FILE + echo -en "Running test: $1 ... " | tee -a $LOG_FILE + start=$(date +"%s") + SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 > $LOG_FILE 2>&1 FAILED=$((PIPESTATUS[0]||$FAILED)) @@ -46,6 +49,9 @@ function run_test() { echo "Had test failures; see logs." echo -en "\033[0m" # No color exit -1 + else + now=$(date +"%s") + echo "ok ($(($now - $start))s)" fi } @@ -57,32 +63,61 @@ function run_core_tests() { PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py" PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py" run_test "pyspark/serializers.py" + run_test "pyspark/profiler.py" run_test "pyspark/shuffle.py" run_test "pyspark/tests.py" } function run_sql_tests() { echo "Run sql tests ..." - run_test "pyspark/sql.py" + run_test "pyspark/sql/_types.py" + run_test "pyspark/sql/context.py" + run_test "pyspark/sql/dataframe.py" + run_test "pyspark/sql/functions.py" + run_test "pyspark/sql/tests.py" } function run_mllib_tests() { echo "Run mllib tests ..." run_test "pyspark/mllib/classification.py" run_test "pyspark/mllib/clustering.py" + run_test "pyspark/mllib/evaluation.py" run_test "pyspark/mllib/feature.py" + run_test "pyspark/mllib/fpm.py" run_test "pyspark/mllib/linalg.py" run_test "pyspark/mllib/rand.py" run_test "pyspark/mllib/recommendation.py" run_test "pyspark/mllib/regression.py" - run_test "pyspark/mllib/stat.py" + run_test "pyspark/mllib/stat/_statistics.py" run_test "pyspark/mllib/tree.py" run_test "pyspark/mllib/util.py" run_test "pyspark/mllib/tests.py" } +function run_ml_tests() { + echo "Run ml tests ..." + run_test "pyspark/ml/feature.py" + run_test "pyspark/ml/classification.py" + run_test "pyspark/ml/tests.py" +} + function run_streaming_tests() { echo "Run streaming tests ..." + + KAFKA_ASSEMBLY_DIR="$FWDIR"/external/kafka-assembly + JAR_PATH="${KAFKA_ASSEMBLY_DIR}/target/scala-${SPARK_SCALA_VERSION}" + for f in "${JAR_PATH}"/spark-streaming-kafka-assembly-*.jar; do + if [[ ! -e "$f" ]]; then + echo "Failed to find Spark Streaming Kafka assembly jar in $KAFKA_ASSEMBLY_DIR" 1>&2 + echo "You need to build Spark with " \ + "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or" \ + "'build/mvn package' before running this program" 1>&2 + exit 1 + fi + KAFKA_ASSEMBLY_JAR="$f" + done + + export PYSPARK_SUBMIT_ARGS="--jars ${KAFKA_ASSEMBLY_JAR} pyspark-shell" run_test "pyspark/streaming/util.py" run_test "pyspark/streaming/tests.py" } @@ -102,8 +137,22 @@ $PYSPARK_PYTHON --version run_core_tests run_sql_tests run_mllib_tests +run_ml_tests run_streaming_tests +# Try to test with Python 3 +if [ $(which python3.4) ]; then + export PYSPARK_PYTHON="python3.4" + echo "Testing with Python3.4 version:" + $PYSPARK_PYTHON --version + + run_core_tests + run_sql_tests + run_mllib_tests + run_ml_tests + run_streaming_tests +fi + # Try to test with PyPy if [ $(which pypy) ]; then export PYSPARK_PYTHON="pypy" @@ -116,9 +165,8 @@ if [ $(which pypy) ]; then fi if [[ $FAILED == 0 ]]; then - echo -en "\033[32m" # Green - echo "Tests passed." - echo -en "\033[0m" # No color + now=$(date +"%s") + echo -e "\033[32mTests passed \033[0min $(($now - $START)) seconds" fi # TODO: in the long-run, it would be nice to use a test runner like `nose`. diff --git a/python/test_support/userlib-0.1-py2.7.egg b/python/test_support/userlib-0.1-py2.7.egg deleted file mode 100644 index 1674c9cb2227..000000000000 Binary files a/python/test_support/userlib-0.1-py2.7.egg and /dev/null differ diff --git a/python/test_support/userlib-0.1.zip b/python/test_support/userlib-0.1.zip new file mode 100644 index 000000000000..496e1349aa96 Binary files /dev/null and b/python/test_support/userlib-0.1.zip differ diff --git a/repl/pom.xml b/repl/pom.xml index ae7c31aef4f5..03053b4c3b28 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT ../pom.xml @@ -33,8 +33,6 @@ repl - /usr/share/spark - root scala-2.10/src/main/scala scala-2.10/src/test/scala @@ -66,7 +64,6 @@ org.apache.spark spark-sql_${scala.binary.version} ${project.version} - test org.scala-lang @@ -87,6 +84,35 @@ scalacheck_${scala.binary.version} test + + org.mockito + mockito-all + test + + + + + org.eclipse.jetty + jetty-server + + + org.eclipse.jetty + jetty-plus + + + org.eclipse.jetty + jetty-util + + + org.eclipse.jetty + jetty-http + + + + + org.scala-lang + scala-library + target/scala-${scala.binary.version}/classes diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 72c1a989999b..8dc0e0c96592 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -45,6 +45,7 @@ import scala.reflect.api.{Mirror, TypeCreator, Universe => ApiUniverse} import org.apache.spark.Logging import org.apache.spark.SparkConf import org.apache.spark.SparkContext +import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils /** The Scala interactive shell. It provides a read-eval-print loop @@ -130,6 +131,7 @@ class SparkILoop( // NOTE: Must be public for visibility @DeveloperApi var sparkContext: SparkContext = _ + var sqlContext: SQLContext = _ override def echoCommandMessage(msg: String) { intp.reporter printMessage msg @@ -1016,6 +1018,23 @@ class SparkILoop( sparkContext } + @DeveloperApi + def createSQLContext(): SQLContext = { + val name = "org.apache.spark.sql.hive.HiveContext" + val loader = Utils.getContextOrSparkClassLoader + try { + sqlContext = loader.loadClass(name).getConstructor(classOf[SparkContext]) + .newInstance(sparkContext).asInstanceOf[SQLContext] + logInfo("Created sql context (with Hive support)..") + } + catch { + case cnf: java.lang.ClassNotFoundException => + sqlContext = new SQLContext(sparkContext) + logInfo("Created sql context..") + } + sqlContext + } + private def getMaster(): String = { val master = this.master match { case Some(m) => m @@ -1045,15 +1064,16 @@ class SparkILoop( private def main(settings: Settings): Unit = process(settings) } -object SparkILoop { +object SparkILoop extends Logging { implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp private def echo(msg: String) = Console println msg def getAddedJars: Array[String] = { val envJars = sys.env.get("ADD_JARS") - val propJars = sys.props.get("spark.jars").flatMap { p => - if (p == "") None else Some(p) + if (envJars.isDefined) { + logWarning("ADD_JARS environment variable is deprecated, use --jar spark submit argument instead") } + val propJars = sys.props.get("spark.jars").flatMap { p => if (p == "") None else Some(p) } val jars = propJars.orElse(envJars).getOrElse("") Utils.resolveURIs(jars).split(",").filter(_.nonEmpty) } diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala index 99bd777c04fd..05faef8786d2 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala @@ -127,7 +127,17 @@ private[repl] trait SparkILoopInit { _sc } """) + command(""" + @transient val sqlContext = { + val _sqlContext = org.apache.spark.repl.Main.interp.createSQLContext() + println("SQL context available as sqlContext.") + _sqlContext + } + """) command("import org.apache.spark.SparkContext._") + command("import sqlContext.implicits._") + command("import sqlContext.sql") + command("import org.apache.spark.sql.functions._") } } diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 91c9c52c3c98..934daaeaafca 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -121,9 +121,9 @@ class ReplSuite extends FunSuite { val output = runInterpreter("local", """ |var v = 7 - |sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => v).collect().reduceLeft(_+_) |v = 10 - |sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => v).collect().reduceLeft(_+_) """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -137,7 +137,7 @@ class ReplSuite extends FunSuite { |class C { |def foo = 5 |} - |sc.parallelize(1 to 10).map(x => (new C).foo).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => (new C).foo).collect().reduceLeft(_+_) """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -148,7 +148,7 @@ class ReplSuite extends FunSuite { val output = runInterpreter("local", """ |def double(x: Int) = x + x - |sc.parallelize(1 to 10).map(x => double(x)).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => double(x)).collect().reduceLeft(_+_) """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -160,9 +160,9 @@ class ReplSuite extends FunSuite { """ |var v = 7 |def getV() = v - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) |v = 10 - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -178,9 +178,9 @@ class ReplSuite extends FunSuite { """ |var array = new Array[Int](5) |val broadcastArray = sc.broadcast(array) - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() |array(0) = 5 - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -216,14 +216,14 @@ class ReplSuite extends FunSuite { """ |var v = 7 |def getV() = v - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) |v = 10 - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) |var array = new Array[Int](5) |val broadcastArray = sc.broadcast(array) - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() |array(0) = 5 - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -255,14 +255,14 @@ class ReplSuite extends FunSuite { assertDoesNotContain("Exception", output) } - test("SPARK-2576 importing SQLContext.createSchemaRDD.") { + test("SPARK-2576 importing SQLContext.implicits._") { // We need to use local-cluster to test this case. val output = runInterpreter("local-cluster[1,1,512]", """ |val sqlContext = new org.apache.spark.sql.SQLContext(sc) - |import sqlContext.createSchemaRDD + |import sqlContext.implicits._ |case class TestCaseClass(value: Int) - |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toSchemaRDD.collect + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -275,26 +275,26 @@ class ReplSuite extends FunSuite { |val t = new TestClass |import t.testMethod |case class TestCaseClass(value: Int) - |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) } - if (System.getenv("MESOS_NATIVE_LIBRARY") != null) { + if (System.getenv("MESOS_NATIVE_JAVA_LIBRARY") != null) { test("running on Mesos") { val output = runInterpreter("localquiet", """ |var v = 7 |def getV() = v - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) |v = 10 - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) |var array = new Array[Int](5) |val broadcastArray = sc.broadcast(array) - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() |array(0) = 5 - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -309,10 +309,22 @@ class ReplSuite extends FunSuite { val output = runInterpreter("local[2]", """ |case class Foo(i: Int) - |val ret = sc.parallelize((1 to 100).map(Foo), 10).collect + |val ret = sc.parallelize((1 to 100).map(Foo), 10).collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) assertContains("ret: Array[Foo] = Array(Foo(1),", output) } + + test("collecting objects of class defined in repl - shuffling") { + val output = runInterpreter("local-cluster[1,1,512]", + """ + |case class Foo(i: Int) + |val list = List((1, Foo(1)), (1, Foo(2))) + |val ret = sc.parallelize(list).groupByKey().collect() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("ret: Array[(Int, Iterable[Foo])] = Array((1,", output) + } } diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index 69e44d4f916e..2210fbaafead 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -19,6 +19,7 @@ package org.apache.spark.repl import org.apache.spark.util.Utils import org.apache.spark._ +import org.apache.spark.sql.SQLContext import scala.tools.nsc.Settings import scala.tools.nsc.interpreter.SparkILoop @@ -34,6 +35,7 @@ object Main extends Logging { "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", "-Yrepl-sync"), true) val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf)) var sparkContext: SparkContext = _ + var sqlContext: SQLContext = _ var interp = new SparkILoop // this is a public var because tests reset it. def main(args: Array[String]) { @@ -49,6 +51,9 @@ object Main extends Logging { def getAddedJars: Array[String] = { val envJars = sys.env.get("ADD_JARS") + if (envJars.isDefined) { + logWarning("ADD_JARS environment variable is deprecated, use --jar spark submit argument instead") + } val propJars = sys.props.get("spark.jars").flatMap { p => if (p == "") None else Some(p) } val jars = propJars.orElse(envJars).getOrElse("") Utils.resolveURIs(jars).split(",").filter(_.nonEmpty) @@ -74,6 +79,22 @@ object Main extends Logging { sparkContext } + def createSQLContext(): SQLContext = { + val name = "org.apache.spark.sql.hive.HiveContext" + val loader = Utils.getContextOrSparkClassLoader + try { + sqlContext = loader.loadClass(name).getConstructor(classOf[SparkContext]) + .newInstance(sparkContext).asInstanceOf[SQLContext] + logInfo("Created sql context (with Hive support)..") + } + catch { + case cnf: java.lang.ClassNotFoundException => + sqlContext = new SQLContext(sparkContext) + logInfo("Created sql context..") + } + sqlContext + } + private def getMaster: String = { val master = { val envMaster = sys.env.get("MASTER") diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 250727305970..7a5e94da5cbf 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -66,8 +66,18 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter) println("Spark context available as sc.") _sc } - """) + """) + command( """ + @transient val sqlContext = { + val _sqlContext = org.apache.spark.repl.Main.createSQLContext() + println("SQL context available as sqlContext.") + _sqlContext + } + """) command("import org.apache.spark.SparkContext._") + command("import sqlContext.implicits._") + command("import sqlContext.sql") + command("import org.apache.spark.sql.functions._") } } diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index f966f25c5a14..14f5e9ed4f25 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -21,11 +21,9 @@ import java.io._ import java.net.URLClassLoader import scala.collection.mutable.ArrayBuffer -import scala.concurrent.Await import scala.concurrent.duration._ import scala.tools.nsc.interpreter.SparkILoop -import com.google.common.io.Files import org.scalatest.FunSuite import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.SparkContext @@ -128,9 +126,9 @@ class ReplSuite extends FunSuite { val output = runInterpreter("local", """ |var v = 7 - |sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => v).collect().reduceLeft(_+_) |v = 10 - |sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => v).collect().reduceLeft(_+_) """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -144,7 +142,7 @@ class ReplSuite extends FunSuite { |class C { |def foo = 5 |} - |sc.parallelize(1 to 10).map(x => (new C).foo).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => (new C).foo).collect().reduceLeft(_+_) """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -155,7 +153,7 @@ class ReplSuite extends FunSuite { val output = runInterpreter("local", """ |def double(x: Int) = x + x - |sc.parallelize(1 to 10).map(x => double(x)).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => double(x)).collect().reduceLeft(_+_) """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -167,9 +165,9 @@ class ReplSuite extends FunSuite { """ |var v = 7 |def getV() = v - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) |v = 10 - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -185,9 +183,9 @@ class ReplSuite extends FunSuite { """ |var array = new Array[Int](5) |val broadcastArray = sc.broadcast(array) - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() |array(0) = 5 - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -196,8 +194,7 @@ class ReplSuite extends FunSuite { } test("interacting with files") { - val tempDir = Files.createTempDir() - tempDir.deleteOnExit() + val tempDir = Utils.createTempDir() val out = new FileWriter(tempDir + "/input") out.write("Hello world!\n") out.write("What's up?\n") @@ -224,14 +221,14 @@ class ReplSuite extends FunSuite { """ |var v = 7 |def getV() = v - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) |v = 10 - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) |var array = new Array[Int](5) |val broadcastArray = sc.broadcast(array) - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() |array(0) = 5 - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -263,14 +260,14 @@ class ReplSuite extends FunSuite { assertDoesNotContain("Exception", output) } - test("SPARK-2576 importing SQLContext.createSchemaRDD.") { + test("SPARK-2576 importing SQLContext.createDataFrame.") { // We need to use local-cluster to test this case. val output = runInterpreter("local-cluster[1,1,512]", """ |val sqlContext = new org.apache.spark.sql.SQLContext(sc) - |import sqlContext.createSchemaRDD + |import sqlContext.implicits._ |case class TestCaseClass(value: Int) - |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toSchemaRDD.collect + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -283,26 +280,26 @@ class ReplSuite extends FunSuite { |val t = new TestClass |import t.testMethod |case class TestCaseClass(value: Int) - |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) } - if (System.getenv("MESOS_NATIVE_LIBRARY") != null) { + if (System.getenv("MESOS_NATIVE_JAVA_LIBRARY") != null) { test("running on Mesos") { val output = runInterpreter("localquiet", """ |var v = 7 |def getV() = v - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) |v = 10 - |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_) |var array = new Array[Int](5) |val broadcastArray = sc.broadcast(array) - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() |array(0) = 5 - |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -317,10 +314,22 @@ class ReplSuite extends FunSuite { val output = runInterpreter("local[2]", """ |case class Foo(i: Int) - |val ret = sc.parallelize((1 to 100).map(Foo), 10).collect + |val ret = sc.parallelize((1 to 100).map(Foo), 10).collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) assertContains("ret: Array[Foo] = Array(Foo(1),", output) } + + test("collecting objects of class defined in repl - shuffling") { + val output = runInterpreter("local-cluster[1,1,512]", + """ + |case class Foo(i: Int) + |val list = List((1, Foo(1)), (1, Foo(2))) + |val ret = sc.parallelize(list).groupByKey().collect() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("ret: Array[(Int, Iterable[Foo])] = Array((1,", output) + } } diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 5ee325008a5c..004941d5f50a 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -17,13 +17,14 @@ package org.apache.spark.repl -import java.io.{ByteArrayOutputStream, InputStream} -import java.net.{URI, URL, URLEncoder} -import java.util.concurrent.{Executors, ExecutorService} +import java.io.{IOException, ByteArrayOutputStream, InputStream} +import java.net.{HttpURLConnection, URI, URL, URLEncoder} + +import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.{SparkConf, SparkEnv, Logging} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.Utils import org.apache.spark.util.ParentClassLoader @@ -37,15 +38,18 @@ import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ * Allows the user to specify if user class path should be first */ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader, - userClassPathFirst: Boolean) extends ClassLoader { + userClassPathFirst: Boolean) extends ClassLoader with Logging { val uri = new URI(classUri) val directory = uri.getPath val parentLoader = new ParentClassLoader(parent) + // Allows HTTP connect and read timeouts to be controlled for testing / debugging purposes + private[repl] var httpUrlConnectionTimeoutMillis: Int = -1 + // Hadoop FileSystem object for our URI, if it isn't using HTTP var fileSystem: FileSystem = { - if (uri.getScheme() == "http") { + if (Set("http", "https", "ftp").contains(uri.getScheme)) { null } else { FileSystem.get(uri, SparkHadoopUtil.get.newConfiguration(conf)) @@ -71,27 +75,82 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader } } + private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = { + val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { + val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) + val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager) + newuri.toURL + } else { + new URL(classUri + "/" + urlEncode(pathInDirectory)) + } + val connection: HttpURLConnection = Utils.setupSecureURLConnection(url.openConnection(), + SparkEnv.get.securityManager).asInstanceOf[HttpURLConnection] + // Set the connection timeouts (for testing purposes) + if (httpUrlConnectionTimeoutMillis != -1) { + connection.setConnectTimeout(httpUrlConnectionTimeoutMillis) + connection.setReadTimeout(httpUrlConnectionTimeoutMillis) + } + connection.connect() + try { + if (connection.getResponseCode != 200) { + // Close the error stream so that the connection is eligible for re-use + try { + connection.getErrorStream.close() + } catch { + case ioe: IOException => + logError("Exception while closing error stream", ioe) + } + throw new ClassNotFoundException(s"Class file not found at URL $url") + } else { + connection.getInputStream + } + } catch { + case NonFatal(e) if !e.isInstanceOf[ClassNotFoundException] => + connection.disconnect() + throw e + } + } + + private def getClassFileInputStreamFromFileSystem(pathInDirectory: String): InputStream = { + val path = new Path(directory, pathInDirectory) + if (fileSystem.exists(path)) { + fileSystem.open(path) + } else { + throw new ClassNotFoundException(s"Class file not found at path $path") + } + } + def findClassLocally(name: String): Option[Class[_]] = { + val pathInDirectory = name.replace('.', '/') + ".class" + var inputStream: InputStream = null try { - val pathInDirectory = name.replace('.', '/') + ".class" - val inputStream = { + inputStream = { if (fileSystem != null) { - fileSystem.open(new Path(directory, pathInDirectory)) + getClassFileInputStreamFromFileSystem(pathInDirectory) } else { - if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { - val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) - val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager) - newuri.toURL().openStream() - } else { - new URL(classUri + "/" + urlEncode(pathInDirectory)).openStream() - } + getClassFileInputStreamFromHttpServer(pathInDirectory) } } val bytes = readAndTransformClass(name, inputStream) - inputStream.close() Some(defineClass(name, bytes, 0, bytes.length)) } catch { - case e: Exception => None + case e: ClassNotFoundException => + // We did not find the class + logDebug(s"Did not load class $name from REPL class server at $uri", e) + None + case e: Exception => + // Something bad happened while checking if the class exists + logError(s"Failed to check existence of class $name on REPL class server at $uri", e) + None + } finally { + if (inputStream != null) { + try { + inputStream.close() + } catch { + case e: Exception => + logError("Exception while closing inputStream", e) + } + } } } diff --git a/repl/src/test/resources/log4j.properties b/repl/src/test/resources/log4j.properties index e7e4a4113174..e2ee9c963a4d 100644 --- a/repl/src/test/resources/log4j.properties +++ b/repl/src/test/resources/log4j.properties @@ -24,4 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index 6a79e76a34db..c709cde74074 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -20,13 +20,25 @@ package org.apache.spark.repl import java.io.File import java.net.{URL, URLClassLoader} +import scala.concurrent.duration._ +import scala.language.implicitConversions +import scala.language.postfixOps + import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite +import org.scalatest.concurrent.Interruptor +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.mock.MockitoSugar +import org.mockito.Mockito._ -import org.apache.spark.{SparkConf, TestUtils} +import org.apache.spark._ import org.apache.spark.util.Utils -class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { +class ExecutorClassLoaderSuite + extends FunSuite + with BeforeAndAfterAll + with MockitoSugar + with Logging { val childClassNames = List("ReplFakeClass1", "ReplFakeClass2") val parentClassNames = List("ReplFakeClass1", "ReplFakeClass2", "ReplFakeClass3") @@ -34,6 +46,7 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { var tempDir2: File = _ var url1: String = _ var urls2: Array[URL] = _ + var classServer: HttpServer = _ override def beforeAll() { super.beforeAll() @@ -47,8 +60,12 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { override def afterAll() { super.afterAll() + if (classServer != null) { + classServer.stop() + } Utils.deleteRecursively(tempDir1) Utils.deleteRecursively(tempDir2) + SparkEnv.set(null) } test("child first") { @@ -83,4 +100,53 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { } } + test("failing to fetch classes from HTTP server should not leak resources (SPARK-6209)") { + // This is a regression test for SPARK-6209, a bug where each failed attempt to load a class + // from the driver's class server would leak a HTTP connection, causing the class server's + // thread / connection pool to be exhausted. + val conf = new SparkConf() + val securityManager = new SecurityManager(conf) + classServer = new HttpServer(conf, tempDir1, securityManager) + classServer.start() + // ExecutorClassLoader uses SparkEnv's SecurityManager, so we need to mock this + val mockEnv = mock[SparkEnv] + when(mockEnv.securityManager).thenReturn(securityManager) + SparkEnv.set(mockEnv) + // Create an ExecutorClassLoader that's configured to load classes from the HTTP server + val parentLoader = new URLClassLoader(Array.empty, null) + val classLoader = new ExecutorClassLoader(conf, classServer.uri, parentLoader, false) + classLoader.httpUrlConnectionTimeoutMillis = 500 + // Check that this class loader can actually load classes that exist + val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() + val fakeClassVersion = fakeClass.toString + assert(fakeClassVersion === "1") + // Try to perform a full GC now, since GC during the test might mask resource leaks + System.gc() + // When the original bug occurs, the test thread becomes blocked in a classloading call + // and does not respond to interrupts. Therefore, use a custom ScalaTest interruptor to + // shut down the HTTP server when the test times out + val interruptor: Interruptor = new Interruptor { + override def apply(thread: Thread): Unit = { + classServer.stop() + classServer = null + thread.interrupt() + } + } + def tryAndFailToLoadABunchOfClasses(): Unit = { + // The number of trials here should be much larger than Jetty's thread / connection limit + // in order to expose thread or connection leaks + for (i <- 1 to 1000) { + if (Thread.currentThread().isInterrupted) { + throw new InterruptedException() + } + // Incorporate the iteration number into the class name in order to avoid any response + // caching that might be added in the future + intercept[ClassNotFoundException] { + classLoader.loadClass(s"ReplFakeClassDoesNotExist$i").newInstance() + } + } + } + failAfter(10 seconds)(tryAndFailToLoadABunchOfClasses())(interruptor) + } + } diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index 89608bc41b71..de762acc8fa0 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -29,7 +29,7 @@ # SPARK_NICENESS The scheduling priority for daemons. Defaults to 0. ## -usage="Usage: spark-daemon.sh [--config ] (start|stop) " +usage="Usage: spark-daemon.sh [--config ] (start|stop|status) " # if no args specified, show usage if [ $# -le 1 ]; then @@ -121,60 +121,97 @@ if [ "$SPARK_NICENESS" = "" ]; then export SPARK_NICENESS=0 fi +run_command() { + mode="$1" + shift -case $option in + mkdir -p "$SPARK_PID_DIR" - (start|spark-submit) + if [ -f "$pid" ]; then + TARGET_ID="$(cat "$pid")" + if [[ $(ps -p "$TARGET_ID" -o comm=) =~ "java" ]]; then + echo "$command running as process $TARGET_ID. Stop it first." + exit 1 + fi + fi - mkdir -p "$SPARK_PID_DIR" + if [ "$SPARK_MASTER" != "" ]; then + echo rsync from "$SPARK_MASTER" + rsync -a -e ssh --delete --exclude=.svn --exclude='logs/*' --exclude='contrib/hod/logs/*' "$SPARK_MASTER/" "$SPARK_HOME" + fi - if [ -f $pid ]; then - if kill -0 `cat $pid` > /dev/null 2>&1; then - echo $command running as process `cat $pid`. Stop it first. - exit 1 - fi - fi + spark_rotate_log "$log" + echo "starting $command, logging to $log" + + case "$mode" in + (class) + nohup nice -n "$SPARK_NICENESS" "$SPARK_PREFIX"/bin/spark-class $command "$@" >> "$log" 2>&1 < /dev/null & + newpid="$!" + ;; + + (submit) + nohup nice -n "$SPARK_NICENESS" "$SPARK_PREFIX"/bin/spark-submit --class $command "$@" >> "$log" 2>&1 < /dev/null & + newpid="$!" + ;; + + (*) + echo "unknown mode: $mode" + exit 1 + ;; + esac + + echo "$newpid" > "$pid" + sleep 2 + # Check if the process has died; in that case we'll tail the log so the user can see + if [[ ! $(ps -p "$newpid" -o comm=) =~ "java" ]]; then + echo "failed to launch $command:" + tail -2 "$log" | sed 's/^/ /' + echo "full log in $log" + fi +} - if [ "$SPARK_MASTER" != "" ]; then - echo rsync from "$SPARK_MASTER" - rsync -a -e ssh --delete --exclude=.svn --exclude='logs/*' --exclude='contrib/hod/logs/*' $SPARK_MASTER/ "$SPARK_HOME" - fi +case $option in - spark_rotate_log "$log" - echo starting $command, logging to $log - if [ $option == spark-submit ]; then - source "$SPARK_HOME"/bin/utils.sh - gatherSparkSubmitOpts "$@" - nohup nice -n $SPARK_NICENESS "$SPARK_PREFIX"/bin/spark-submit --class $command \ - "${SUBMISSION_OPTS[@]}" spark-internal "${APPLICATION_OPTS[@]}" >> "$log" 2>&1 < /dev/null & - else - nohup nice -n $SPARK_NICENESS "$SPARK_PREFIX"/bin/spark-class $command "$@" >> "$log" 2>&1 < /dev/null & - fi - newpid=$! - echo $newpid > $pid - sleep 2 - # Check if the process has died; in that case we'll tail the log so the user can see - if ! kill -0 $newpid >/dev/null 2>&1; then - echo "failed to launch $command:" - tail -2 "$log" | sed 's/^/ /' - echo "full log in $log" - fi + (submit) + run_command submit "$@" + ;; + + (start) + run_command class "$@" ;; (stop) if [ -f $pid ]; then - if kill -0 `cat $pid` > /dev/null 2>&1; then - echo stopping $command - kill `cat $pid` + TARGET_ID="$(cat "$pid")" + if [[ $(ps -p "$TARGET_ID" -o comm=) =~ "java" ]]; then + echo "stopping $command" + kill "$TARGET_ID" && rm -f "$pid" else - echo no $command to stop + echo "no $command to stop" fi else - echo no $command to stop + echo "no $command to stop" fi ;; + (status) + + if [ -f $pid ]; then + TARGET_ID="$(cat "$pid")" + if [[ $(ps -p "$TARGET_ID" -o comm=) =~ "java" ]]; then + echo $command is running. + exit 0 + else + echo $pid file is present but $command not running + exit 1 + fi + else + echo $command not running. + exit 2 + fi + ;; + (*) echo $usage exit 1 diff --git a/sbin/start-slave.sh b/sbin/start-slave.sh index 2fc35309f4ca..4c919ff76a8f 100755 --- a/sbin/start-slave.sh +++ b/sbin/start-slave.sh @@ -17,10 +17,69 @@ # limitations under the License. # -# Usage: start-slave.sh -# where is like "spark://localhost:7077" +# Starts a slave on the machine this script is executed on. +# +# Environment Variables +# +# SPARK_WORKER_INSTANCES The number of worker instances to run on this +# slave. Default is 1. +# SPARK_WORKER_PORT The base port number for the first worker. If set, +# subsequent workers will increment this number. If +# unset, Spark will find a valid port number, but +# with no guarantee of a predictable pattern. +# SPARK_WORKER_WEBUI_PORT The base port for the web interface of the first +# worker. Subsequent workers will increment this +# number. Default is 8081. + +usage="Usage: start-slave.sh where is like spark://localhost:7077" + +if [ $# -lt 1 ]; then + echo $usage + echo Called as start-slave.sh $* + exit 1 +fi sbin="`dirname "$0"`" sbin="`cd "$sbin"; pwd`" -"$sbin"/spark-daemon.sh start org.apache.spark.deploy.worker.Worker "$@" +. "$sbin/spark-config.sh" + +. "$SPARK_PREFIX/bin/load-spark-env.sh" + +# First argument should be the master; we need to store it aside because we may +# need to insert arguments between it and the other arguments +MASTER=$1 +shift + +# Determine desired worker port +if [ "$SPARK_WORKER_WEBUI_PORT" = "" ]; then + SPARK_WORKER_WEBUI_PORT=8081 +fi + +# Start up the appropriate number of workers on this machine. +# quick local function to start a worker +function start_instance { + WORKER_NUM=$1 + shift + + if [ "$SPARK_WORKER_PORT" = "" ]; then + PORT_FLAG= + PORT_NUM= + else + PORT_FLAG="--port" + PORT_NUM=$(( $SPARK_WORKER_PORT + $WORKER_NUM - 1 )) + fi + WEBUI_PORT=$(( $SPARK_WORKER_WEBUI_PORT + $WORKER_NUM - 1 )) + + "$sbin"/spark-daemon.sh start org.apache.spark.deploy.worker.Worker $WORKER_NUM \ + --webui-port "$WEBUI_PORT" $PORT_FLAG $PORT_NUM $MASTER "$@" +} + +if [ "$SPARK_WORKER_INSTANCES" = "" ]; then + start_instance 1 "$@" +else + for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do + start_instance $(( 1 + $i )) "$@" + done +fi + diff --git a/sbin/start-slaves.sh b/sbin/start-slaves.sh index ba1a84abc1fe..24d6268815ed 100755 --- a/sbin/start-slaves.sh +++ b/sbin/start-slaves.sh @@ -17,6 +17,8 @@ # limitations under the License. # +# Starts a slave instance on each machine specified in the conf/slaves file. + sbin="`dirname "$0"`" sbin="`cd "$sbin"; pwd`" @@ -57,13 +59,4 @@ if [ "$START_TACHYON" == "true" ]; then fi # Launch the slaves -if [ "$SPARK_WORKER_INSTANCES" = "" ]; then - exec "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" 1 "spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT" -else - if [ "$SPARK_WORKER_WEBUI_PORT" = "" ]; then - SPARK_WORKER_WEBUI_PORT=8081 - fi - for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do - "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" $(( $i + 1 )) "spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT" --webui-port $(( $SPARK_WORKER_WEBUI_PORT + $i )) - done -fi +"$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" "spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT" diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh index 50e8e06418b0..5b0aeb177fff 100755 --- a/sbin/start-thriftserver.sh +++ b/sbin/start-thriftserver.sh @@ -26,6 +26,8 @@ set -o posix # Figure out where Spark is installed FWDIR="$(cd "`dirname "$0"`"/..; pwd)" +# NOTE: This exact class name is matched downstream by SparkSubmit. +# Any changes need to be reflected there. CLASS="org.apache.spark.sql.hive.thriftserver.HiveThriftServer2" function usage { @@ -50,4 +52,4 @@ fi export SUBMIT_USAGE_FUNCTION=usage -exec "$FWDIR"/sbin/spark-daemon.sh spark-submit $CLASS 1 "$@" +exec "$FWDIR"/sbin/spark-daemon.sh submit $CLASS 1 "$@" diff --git a/sbin/stop-all.sh b/sbin/stop-all.sh index 971d5d49da66..1a9abe07db84 100755 --- a/sbin/stop-all.sh +++ b/sbin/stop-all.sh @@ -17,8 +17,8 @@ # limitations under the License. # -# Start all spark daemons. -# Run this on the master nde +# Stop all spark daemons. +# Run this on the master node. sbin="`dirname "$0"`" diff --git a/sbin/stop-master.sh b/sbin/stop-master.sh index b6bdaa4db373..729702d92191 100755 --- a/sbin/stop-master.sh +++ b/sbin/stop-master.sh @@ -17,7 +17,7 @@ # limitations under the License. # -# Starts the master on the machine this script is executed on. +# Stops the master on the machine this script is executed on. sbin=`dirname "$0"` sbin=`cd "$sbin"; pwd` diff --git a/sbin/stop-slave.sh b/sbin/stop-slave.sh new file mode 100755 index 000000000000..3d1da5b254f2 --- /dev/null +++ b/sbin/stop-slave.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# A shell script to stop all workers on a single slave +# +# Environment variables +# +# SPARK_WORKER_INSTANCES The number of worker instances that should be +# running on this slave. Default is 1. + +# Usage: stop-slave.sh +# Stops all slaves on this worker machine + +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" + +. "$sbin/spark-config.sh" + +. "$SPARK_PREFIX/bin/load-spark-env.sh" + +if [ "$SPARK_WORKER_INSTANCES" = "" ]; then + "$sbin"/spark-daemon.sh stop org.apache.spark.deploy.worker.Worker 1 +else + for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do + "$sbin"/spark-daemon.sh stop org.apache.spark.deploy.worker.Worker $(( $i + 1 )) + done +fi diff --git a/sbin/stop-slaves.sh b/sbin/stop-slaves.sh index 7c2201100ef9..54c9bd46803a 100755 --- a/sbin/stop-slaves.sh +++ b/sbin/stop-slaves.sh @@ -17,8 +17,8 @@ # limitations under the License. # -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" . "$sbin/spark-config.sh" @@ -29,10 +29,4 @@ if [ -e "$sbin"/../tachyon/bin/tachyon ]; then "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin"/../tachyon/bin/tachyon killAll tachyon.worker.Worker fi -if [ "$SPARK_WORKER_INSTANCES" = "" ]; then - "$sbin"/spark-daemons.sh stop org.apache.spark.deploy.worker.Worker 1 -else - for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do - "$sbin"/spark-daemons.sh stop org.apache.spark.deploy.worker.Worker $(( $i + 1 )) - done -fi +"$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin"/stop-slave.sh diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 0ff521706c71..7168d5b2a8e2 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -137,9 +137,9 @@ - + - + diff --git a/sql/README.md b/sql/README.md index d058a6b011d3..237620e3fa80 100644 --- a/sql/README.md +++ b/sql/README.md @@ -22,7 +22,8 @@ export HADOOP_HOME="/hadoop-1.0.4" Using the console ================= -An interactive scala console can be invoked by running `build/sbt hive/console`. From here you can execute queries and inspect the various stages of query optimization. +An interactive scala console can be invoked by running `build/sbt hive/console`. +From here you can execute queries with HiveQl and manipulate DataFrame by using DSL. ```scala catalyst$ build/sbt hive/console @@ -36,45 +37,25 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution +import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.TestHive._ +import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.types._ -Welcome to Scala version 2.10.4 (Java HotSpot(TM) 64-Bit Server VM, Java 1.7.0_45). Type in expressions to have them evaluated. Type :help for more information. scala> val query = sql("SELECT * FROM (SELECT * FROM src) a") -query: org.apache.spark.sql.SchemaRDD = -== Query Plan == -== Physical Plan == -HiveTableScan [key#10,value#11], (MetastoreRelation default, src, None), None +query: org.apache.spark.sql.DataFrame = org.apache.spark.sql.DataFrame@74448eed ``` -Query results are RDDs and can be operated as such. +Query results are `DataFrames` and can be operated as such. ``` scala> query.collect() res2: Array[org.apache.spark.sql.Row] = Array([238,val_238], [86,val_86], [311,val_311], [27,val_27]... ``` -You can also build further queries on top of these RDDs using the query DSL. +You can also build further queries on top of these `DataFrames` using the query DSL. ``` -scala> query.where('key === 100).collect() -res3: Array[org.apache.spark.sql.Row] = Array([100,val_100], [100,val_100]) -``` - -From the console you can even write rules that transform query plans. For example, the above query has redundant project operators that aren't doing anything. This redundancy can be eliminated using the `transform` function that is available on all [`TreeNode`](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala) objects. -```scala -scala> query.queryExecution.analyzed -res4: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan = -Project [key#10,value#11] - Project [key#10,value#11] - MetastoreRelation default, src, None - - -scala> query.queryExecution.analyzed transform { - | case Project(projectList, child) if projectList == child.output => child - | } -res5: res17: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan = -Project [key#10,value#11] - MetastoreRelation default, src, None +scala> query.where(query("key") > 30).select(avg(query("key"))).collect() +res3: Array[org.apache.spark.sql.Row] = Array([274.79025423728814]) ``` diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index a1947fb022e5..3dea2ee76542 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -21,8 +21,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala new file mode 100644 index 000000000000..f9992185a456 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * Thrown when a query fails to analyze, usually because the query itself is invalid. + */ +@DeveloperApi +class AnalysisException protected[sql] ( + val message: String, + val line: Option[Int] = None, + val startPosition: Option[Int] = None) + extends Exception with Serializable { + + def withPosition(line: Option[Int], startPosition: Option[Int]): AnalysisException = { + val newException = new AnalysisException(message, line, startPosition) + newException.setStackTrace(getStackTrace) + newException + } + + override def getMessage: String = { + val lineAnnotation = line.map(l => s" line $l").getOrElse("") + val positionAnnotation = startPosition.map(p => s" pos $p").getOrElse("") + s"$message;$lineAnnotation$positionAnnotation" + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 41bb4f012f2e..4190b7ffe1c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import scala.util.hashing.MurmurHash3 import org.apache.spark.sql.catalyst.expressions.GenericRow - +import org.apache.spark.sql.types.StructType object Row { /** @@ -122,6 +122,11 @@ trait Row extends Serializable { /** Number of elements in the Row. */ def length: Int + /** + * Schema for the row. + */ + def schema: StructType = null + /** * Returns the value at position i. If the value is null, null is returned. The following * is a mapping between Spark SQL types and return types: @@ -252,6 +257,7 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ + // TODO(davies): This is not the right default implementation, we use Int as Date internally def getDate(i: Int): java.sql.Date = apply(i).asInstanceOf[java.sql.Date] /** @@ -300,6 +306,38 @@ trait Row extends Serializable { */ def getAs[T](i: Int): T = apply(i).asInstanceOf[T] + /** + * Returns the value of a given fieldName. + * + * @throws UnsupportedOperationException when schema is not defined. + * @throws IllegalArgumentException when fieldName do not exist. + * @throws ClassCastException when data type does not match. + */ + def getAs[T](fieldName: String): T = getAs[T](fieldIndex(fieldName)) + + /** + * Returns the index of a given field name. + * + * @throws UnsupportedOperationException when schema is not defined. + * @throws IllegalArgumentException when fieldName do not exist. + */ + def fieldIndex(name: String): Int = { + throw new UnsupportedOperationException("fieldIndex on a Row without schema is undefined.") + } + + /** + * Returns a Map(name -> value) for the requested fieldNames + * + * @throws UnsupportedOperationException when schema is not defined. + * @throws IllegalArgumentException when fieldName do not exist. + * @throws ClassCastException when data type does not match. + */ + def getValuesMap[T](fieldNames: Seq[String]): Map[String, T] = { + fieldNames.map { name => + name -> getAs[T](name) + }.toMap + } + override def toString(): String = s"[${this.mkString(",")}]" /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index 366be00473d1..85ac4ac3fa39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -26,13 +26,13 @@ import scala.util.parsing.input.CharArrayReader.EofCh import org.apache.spark.sql.catalyst.plans.logical._ private[sql] object KeywordNormalizer { - def apply(str: String) = str.toLowerCase() + def apply(str: String): String = str.toLowerCase() } private[sql] abstract class AbstractSparkSQLParser extends StandardTokenParsers with PackratParsers { - def apply(input: String): LogicalPlan = { + def parse(input: String): LogicalPlan = { // Initialize the Keywords. lexical.initialize(reservedWords) phrase(start)(new lexical.Scanner(input)) match { @@ -42,7 +42,7 @@ private[sql] abstract class AbstractSparkSQLParser } protected case class Keyword(str: String) { - def normalize = KeywordNormalizer(str) + def normalize: String = KeywordNormalizer(str) def parser: Parser[String] = normalize } @@ -52,6 +52,8 @@ private[sql] abstract class AbstractSparkSQLParser // NOTICE, Since the Keyword properties defined by sub class, we couldn't call this // method during the parent class instantiation, because the sub class instance // isn't created yet. + // 貌似这里 用lazy还是有玄机的,不想在初始化时父类里面调用 + // 但是说实在的,定义一个 Keyword 类,采用反射获取 reserved 很奇怪, 何不直接定义一个Seq【String】 protected lazy val reservedWords: Seq[String] = this .getClass @@ -79,9 +81,19 @@ private[sql] abstract class AbstractSparkSQLParser } } +/** + * spark sql 提供的词法解析器, 该解析器的主要解析出 keywords, identifier, numeric, string 和 delimiters + * 为了区分identifiers 和 keywords, 词法解析器根据reserved集合来匹配keywords,只要是在reserved里面的字符串都不会被解析为identifire 而是会被解析成 keyword + * 用户可以自定义reserved + * 另外分隔符可以通过 delimiters 来设置 + * + * 通常来讲,这个词法解析器用于将字符输入切割成多个token,然后我们会将这些token 传给一个token parser去解析 + * [[scala.util.parsing.combinator.syntactical.TokenParsers]].) + */ class SqlLexical extends StdLexical { + // 定义了一个 float 的token,后面解析float字符串用 case class FloatLit(chars: String) extends Token { - override def toString = chars + override def toString: String = chars } /* This is a work around to support the lazy setting */ @@ -96,7 +108,9 @@ class SqlLexical extends StdLexical { ) protected override def processIdent(name: String) = { + // 全部转为小写 val token = KeywordNormalizer(name) + // 这个够直观了吧 if (reserved contains token) Keyword(token) else Identifier(name) } @@ -120,7 +134,7 @@ class SqlLexical extends StdLexical { | failure("illegal character") ) - override def identChar = letter | elem('_') + override def identChar: Parser[Elem] = letter | elem('_') override def whitespace: Parser[Any] = ( whitespaceChar diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala new file mode 100644 index 000000000000..a13e2f36a1a1 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -0,0 +1,352 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import java.lang.{Iterable => JavaIterable} +import java.util.{Map => JavaMap} + +import scala.collection.mutable.HashMap + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +/** + * Functions to convert Scala types to Catalyst types and vice versa. + */ +object CatalystTypeConverters { + // The Predef.Map is scala.collection.immutable.Map. + // Since the map values can be mutable, we explicitly import scala.collection.Map at here. + import scala.collection.Map + + /** + * Converts Scala objects to catalyst rows / types. This method is slow, and for batch + * conversion you should be using converter produced by createToCatalystConverter. + * Note: This is always called after schemaFor has been called. + * This ordering is important for UDT registration. + */ + def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { + // Check UDT first since UDTs can override other types + case (obj, udt: UserDefinedType[_]) => + udt.serialize(obj) + + case (o: Option[_], _) => + o.map(convertToCatalyst(_, dataType)).orNull + + case (s: Seq[_], arrayType: ArrayType) => + s.map(convertToCatalyst(_, arrayType.elementType)) + + case (jit: JavaIterable[_], arrayType: ArrayType) => { + val iter = jit.iterator + var listOfItems: List[Any] = List() + while (iter.hasNext) { + val item = iter.next() + listOfItems :+= convertToCatalyst(item, arrayType.elementType) + } + listOfItems + } + + case (s: Array[_], arrayType: ArrayType) => + s.toSeq.map(convertToCatalyst(_, arrayType.elementType)) + + case (m: Map[_, _], mapType: MapType) => + m.map { case (k, v) => + convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) + } + + case (jmap: JavaMap[_, _], mapType: MapType) => + val iter = jmap.entrySet.iterator + var listOfEntries: List[(Any, Any)] = List() + while (iter.hasNext) { + val entry = iter.next() + listOfEntries :+= (convertToCatalyst(entry.getKey, mapType.keyType), + convertToCatalyst(entry.getValue, mapType.valueType)) + } + listOfEntries.toMap + + case (p: Product, structType: StructType) => + val ar = new Array[Any](structType.size) + val iter = p.productIterator + var idx = 0 + while (idx < structType.size) { + ar(idx) = convertToCatalyst(iter.next(), structType.fields(idx).dataType) + idx += 1 + } + new GenericRowWithSchema(ar, structType) + + case (d: String, _) => + UTF8String(d) + + case (d: BigDecimal, _) => + Decimal(d) + + case (d: java.math.BigDecimal, _) => + Decimal(d) + + case (d: java.sql.Date, _) => + DateUtils.fromJavaDate(d) + + case (r: Row, structType: StructType) => + val converters = structType.fields.map { + f => (item: Any) => convertToCatalyst(item, f.dataType) + } + convertRowWithConverters(r, structType, converters) + + case (other, _) => + other + } + + /** + * Creates a converter function that will convert Scala objects to the specified catalyst type. + * Typical use case would be converting a collection of rows that have the same schema. You will + * call this function once to get a converter, and apply it to every row. + */ + private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = { + def extractOption(item: Any): Any = item match { + case opt: Option[_] => opt.orNull + case other => other + } + + dataType match { + // Check UDT first since UDTs can override other types + case udt: UserDefinedType[_] => + (item) => extractOption(item) match { + case null => null + case other => udt.serialize(other) + } + + case arrayType: ArrayType => + val elementConverter = createToCatalystConverter(arrayType.elementType) + (item: Any) => { + extractOption(item) match { + case a: Array[_] => a.toSeq.map(elementConverter) + case s: Seq[_] => s.map(elementConverter) + case i: JavaIterable[_] => { + val iter = i.iterator + var convertedIterable: List[Any] = List() + while (iter.hasNext) { + val item = iter.next() + convertedIterable :+= elementConverter(item) + } + convertedIterable + } + case null => null + } + } + + case mapType: MapType => + val keyConverter = createToCatalystConverter(mapType.keyType) + val valueConverter = createToCatalystConverter(mapType.valueType) + (item: Any) => { + extractOption(item) match { + case m: Map[_, _] => + m.map { case (k, v) => + keyConverter(k) -> valueConverter(v) + } + + case jmap: JavaMap[_, _] => + val iter = jmap.entrySet.iterator + val convertedMap: HashMap[Any, Any] = HashMap() + while (iter.hasNext) { + val entry = iter.next() + convertedMap(keyConverter(entry.getKey)) = valueConverter(entry.getValue) + } + convertedMap + + case null => null + } + } + + case structType: StructType => + val converters = structType.fields.map(f => createToCatalystConverter(f.dataType)) + (item: Any) => { + extractOption(item) match { + case r: Row => + convertRowWithConverters(r, structType, converters) + + case p: Product => + val ar = new Array[Any](structType.size) + val iter = p.productIterator + var idx = 0 + while (idx < structType.size) { + ar(idx) = converters(idx)(iter.next()) + idx += 1 + } + new GenericRowWithSchema(ar, structType) + + case null => + null + } + } + + case dateType: DateType => (item: Any) => extractOption(item) match { + case d: java.sql.Date => DateUtils.fromJavaDate(d) + case other => other + } + + case dataType: StringType => (item: Any) => extractOption(item) match { + case s: String => UTF8String(s) + case other => other + } + + case _ => + (item: Any) => extractOption(item) match { + case d: BigDecimal => Decimal(d) + case d: java.math.BigDecimal => Decimal(d) + case other => other + } + } + } + + /** + * Converts Scala objects to catalyst rows / types. + * + * Note: This should be called before do evaluation on Row + * (It does not support UDT) + * This is used to create an RDD or test results with correct types for Catalyst. + */ + def convertToCatalyst(a: Any): Any = a match { + case s: String => UTF8String(s) + case d: java.sql.Date => DateUtils.fromJavaDate(d) + case d: BigDecimal => Decimal(d) + case d: java.math.BigDecimal => Decimal(d) + case seq: Seq[Any] => seq.map(convertToCatalyst) + case r: Row => Row(r.toSeq.map(convertToCatalyst): _*) + case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray + case m: Map[Any, Any] => + m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap + case other => other + } + + /** + * Converts Catalyst types used internally in rows to standard Scala types + * This method is slow, and for batch conversion you should be using converter + * produced by createToScalaConverter. + */ + def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match { + // Check UDT first since UDTs can override other types + case (d, udt: UserDefinedType[_]) => + udt.deserialize(d) + + case (s: Seq[_], arrayType: ArrayType) => + s.map(convertToScala(_, arrayType.elementType)) + + case (m: Map[_, _], mapType: MapType) => + m.map { case (k, v) => + convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) + } + + case (r: Row, s: StructType) => + convertRowToScala(r, s) + + case (d: Decimal, _: DecimalType) => + d.toJavaBigDecimal + + case (i: Int, DateType) => + DateUtils.toJavaDate(i) + + case (s: UTF8String, StringType) => + s.toString() + + case (other, _) => + other + } + + /** + * Creates a converter function that will convert Catalyst types to Scala type. + * Typical use case would be converting a collection of rows that have the same schema. You will + * call this function once to get a converter, and apply it to every row. + */ + private[sql] def createToScalaConverter(dataType: DataType): Any => Any = dataType match { + // Check UDT first since UDTs can override other types + case udt: UserDefinedType[_] => + (item: Any) => if (item == null) null else udt.deserialize(item) + + case arrayType: ArrayType => + val elementConverter = createToScalaConverter(arrayType.elementType) + (item: Any) => if (item == null) null else item.asInstanceOf[Seq[_]].map(elementConverter) + + case mapType: MapType => + val keyConverter = createToScalaConverter(mapType.keyType) + val valueConverter = createToScalaConverter(mapType.valueType) + (item: Any) => if (item == null) { + null + } else { + item.asInstanceOf[Map[_, _]].map { case (k, v) => + keyConverter(k) -> valueConverter(v) + } + } + + case s: StructType => + val converters = s.fields.map(f => createToScalaConverter(f.dataType)) + (item: Any) => { + if (item == null) { + null + } else { + convertRowWithConverters(item.asInstanceOf[Row], s, converters) + } + } + + case _: DecimalType => + (item: Any) => item match { + case d: Decimal => d.toJavaBigDecimal + case other => other + } + + case DateType => + (item: Any) => item match { + case i: Int => DateUtils.toJavaDate(i) + case other => other + } + + case StringType => + (item: Any) => item match { + case s: UTF8String => s.toString() + case other => other + } + + case other => + (item: Any) => item + } + + def convertRowToScala(r: Row, schema: StructType): Row = { + val ar = new Array[Any](r.size) + var idx = 0 + while (idx < r.size) { + ar(idx) = convertToScala(r(idx), schema.fields(idx).dataType) + idx += 1 + } + new GenericRowWithSchema(ar, schema) + } + + /** + * Converts a row by applying the provided set of converter functions. It is used for both + * toScala and toCatalyst conversions. + */ + private[sql] def convertRowWithConverters( + row: Row, + schema: StructType, + converters: Array[Any => Any]): Row = { + val ar = new Array[Any](row.size) + var idx = 0 + while (idx < row.size) { + ar(idx) = converters(idx)(row(idx)) + idx += 1 + } + new GenericRowWithSchema(ar, schema) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 191d16fb10b5..c52965507c71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,14 +17,11 @@ package org.apache.spark.sql.catalyst -import java.sql.{Date, Timestamp} - import org.apache.spark.util.Utils -import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ - /** * A default version of ScalaReflection that uses the runtime universe. */ @@ -47,60 +44,18 @@ trait ScalaReflection { case class Schema(dataType: DataType, nullable: Boolean) - /** - * Converts Scala objects to catalyst rows / types. - * Note: This is always called after schemaFor has been called. - * This ordering is important for UDT registration. - */ - def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { - // Check UDT first since UDTs can override other types - case (obj, udt: UserDefinedType[_]) => udt.serialize(obj) - case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull - case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType)) - case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => - convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) - } - case (p: Product, structType: StructType) => - new GenericRow( - p.productIterator.toSeq.zip(structType.fields).map { case (elem, field) => - convertToCatalyst(elem, field.dataType) - }.toArray) - case (d: BigDecimal, _) => Decimal(d) - case (d: java.math.BigDecimal, _) => Decimal(d) - case (other, _) => other - } - - /** Converts Catalyst types used internally in rows to standard Scala types */ - def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match { - // Check UDT first since UDTs can override other types - case (d, udt: UserDefinedType[_]) => udt.deserialize(d) - case (s: Seq[_], arrayType: ArrayType) => s.map(convertToScala(_, arrayType.elementType)) - case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => - convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) - } - case (r: Row, s: StructType) => convertRowToScala(r, s) - case (d: Decimal, _: DecimalType) => d.toJavaBigDecimal - case (other, _) => other - } - - def convertRowToScala(r: Row, schema: StructType): Row = { - // TODO: This is very slow!!! - new GenericRow( - r.toSeq.zip(schema.fields.map(_.dataType)) - .map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray) - } - /** Returns a Sequence of attributes for the given case class type. */ def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { case Schema(s: StructType, _) => - s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) + s.toAttributes } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T]) + def schemaFor[T: TypeTag]: Schema = + ScalaReflectionLock.synchronized { schemaFor(typeOf[T]) } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor(tpe: `Type`): Schema = { + def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { val className: String = tpe.erasure.typeSymbol.asClass.fullName tpe match { case t if Utils.classIsLoadable(className) && @@ -115,6 +70,21 @@ trait ScalaReflection { case t if t <:< typeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t Schema(schemaFor(optType).dataType, nullable = true) + // Need to decide if we actually need a special type here. + case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) + case t if t <:< typeOf[Array[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) + case t if t <:< typeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) + case t if t <:< typeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + val Schema(valueDataType, valueNullable) = schemaFor(valueType) + Schema(MapType(schemaFor(keyType).dataType, + valueDataType, valueContainsNull = valueNullable), nullable = true) case t if t <:< typeOf[Product] => val formalTypeArgs = t.typeSymbol.asClass.typeParams val TypeRef(_, _, actualTypeArgs) = t @@ -137,22 +107,9 @@ trait ScalaReflection { schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)) StructField(p.name.toString, dataType, nullable) }), nullable = true) - // Need to decide if we actually need a special type here. - case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) - case t if t <:< typeOf[Array[_]] => - sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") - case t if t <:< typeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val Schema(dataType, nullable) = schemaFor(elementType) - Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if t <:< typeOf[Map[_, _]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - Schema(MapType(schemaFor(keyType).dataType, - valueDataType, valueContainsNull = valueNullable), nullable = true) case t if t <:< typeOf[String] => Schema(StringType, nullable = true) - case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) - case t if t <:< typeOf[Date] => Schema(DateType, nullable = true) + case t if t <:< typeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true) + case t if t <:< typeOf[java.sql.Date] => Schema(DateType, nullable = true) case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) case t if t <:< typeOf[java.math.BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true) @@ -170,24 +127,27 @@ trait ScalaReflection { case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + case other => + throw new UnsupportedOperationException(s"Schema for type $other is not supported") } } def typeOfObject: PartialFunction[Any, DataType] = { // The data type can be determined without ambiguity. - case obj: BooleanType.JvmType => BooleanType - case obj: BinaryType.JvmType => BinaryType - case obj: StringType.JvmType => StringType - case obj: ByteType.JvmType => ByteType - case obj: ShortType.JvmType => ShortType - case obj: IntegerType.JvmType => IntegerType - case obj: LongType.JvmType => LongType - case obj: FloatType.JvmType => FloatType - case obj: DoubleType.JvmType => DoubleType - case obj: DateType.JvmType => DateType + case obj: Boolean => BooleanType + case obj: Array[Byte] => BinaryType + case obj: String => StringType + case obj: UTF8String => StringType + case obj: Byte => ByteType + case obj: Short => ShortType + case obj: Int => IntegerType + case obj: Long => LongType + case obj: Float => FloatType + case obj: Double => DoubleType + case obj: java.sql.Date => DateType case obj: java.math.BigDecimal => DecimalType.Unlimited case obj: Decimal => DecimalType.Unlimited - case obj: TimestampType.JvmType => TimestampType + case obj: java.sql.Timestamp => TimestampType case null => NullType // For other cases, there is no obvious mapping from the type of the given object to a // Catalyst data type. A user should provide his/her specific rules @@ -203,7 +163,7 @@ trait ScalaReflection { */ def asRelation: LocalRelation = { val output = attributesFor[A] - LocalRelation(output, data) + LocalRelation.fromProduct(output, data) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala old mode 100755 new mode 100644 index eaadbe9fd509..0af969cc5cc6 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -35,7 +35,17 @@ import org.apache.spark.sql.types._ * This is currently included mostly for illustrative purposes. Users wanting more complete support * for a SQL like language should checkout the HiveQL support in the sql/hive sub-project. */ -class SqlParser extends AbstractSparkSQLParser { +class SqlParser extends AbstractSparkSQLParser with DataTypeParser { + + def parseExpression(input: String): Expression = { + // Initialize the Keywords. + lexical.initialize(reservedWords) + phrase(projection)(new lexical.Scanner(input)) match { + case Success(plan, _) => plan + case failureOrError => sys.error(failureOrError.toString) + } + } + // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` // properties via reflection the class in runtime for constructing the SqlLexical object protected val ABS = Keyword("ABS") @@ -47,14 +57,12 @@ class SqlParser extends AbstractSparkSQLParser { protected val AVG = Keyword("AVG") protected val BETWEEN = Keyword("BETWEEN") protected val BY = Keyword("BY") - protected val CACHE = Keyword("CACHE") protected val CASE = Keyword("CASE") protected val CAST = Keyword("CAST") + protected val COALESCE = Keyword("COALESCE") protected val COUNT = Keyword("COUNT") - protected val DECIMAL = Keyword("DECIMAL") protected val DESC = Keyword("DESC") protected val DISTINCT = Keyword("DISTINCT") - protected val DOUBLE = Keyword("DOUBLE") protected val ELSE = Keyword("ELSE") protected val END = Keyword("END") protected val EXCEPT = Keyword("EXCEPT") @@ -93,18 +101,17 @@ class SqlParser extends AbstractSparkSQLParser { protected val SELECT = Keyword("SELECT") protected val SEMI = Keyword("SEMI") protected val SQRT = Keyword("SQRT") - protected val STRING = Keyword("STRING") protected val SUBSTR = Keyword("SUBSTR") protected val SUBSTRING = Keyword("SUBSTRING") protected val SUM = Keyword("SUM") protected val TABLE = Keyword("TABLE") protected val THEN = Keyword("THEN") - protected val TIMESTAMP = Keyword("TIMESTAMP") protected val TRUE = Keyword("TRUE") protected val UNION = Keyword("UNION") protected val UPPER = Keyword("UPPER") protected val WHEN = Keyword("WHEN") protected val WHERE = Keyword("WHERE") + protected val WITH = Keyword("WITH") protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = { exprs.zipWithIndex.map { @@ -114,13 +121,14 @@ class SqlParser extends AbstractSparkSQLParser { } protected lazy val start: Parser[LogicalPlan] = - ( (select | ("(" ~> select <~ ")")) * - ( UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } - | INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) } - | EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} - | UNION ~ DISTINCT.? ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } - ) - | insert + start1 | insert | cte + + protected lazy val start1: Parser[LogicalPlan] = + (select | ("(" ~> select <~ ")")) * + ( UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } + | INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) } + | EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} + | UNION ~ DISTINCT.? ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } ) protected lazy val select: Parser[LogicalPlan] = @@ -133,7 +141,7 @@ class SqlParser extends AbstractSparkSQLParser { sortType.? ~ (LIMIT ~> expression).? ^^ { case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => - val base = r.getOrElse(NoRelation) + val base = r.getOrElse(OneRowRelation) val withFilter = f.map(Filter(_, base)).getOrElse(base) val withProjection = g .map(Aggregate(_, assignAliases(p), withFilter)) @@ -146,8 +154,13 @@ class SqlParser extends AbstractSparkSQLParser { } protected lazy val insert: Parser[LogicalPlan] = - INSERT ~> OVERWRITE.? ~ (INTO ~> relation) ~ select ^^ { - case o ~ r ~ s => InsertIntoTable(r, Map.empty[String, Option[String]], s, o.isDefined) + INSERT ~> (OVERWRITE ^^^ true | INTO ^^^ false) ~ (TABLE ~> relation) ~ select ^^ { + case o ~ r ~ s => InsertIntoTable(r, Map.empty[String, Option[String]], s, o, false) + } + + protected lazy val cte: Parser[LogicalPlan] = + WITH ~> rep1sep(ident ~ ( AS ~ "(" ~> start1 <~ ")"), ",") ~ (start1 | insert) ^^ { + case r ~ s => With(s, r.map({case n ~ s => (n, Subquery(n, s))}).toMap) } protected lazy val projection: Parser[Expression] = @@ -295,6 +308,7 @@ class SqlParser extends AbstractSparkSQLParser { { case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) } | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ { case s ~ p ~ l => Substring(s, p, l) } + | COALESCE ~ "(" ~> repsep(expression, ",") <~ ")" ^^ { case exprs => Coalesce(exprs) } | SQRT ~ "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } | ABS ~ "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) } | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ @@ -302,18 +316,20 @@ class SqlParser extends AbstractSparkSQLParser { ) protected lazy val cast: Parser[Expression] = - CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { case exp ~ t => Cast(exp, t) } + CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { + case exp ~ t => Cast(exp, t) + } protected lazy val literal: Parser[Literal] = ( numericLiteral | booleanLiteral - | stringLit ^^ {case s => Literal(s, StringType) } - | NULL ^^^ Literal(null, NullType) + | stringLit ^^ {case s => Literal.create(s, StringType) } + | NULL ^^^ Literal.create(null, NullType) ) protected lazy val booleanLiteral: Parser[Literal] = - ( TRUE ^^^ Literal(true, BooleanType) - | FALSE ^^^ Literal(false, BooleanType) + ( TRUE ^^^ Literal.create(true, BooleanType) + | FALSE ^^^ Literal.create(false, BooleanType) ) protected lazy val numericLiteral: Parser[Literal] = @@ -348,7 +364,7 @@ class SqlParser extends AbstractSparkSQLParser { ) protected lazy val baseExpression: Parser[Expression] = - ( "*" ^^^ Star(None) + ( "*" ^^^ UnresolvedStar(None) | primary ) @@ -360,31 +376,18 @@ class SqlParser extends AbstractSparkSQLParser { | expression ~ ("[" ~> expression <~ "]") ^^ { case base ~ ordinal => GetItem(base, ordinal) } | (expression <~ ".") ~ ident ^^ - { case base ~ fieldName => GetField(base, fieldName) } + { case base ~ fieldName => UnresolvedGetField(base, fieldName) } | cast | "(" ~> expression <~ ")" | function | dotExpressionHeader - | ident ^^ UnresolvedAttribute + | ident ^^ {case i => UnresolvedAttribute.quoted(i)} | signedPrimary | "~" ~> expression ^^ BitwiseNot ) protected lazy val dotExpressionHeader: Parser[Expression] = (ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ { - case i1 ~ i2 ~ rest => UnresolvedAttribute(i1 + "." + i2 + rest.mkString(".", ".", "")) - } - - protected lazy val dataType: Parser[DataType] = - ( STRING ^^^ StringType - | TIMESTAMP ^^^ TimestampType - | DOUBLE ^^^ DoubleType - | fixedDecimalType - | DECIMAL ^^^ DecimalType.Unlimited - ) - - protected lazy val fixedDecimalType: Parser[DataType] = - (DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ { - case precision ~ scale => DecimalType(precision.toInt, scale.toInt) + case i1 ~ i2 ~ rest => UnresolvedAttribute(Seq(i1, i2) ++ rest) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7f4cc234dc9c..5e42b409dcc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -18,16 +18,15 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.util.collection.OpenHashSet -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types._ /** * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing - * when all relations are already filled in and the analyser needs only to resolve attribute + * when all relations are already filled in and the analyzer needs only to resolve attribute * references. */ object SimpleAnalyzer extends Analyzer(EmptyCatalog, EmptyFunctionRegistry, true) @@ -37,11 +36,12 @@ object SimpleAnalyzer extends Analyzer(EmptyCatalog, EmptyFunctionRegistry, true * [[UnresolvedRelation]]s into fully typed objects using information in a schema [[Catalog]] and * a [[FunctionRegistry]]. */ -class Analyzer(catalog: Catalog, - registry: FunctionRegistry, - caseSensitive: Boolean, - maxIterations: Int = 100) - extends RuleExecutor[LogicalPlan] with HiveTypeCoercion { +class Analyzer( + catalog: Catalog, + registry: FunctionRegistry, + caseSensitive: Boolean, + maxIterations: Int = 100) + extends RuleExecutor[LogicalPlan] with HiveTypeCoercion with CheckAnalysis { val resolver = if (caseSensitive) caseSensitiveResolution else caseInsensitiveResolution @@ -50,51 +50,24 @@ class Analyzer(catalog: Catalog, /** * Override to provide additional rules for the "Resolution" batch. */ - val extendedRules: Seq[Rule[LogicalPlan]] = Nil + val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Nil lazy val batches: Seq[Batch] = Seq( - Batch("MultiInstanceRelations", Once, - NewRelationInstances), Batch("Resolution", fixedPoint, - ResolveReferences :: ResolveRelations :: + ResolveReferences :: ResolveGroupingAnalytics :: ResolveSortReferences :: - NewRelationInstances :: + ResolveGenerate :: ImplicitGenerate :: ResolveFunctions :: GlobalAggregates :: UnresolvedHavingClauseAttributes :: TrimGroupingAliases :: typeCoercionRules ++ - extendedRules : _*), - Batch("Check Analysis", Once, - CheckResolution, - CheckAggregation), - Batch("AnalysisOperators", fixedPoint, - EliminateAnalysisOperators) + extendedResolutionRules : _*) ) - /** - * Makes sure all attributes and logical plans have been resolved. - */ - object CheckResolution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = { - plan.transform { - case p if p.expressions.exists(!_.resolved) => - throw new TreeNodeException(p, - s"Unresolved attributes: ${p.expressions.filterNot(_.resolved).mkString(",")}") - case p if !p.resolved && p.childrenResolved => - throw new TreeNodeException(p, "Unresolved plan found") - } match { - // As a backstop, use the root node to check that the entire plan tree is resolved. - case p if !p.resolved => - throw new TreeNodeException(p, "Unresolved plan in tree") - case p => p - } - } - } - /** * Removes no-op Alias expressions from the plan. */ @@ -167,10 +140,10 @@ class Analyzer(catalog: Catalog, case x: Expression if nonSelectedGroupExprSet.contains(x) => // if the input attribute in the Invalid Grouping Expression set of for this group // replace it with constant null - Literal(null, expr.dataType) + Literal.create(null, expr.dataType) case x if x == g.gid => // replace the groupingId with concrete value (the bit mask) - Literal(bitmask, IntegerType) + Literal.create(bitmask, IntegerType) }) result += GroupExpression(substitution) @@ -193,46 +166,39 @@ class Analyzer(catalog: Catalog, } /** - * Checks for non-aggregated attributes with aggregation + * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. */ - object CheckAggregation extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = { - plan.transform { - case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) => - def isValidAggregateExpression(expr: Expression): Boolean = expr match { - case _: AggregateExpression => true - case e: Attribute => groupingExprs.contains(e) - case e if groupingExprs.contains(e) => true - case e if e.references.isEmpty => true - case e => e.children.forall(isValidAggregateExpression) - } - - aggregateExprs.find { e => - !isValidAggregateExpression(e.transform { - // Should trim aliases around `GetField`s. These aliases are introduced while - // resolving struct field accesses, because `GetField` is not a `NamedExpression`. - // (Should we just turn `GetField` into a `NamedExpression`?) - case Alias(g: GetField, _) => g - }) - }.foreach { e => - throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e") - } - - aggregatePlan + object ResolveRelations extends Rule[LogicalPlan] { + def getTable(u: UnresolvedRelation, cteRelations: Map[String, LogicalPlan]): LogicalPlan = { + try { + // In hive, if there is same table name in database and CTE definition, + // hive will use the table in database, not the CTE one. + // Taking into account the reasonableness and the implementation complexity, + // here use the CTE definition first, check table name only and ignore database name + cteRelations.get(u.tableIdentifier.last) + .map(relation => u.alias.map(Subquery(_, relation)).getOrElse(relation)) + .getOrElse(catalog.lookupRelation(u.tableIdentifier, u.alias)) + } catch { + case _: NoSuchTableException => + u.failAnalysis(s"no such table ${u.tableName}") } } - } - /** - * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. - */ - object ResolveRelations extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case i @ InsertIntoTable(UnresolvedRelation(tableIdentifier, alias), _, _, _) => - i.copy( - table = EliminateAnalysisOperators(catalog.lookupRelation(tableIdentifier, alias))) - case UnresolvedRelation(tableIdentifier, alias) => - catalog.lookupRelation(tableIdentifier, alias) + def apply(plan: LogicalPlan): LogicalPlan = { + val (realPlan, cteRelations) = plan match { + // TODO allow subquery to define CTE + // Add cte table to a temp relation map,drop `with` plan and keep its child + case With(child, relations) => (child, relations) + case other => (other, Map.empty[String, LogicalPlan]) + } + + realPlan transform { + case i@InsertIntoTable(u: UnresolvedRelation, _, _, _, _) => + i.copy( + table = EliminateSubQueries(getTable(u, cteRelations))) + case u: UnresolvedRelation => + getTable(u, cteRelations) + } } } @@ -250,6 +216,24 @@ class Analyzer(catalog: Catalog, Project( projectList.flatMap { case s: Star => s.expand(child.output, resolver) + case Alias(f @ UnresolvedFunction(_, args), name) if containsStar(args) => + val expandedArgs = args.flatMap { + case s: Star => s.expand(child.output, resolver) + case o => o :: Nil + } + Alias(child = f.copy(children = expandedArgs), name)() :: Nil + case Alias(c @ CreateArray(args), name) if containsStar(args) => + val expandedArgs = args.flatMap { + case s: Star => s.expand(child.output, resolver) + case o => o :: Nil + } + Alias(c.copy(children = expandedArgs), name)() :: Nil + case Alias(c @ CreateStruct(args), name) if containsStar(args) => + val expandedArgs = args.flatMap { + case s: Star => s.expand(child.output, resolver) + case o => o :: Nil + } + Alias(c.copy(children = expandedArgs), name)() :: Nil case o => o :: Nil }, child) @@ -270,40 +254,85 @@ class Analyzer(catalog: Catalog, } ) + // Special handling for cases when self-join introduce duplicate expression ids. + case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty => + val conflictingAttributes = left.outputSet.intersect(right.outputSet) + logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} in $j") + + val (oldRelation, newRelation) = right.collect { + // Handle base relations that might appear more than once. + case oldVersion: MultiInstanceRelation + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + val newVersion = oldVersion.newInstance() + (oldVersion, newVersion) + + // Handle projects that create conflicting aliases. + case oldVersion @ Project(projectList, _) + if findAliases(projectList).intersect(conflictingAttributes).nonEmpty => + (oldVersion, oldVersion.copy(projectList = newAliases(projectList))) + + case oldVersion @ Aggregate(_, aggregateExpressions, _) + if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => + (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) + }.headOption.getOrElse { // Only handle first case, others will be fixed on the next pass. + sys.error( + s""" + |Failure when resolving conflicting references in Join: + |$plan + | + |Conflicting attributes: ${conflictingAttributes.mkString(",")} + """.stripMargin) + } + + val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) + val newRight = right transformUp { + case r if r == oldRelation => newRelation + } transformUp { + case other => other transformExpressions { + case a: Attribute => attributeRewrites.get(a).getOrElse(a) + } + } + j.copy(right = newRight) + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") - q transformExpressions { - case u @ UnresolvedAttribute(name) - if resolver(name, VirtualColumn.groupingIdName) && - q.isInstanceOf[GroupingAnalytics] => - // Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics + q transformExpressionsUp { + case u @ UnresolvedAttribute(nameParts) if nameParts.length == 1 && + resolver(nameParts(0), VirtualColumn.groupingIdName) && + q.isInstanceOf[GroupingAnalytics] => + // Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics q.asInstanceOf[GroupingAnalytics].gid - case u @ UnresolvedAttribute(name) => + case u @ UnresolvedAttribute(nameParts) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = q.resolveChildren(name, resolver).getOrElse(u) + val result = + withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } logDebug(s"Resolving $u to $result") result - - // Resolve field names using the resolver. - case f @ GetField(child, fieldName) if !f.resolved && child.resolved => - child.dataType match { - case StructType(fields) => - val resolvedFieldName = fields.map(_.name).find(resolver(_, fieldName)) - resolvedFieldName.map(n => f.copy(fieldName = n)).getOrElse(f) - case _ => f - } + case UnresolvedGetField(child, fieldName) if child.resolved => + GetField(child, fieldName, resolver) } } + def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = { + expressions.map { + case a: Alias => Alias(a.child, a.name)() + case other => other + } + } + + def findAliases(projectList: Seq[NamedExpression]): AttributeSet = { + AttributeSet(projectList.collect { case a: Alias => a.toAttribute }) + } + /** * Returns true if `exprs` contains a [[Star]]. */ protected def containsStar(exprs: Seq[Expression]): Boolean = - exprs.collect { case _: Star => true}.nonEmpty + exprs.exists(_.collect { case _: Star => true }.nonEmpty) } /** - * In many dialects of SQL is it valid to sort by attributes that are not present in the SELECT + * In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT * clause. This rule detects such queries and adds the required attributes to the original * projection, so that they will be available during sorting. Another projection is added to * remove these attributes after sorting. @@ -312,18 +341,16 @@ class Analyzer(catalog: Catalog, def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case s @ Sort(ordering, global, p @ Project(projectList, child)) if !s.resolved && p.resolved => - val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) - val resolved = unresolved.flatMap(child.resolve(_, resolver)) - val requiredAttributes = AttributeSet(resolved.collect { case a: Attribute => a }) + val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, p, child) - val missingInProject = requiredAttributes -- p.output - if (missingInProject.nonEmpty) { + // If this rule was not a no-op, return the transformed plan, otherwise return the original. + if (missing.nonEmpty) { // Add missing attributes and then project them away after the sort. - Project(projectList.map(_.toAttribute), - Sort(ordering, global, - Project(projectList ++ missingInProject, child))) + Project(p.output, + Sort(resolvedOrdering, global, + Project(projectList ++ missing, child))) } else { - logDebug(s"Failed to find $missingInProject in ${p.output.mkString(", ")}") + logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}") s // Nothing we can do here. Return original plan. } case s @ Sort(ordering, global, a @ Aggregate(grouping, aggs, child)) @@ -335,18 +362,54 @@ class Analyzer(catalog: Catalog, grouping.collect { case ne: NamedExpression => ne.toAttribute } ) - logDebug(s"Grouping expressions: $groupingRelation") - val resolved = unresolved.flatMap(groupingRelation.resolve(_, resolver)) - val missingInAggs = resolved.filterNot(a.outputSet.contains) - logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs") - if (missingInAggs.nonEmpty) { + val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, a, groupingRelation) + + if (missing.nonEmpty) { // Add missing grouping exprs and then project them away after the sort. Project(a.output, - Sort(ordering, global, Aggregate(grouping, aggs ++ missingInAggs, child))) + Sort(resolvedOrdering, global, + Aggregate(grouping, aggs ++ missing, child))) } else { s // Nothing we can do here. Return original plan. } } + + /** + * Given a child and a grandchild that are present beneath a sort operator, returns + * a resolved sort ordering and a list of attributes that are missing from the child + * but are present in the grandchild. + */ + def resolveAndFindMissing( + ordering: Seq[SortOrder], + child: LogicalPlan, + grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { + // Find any attributes that remain unresolved in the sort. + val unresolved: Seq[Seq[String]] = + ordering.flatMap(_.collect { case UnresolvedAttribute(nameParts) => nameParts }) + + // Create a map from name, to resolved attributes, when the desired name can be found + // prior to the projection. + val resolved: Map[Seq[String], NamedExpression] = + unresolved.flatMap(u => grandchild.resolve(u, resolver).map(a => u -> a)).toMap + + // Construct a set that contains all of the attributes that we need to evaluate the + // ordering. + val requiredAttributes = AttributeSet(resolved.values) + + // Figure out which ones are missing from the projection, so that we can add them and + // remove them after the sort. + val missingInProject = requiredAttributes -- child.output + + // Now that we have all the attributes we need, reconstruct a resolved ordering. + // It is important to do it here, instead of waiting for the standard resolved as adding + // attributes to the project below can actually introduce ambiquity that was not present + // before. + val resolvedOrdering = ordering.map(_ transform { + case u @ UnresolvedAttribute(name) => resolved.getOrElse(name, u) + }).asInstanceOf[Seq[SortOrder]] + + (resolvedOrdering, missingInProject.toSeq) + } } /** @@ -411,8 +474,59 @@ class Analyzer(catalog: Catalog, */ object ImplicitGenerate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Project(Seq(Alias(g: Generator, _)), child) => - Generate(g, join = false, outer = false, None, child) + case Project(Seq(Alias(g: Generator, name)), child) => + Generate(g, join = false, outer = false, + qualifier = None, UnresolvedAttribute(name) :: Nil, child) + case Project(Seq(MultiAlias(g: Generator, names)), child) => + Generate(g, join = false, outer = false, + qualifier = None, names.map(UnresolvedAttribute(_)), child) + } + } + + /** + * Resolve the Generate, if the output names specified, we will take them, otherwise + * we will try to provide the default names, which follow the same rule with Hive. + */ + object ResolveGenerate extends Rule[LogicalPlan] { + // Construct the output attributes for the generator, + // The output attribute names can be either specified or + // auto generated. + private def makeGeneratorOutput( + generator: Generator, + generatorOutput: Seq[Attribute]): Seq[Attribute] = { + val elementTypes = generator.elementTypes + + if (generatorOutput.length == elementTypes.length) { + generatorOutput.zip(elementTypes).map { + case (a, (t, nullable)) if !a.resolved => + AttributeReference(a.name, t, nullable)() + case (a, _) => a + } + } else if (generatorOutput.length == 0) { + elementTypes.zipWithIndex.map { + // keep the default column names as Hive does _c0, _c1, _cN + case ((t, nullable), i) => AttributeReference(s"_c$i", t, nullable)() + } + } else { + throw new AnalysisException( + s""" + |The number of aliases supplied in the AS clause does not match + |the number of columns output by the UDTF expected + |${elementTypes.size} aliases but got ${generatorOutput.size} + """.stripMargin) + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case p: Generate if !p.child.resolved || !p.generator.resolved => p + case p: Generate if p.resolved == false => + // if the generator output names are not specified, we will use the default ones. + Generate( + p.generator, + join = p.join, + outer = p.outer, + p.qualifier, + makeGeneratorOutput(p.generator, p.generatorOutput), p.child) } } } @@ -422,7 +536,7 @@ class Analyzer(catalog: Catalog, * only required to provide scoping information for attributes and can be removed once analysis is * complete. */ -object EliminateAnalysisOperators extends Rule[LogicalPlan] { +object EliminateSubQueries extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Subquery(_, child) => child } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index df8d03b86c53..b2f8157a1a61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -21,6 +21,12 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery} +/** + * Thrown by a catalog when a table cannot be found. The analyzer will rethrow the exception + * as an AnalysisException with the correct position information. + */ +class NoSuchTableException extends Exception + /** * An interface for looking up relations by name. Used by an [[Analyzer]]. */ @@ -34,6 +40,14 @@ trait Catalog { tableIdentifier: Seq[String], alias: Option[String] = None): LogicalPlan + /** + * Returns tuples of (tableName, isTemporary) for all tables in the given database. + * isTemporary is a Boolean value indicates if a table is a temporary or not. + */ + def getTables(databaseName: Option[String]): Seq[(String, Boolean)] + + def refreshTable(databaseName: String, tableName: String): Unit + def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit def unregisterTable(tableIdentifier: Seq[String]): Unit @@ -72,12 +86,12 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog { tables += ((getDbTableName(tableIdent), plan)) } - override def unregisterTable(tableIdentifier: Seq[String]) = { + override def unregisterTable(tableIdentifier: Seq[String]): Unit = { val tableIdent = processTableIdentifier(tableIdentifier) tables -= getDbTableName(tableIdent) } - override def unregisterAllTables() = { + override def unregisterAllTables(): Unit = { tables.clear() } @@ -101,6 +115,16 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog { // properly qualified with this alias. alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers) } + + override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { + tables.map { + case (name, _) => (name, true) + }.toSeq + } + + override def refreshTable(databaseName: String, tableName: String): Unit = { + throw new UnsupportedOperationException + } } /** @@ -123,8 +147,8 @@ trait OverrideCatalog extends Catalog { } abstract override def lookupRelation( - tableIdentifier: Seq[String], - alias: Option[String] = None): LogicalPlan = { + tableIdentifier: Seq[String], + alias: Option[String] = None): LogicalPlan = { val tableIdent = processTableIdentifier(tableIdentifier) val overriddenTable = overrides.get(getDBTable(tableIdent)) val tableWithQualifers = overriddenTable.map(r => Subquery(tableIdent.last, r)) @@ -137,6 +161,27 @@ trait OverrideCatalog extends Catalog { withAlias.getOrElse(super.lookupRelation(tableIdentifier, alias)) } + abstract override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { + val dbName = if (!caseSensitive) { + if (databaseName.isDefined) Some(databaseName.get.toLowerCase) else None + } else { + databaseName + } + + val temporaryTables = overrides.filter { + // If a temporary table does not have an associated database, we should return its name. + case ((None, _), _) => true + // If a temporary table does have an associated database, we should return it if the database + // matches the given database name. + case ((db: Some[String], _), _) if db == dbName => true + case _ => false + }.map { + case ((_, tableName), _) => (tableName, true) + }.toSeq + + temporaryTables ++ super.getTables(databaseName) + } + override def registerTable( tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { @@ -156,29 +201,37 @@ trait OverrideCatalog extends Catalog { /** * A trivial catalog that returns an error when a relation is requested. Used for testing when all - * relations are already filled in and the analyser needs only to resolve attribute references. + * relations are already filled in and the analyzer needs only to resolve attribute references. */ object EmptyCatalog extends Catalog { - val caseSensitive: Boolean = true + override val caseSensitive: Boolean = true - def tableExists(tableIdentifier: Seq[String]): Boolean = { + override def tableExists(tableIdentifier: Seq[String]): Boolean = { throw new UnsupportedOperationException } - def lookupRelation( - tableIdentifier: Seq[String], - alias: Option[String] = None) = { + override def lookupRelation( + tableIdentifier: Seq[String], + alias: Option[String] = None): LogicalPlan = { + throw new UnsupportedOperationException + } + + override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { throw new UnsupportedOperationException } - def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { + override def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { throw new UnsupportedOperationException } - def unregisterTable(tableIdentifier: Seq[String]): Unit = { + override def unregisterTable(tableIdentifier: Seq[String]): Unit = { throw new UnsupportedOperationException } override def unregisterAllTables(): Unit = {} + + override def refreshTable(databaseName: String, tableName: String): Unit = { + throw new UnsupportedOperationException + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala new file mode 100644 index 000000000000..2381689e1752 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types._ + +/** + * Throws user facing errors when passed invalid queries that fail to analyze. + */ +trait CheckAnalysis { + self: Analyzer => + + /** + * Override to provide additional checks for correct analysis. + * These rules will be evaluated after our built-in check rules. + */ + val extendedCheckRules: Seq[LogicalPlan => Unit] = Nil + + protected def failAnalysis(msg: String): Nothing = { + throw new AnalysisException(msg) + } + + def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = { + exprs.flatMap(_.collect { + case e: Generator => true + }).length >= 1 + } + + def checkAnalysis(plan: LogicalPlan): Unit = { + // We transform up and order the rules so as to catch the first possible failure instead + // of the result of cascading resolution failures. + plan.foreachUp { + case operator: LogicalPlan => + operator transformExpressionsUp { + case a: Attribute if !a.resolved => + if (operator.childrenResolved) { + a match { + case UnresolvedAttribute(nameParts) => + // Throw errors for specific problems with get field. + operator.resolveChildren(nameParts, resolver, throwErrors = true) + } + } + + val from = operator.inputSet.map(_.name).mkString(", ") + a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") + + case c: Cast if !c.resolved => + failAnalysis( + s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") + + case b: BinaryExpression if !b.resolved => + failAnalysis( + s"invalid expression ${b.prettyString} " + + s"between ${b.left.simpleString} and ${b.right.simpleString}") + } + + operator match { + case f: Filter if f.condition.dataType != BooleanType => + failAnalysis( + s"filter expression '${f.condition.prettyString}' " + + s"of type ${f.condition.dataType.simpleString} is not a boolean.") + + case Aggregate(groupingExprs, aggregateExprs, child) => + def checkValidAggregateExpression(expr: Expression): Unit = expr match { + case _: AggregateExpression => // OK + case e: Attribute if !groupingExprs.contains(e) => + failAnalysis( + s"expression '${e.prettyString}' is neither present in the group by, " + + s"nor is it an aggregate function. " + + "Add to group by or wrap in first() if you don't care which value you get.") + case e if groupingExprs.contains(e) => // OK + case e if e.references.isEmpty => // OK + case e => e.children.foreach(checkValidAggregateExpression) + } + + val cleaned = aggregateExprs.map(_.transform { + // Should trim aliases around `GetField`s. These aliases are introduced while + // resolving struct field accesses, because `GetField` is not a `NamedExpression`. + // (Should we just turn `GetField` into a `NamedExpression`?) + case Alias(g, _) => g + }) + + cleaned.foreach(checkValidAggregateExpression) + + case _ => // Fallbacks to the following checks + } + + operator match { + case o if o.children.nonEmpty && o.missingInput.nonEmpty => + val missingAttributes = o.missingInput.mkString(",") + val input = o.inputSet.mkString(",") + + failAnalysis( + s"resolved attribute(s) $missingAttributes missing from $input " + + s"in operator ${operator.simpleString}") + + case o if !o.resolved => + failAnalysis( + s"unresolved operator ${operator.simpleString}") + + case p @ Project(exprs, _) if containsMultipleGenerators(exprs) => + failAnalysis( + s"""Only a single table generating function is allowed in a SELECT clause, found: + | ${exprs.map(_.prettyString).mkString(",")}""".stripMargin) + + + case _ => // Analysis successful! + } + } + extendedCheckRules.foreach(_(plan)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 760c49fbca4a..16ca5bcd57a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -27,25 +27,27 @@ trait FunctionRegistry { def registerFunction(name: String, builder: FunctionBuilder): Unit def lookupFunction(name: String, children: Seq[Expression]): Expression + + def caseSensitive: Boolean } trait OverrideFunctionRegistry extends FunctionRegistry { - val functionBuilders = new mutable.HashMap[String, FunctionBuilder]() + val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive) - def registerFunction(name: String, builder: FunctionBuilder) = { + override def registerFunction(name: String, builder: FunctionBuilder): Unit = { functionBuilders.put(name, builder) } abstract override def lookupFunction(name: String, children: Seq[Expression]): Expression = { - functionBuilders.get(name).map(_(children)).getOrElse(super.lookupFunction(name,children)) + functionBuilders.get(name).map(_(children)).getOrElse(super.lookupFunction(name, children)) } } -class SimpleFunctionRegistry extends FunctionRegistry { - val functionBuilders = new mutable.HashMap[String, FunctionBuilder]() +class SimpleFunctionRegistry(val caseSensitive: Boolean) extends FunctionRegistry { + val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive) - def registerFunction(name: String, builder: FunctionBuilder) = { + override def registerFunction(name: String, builder: FunctionBuilder): Unit = { functionBuilders.put(name, builder) } @@ -55,13 +57,41 @@ class SimpleFunctionRegistry extends FunctionRegistry { } /** - * A trivial catalog that returns an error when a function is requested. Used for testing when all - * functions are already filled in and the analyser needs only to resolve attribute references. + * A trivial catalog that returns an error when a function is requested. Used for testing when all + * functions are already filled in and the analyzer needs only to resolve attribute references. */ object EmptyFunctionRegistry extends FunctionRegistry { - def registerFunction(name: String, builder: FunctionBuilder) = ??? + override def registerFunction(name: String, builder: FunctionBuilder): Unit = { + throw new UnsupportedOperationException + } - def lookupFunction(name: String, children: Seq[Expression]): Expression = { + override def lookupFunction(name: String, children: Seq[Expression]): Expression = { throw new UnsupportedOperationException } + + override def caseSensitive: Boolean = throw new UnsupportedOperationException } + +/** + * Build a map with String type of key, and it also supports either key case + * sensitive or insensitive. + * TODO move this into util folder? + */ +object StringKeyHashMap { + def apply[T](caseSensitive: Boolean): StringKeyHashMap[T] = caseSensitive match { + case false => new StringKeyHashMap[T](_.toLowerCase) + case true => new StringKeyHashMap[T](identity) + } +} + +class StringKeyHashMap[T](normalizer: (String) => String) { + private val base = new collection.mutable.HashMap[String, T]() + + def apply(key: String): T = base(normalizer(key)) + + def get(key: String): Option[T] = base.get(normalizer(key)) + def put(key: String, value: T): Option[T] = base.put(normalizer(key), value) + def remove(key: String): Option[T] = base.remove(normalizer(key)) + def iterator: Iterator[(String, T)] = base.toIterator +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 6ef8577fd04d..35c7f00d4e42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -78,6 +78,7 @@ trait HiveTypeCoercion { FunctionArgumentConversion :: CaseWhenCoercion :: Division :: + PropagateTypes :: Nil /** @@ -114,7 +115,7 @@ trait HiveTypeCoercion { * the appropriate numeric equivalent. */ object ConvertNaNs extends Rule[LogicalPlan] { - val stringNaN = Literal("NaN", StringType) + val stringNaN = Literal("NaN") def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressions { @@ -284,6 +285,7 @@ trait HiveTypeCoercion { * Calculates and propagates precision for fixed-precision decimals. Hive has a number of * rules for this based on the SQL standard and MS SQL: * https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf + * https://msdn.microsoft.com/en-us/library/ms190476.aspx * * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2 * respectively, then the following operations have the following precision / scale: @@ -295,6 +297,7 @@ trait HiveTypeCoercion { * e1 * e2 p1 + p2 + 1 s1 + s2 * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1) * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) + * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2) * sum(e1) p1 + 10 s1 * avg(e1) p1 + 4 s1 + 4 * @@ -310,7 +313,12 @@ trait HiveTypeCoercion { * - SHORT gets turned into DECIMAL(5, 0) * - INT gets turned into DECIMAL(10, 0) * - LONG gets turned into DECIMAL(20, 0) - * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive, + * - FLOAT and DOUBLE + * 1. Union operation: + * FLOAT gets turned into DECIMAL(7, 7), DOUBLE gets turned into DECIMAL(15, 15) (this is the + * same as Hive) + * 2. Other operation: + * FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive, * but note that unlimited decimals are considered bigger than doubles in WidenTypes) */ // scalastyle:on @@ -327,76 +335,127 @@ trait HiveTypeCoercion { def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - // Skip nodes whose children have not been resolved yet - case e if !e.childrenResolved => e + // Conversion rules for float and double into fixed-precision decimals + val floatTypeToFixed: Map[DataType, DecimalType] = Map( + FloatType -> DecimalType(7, 7), + DoubleType -> DecimalType(15, 15) + ) - case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - ) - - case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - ) - - case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(p1 + p2 + 1, s1 + s2) - ) - - case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1)) - ) - - case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) - ) - - case LessThan(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case GreaterThan(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - // Promote integers inside a binary expression with fixed-precision decimals to decimals, - // and fixed-precision decimals in an expression with floats / doubles to doubles - case b: BinaryExpression if b.left.dataType != b.right.dataType => - (b.left.dataType, b.right.dataType) match { - case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => - b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right)) - case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => - b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t)))) - case (t, DecimalType.Fixed(p, s)) if isFloat(t) => - b.makeCopy(Array(b.left, Cast(b.right, DoubleType))) - case (DecimalType.Fixed(p, s), t) if isFloat(t) => - b.makeCopy(Array(Cast(b.left, DoubleType), b.right)) - case _ => - b + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // fix decimal precision for union + case u @ Union(left, right) if u.childrenResolved && !u.resolved => + val castedInput = left.output.zip(right.output).map { + case (l, r) if l.dataType != r.dataType => + (l.dataType, r.dataType) match { + case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) => + // Union decimals with precision/scale p1/s2 and p2/s2 will be promoted to + // DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)) + val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2)) + (Alias(Cast(l, fixedType), l.name)(), Alias(Cast(r, fixedType), r.name)()) + case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => + (Alias(Cast(l, intTypeToFixed(t)), l.name)(), r) + case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => + (l, Alias(Cast(r, intTypeToFixed(t)), r.name)()) + case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) => + (Alias(Cast(l, floatTypeToFixed(t)), l.name)(), r) + case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) => + (l, Alias(Cast(r, floatTypeToFixed(t)), r.name)()) + case _ => (l, r) + } + case other => other } - // TODO: MaxOf, MinOf, etc might want other rules + val (castedLeft, castedRight) = castedInput.unzip + + val newLeft = + if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { + Project(castedLeft, left) + } else { + left + } + + val newRight = + if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { + Project(castedRight, right) + } else { + right + } + + Union(newLeft, newRight) + + // fix decimal precision for expressions + case q => q.transformExpressions { + // Skip nodes whose children have not been resolved yet + case e if !e.childrenResolved => e + + case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + ) + + case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + ) + + case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(p1 + p2 + 1, s1 + s2) + ) + + case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1)) + ) + + case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + ) + + case LessThan(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + case GreaterThan(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + // Promote integers inside a binary expression with fixed-precision decimals to decimals, + // and fixed-precision decimals in an expression with floats / doubles to doubles + case b: BinaryExpression if b.left.dataType != b.right.dataType => + (b.left.dataType, b.right.dataType) match { + case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => + b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right)) + case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => + b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t)))) + case (t, DecimalType.Fixed(p, s)) if isFloat(t) => + b.makeCopy(Array(b.left, Cast(b.right, DoubleType))) + case (DecimalType.Fixed(p, s), t) if isFloat(t) => + b.makeCopy(Array(Cast(b.left, DoubleType), b.right)) + case _ => + b + } + + // TODO: MaxOf, MinOf, etc might want other rules - // SUM and AVERAGE are handled by the implementations of those expressions + // SUM and AVERAGE are handled by the implementations of those expressions + } } + } /** @@ -503,6 +562,26 @@ trait HiveTypeCoercion { // Hive lets you do aggregation of timestamps... for some reason case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType)) case Average(e @ TimestampType()) => Average(Cast(e, DoubleType)) + + // Compatible with Hive + case Substring(e, start, len) if e.dataType != StringType => + Substring(Cast(e, StringType), start, len) + + // Coalesce should return the first non-null value, which could be any column + // from the list. So we need to make sure the return type is deterministic and + // compatible with every child column. + case Coalesce(es) if es.map(_.dataType).distinct.size > 1 => + val dt: Option[DataType] = Some(NullType) + val types = es.map(_.dataType) + val rt = types.foldLeft(dt)((r, c) => r match { + case None => None + case Some(d) => findTightestCommonType(d, c) + }) + rt match { + case Some(finaldt) => Coalesce(es.map(Cast(_, finaldt))) + case None => + sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}") + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala index 22941edef2d4..35b74024a4ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala @@ -26,28 +26,9 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan * produced by distinct operators in a query tree as this breaks the guarantee that expression * ids, which are used to differentiate attributes, are unique. * - * Before analysis, all operators that include this trait will be asked to produce a new version + * During analysis, operators that include this trait may be asked to produce a new version * of itself with globally unique expression ids. */ trait MultiInstanceRelation { - def newInstance(): this.type -} - -/** - * If any MultiInstanceRelation appears more than once in the query plan then the plan is updated so - * that each instance has unique expression ids for the attributes produced. - */ -object NewRelationInstances extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = { - val localRelations = plan collect { case l: MultiInstanceRelation => l} - val multiAppearance = localRelations - .groupBy(identity[MultiInstanceRelation]) - .filter { case (_, ls) => ls.size > 1 } - .map(_._1) - .toSet - - plan transform { - case l: MultiInstanceRelation if multiAppearance contains l => l.newInstance - } - } + def newInstance(): LogicalPlan } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala index 3f672a3e0fd9..7731336d247d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.trees.TreeNode + /** * Provides a logical query plan [[Analyzer]] and supporting classes for performing analysis. * Analysis consists of translating [[UnresolvedAttribute]]s and [[UnresolvedRelation]]s @@ -25,11 +28,26 @@ package org.apache.spark.sql.catalyst package object analysis { /** - * Responsible for resolving which identifiers refer to the same entity. For example, by using - * case insensitive equality. + * Resolver should return true if the first string refers to the same entity as the second string. + * For example, by using case insensitive equality. */ type Resolver = (String, String) => Boolean val caseInsensitiveResolution = (a: String, b: String) => a.equalsIgnoreCase(b) val caseSensitiveResolution = (a: String, b: String) => a == b + + implicit class AnalysisErrorAt(t: TreeNode[_]) { + /** Fails the analysis at the point where a specific tree node was parsed. */ + def failAnalysis(msg: String): Nothing = { + throw new AnalysisException(msg, t.origin.line, t.origin.startPosition) + } + } + + /** Catches any AnalysisExceptions thrown by `f` and attaches `t`'s position if any. */ + def withPosition[A](t: TreeNode[_])(f: => A): A = { + try f catch { + case a: AnalysisException => + throw a.withPosition(t.origin.line, t.origin.startPosition) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 71a738a0b2ca..3f567e3e8b2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.types.DataType /** * Thrown when an invalid attempt is made to access a property of a tree that has yet to be fully @@ -36,24 +37,34 @@ class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: Str case class UnresolvedRelation( tableIdentifier: Seq[String], alias: Option[String] = None) extends LeafNode { - override def output = Nil + + /** Returns a `.` separated name for this relation. */ + def tableName: String = tableIdentifier.mkString(".") + + override def output: Seq[Attribute] = Nil + override lazy val resolved = false } /** * Holds the name of an attribute that has yet to be resolved. */ -case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNode[Expression] { - override def exprId = throw new UnresolvedException(this, "exprId") - override def dataType = throw new UnresolvedException(this, "dataType") - override def nullable = throw new UnresolvedException(this, "nullable") - override def qualifiers = throw new UnresolvedException(this, "qualifiers") +case class UnresolvedAttribute(nameParts: Seq[String]) + extends Attribute with trees.LeafNode[Expression] { + + def name: String = + nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") + + override def exprId: ExprId = throw new UnresolvedException(this, "exprId") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") override lazy val resolved = false - override def newInstance = this - override def withNullability(newNullability: Boolean) = this - override def withQualifiers(newQualifiers: Seq[String]) = this - override def withName(newName: String) = UnresolvedAttribute(name) + override def newInstance(): UnresolvedAttribute = this + override def withNullability(newNullability: Boolean): UnresolvedAttribute = this + override def withQualifiers(newQualifiers: Seq[String]): UnresolvedAttribute = this + override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName) // Unresolved attributes are transient at compile time and don't get evaluated during execution. override def eval(input: Row = null): EvaluatedType = @@ -62,19 +73,51 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo override def toString: String = s"'$name" } +object UnresolvedAttribute { + def apply(name: String): UnresolvedAttribute = new UnresolvedAttribute(name.split("\\.")) + def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name)) +} + case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression { - override def dataType = throw new UnresolvedException(this, "dataType") - override def foldable = throw new UnresolvedException(this, "foldable") - override def nullable = throw new UnresolvedException(this, "nullable") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false // Unresolved functions are transient at compile time and don't get evaluated during execution. override def eval(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString = s"'$name(${children.mkString(",")})" + override def toString: String = s"'$name(${children.mkString(",")})" +} + +/** + * Represents all of the input attributes to a given relational operator, for example in + * "SELECT * FROM ...". A [[Star]] gets automatically expanded during analysis. + */ +trait Star extends Attribute with trees.LeafNode[Expression] { + self: Product => + + override def name: String = throw new UnresolvedException(this, "name") + override def exprId: ExprId = throw new UnresolvedException(this, "exprId") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") + override lazy val resolved = false + + override def newInstance(): Star = this + override def withNullability(newNullability: Boolean): Star = this + override def withQualifiers(newQualifiers: Seq[String]): Star = this + override def withName(newName: String): Star = this + + // Star gets expanded at runtime so we never evaluate a Star. + override def eval(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + + def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] } + /** * Represents all of the input attributes to a given relational operator, for example in * "SELECT * FROM ...". @@ -82,41 +125,83 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E * @param table an optional table that should be the target of the expansion. If omitted all * tables' columns are produced. */ -case class Star( - table: Option[String], - mapFunction: Attribute => Expression = identity[Attribute]) - extends Attribute with trees.LeafNode[Expression] { - - override def name = throw new UnresolvedException(this, "name") - override def exprId = throw new UnresolvedException(this, "exprId") - override def dataType = throw new UnresolvedException(this, "dataType") - override def nullable = throw new UnresolvedException(this, "nullable") - override def qualifiers = throw new UnresolvedException(this, "qualifiers") - override lazy val resolved = false - - override def newInstance = this - override def withNullability(newNullability: Boolean) = this - override def withQualifiers(newQualifiers: Seq[String]) = this - override def withName(newName: String) = this +case class UnresolvedStar(table: Option[String]) extends Star { - def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = { + override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = { val expandedAttributes: Seq[Attribute] = table match { // If there is no table specified, use all input attributes. case None => input // If there is a table, pick out attributes that are part of this table. case Some(t) => input.filter(_.qualifiers.filter(resolver(_, t)).nonEmpty) } - val mappedAttributes = expandedAttributes.map(mapFunction).zip(input).map { + expandedAttributes.zip(input).map { case (n: NamedExpression, _) => n case (e, originalAttribute) => Alias(e, originalAttribute.name)(qualifiers = originalAttribute.qualifiers) } - mappedAttributes } - // Star gets expanded at runtime so we never evaluate a Star. + override def toString: String = table.map(_ + ".").getOrElse("") + "*" +} + +/** + * Used to assign new names to Generator's output, such as hive udtf. + * For example the SQL expression "stack(2, key, value, key, value) as (a, b)" could be represented + * as follows: + * MultiAlias(stack_function, Seq(a, b)) + + * @param child the computation being performed + * @param names the names to be associated with each output of computing [[child]]. + */ +case class MultiAlias(child: Expression, names: Seq[String]) + extends Attribute with trees.UnaryNode[Expression] { + + override def name: String = throw new UnresolvedException(this, "name") + + override def exprId: ExprId = throw new UnresolvedException(this, "exprId") + + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + + override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") + + override lazy val resolved = false + + override def newInstance(): MultiAlias = this + + override def withNullability(newNullability: Boolean): MultiAlias = this + + override def withQualifiers(newQualifiers: Seq[String]): MultiAlias = this + + override def withName(newName: String): MultiAlias = this + + override def eval(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + + override def toString: String = s"$child AS $names" + +} + +/** + * Represents all the resolved input attributes to a given relational operator. This is used + * in the data frame DSL. + * + * @param expressions Expressions to expand. + */ +case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star { + override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = expressions + override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")") +} + +case class UnresolvedGetField(child: Expression, fieldName: String) extends UnaryExpression { + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override lazy val resolved = false + override def eval(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString = table.map(_ + ".").getOrElse("") + "*" + override def toString: String = s"$child.$fieldName" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala old mode 100755 new mode 100644 index 417659eed595..5d5aba9644ff --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp} import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} @@ -61,60 +61,60 @@ package object dsl { trait ImplicitOperators { def expr: Expression - def unary_- = UnaryMinus(expr) - def unary_! = Not(expr) - def unary_~ = BitwiseNot(expr) - - def + (other: Expression) = Add(expr, other) - def - (other: Expression) = Subtract(expr, other) - def * (other: Expression) = Multiply(expr, other) - def / (other: Expression) = Divide(expr, other) - def % (other: Expression) = Remainder(expr, other) - def & (other: Expression) = BitwiseAnd(expr, other) - def | (other: Expression) = BitwiseOr(expr, other) - def ^ (other: Expression) = BitwiseXor(expr, other) - - def && (other: Expression) = And(expr, other) - def || (other: Expression) = Or(expr, other) - - def < (other: Expression) = LessThan(expr, other) - def <= (other: Expression) = LessThanOrEqual(expr, other) - def > (other: Expression) = GreaterThan(expr, other) - def >= (other: Expression) = GreaterThanOrEqual(expr, other) - def === (other: Expression) = EqualTo(expr, other) - def <=> (other: Expression) = EqualNullSafe(expr, other) - def !== (other: Expression) = Not(EqualTo(expr, other)) - - def in(list: Expression*) = In(expr, list) - - def like(other: Expression) = Like(expr, other) - def rlike(other: Expression) = RLike(expr, other) - def contains(other: Expression) = Contains(expr, other) - def startsWith(other: Expression) = StartsWith(expr, other) - def endsWith(other: Expression) = EndsWith(expr, other) - def substr(pos: Expression, len: Expression = Literal(Int.MaxValue)) = + def unary_- : Expression= UnaryMinus(expr) + def unary_! : Predicate = Not(expr) + def unary_~ : Expression = BitwiseNot(expr) + + def + (other: Expression): Expression = Add(expr, other) + def - (other: Expression): Expression = Subtract(expr, other) + def * (other: Expression): Expression = Multiply(expr, other) + def / (other: Expression): Expression = Divide(expr, other) + def % (other: Expression): Expression = Remainder(expr, other) + def & (other: Expression): Expression = BitwiseAnd(expr, other) + def | (other: Expression): Expression = BitwiseOr(expr, other) + def ^ (other: Expression): Expression = BitwiseXor(expr, other) + + def && (other: Expression): Predicate = And(expr, other) + def || (other: Expression): Predicate = Or(expr, other) + + def < (other: Expression): Predicate = LessThan(expr, other) + def <= (other: Expression): Predicate = LessThanOrEqual(expr, other) + def > (other: Expression): Predicate = GreaterThan(expr, other) + def >= (other: Expression): Predicate = GreaterThanOrEqual(expr, other) + def === (other: Expression): Predicate = EqualTo(expr, other) + def <=> (other: Expression): Predicate = EqualNullSafe(expr, other) + def !== (other: Expression): Predicate = Not(EqualTo(expr, other)) + + def in(list: Expression*): Expression = In(expr, list) + + def like(other: Expression): Expression = Like(expr, other) + def rlike(other: Expression): Expression = RLike(expr, other) + def contains(other: Expression): Expression = Contains(expr, other) + def startsWith(other: Expression): Expression = StartsWith(expr, other) + def endsWith(other: Expression): Expression = EndsWith(expr, other) + def substr(pos: Expression, len: Expression = Literal(Int.MaxValue)): Expression = Substring(expr, pos, len) - def substring(pos: Expression, len: Expression = Literal(Int.MaxValue)) = + def substring(pos: Expression, len: Expression = Literal(Int.MaxValue)): Expression = Substring(expr, pos, len) - def isNull = IsNull(expr) - def isNotNull = IsNotNull(expr) + def isNull: Predicate = IsNull(expr) + def isNotNull: Predicate = IsNotNull(expr) - def getItem(ordinal: Expression) = GetItem(expr, ordinal) - def getField(fieldName: String) = GetField(expr, fieldName) + def getItem(ordinal: Expression): Expression = GetItem(expr, ordinal) + def getField(fieldName: String): UnresolvedGetField = UnresolvedGetField(expr, fieldName) - def cast(to: DataType) = Cast(expr, to) + def cast(to: DataType): Expression = Cast(expr, to) - def asc = SortOrder(expr, Ascending) - def desc = SortOrder(expr, Descending) + def asc: SortOrder = SortOrder(expr, Ascending) + def desc: SortOrder = SortOrder(expr, Descending) - def as(alias: String) = Alias(expr, alias)() - def as(alias: Symbol) = Alias(expr, alias.name)() + def as(alias: String): NamedExpression = Alias(expr, alias)() + def as(alias: Symbol): NamedExpression = Alias(expr, alias.name)() } trait ExpressionConversions { implicit class DslExpression(e: Expression) extends ImplicitOperators { - def expr = e + def expr: Expression = e } implicit def booleanToLiteral(b: Boolean): Literal = Literal(b) @@ -144,94 +144,100 @@ package object dsl { } } - def sum(e: Expression) = Sum(e) - def sumDistinct(e: Expression) = SumDistinct(e) - def count(e: Expression) = Count(e) - def countDistinct(e: Expression*) = CountDistinct(e) - def approxCountDistinct(e: Expression, rsd: Double = 0.05) = ApproxCountDistinct(e, rsd) - def avg(e: Expression) = Average(e) - def first(e: Expression) = First(e) - def last(e: Expression) = Last(e) - def min(e: Expression) = Min(e) - def max(e: Expression) = Max(e) - def upper(e: Expression) = Upper(e) - def lower(e: Expression) = Lower(e) - def sqrt(e: Expression) = Sqrt(e) - def abs(e: Expression) = Abs(e) - - implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name } + def sum(e: Expression): Expression = Sum(e) + def sumDistinct(e: Expression): Expression = SumDistinct(e) + def count(e: Expression): Expression = Count(e) + def countDistinct(e: Expression*): Expression = CountDistinct(e) + def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression = + ApproxCountDistinct(e, rsd) + def avg(e: Expression): Expression = Average(e) + def first(e: Expression): Expression = First(e) + def last(e: Expression): Expression = Last(e) + def min(e: Expression): Expression = Min(e) + def max(e: Expression): Expression = Max(e) + def upper(e: Expression): Expression = Upper(e) + def lower(e: Expression): Expression = Lower(e) + def sqrt(e: Expression): Expression = Sqrt(e) + def abs(e: Expression): Expression = Abs(e) + + implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name } // TODO more implicit class for literal? implicit class DslString(val s: String) extends ImplicitOperators { override def expr: Expression = Literal(s) - def attr = analysis.UnresolvedAttribute(s) + def attr: UnresolvedAttribute = analysis.UnresolvedAttribute(s) } abstract class ImplicitAttribute extends ImplicitOperators { def s: String - def expr = attr - def attr = analysis.UnresolvedAttribute(s) + def expr: UnresolvedAttribute = attr + def attr: UnresolvedAttribute = analysis.UnresolvedAttribute(s) /** Creates a new AttributeReference of type boolean */ - def boolean = AttributeReference(s, BooleanType, nullable = true)() + def boolean: AttributeReference = AttributeReference(s, BooleanType, nullable = true)() /** Creates a new AttributeReference of type byte */ - def byte = AttributeReference(s, ByteType, nullable = true)() + def byte: AttributeReference = AttributeReference(s, ByteType, nullable = true)() /** Creates a new AttributeReference of type short */ - def short = AttributeReference(s, ShortType, nullable = true)() + def short: AttributeReference = AttributeReference(s, ShortType, nullable = true)() /** Creates a new AttributeReference of type int */ - def int = AttributeReference(s, IntegerType, nullable = true)() + def int: AttributeReference = AttributeReference(s, IntegerType, nullable = true)() /** Creates a new AttributeReference of type long */ - def long = AttributeReference(s, LongType, nullable = true)() + def long: AttributeReference = AttributeReference(s, LongType, nullable = true)() /** Creates a new AttributeReference of type float */ - def float = AttributeReference(s, FloatType, nullable = true)() + def float: AttributeReference = AttributeReference(s, FloatType, nullable = true)() /** Creates a new AttributeReference of type double */ - def double = AttributeReference(s, DoubleType, nullable = true)() + def double: AttributeReference = AttributeReference(s, DoubleType, nullable = true)() /** Creates a new AttributeReference of type string */ - def string = AttributeReference(s, StringType, nullable = true)() + def string: AttributeReference = AttributeReference(s, StringType, nullable = true)() /** Creates a new AttributeReference of type date */ - def date = AttributeReference(s, DateType, nullable = true)() + def date: AttributeReference = AttributeReference(s, DateType, nullable = true)() /** Creates a new AttributeReference of type decimal */ - def decimal = AttributeReference(s, DecimalType.Unlimited, nullable = true)() + def decimal: AttributeReference = + AttributeReference(s, DecimalType.Unlimited, nullable = true)() /** Creates a new AttributeReference of type decimal */ - def decimal(precision: Int, scale: Int) = + def decimal(precision: Int, scale: Int): AttributeReference = AttributeReference(s, DecimalType(precision, scale), nullable = true)() /** Creates a new AttributeReference of type timestamp */ - def timestamp = AttributeReference(s, TimestampType, nullable = true)() + def timestamp: AttributeReference = AttributeReference(s, TimestampType, nullable = true)() /** Creates a new AttributeReference of type binary */ - def binary = AttributeReference(s, BinaryType, nullable = true)() + def binary: AttributeReference = AttributeReference(s, BinaryType, nullable = true)() /** Creates a new AttributeReference of type array */ - def array(dataType: DataType) = AttributeReference(s, ArrayType(dataType), nullable = true)() + def array(dataType: DataType): AttributeReference = + AttributeReference(s, ArrayType(dataType), nullable = true)() /** Creates a new AttributeReference of type map */ def map(keyType: DataType, valueType: DataType): AttributeReference = map(MapType(keyType, valueType)) - def map(mapType: MapType) = AttributeReference(s, mapType, nullable = true)() + + def map(mapType: MapType): AttributeReference = + AttributeReference(s, mapType, nullable = true)() /** Creates a new AttributeReference of type struct */ def struct(fields: StructField*): AttributeReference = struct(StructType(fields)) - def struct(structType: StructType) = AttributeReference(s, structType, nullable = true)() + def struct(structType: StructType): AttributeReference = + AttributeReference(s, structType, nullable = true)() } implicit class DslAttribute(a: AttributeReference) { - def notNull = a.withNullability(false) - def nullable = a.withNullability(true) + def notNull: AttributeReference = a.withNullability(false) + def nullable: AttributeReference = a.withNullability(true) // Protobuf terminology - def required = a.withNullability(false) + def required: AttributeReference = a.withNullability(false) - def at(ordinal: Int) = BoundReference(ordinal, a.dataType, a.nullable) + def at(ordinal: Int): BoundReference = BoundReference(ordinal, a.dataType, a.nullable) } } @@ -241,23 +247,23 @@ package object dsl { abstract class LogicalPlanFunctions { def logicalPlan: LogicalPlan - def select(exprs: NamedExpression*) = Project(exprs, logicalPlan) + def select(exprs: NamedExpression*): LogicalPlan = Project(exprs, logicalPlan) - def where(condition: Expression) = Filter(condition, logicalPlan) + def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan) - def limit(limitExpr: Expression) = Limit(limitExpr, logicalPlan) + def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan) def join( otherPlan: LogicalPlan, joinType: JoinType = Inner, - condition: Option[Expression] = None) = + condition: Option[Expression] = None): LogicalPlan = Join(logicalPlan, otherPlan, joinType, condition) - def orderBy(sortExprs: SortOrder*) = Sort(sortExprs, true, logicalPlan) + def orderBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, true, logicalPlan) - def sortBy(sortExprs: SortOrder*) = Sort(sortExprs, false, logicalPlan) + def sortBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, false, logicalPlan) - def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*) = { + def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): LogicalPlan = { val aliasedExprs = aggregateExprs.map { case ne: NamedExpression => ne case e => Alias(e, e.toString)() @@ -265,41 +271,44 @@ package object dsl { Aggregate(groupingExprs, aliasedExprs, logicalPlan) } - def subquery(alias: Symbol) = Subquery(alias.name, logicalPlan) + def subquery(alias: Symbol): LogicalPlan = Subquery(alias.name, logicalPlan) - def unionAll(otherPlan: LogicalPlan) = Union(logicalPlan, otherPlan) + def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) - def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean) = + def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean): LogicalPlan = Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan) def sample( fraction: Double, withReplacement: Boolean = true, - seed: Int = (math.random * 1000).toInt) = + seed: Int = (math.random * 1000).toInt): LogicalPlan = Sample(fraction, withReplacement, seed, logicalPlan) + // TODO specify the output column names def generate( generator: Generator, join: Boolean = false, outer: Boolean = false, - alias: Option[String] = None) = - Generate(generator, join, outer, None, logicalPlan) + alias: Option[String] = None): LogicalPlan = + Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan) - def insertInto(tableName: String, overwrite: Boolean = false) = + def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( - analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite) + analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite, false) - def analyze = analysis.SimpleAnalyzer(logicalPlan) + def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer.execute(logicalPlan)) } object plans { // scalastyle:ignore implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) extends LogicalPlanFunctions { - def writeToFile(path: String) = WriteToFile(path, logicalPlan) + def writeToFile(path: String): LogicalPlan = WriteToFile(path, logicalPlan) } } case class ScalaUdfBuilder[T: TypeTag](f: AnyRef) { - def call(args: Expression*) = ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args) + def call(args: Expression*): ScalaUdf = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args) + } } // scalastyle:off diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index 82e760b6c691..96a11e352ec5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -23,7 +23,9 @@ package org.apache.spark.sql.catalyst.expressions * of the name, or the expected nullability). */ object AttributeMap { - def apply[A](kvs: Seq[(Attribute, A)]) = new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap) + def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = { + new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap) + } } class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index 171845ad14e3..5345696570b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -17,26 +17,29 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.Star protected class AttributeEquals(val a: Attribute) { - override def hashCode() = a.exprId.hashCode() - override def equals(other: Any) = (a, other.asInstanceOf[AttributeEquals].a) match { + override def hashCode(): Int = a match { + case ar: AttributeReference => ar.exprId.hashCode() + case a => a.hashCode() + } + + override def equals(other: Any): Boolean = (a, other.asInstanceOf[AttributeEquals].a) match { case (a1: AttributeReference, a2: AttributeReference) => a1.exprId == a2.exprId case (a1, a2) => a1 == a2 } } object AttributeSet { - def apply(a: Attribute) = - new AttributeSet(Set(new AttributeEquals(a))) + def apply(a: Attribute): AttributeSet = new AttributeSet(Set(new AttributeEquals(a))) /** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */ - def apply(baseSet: Seq[Expression]) = + def apply(baseSet: Iterable[Expression]): AttributeSet = { new AttributeSet( baseSet .flatMap(_.references) .map(new AttributeEquals(_)).toSet) + } } /** @@ -54,8 +57,9 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) extends Traversable[Attribute] with Serializable { /** Returns true if the members of this AttributeSet and other are the same. */ - override def equals(other: Any) = other match { - case otherSet: AttributeSet => baseSet.map(_.a).forall(otherSet.contains) + override def equals(other: Any): Boolean = other match { + case otherSet: AttributeSet => + otherSet.size == baseSet.size && baseSet.map(_.a).forall(otherSet.contains) case _ => false } @@ -78,32 +82,34 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) * Returns true if the [[Attribute Attributes]] in this set are a subset of the Attributes in * `other`. */ - def subsetOf(other: AttributeSet) = baseSet.subsetOf(other.baseSet) + def subsetOf(other: AttributeSet): Boolean = baseSet.subsetOf(other.baseSet) /** * Returns a new [[AttributeSet]] that does not contain any of the [[Attribute Attributes]] found * in `other`. */ - def --(other: Traversable[NamedExpression]) = + def --(other: Traversable[NamedExpression]): AttributeSet = new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute))) /** * Returns a new [[AttributeSet]] that contains all of the [[Attribute Attributes]] found * in `other`. */ - def ++(other: AttributeSet) = new AttributeSet(baseSet ++ other.baseSet) + def ++(other: AttributeSet): AttributeSet = new AttributeSet(baseSet ++ other.baseSet) /** * Returns a new [[AttributeSet]] contain only the [[Attribute Attributes]] where `f` evaluates to * true. */ - override def filter(f: Attribute => Boolean) = new AttributeSet(baseSet.filter(ae => f(ae.a))) + override def filter(f: Attribute => Boolean): AttributeSet = + new AttributeSet(baseSet.filter(ae => f(ae.a))) /** * Returns a new [[AttributeSet]] that only contains [[Attribute Attributes]] that are found in * `this` and `other`. */ - def intersect(other: AttributeSet) = new AttributeSet(baseSet.intersect(other.baseSet)) + def intersect(other: AttributeSet): AttributeSet = + new AttributeSet(baseSet.intersect(other.baseSet)) override def foreach[U](f: (Attribute) => U): Unit = baseSet.map(_.a).foreach(f) @@ -111,7 +117,7 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) // sorts of things in its closure. override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq - override def toString = "{" + baseSet.map(_.a).mkString(", ") + "}" + override def toString: String = "{" + baseSet.map(_.a).mkString(", ") + "}" override def isEmpty: Boolean = baseSet.isEmpty } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 76a9f08dea85..976786c29dbc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -32,7 +32,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) type EvaluatedType = Any - override def toString = s"input[$ordinal]" + override def toString: String = s"input[$ordinal]" override def eval(input: Row): Any = input(ordinal) } @@ -42,7 +42,7 @@ object BindReferences extends Logging { def bindReference[A <: Expression]( expression: A, input: Seq[Attribute], - allowFailures: Boolean = false): A = { + allowFailures: Boolean = false): A = { // allowFailures 参数用于控制是否 允许在input找不到 该expression expression.transform { case a: AttributeReference => attachTree(a, "Binding attribute") { val ordinal = input.indexWhere(_.exprId == a.exprId) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index ece5ee73618c..adf941ab2a45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -21,7 +21,6 @@ import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.types._ /** Cast the child expression to the target data type. */ @@ -29,9 +28,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w override lazy val resolved = childrenResolved && resolve(child.dataType, dataType) - override def foldable = child.foldable + override def foldable: Boolean = child.foldable - override def nullable = forceNullable(child.dataType, dataType) || child.nullable + override def nullable: Boolean = forceNullable(child.dataType, dataType) || child.nullable private[this] def forceNullable(from: DataType, to: DataType) = (from, to) match { case (StringType, _: NumericType) => true @@ -103,7 +102,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } } - override def toString = s"CAST($child, $dataType)" + override def toString: String = s"CAST($child, $dataType)" type EvaluatedType = Any @@ -112,26 +111,26 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // UDFToString private[this] def castToString(from: DataType): Any => Any = from match { - case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8")) - case DateType => buildCast[Date](_, dateToString) - case TimestampType => buildCast[Timestamp](_, timestampToString) - case _ => buildCast[Any](_, _.toString) + case BinaryType => buildCast[Array[Byte]](_, UTF8String(_)) + case DateType => buildCast[Int](_, d => UTF8String(DateUtils.toString(d))) + case TimestampType => buildCast[Timestamp](_, t => UTF8String(timestampToString(t))) + case _ => buildCast[Any](_, o => UTF8String(o.toString)) } // BinaryConverter private[this] def castToBinary(from: DataType): Any => Any = from match { - case StringType => buildCast[String](_, _.getBytes("UTF-8")) + case StringType => buildCast[UTF8String](_, _.getBytes) } // UDFToBoolean private[this] def castToBoolean(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, _.length() != 0) + buildCast[UTF8String](_, _.length() != 0) case TimestampType => buildCast[Timestamp](_, t => t.getTime() != 0 || t.getNanos() != 0) case DateType => // Hive would return null when cast from date to boolean - buildCast[Date](_, d => null) + buildCast[Int](_, d => null) case LongType => buildCast[Long](_, _ != 0) case IntegerType => @@ -151,8 +150,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // TimestampConverter private[this] def castToTimestamp(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => { + buildCast[UTF8String](_, utfs => { // Throw away extra if more than 9 decimal places + val s = utfs.toString val periodIdx = s.indexOf(".") var n = s if (periodIdx != -1 && n.length() - periodIdx > 9) { @@ -171,7 +171,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case ByteType => buildCast[Byte](_, b => new Timestamp(b)) case DateType => - buildCast[Date](_, d => new Timestamp(d.getTime)) + buildCast[Int](_, d => new Timestamp(DateUtils.toJavaDate(d).getTime)) // TimestampWritable.decimalToTimestamp case DecimalType() => buildCast[Decimal](_, d => decimalToTimestamp(d)) @@ -224,47 +224,34 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } } - // Converts Timestamp to string according to Hive TimestampWritable convention - private[this] def timestampToDateString(ts: Timestamp): String = { - Cast.threadLocalDateFormat.get.format(ts) - } - // DateConverter private[this] def castToDate(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => - try Date.valueOf(s) catch { case _: java.lang.IllegalArgumentException => null }) + buildCast[UTF8String](_, s => + try DateUtils.fromJavaDate(Date.valueOf(s.toString)) + catch { case _: java.lang.IllegalArgumentException => null } + ) case TimestampType => // throw valid precision more than seconds, according to Hive. // Timestamp.nanos is in 0 to 999,999,999, no more than a second. - buildCast[Timestamp](_, t => new Date(Math.floor(t.getTime / 1000.0).toLong * 1000)) + buildCast[Timestamp](_, t => DateUtils.millisToDays(t.getTime)) // Hive throws this exception as a Semantic Exception - // It is never possible to compare result when hive return with exception, so we can return null + // It is never possible to compare result when hive return with exception, + // so we can return null // NULL is more reasonable here, since the query itself obeys the grammar. case _ => _ => null } - // Date cannot be cast to long, according to hive - private[this] def dateToLong(d: Date) = null - - // Date cannot be cast to double, according to hive - private[this] def dateToDouble(d: Date) = null - - // Converts Date to string according to Hive DateWritable convention - private[this] def dateToString(d: Date): String = { - Cast.threadLocalDateFormat.get.format(d) - } - // LongConverter private[this] def castToLong(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toLong catch { + buildCast[UTF8String](_, s => try s.toString.toLong catch { case _: NumberFormatException => null }) case BooleanType => buildCast[Boolean](_, b => if (b) 1L else 0L) case DateType => - buildCast[Date](_, d => dateToLong(d)) + buildCast[Int](_, d => null) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t)) case x: NumericType => @@ -274,13 +261,13 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // IntConverter private[this] def castToInt(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toInt catch { + buildCast[UTF8String](_, s => try s.toString.toInt catch { case _: NumberFormatException => null }) case BooleanType => buildCast[Boolean](_, b => if (b) 1 else 0) case DateType => - buildCast[Date](_, d => dateToLong(d)) + buildCast[Int](_, d => null) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toInt) case x: NumericType => @@ -290,13 +277,13 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // ShortConverter private[this] def castToShort(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toShort catch { + buildCast[UTF8String](_, s => try s.toString.toShort catch { case _: NumberFormatException => null }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort) case DateType => - buildCast[Date](_, d => dateToLong(d)) + buildCast[Int](_, d => null) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toShort) case x: NumericType => @@ -306,13 +293,13 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // ByteConverter private[this] def castToByte(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toByte catch { + buildCast[UTF8String](_, s => try s.toString.toByte catch { case _: NumberFormatException => null }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte) case DateType => - buildCast[Date](_, d => dateToLong(d)) + buildCast[Int](_, d => null) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toByte) case x: NumericType => @@ -336,13 +323,15 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { case StringType => - buildCast[String](_, s => try changePrecision(Decimal(s.toDouble), target) catch { + buildCast[UTF8String](_, s => try { + changePrecision(Decimal(s.toString.toDouble), target) + } catch { case _: NumberFormatException => null }) case BooleanType => buildCast[Boolean](_, b => changePrecision(if (b) Decimal(1) else Decimal(0), target)) case DateType => - buildCast[Date](_, d => null) // date can't cast to decimal in Hive + buildCast[Int](_, d => null) // date can't cast to decimal in Hive case TimestampType => // Note that we lose precision here. buildCast[Timestamp](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) @@ -361,13 +350,13 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // DoubleConverter private[this] def castToDouble(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toDouble catch { + buildCast[UTF8String](_, s => try s.toString.toDouble catch { case _: NumberFormatException => null }) case BooleanType => buildCast[Boolean](_, b => if (b) 1d else 0d) case DateType => - buildCast[Date](_, d => dateToDouble(d)) + buildCast[Int](_, d => null) case TimestampType => buildCast[Timestamp](_, t => timestampToDouble(t)) case x: NumericType => @@ -377,13 +366,13 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // FloatConverter private[this] def castToFloat(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toFloat catch { + buildCast[UTF8String](_, s => try s.toString.toFloat catch { case _: NumberFormatException => null }) case BooleanType => buildCast[Boolean](_, b => if (b) 1f else 0f) case DateType => - buildCast[Date](_, d => dateToDouble(d)) + buildCast[Int](_, d => null) case TimestampType => buildCast[Timestamp](_, t => timestampToDouble(t).toFloat) case x: NumericType => @@ -407,10 +396,17 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w val casts = from.fields.zip(to.fields).map { case (fromField, toField) => cast(fromField.dataType, toField.dataType) } - // TODO: This is very slow! - buildCast[Row](_, row => Row(row.toSeq.zip(casts).map { - case (v, cast) => if (v == null) null else cast(v) - }: _*)) + // TODO: Could be faster? + val newRow = new GenericMutableRow(from.fields.size) + buildCast[Row](_, row => { + var i = 0 + while (i < row.length) { + val v = row(i) + newRow.update(i, if (v == null) null else casts(i)(v)) + i += 1 + } + newRow.copy() + }) } private[this] def cast(from: DataType, to: DataType): Any => Any = to match { @@ -442,16 +438,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w object Cast { // `SimpleDateFormat` is not thread-safe. - private[sql] val threadLocalDateFormat = new ThreadLocal[DateFormat] { - override def initialValue() = { - new SimpleDateFormat("yyyy-MM-dd") + private[sql] val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { + override def initialValue(): SimpleDateFormat = { + new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") } } // `SimpleDateFormat` is not thread-safe. - private[sql] val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { - override def initialValue() = { - new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + private[sql] val threadLocalDateFormat = new ThreadLocal[DateFormat] { + override def initialValue(): SimpleDateFormat = { + new SimpleDateFormat("yyyy-MM-dd") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index cf14992ef835..4e3bbc06a5b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode @@ -64,206 +65,17 @@ abstract class Expression extends TreeNode[Expression] { * Returns true if all the children of this expression have been resolved to a specific schema * and false if any still contains any unresolved placeholders. */ - def childrenResolved = !children.exists(!_.resolved) + def childrenResolved: Boolean = !children.exists(!_.resolved) /** - * A set of helper functions that return the correct descendant of `scala.math.Numeric[T]` type - * and do any casting necessary of child evaluation. + * Returns a string representation of this expression that does not have developer centric + * debugging information like the expression id. */ - @inline - def n1(e: Expression, i: Row, f: ((Numeric[Any], Any) => Any)): Any = { - val evalE = e.eval(i) - if (evalE == null) { - null - } else { - e.dataType match { - case n: NumericType => - val castedFunction = f.asInstanceOf[(Numeric[n.JvmType], n.JvmType) => n.JvmType] - castedFunction(n.numeric, evalE.asInstanceOf[n.JvmType]) - case other => sys.error(s"Type $other does not support numeric operations") - } - } - } - - /** - * Evaluation helper function for 2 Numeric children expressions. Those expressions are supposed - * to be in the same data type, and also the return type. - * Either one of the expressions result is null, the evaluation result should be null. - */ - @inline - protected final def n2( - i: Row, - e1: Expression, - e2: Expression, - f: ((Numeric[Any], Any, Any) => Any)): Any = { - - if (e1.dataType != e2.dataType) { - throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}") - } - - val evalE1 = e1.eval(i) - if(evalE1 == null) { - null - } else { - val evalE2 = e2.eval(i) - if (evalE2 == null) { - null - } else { - e1.dataType match { - case n: NumericType => - f.asInstanceOf[(Numeric[n.JvmType], n.JvmType, n.JvmType) => n.JvmType]( - n.numeric, evalE1.asInstanceOf[n.JvmType], evalE2.asInstanceOf[n.JvmType]) - case other => sys.error(s"Type $other does not support numeric operations") - } - } - } - } - - /** - * Evaluation helper function for 2 Fractional children expressions. Those expressions are - * supposed to be in the same data type, and also the return type. - * Either one of the expressions result is null, the evaluation result should be null. - */ - @inline - protected final def f2( - i: Row, - e1: Expression, - e2: Expression, - f: ((Fractional[Any], Any, Any) => Any)): Any = { - if (e1.dataType != e2.dataType) { - throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}") - } - - val evalE1 = e1.eval(i: Row) - if(evalE1 == null) { - null - } else { - val evalE2 = e2.eval(i: Row) - if (evalE2 == null) { - null - } else { - e1.dataType match { - case ft: FractionalType => - f.asInstanceOf[(Fractional[ft.JvmType], ft.JvmType, ft.JvmType) => ft.JvmType]( - ft.fractional, evalE1.asInstanceOf[ft.JvmType], evalE2.asInstanceOf[ft.JvmType]) - case other => sys.error(s"Type $other does not support fractional operations") - } - } - } - } - - /** - * Evaluation helper function for 1 Fractional children expression. - * if the expression result is null, the evaluation result should be null. - */ - @inline - protected final def f1(i: Row, e1: Expression, f: ((Fractional[Any], Any) => Any)): Any = { - val evalE1 = e1.eval(i: Row) - if(evalE1 == null) { - null - } else { - e1.dataType match { - case ft: FractionalType => - f.asInstanceOf[(Fractional[ft.JvmType], ft.JvmType) => ft.JvmType]( - ft.fractional, evalE1.asInstanceOf[ft.JvmType]) - case other => sys.error(s"Type $other does not support fractional operations") - } - } - } - - /** - * Evaluation helper function for 2 Integral children expressions. Those expressions are - * supposed to be in the same data type, and also the return type. - * Either one of the expressions result is null, the evaluation result should be null. - */ - @inline - protected final def i2( - i: Row, - e1: Expression, - e2: Expression, - f: ((Integral[Any], Any, Any) => Any)): Any = { - if (e1.dataType != e2.dataType) { - throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}") - } - - val evalE1 = e1.eval(i) - if(evalE1 == null) { - null - } else { - val evalE2 = e2.eval(i) - if (evalE2 == null) { - null - } else { - e1.dataType match { - case i: IntegralType => - f.asInstanceOf[(Integral[i.JvmType], i.JvmType, i.JvmType) => i.JvmType]( - i.integral, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType]) - case i: FractionalType => - f.asInstanceOf[(Integral[i.JvmType], i.JvmType, i.JvmType) => i.JvmType]( - i.asIntegral, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType]) - case other => sys.error(s"Type $other does not support numeric operations") - } - } - } - } - - /** - * Evaluation helper function for 1 Integral children expression. - * if the expression result is null, the evaluation result should be null. - */ - @inline - protected final def i1(i: Row, e1: Expression, f: ((Integral[Any], Any) => Any)): Any = { - val evalE1 = e1.eval(i) - if(evalE1 == null) { - null - } else { - e1.dataType match { - case i: IntegralType => - f.asInstanceOf[(Integral[i.JvmType], i.JvmType) => i.JvmType]( - i.integral, evalE1.asInstanceOf[i.JvmType]) - case i: FractionalType => - f.asInstanceOf[(Integral[i.JvmType], i.JvmType) => i.JvmType]( - i.asIntegral, evalE1.asInstanceOf[i.JvmType]) - case other => sys.error(s"Type $other does not support numeric operations") - } - } - } - - /** - * Evaluation helper function for 2 Comparable children expressions. Those expressions are - * supposed to be in the same data type, and the return type should be Integer: - * Negative value: 1st argument less than 2nd argument - * Zero: 1st argument equals 2nd argument - * Positive value: 1st argument greater than 2nd argument - * - * Either one of the expressions result is null, the evaluation result should be null. - */ - @inline - protected final def c2( - i: Row, - e1: Expression, - e2: Expression, - f: ((Ordering[Any], Any, Any) => Any)): Any = { - if (e1.dataType != e2.dataType) { - throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}") - } - - val evalE1 = e1.eval(i) - if(evalE1 == null) { - null - } else { - val evalE2 = e2.eval(i) - if (evalE2 == null) { - null - } else { - e1.dataType match { - case i: NativeType => - f.asInstanceOf[(Ordering[i.JvmType], i.JvmType, i.JvmType) => Boolean]( - i.ordering, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType]) - case other => sys.error(s"Type $other does not support ordered operations") - } - } - } + def prettyString: String = { + transform { + case a: AttributeReference => PrettyAttribute(a.name) + case u: UnresolvedAttribute => PrettyAttribute(u.name) + }.toString } } @@ -272,9 +84,9 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express def symbol: String - override def foldable = left.foldable && right.foldable + override def foldable: Boolean = left.foldable && right.foldable - override def toString = s"($left $symbol $right)" + override def toString: String = s"($left $symbol $right)" } abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { @@ -292,8 +104,8 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio case class GroupExpression(children: Seq[Expression]) extends Expression { self: Product => type EvaluatedType = Seq[Any] - override def eval(input: Row): EvaluatedType = ??? - override def nullable = false - override def foldable = false - override def dataType = ??? + override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException + override def nullable: Boolean = false + override def foldable: Boolean = false + override def dataType: DataType = throw new UnsupportedOperationException } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index db5d897ee569..c2866cd95540 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -40,7 +40,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { new GenericRow(outputArray) } - override def toString = s"Row => [${exprArray.mkString(",")}]" + override def toString: String = s"Row => [${exprArray.mkString(",")}]" } /** @@ -107,12 +107,12 @@ class JoinedRow extends Row { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length = row1.length + row2.length + override def length: Int = row1.length + row2.length - override def apply(i: Int) = + override def apply(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) - override def isNullAt(i: Int) = + override def isNullAt(i: Int): Boolean = if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) override def getInt(i: Int): Int = @@ -142,7 +142,7 @@ class JoinedRow extends Row { override def getAs[T](i: Int): T = if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy() = { + override def copy(): Row = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -153,7 +153,7 @@ class JoinedRow extends Row { new GenericRow(copiedValues) } - override def toString() = { + override def toString: String = { // Make sure toString never throws NullPointerException. if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" @@ -207,12 +207,12 @@ class JoinedRow2 extends Row { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length = row1.length + row2.length + override def length: Int = row1.length + row2.length - override def apply(i: Int) = + override def apply(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) - override def isNullAt(i: Int) = + override def isNullAt(i: Int): Boolean = if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) override def getInt(i: Int): Int = @@ -242,7 +242,7 @@ class JoinedRow2 extends Row { override def getAs[T](i: Int): T = if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy() = { + override def copy(): Row = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -253,7 +253,7 @@ class JoinedRow2 extends Row { new GenericRow(copiedValues) } - override def toString() = { + override def toString: String = { // Make sure toString never throws NullPointerException. if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" @@ -301,12 +301,12 @@ class JoinedRow3 extends Row { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length = row1.length + row2.length + override def length: Int = row1.length + row2.length - override def apply(i: Int) = + override def apply(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) - override def isNullAt(i: Int) = + override def isNullAt(i: Int): Boolean = if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) override def getInt(i: Int): Int = @@ -336,7 +336,7 @@ class JoinedRow3 extends Row { override def getAs[T](i: Int): T = if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy() = { + override def copy(): Row = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -347,7 +347,7 @@ class JoinedRow3 extends Row { new GenericRow(copiedValues) } - override def toString() = { + override def toString: String = { // Make sure toString never throws NullPointerException. if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" @@ -395,12 +395,12 @@ class JoinedRow4 extends Row { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length = row1.length + row2.length + override def length: Int = row1.length + row2.length - override def apply(i: Int) = + override def apply(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) - override def isNullAt(i: Int) = + override def isNullAt(i: Int): Boolean = if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) override def getInt(i: Int): Int = @@ -430,7 +430,7 @@ class JoinedRow4 extends Row { override def getAs[T](i: Int): T = if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy() = { + override def copy(): Row = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -441,7 +441,7 @@ class JoinedRow4 extends Row { new GenericRow(copiedValues) } - override def toString() = { + override def toString: String = { // Make sure toString never throws NullPointerException. if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" @@ -489,12 +489,12 @@ class JoinedRow5 extends Row { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length = row1.length + row2.length + override def length: Int = row1.length + row2.length - override def apply(i: Int) = + override def apply(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) - override def isNullAt(i: Int) = + override def isNullAt(i: Int): Boolean = if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) override def getInt(i: Int): Int = @@ -524,7 +524,7 @@ class JoinedRow5 extends Row { override def getAs[T](i: Int): T = if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy() = { + override def copy(): Row = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -535,7 +535,7 @@ class JoinedRow5 extends Row { new GenericRow(copiedValues) } - override def toString() = { + override def toString: String = { // Make sure toString never throws NullPointerException. if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala index b2c6d3029031..f5fea3f015dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala @@ -18,16 +18,19 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Random -import org.apache.spark.sql.types.DoubleType + +import org.apache.spark.sql.types.{DataType, DoubleType} case object Rand extends LeafExpression { - override def dataType = DoubleType - override def nullable = false + override def dataType: DataType = DoubleType + override def nullable: Boolean = false private[this] lazy val rand = new Random - override def eval(input: Row = null) = rand.nextDouble().asInstanceOf[EvaluatedType] + override def eval(input: Row = null): EvaluatedType = { + rand.nextDouble().asInstanceOf[EvaluatedType] + } - override def toString = "RAND()" + override def toString: String = "RAND()" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 8a36c6810790..9a77ca624ebe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.types.DataType /** @@ -29,9 +29,9 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi type EvaluatedType = Any - def nullable = true + override def nullable: Boolean = true - override def toString = s"scalaUDF(${children.mkString(",")})" + override def toString: String = s"scalaUDF(${children.mkString(",")})" // scalastyle:off @@ -39,363 +39,924 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi (1 to 22).map { x => val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _) - val evals = (0 to x - 1).map(x => s" ScalaReflection.convertToScala(children($x).eval(input), children($x).dataType)").reduce(_ + ",\n " + _) - - s""" - case $x => - function.asInstanceOf[($anys) => Any]( - $evals) - """ + val childs = (0 to x - 1).map(x => s"val child$x = children($x)").reduce(_ + "\n " + _) + lazy val converters = (0 to x - 1).map(x => s"lazy val converter$x = CatalystTypeConverters.createToScalaConverter(child$x.dataType)").reduce(_ + "\n " + _) + val evals = (0 to x - 1).map(x => s"converter$x(child$x.eval(input))").reduce(_ + ",\n " + _) + + s"""case $x => + val func = function.asInstanceOf[($anys) => Any] + $childs + $converters + (input: Row) => { + func( + $evals) + } + """ }.foreach(println) */ - - override def eval(input: Row): Any = { - val result = children.size match { - case 0 => function.asInstanceOf[() => Any]() - case 1 => - function.asInstanceOf[(Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType)) - - - case 2 => - function.asInstanceOf[(Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType)) - - - case 3 => - function.asInstanceOf[(Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType)) - - - case 4 => - function.asInstanceOf[(Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType)) - - - case 5 => - function.asInstanceOf[(Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType)) - - - case 6 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType)) - - - case 7 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType)) - - - case 8 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType)) - - - case 9 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType)) - - - case 10 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType)) - - - case 11 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType)) - - - case 12 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType)) - - - case 13 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType)) - - - case 14 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType)) - - - case 15 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType)) - - - case 16 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType)) - - - case 17 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), - ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType)) - - - case 18 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), - ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), - ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType)) - - - case 19 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), - ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), - ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), - ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType)) - - - case 20 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), - ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), - ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), - ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), - ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType)) - - - case 21 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), - ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), - ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), - ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), - ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType), - ScalaReflection.convertToScala(children(20).eval(input), children(20).dataType)) - - - case 22 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), - ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), - ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), - ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), - ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType), - ScalaReflection.convertToScala(children(20).eval(input), children(20).dataType), - ScalaReflection.convertToScala(children(21).eval(input), children(21).dataType)) - - } - // scalastyle:on - - ScalaReflection.convertToCatalyst(result, dataType) + + val f = children.size match { + case 0 => + val func = function.asInstanceOf[() => Any] + (input: Row) => { + func() + } + + case 1 => + val func = function.asInstanceOf[(Any) => Any] + val child0 = children(0) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + (input: Row) => { + func( + converter0(child0.eval(input))) + } + + case 2 => + val func = function.asInstanceOf[(Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + (input: Row) => { + func( + converter0(child0.eval(input)), + converter1(child1.eval(input))) + } + + case 3 => + val func = function.asInstanceOf[(Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + (input: Row) => { + func( + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input))) + } + + case 4 => + val func = function.asInstanceOf[(Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + (input: Row) => { + func( + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input))) + } + + case 5 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + (input: Row) => { + func( + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input))) + } + + case 6 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + (input: Row) => { + func( + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input))) + } + + case 7 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + (input: Row) => { + func( + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input))) + } + + case 8 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + (input: Row) => { + func( + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input))) + } + + case 9 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + (input: Row) => { + func( + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input))) + } + + case 10 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + (input: Row) => { + func( + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input))) + } + + case 11 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + (input: Row) => { + func( + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input))) + } + + case 12 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + (input: Row) => { + func( + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input))) + } + + case 13 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + (input: Row) => { + func( + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input))) + } + + case 14 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + (input: Row) => { + func( + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input))) + } + + case 15 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + (input: Row) => { + func( + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input))) + } + + case 16 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) + (input: Row) => { + func( + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input))) + } + + case 17 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + val child16 = children(16) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) + lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) + (input: Row) => { + func( + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input)), + converter16(child16.eval(input))) + } + + case 18 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + val child16 = children(16) + val child17 = children(17) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) + lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) + lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) + (input: Row) => { + func( + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input)), + converter16(child16.eval(input)), + converter17(child17.eval(input))) + } + + case 19 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + val child16 = children(16) + val child17 = children(17) + val child18 = children(18) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) + lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) + lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) + lazy val converter18 = CatalystTypeConverters.createToScalaConverter(child18.dataType) + (input: Row) => { + func( + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input)), + converter16(child16.eval(input)), + converter17(child17.eval(input)), + converter18(child18.eval(input))) + } + + case 20 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + val child16 = children(16) + val child17 = children(17) + val child18 = children(18) + val child19 = children(19) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) + lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) + lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) + lazy val converter18 = CatalystTypeConverters.createToScalaConverter(child18.dataType) + lazy val converter19 = CatalystTypeConverters.createToScalaConverter(child19.dataType) + (input: Row) => { + func( + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input)), + converter16(child16.eval(input)), + converter17(child17.eval(input)), + converter18(child18.eval(input)), + converter19(child19.eval(input))) + } + + case 21 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + val child16 = children(16) + val child17 = children(17) + val child18 = children(18) + val child19 = children(19) + val child20 = children(20) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) + lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) + lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) + lazy val converter18 = CatalystTypeConverters.createToScalaConverter(child18.dataType) + lazy val converter19 = CatalystTypeConverters.createToScalaConverter(child19.dataType) + lazy val converter20 = CatalystTypeConverters.createToScalaConverter(child20.dataType) + (input: Row) => { + func( + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input)), + converter16(child16.eval(input)), + converter17(child17.eval(input)), + converter18(child18.eval(input)), + converter19(child19.eval(input)), + converter20(child20.eval(input))) + } + + case 22 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + val child16 = children(16) + val child17 = children(17) + val child18 = children(18) + val child19 = children(19) + val child20 = children(20) + val child21 = children(21) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) + lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) + lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) + lazy val converter18 = CatalystTypeConverters.createToScalaConverter(child18.dataType) + lazy val converter19 = CatalystTypeConverters.createToScalaConverter(child19.dataType) + lazy val converter20 = CatalystTypeConverters.createToScalaConverter(child20.dataType) + lazy val converter21 = CatalystTypeConverters.createToScalaConverter(child21.dataType) + (input: Row) => { + func( + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input)), + converter16(child16.eval(input)), + converter17(child17.eval(input)), + converter18(child18.eval(input)), + converter19(child19.eval(input)), + converter20(child20.eval(input)), + converter21(child21.eval(input))) + } } + + // scalastyle:on + + override def eval(input: Row): Any = CatalystTypeConverters.convertToCatalyst(f(input), dataType) + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index d00b2ac09745..83074eb1e631 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.types.DataType abstract sealed class SortDirection case object Ascending extends SortDirection @@ -31,12 +32,12 @@ case object Descending extends SortDirection case class SortOrder(child: Expression, direction: SortDirection) extends Expression with trees.UnaryNode[Expression] { - override def dataType = child.dataType - override def nullable = child.nullable + override def dataType: DataType = child.dataType + override def nullable: Boolean = child.nullable // SortOrder itself is never evaluated. override def eval(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" + override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 7434165f654f..3475ed05f445 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -62,126 +62,126 @@ abstract class MutableValue extends Serializable { var isNull: Boolean = true def boxed: Any def update(v: Any) - def copy(): this.type + def copy(): MutableValue } final class MutableInt extends MutableValue { var value: Int = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = { isNull = false - v.asInstanceOf[Int] + value = v.asInstanceOf[Int] } - def copy() = { + override def copy(): MutableInt = { val newCopy = new MutableInt newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableInt] } } final class MutableFloat extends MutableValue { var value: Float = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = { isNull = false - v.asInstanceOf[Float] + value = v.asInstanceOf[Float] } - def copy() = { + override def copy(): MutableFloat = { val newCopy = new MutableFloat newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableFloat] } } final class MutableBoolean extends MutableValue { var value: Boolean = false - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = { isNull = false - v.asInstanceOf[Boolean] + value = v.asInstanceOf[Boolean] } - def copy() = { + override def copy(): MutableBoolean = { val newCopy = new MutableBoolean newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableBoolean] } } final class MutableDouble extends MutableValue { var value: Double = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = { isNull = false - v.asInstanceOf[Double] + value = v.asInstanceOf[Double] } - def copy() = { + override def copy(): MutableDouble = { val newCopy = new MutableDouble newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableDouble] } } final class MutableShort extends MutableValue { var value: Short = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = value = { isNull = false v.asInstanceOf[Short] } - def copy() = { + override def copy(): MutableShort = { val newCopy = new MutableShort newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableShort] } } final class MutableLong extends MutableValue { var value: Long = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = value = { isNull = false v.asInstanceOf[Long] } - def copy() = { + override def copy(): MutableLong = { val newCopy = new MutableLong newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableLong] } } final class MutableByte extends MutableValue { var value: Byte = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = value = { isNull = false v.asInstanceOf[Byte] } - def copy() = { + override def copy(): MutableByte = { val newCopy = new MutableByte newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableByte] } } final class MutableAny extends MutableValue { var value: Any = _ - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = { isNull = false - v.asInstanceOf[Any] + value = v.asInstanceOf[Any] } - def copy() = { + override def copy(): MutableAny = { val newCopy = new MutableAny newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableAny] } } @@ -220,22 +220,27 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def isNullAt(i: Int): Boolean = values(i).isNull override def copy(): Row = { - val newValues = new Array[MutableValue](values.length) + val newValues = new Array[Any](values.length) var i = 0 while (i < values.length) { - newValues(i) = values(i).copy() + newValues(i) = values(i).boxed i += 1 } - new SpecificMutableRow(newValues) + + new GenericRow(newValues) } - override def update(ordinal: Int, value: Any): Unit = { - if (value == null) setNullAt(ordinal) else values(ordinal).update(value) + override def update(ordinal: Int, value: Any) { + if (value == null) { + setNullAt(ordinal) + } else { + values(ordinal).update(value) + } } - override def setString(ordinal: Int, value: String) = update(ordinal, value) + override def setString(ordinal: Int, value: String): Unit = update(ordinal, UTF8String(value)) - override def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String] + override def getString(ordinal: Int): String = apply(ordinal).toString override def setInt(ordinal: Int, value: Int): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableInt] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala old mode 100755 new mode 100644 index 735b7488fdcb..f3830c6d3bcf --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -79,27 +79,29 @@ abstract class AggregateFunction /** Base should return the generic aggregate expression that this function is computing */ val base: AggregateExpression - override def nullable = base.nullable - override def dataType = base.dataType + override def nullable: Boolean = base.nullable + override def dataType: DataType = base.dataType def update(input: Row): Unit // Do we really need this? - override def newInstance() = makeCopy(productIterator.map { case a: AnyRef => a }.toArray) + override def newInstance(): AggregateFunction = { + makeCopy(productIterator.map { case a: AnyRef => a }.toArray) + } } case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = true - override def dataType = child.dataType - override def toString = s"MIN($child)" + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"MIN($child)" override def asPartial: SplitEvaluation = { val partialMin = Alias(Min(child), "PartialMin")() SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil) } - override def newInstance() = new MinFunction(child, this) + override def newInstance(): MinFunction = new MinFunction(child, this) } case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -121,16 +123,16 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = true - override def dataType = child.dataType - override def toString = s"MAX($child)" + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"MAX($child)" override def asPartial: SplitEvaluation = { val partialMax = Alias(Max(child), "PartialMax")() SplitEvaluation(Max(partialMax.toAttribute), partialMax :: Nil) } - override def newInstance() = new MaxFunction(child, this) + override def newInstance(): MaxFunction = new MaxFunction(child, this) } case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -152,29 +154,29 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = false - override def dataType = LongType - override def toString = s"COUNT($child)" + override def nullable: Boolean = false + override def dataType: LongType.type = LongType + override def toString: String = s"COUNT($child)" override def asPartial: SplitEvaluation = { val partialCount = Alias(Count(child), "PartialCount")() SplitEvaluation(Coalesce(Seq(Sum(partialCount.toAttribute), Literal(0L))), partialCount :: Nil) } - override def newInstance() = new CountFunction(child, this) + override def newInstance(): CountFunction = new CountFunction(child, this) } case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate { def this() = this(null) - override def children = expressions + override def children: Seq[Expression] = expressions - override def nullable = false - override def dataType = LongType - override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})" - override def newInstance() = new CountDistinctFunction(expressions, this) + override def nullable: Boolean = false + override def dataType: DataType = LongType + override def toString: String = s"COUNT(DISTINCT ${expressions.mkString(",")})" + override def newInstance(): CountDistinctFunction = new CountDistinctFunction(expressions, this) - override def asPartial = { + override def asPartial: SplitEvaluation = { val partialSet = Alias(CollectHashSet(expressions), "partialSets")() SplitEvaluation( CombineSetsAndCount(partialSet.toAttribute), @@ -185,11 +187,12 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression { def this() = this(null) - override def children = expressions - override def nullable = false - override def dataType = ArrayType(expressions.head.dataType) - override def toString = s"AddToHashSet(${expressions.mkString(",")})" - override def newInstance() = new CollectHashSetFunction(expressions, this) + override def children: Seq[Expression] = expressions + override def nullable: Boolean = false + override def dataType: OpenHashSetUDT = new OpenHashSetUDT(expressions.head.dataType) + override def toString: String = s"AddToHashSet(${expressions.mkString(",")})" + override def newInstance(): CollectHashSetFunction = + new CollectHashSetFunction(expressions, this) } case class CollectHashSetFunction( @@ -219,11 +222,13 @@ case class CollectHashSetFunction( case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression { def this() = this(null) - override def children = inputSet :: Nil - override def nullable = false - override def dataType = LongType - override def toString = s"CombineAndCount($inputSet)" - override def newInstance() = new CombineSetsAndCountFunction(inputSet, this) + override def children: Seq[Expression] = inputSet :: Nil + override def nullable: Boolean = false + override def dataType: DataType = LongType + override def toString: String = s"CombineAndCount($inputSet)" + override def newInstance(): CombineSetsAndCountFunction = { + new CombineSetsAndCountFunction(inputSet, this) + } } case class CombineSetsAndCountFunction( @@ -246,30 +251,51 @@ case class CombineSetsAndCountFunction( override def eval(input: Row): Any = seen.size.toLong } +/** The data type of ApproxCountDistinctPartition since its output is a HyperLogLog object. */ +private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] { + + override def sqlType: DataType = BinaryType + + /** Since we are using HyperLogLog internally, usually it will not be called. */ + override def serialize(obj: Any): Array[Byte] = + obj.asInstanceOf[HyperLogLog].getBytes + + + /** Since we are using HyperLogLog internally, usually it will not be called. */ + override def deserialize(datum: Any): HyperLogLog = + HyperLogLog.Builder.build(datum.asInstanceOf[Array[Byte]]) + + override def userClass: Class[HyperLogLog] = classOf[HyperLogLog] +} + case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) extends AggregateExpression with trees.UnaryNode[Expression] { - override def nullable = false - override def dataType = child.dataType - override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" - override def newInstance() = new ApproxCountDistinctPartitionFunction(child, this, relativeSD) + override def nullable: Boolean = false + override def dataType: DataType = HyperLogLogUDT + override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" + override def newInstance(): ApproxCountDistinctPartitionFunction = { + new ApproxCountDistinctPartitionFunction(child, this, relativeSD) + } } case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) extends AggregateExpression with trees.UnaryNode[Expression] { - override def nullable = false - override def dataType = LongType - override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" - override def newInstance() = new ApproxCountDistinctMergeFunction(child, this, relativeSD) + override def nullable: Boolean = false + override def dataType: LongType.type = LongType + override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" + override def newInstance(): ApproxCountDistinctMergeFunction = { + new ApproxCountDistinctMergeFunction(child, this, relativeSD) + } } case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = false - override def dataType = LongType - override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" + override def nullable: Boolean = false + override def dataType: LongType.type = LongType + override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" override def asPartial: SplitEvaluation = { val partialCount = @@ -280,14 +306,14 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) partialCount :: Nil) } - override def newInstance() = new CountDistinctFunction(child :: Nil, this) + override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this) } case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = true + override def nullable: Boolean = true - override def dataType = child.dataType match { + override def dataType: DataType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType(precision + 4, scale + 4) // Add 4 digits after decimal point, like Hive case DecimalType.Unlimited => @@ -296,11 +322,11 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN DoubleType } - override def toString = s"AVG($child)" + override def toString: String = s"AVG($child)" override def asPartial: SplitEvaluation = { child.dataType match { - case DecimalType.Fixed(_, _) => + case DecimalType.Fixed(_, _) | DecimalType.Unlimited => // Turn the child to unlimited decimals for calculation, before going back to fixed val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() val partialCount = Alias(Count(child), "PartialCount")() @@ -323,14 +349,14 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN } } - override def newInstance() = new AverageFunction(child, this) + override def newInstance(): AverageFunction = new AverageFunction(child, this) } case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = true + override def nullable: Boolean = true - override def dataType = child.dataType match { + override def dataType: DataType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive case DecimalType.Unlimited => @@ -339,33 +365,57 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ child.dataType } - override def toString = s"SUM($child)" + override def toString: String = s"SUM($child)" override def asPartial: SplitEvaluation = { child.dataType match { case DecimalType.Fixed(_, _) => val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() SplitEvaluation( - Cast(Sum(partialSum.toAttribute), dataType), + Cast(CombineSum(partialSum.toAttribute), dataType), partialSum :: Nil) case _ => val partialSum = Alias(Sum(child), "PartialSum")() SplitEvaluation( - Sum(partialSum.toAttribute), + CombineSum(partialSum.toAttribute), partialSum :: Nil) } } - override def newInstance() = new SumFunction(child, this) + override def newInstance(): SumFunction = new SumFunction(child, this) +} + +/** + * Sum should satisfy 3 cases: + * 1) sum of all null values = zero + * 2) sum for table column with no data = null + * 3) sum of column with null and not null values = sum of not null values + * Require separate CombineSum Expression and function as it has to distinguish "No data" case + * versus "data equals null" case, while aggregating results and at each partial expression.i.e., + * Combining PartitionLevel InputData + * <-- null + * Zero <-- Zero <-- null + * + * <-- null <-- no data + * null <-- null <-- no data + */ +case class CombineSum(child: Expression) extends AggregateExpression { + def this() = this(null) + + override def children: Seq[Expression] = child :: Nil + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"CombineSum($child)" + override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this) } case class SumDistinct(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { def this() = this(null) - override def nullable = true - override def dataType = child.dataType match { + override def nullable: Boolean = true + override def dataType: DataType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive case DecimalType.Unlimited => @@ -373,10 +423,10 @@ case class SumDistinct(child: Expression) case _ => child.dataType } - override def toString = s"SUM(DISTINCT ${child})" - override def newInstance() = new SumDistinctFunction(child, this) + override def toString: String = s"SUM(DISTINCT $child)" + override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this) - override def asPartial = { + override def asPartial: SplitEvaluation = { val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")() SplitEvaluation( CombineSetsAndSum(partialSet.toAttribute, this), @@ -387,11 +437,13 @@ case class SumDistinct(child: Expression) case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression { def this() = this(null, null) - override def children = inputSet :: Nil - override def nullable = true - override def dataType = base.dataType - override def toString = s"CombineAndSum($inputSet)" - override def newInstance() = new CombineSetsAndSumFunction(inputSet, this) + override def children: Seq[Expression] = inputSet :: Nil + override def nullable: Boolean = true + override def dataType: DataType = base.dataType + override def toString: String = s"CombineAndSum($inputSet)" + override def newInstance(): CombineSetsAndSumFunction = { + new CombineSetsAndSumFunction(inputSet, this) + } } case class CombineSetsAndSumFunction( @@ -425,9 +477,9 @@ case class CombineSetsAndSumFunction( } case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = true - override def dataType = child.dataType - override def toString = s"FIRST($child)" + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"FIRST($child)" override def asPartial: SplitEvaluation = { val partialFirst = Alias(First(child), "PartialFirst")() @@ -435,14 +487,14 @@ case class First(child: Expression) extends PartialAggregate with trees.UnaryNod First(partialFirst.toAttribute), partialFirst :: Nil) } - override def newInstance() = new FirstFunction(child, this) + override def newInstance(): FirstFunction = new FirstFunction(child, this) } case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references = child.references - override def nullable = true - override def dataType = child.dataType - override def toString = s"LAST($child)" + override def references: AttributeSet = child.references + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"LAST($child)" override def asPartial: SplitEvaluation = { val partialLast = Alias(Last(child), "PartialLast")() @@ -450,7 +502,7 @@ case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode Last(partialLast.toAttribute), partialLast :: Nil) } - override def newInstance() = new LastFunction(child, this) + override def newInstance(): LastFunction = new LastFunction(child, this) } case class AverageFunction(expr: Expression, base: AggregateExpression) @@ -471,7 +523,8 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) private var count: Long = _ private val sum = MutableLiteral(zero.eval(null), calcType) - private def addFunction(value: Any) = Add(sum, Cast(Literal(value, expr.dataType), calcType)) + private def addFunction(value: Any) = Add(sum, + Cast(Literal.create(value, expr.dataType), calcType)) override def eval(input: Row): Any = { if (count == 0L) { @@ -565,7 +618,8 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr private val sum = MutableLiteral(null, calcType) - private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum)) + private val addFunction = + Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) override def update(input: Row): Unit = { sum.update(addFunction, input) @@ -580,6 +634,43 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr } } +case class CombineSumFunction(expr: Expression, base: AggregateExpression) + extends AggregateFunction { + + def this() = this(null, null) // Required for serialization. + + private val calcType = + expr.dataType match { + case DecimalType.Fixed(_, _) => + DecimalType.Unlimited + case _ => + expr.dataType + } + + private val zero = Cast(Literal(0), calcType) + + private val sum = MutableLiteral(null, calcType) + + private val addFunction = + Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) + + override def update(input: Row): Unit = { + val result = expr.eval(input) + // partial sum result can be null only when no input rows present + if(result != null) { + sum.update(addFunction, input) + } + } + + override def eval(input: Row): Any = { + expr.dataType match { + case DecimalType.Fixed(_, _) => + Cast(sum, dataType).eval(null) + case _ => sum.eval(null) + } + } +} + case class SumDistinctFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -651,6 +742,7 @@ case class LastFunction(expr: Expression, base: AggregateExpression) extends Agg result = input } - override def eval(input: Row): Any = if (result != null) expr.eval(result.asInstanceOf[Row]) - else null + override def eval(input: Row): Any = { + if (result != null) expr.eval(result.asInstanceOf[Row]) else null + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 574907f566c0..140ccd8d3796 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -18,41 +18,53 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.types._ case class UnaryMinus(child: Expression) extends UnaryExpression { type EvaluatedType = Any - def dataType = child.dataType - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"-$child" + override def dataType: DataType = child.dataType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"-$child" + + lazy val numeric = dataType match { + case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] + case other => sys.error(s"Type $other does not support numeric operations") + } override def eval(input: Row): Any = { - n1(child, input, _.negate(_)) + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + numeric.negate(evalE) + } } } case class Sqrt(child: Expression) extends UnaryExpression { type EvaluatedType = Any - def dataType = DoubleType - override def foldable = child.foldable - def nullable = true - override def toString = s"SQRT($child)" + override def dataType: DataType = DoubleType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = true + override def toString: String = s"SQRT($child)" + + lazy val numeric = child.dataType match { + case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] + case other => sys.error(s"Type $other does not support non-negative numeric operations") + } override def eval(input: Row): Any = { val evalE = child.eval(input) if (evalE == null) { null } else { - child.dataType match { - case n: NumericType => - val value = n.numeric.toDouble(evalE.asInstanceOf[n.JvmType]) - if (value < 0) null - else math.sqrt(value) - case other => sys.error(s"Type $other does not support non-negative numeric operations") - } + val value = numeric.toDouble(evalE) + if (value < 0) null + else math.sqrt(value) } } } @@ -62,14 +74,14 @@ abstract class BinaryArithmetic extends BinaryExpression { type EvaluatedType = Any - def nullable = left.nullable || right.nullable + def nullable: Boolean = left.nullable || right.nullable override lazy val resolved = left.resolved && right.resolved && left.dataType == right.dataType && !DecimalType.isFixed(left.dataType) - def dataType = { + def dataType: DataType = { if (!resolved) { throw new UnresolvedException(this, s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") @@ -96,51 +108,122 @@ abstract class BinaryArithmetic extends BinaryExpression { } case class Add(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "+" + override def symbol: String = "+" + + lazy val numeric = dataType match { + case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] + case other => sys.error(s"Type $other does not support numeric operations") + } - override def eval(input: Row): Any = n2(input, left, right, _.plus(_, _)) + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if(evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + numeric.plus(evalE1, evalE2) + } + } + } } case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "-" + override def symbol: String = "-" - override def eval(input: Row): Any = n2(input, left, right, _.minus(_, _)) + lazy val numeric = dataType match { + case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] + case other => sys.error(s"Type $other does not support numeric operations") + } + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if(evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + numeric.minus(evalE1, evalE2) + } + } + } } case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "*" + override def symbol: String = "*" + + lazy val numeric = dataType match { + case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] + case other => sys.error(s"Type $other does not support numeric operations") + } - override def eval(input: Row): Any = n2(input, left, right, _.times(_, _)) + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if(evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + numeric.times(evalE1, evalE2) + } + } + } } case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "/" + override def symbol: String = "/" - override def nullable = true + override def nullable: Boolean = true + lazy val div: (Any, Any) => Any = dataType match { + case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div + case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot + case other => sys.error(s"Type $other does not support numeric operations") + } + override def eval(input: Row): Any = { val evalE2 = right.eval(input) - dataType match { - case _ if evalE2 == null => null - case _ if evalE2 == 0 => null - case ft: FractionalType => f1(input, left, _.div(_, evalE2.asInstanceOf[ft.JvmType])) - case it: IntegralType => i1(input, left, _.quot(_, evalE2.asInstanceOf[it.JvmType])) + if (evalE2 == null || evalE2 == 0) { + null + } else { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + div(evalE1, evalE2) + } } } - } case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "%" + override def symbol: String = "%" - override def nullable = true + override def nullable: Boolean = true + + lazy val integral = dataType match { + case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] + case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]] + case other => sys.error(s"Type $other does not support numeric operations") + } override def eval(input: Row): Any = { val evalE2 = right.eval(input) - dataType match { - case _ if evalE2 == null => null - case _ if evalE2 == 0 => null - case nt: NumericType => i1(input, left, _.rem(_, evalE2.asInstanceOf[nt.JvmType])) + if (evalE2 == null || evalE2 == 0) { + null + } else { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + integral.rem(evalE1, evalE2) + } } } } @@ -149,45 +232,63 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet * A function that calculates bitwise and(&) of two numbers. */ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "&" - - override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = dataType match { - case ByteType => (evalE1.asInstanceOf[Byte] & evalE2.asInstanceOf[Byte]).toByte - case ShortType => (evalE1.asInstanceOf[Short] & evalE2.asInstanceOf[Short]).toShort - case IntegerType => evalE1.asInstanceOf[Int] & evalE2.asInstanceOf[Int] - case LongType => evalE1.asInstanceOf[Long] & evalE2.asInstanceOf[Long] + override def symbol: String = "&" + + lazy val and: (Any, Any) => Any = dataType match { + case ByteType => + ((evalE1: Byte, evalE2: Byte) => (evalE1 & evalE2).toByte).asInstanceOf[(Any, Any) => Any] + case ShortType => + ((evalE1: Short, evalE2: Short) => (evalE1 & evalE2).toShort).asInstanceOf[(Any, Any) => Any] + case IntegerType => + ((evalE1: Int, evalE2: Int) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] + case LongType => + ((evalE1: Long, evalE2: Long) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] case other => sys.error(s"Unsupported bitwise & operation on $other") } + + override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = and(evalE1, evalE2) } /** * A function that calculates bitwise or(|) of two numbers. */ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "|" - - override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = dataType match { - case ByteType => (evalE1.asInstanceOf[Byte] | evalE2.asInstanceOf[Byte]).toByte - case ShortType => (evalE1.asInstanceOf[Short] | evalE2.asInstanceOf[Short]).toShort - case IntegerType => evalE1.asInstanceOf[Int] | evalE2.asInstanceOf[Int] - case LongType => evalE1.asInstanceOf[Long] | evalE2.asInstanceOf[Long] + override def symbol: String = "|" + + lazy val or: (Any, Any) => Any = dataType match { + case ByteType => + ((evalE1: Byte, evalE2: Byte) => (evalE1 | evalE2).toByte).asInstanceOf[(Any, Any) => Any] + case ShortType => + ((evalE1: Short, evalE2: Short) => (evalE1 | evalE2).toShort).asInstanceOf[(Any, Any) => Any] + case IntegerType => + ((evalE1: Int, evalE2: Int) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] + case LongType => + ((evalE1: Long, evalE2: Long) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] case other => sys.error(s"Unsupported bitwise | operation on $other") } + + override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = or(evalE1, evalE2) } /** * A function that calculates bitwise xor(^) of two numbers. */ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "^" - - override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = dataType match { - case ByteType => (evalE1.asInstanceOf[Byte] ^ evalE2.asInstanceOf[Byte]).toByte - case ShortType => (evalE1.asInstanceOf[Short] ^ evalE2.asInstanceOf[Short]).toShort - case IntegerType => evalE1.asInstanceOf[Int] ^ evalE2.asInstanceOf[Int] - case LongType => evalE1.asInstanceOf[Long] ^ evalE2.asInstanceOf[Long] + override def symbol: String = "^" + + lazy val xor: (Any, Any) => Any = dataType match { + case ByteType => + ((evalE1: Byte, evalE2: Byte) => (evalE1 ^ evalE2).toByte).asInstanceOf[(Any, Any) => Any] + case ShortType => + ((evalE1: Short, evalE2: Short) => (evalE1 ^ evalE2).toShort).asInstanceOf[(Any, Any) => Any] + case IntegerType => + ((evalE1: Int, evalE2: Int) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] + case LongType => + ((evalE1: Long, evalE2: Long) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] case other => sys.error(s"Unsupported bitwise ^ operation on $other") } + + override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = xor(evalE1, evalE2) } /** @@ -196,23 +297,29 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme case class BitwiseNot(child: Expression) extends UnaryExpression { type EvaluatedType = Any - def dataType = child.dataType - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"~$child" + override def dataType: DataType = child.dataType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"~$child" + + lazy val not: (Any) => Any = dataType match { + case ByteType => + ((evalE: Byte) => (~evalE).toByte).asInstanceOf[(Any) => Any] + case ShortType => + ((evalE: Short) => (~evalE).toShort).asInstanceOf[(Any) => Any] + case IntegerType => + ((evalE: Int) => ~evalE).asInstanceOf[(Any) => Any] + case LongType => + ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any] + case other => sys.error(s"Unsupported bitwise ~ operation on $other") + } override def eval(input: Row): Any = { val evalE = child.eval(input) if (evalE == null) { null } else { - dataType match { - case ByteType => (~evalE.asInstanceOf[Byte]).toByte - case ShortType => (~evalE.asInstanceOf[Short]).toShort - case IntegerType => ~evalE.asInstanceOf[Int] - case LongType => ~evalE.asInstanceOf[Long] - case other => sys.error(s"Unsupported bitwise ~ operation on $other") - } + not(evalE) } } } @@ -220,32 +327,91 @@ case class BitwiseNot(child: Expression) extends UnaryExpression { case class MaxOf(left: Expression, right: Expression) extends Expression { type EvaluatedType = Any - override def foldable = left.foldable && right.foldable + override def foldable: Boolean = left.foldable && right.foldable + + override def nullable: Boolean = left.nullable && right.nullable + + override def children: Seq[Expression] = left :: right :: Nil + + override lazy val resolved = + left.resolved && right.resolved && + left.dataType == right.dataType + + override def dataType: DataType = { + if (!resolved) { + throw new UnresolvedException(this, + s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + } + left.dataType + } + + lazy val ordering = left.dataType match { + case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] + case other => sys.error(s"Type $other does not support ordered operations") + } + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + val evalE2 = right.eval(input) + if (evalE1 == null) { + evalE2 + } else if (evalE2 == null) { + evalE1 + } else { + if (ordering.compare(evalE1, evalE2) < 0) { + evalE2 + } else { + evalE1 + } + } + } + + override def toString: String = s"MaxOf($left, $right)" +} + +case class MinOf(left: Expression, right: Expression) extends Expression { + type EvaluatedType = Any + + override def foldable: Boolean = left.foldable && right.foldable - override def nullable = left.nullable && right.nullable + override def nullable: Boolean = left.nullable && right.nullable - override def children = left :: right :: Nil + override def children: Seq[Expression] = left :: right :: Nil - override def dataType = left.dataType + override lazy val resolved = + left.resolved && right.resolved && + left.dataType == right.dataType + + override def dataType: DataType = { + if (!resolved) { + throw new UnresolvedException(this, + s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + } + left.dataType + } + + lazy val ordering = left.dataType match { + case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] + case other => sys.error(s"Type $other does not support ordered operations") + } override def eval(input: Row): Any = { - val leftEval = left.eval(input) - val rightEval = right.eval(input) - if (leftEval == null) { - rightEval - } else if (rightEval == null) { - leftEval + val evalE1 = left.eval(input) + val evalE2 = right.eval(input) + if (evalE1 == null) { + evalE2 + } else if (evalE2 == null) { + evalE1 } else { - val numeric = left.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] - if (numeric.compare(leftEval, rightEval) < 0) { - rightEval + if (ordering.compare(evalE1, evalE2) < 0) { + evalE1 } else { - leftEval + evalE2 } } } - override def toString = s"MaxOf($left, $right)" + override def toString: String = s"MinOf($left, $right)" } /** @@ -254,10 +420,22 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { case class Abs(child: Expression) extends UnaryExpression { type EvaluatedType = Any - def dataType = child.dataType - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"Abs($child)" + override def dataType: DataType = child.dataType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"Abs($child)" + + lazy val numeric = dataType match { + case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] + case other => sys.error(s"Type $other does not support numeric operations") + } - override def eval(input: Row): Any = n1(child, input, _.abs(_)) + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + numeric.abs(evalE) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4cae5c471868..dbc92fb93e95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -91,18 +91,18 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val startTime = System.nanoTime() val result = create(in) val endTime = System.nanoTime() - def timeMs = (endTime - startTime).toDouble / 1000000 + def timeMs: Double = (endTime - startTime).toDouble / 1000000 logInfo(s"Code generated expression $in in $timeMs ms") result } }) /** Generates the requested evaluator binding the given expression(s) to the inputSchema. */ - def apply(expressions: InType, inputSchema: Seq[Attribute]): OutType = - apply(bind(expressions, inputSchema)) + def generate(expressions: InType, inputSchema: Seq[Attribute]): OutType = + generate(bind(expressions, inputSchema)) /** Generates the requested evaluator given already bound expression(s). */ - def apply(expressions: InType): OutType = cache.get(canonicalize(expressions)) + def generate(expressions: InType): OutType = cache.get(canonicalize(expressions)) /** * Returns a term name that is unique within this instance of a `CodeGenerator`. @@ -121,7 +121,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * @param nullTerm A term that holds a boolean value representing whether the expression evaluated * to null. * @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not - * valid if `nullTerm` is set to `false`. + * valid if `nullTerm` is set to `true`. * @param objectTerm A possibly boxed version of the result of evaluating this expression. */ protected case class EvaluatedExpression( @@ -216,10 +216,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val $primitiveTerm: ${termForType(dataType)} = $value """.children - case expressions.Literal(value: String, dataType) => + case expressions.Literal(value: UTF8String, dataType) => q""" val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = $value + val $primitiveTerm: ${termForType(dataType)} = + org.apache.spark.sql.types.UTF8String(${value.getBytes}) """.children case expressions.Literal(value: Int, dataType) => @@ -243,9 +244,15 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin if($nullTerm) ${defaultPrimitive(StringType)} else - new String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) + org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) """.children + case Cast(child @ DateType(), StringType) => + child.castOrNull(c => + q"""org.apache.spark.sql.types.UTF8String( + org.apache.spark.sql.types.DateUtils.toString($c))""", + StringType) + case Cast(child @ NumericType(), IntegerType) => child.castOrNull(c => q"$c.toInt", IntegerType) @@ -256,7 +263,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin child.castOrNull(c => q"$c.toDouble", DoubleType) case Cast(child @ NumericType(), FloatType) => - child.castOrNull(c => q"$c.toFloat", IntegerType) + child.castOrNull(c => q"$c.toFloat", FloatType) // Special handling required for timestamps in hive test cases since the toString function // does not match the expected output. @@ -269,9 +276,18 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin if($nullTerm) ${defaultPrimitive(StringType)} else - ${eval.primitiveTerm}.toString + org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString) """.children + case EqualTo(e1 @ BinaryType(), e2 @ BinaryType()) => + (e1, e2).evaluateAs (BooleanType) { + case (eval1, eval2) => + q""" + java.util.Arrays.equals($eval1.asInstanceOf[Array[Byte]], + $eval2.asInstanceOf[Array[Byte]]) + """ + } + case EqualTo(e1, e2) => (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" } @@ -461,7 +477,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val itemEval = expressionEvaluator(item) val setEval = expressionEvaluator(set) - val ArrayType(elementType, _) = set.dataType + val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType itemEval.code ++ setEval.code ++ q""" @@ -479,7 +495,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val leftEval = expressionEvaluator(left) val rightEval = expressionEvaluator(right) - val ArrayType(elementType, _) = left.dataType + val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType leftEval.code ++ rightEval.code ++ q""" @@ -521,6 +537,30 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } """.children + case MinOf(e1, e2) => + val eval1 = expressionEvaluator(e1) + val eval2 = expressionEvaluator(e2) + + eval1.code ++ eval2.code ++ + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)} + + if (${eval1.nullTerm}) { + $nullTerm = ${eval2.nullTerm} + $primitiveTerm = ${eval2.primitiveTerm} + } else if (${eval2.nullTerm}) { + $nullTerm = ${eval1.nullTerm} + $primitiveTerm = ${eval1.primitiveTerm} + } else { + if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) { + $primitiveTerm = ${eval1.primitiveTerm} + } else { + $primitiveTerm = ${eval2.primitiveTerm} + } + } + """.children + case UnscaledValue(child) => val childEval = expressionEvaluator(child) @@ -570,7 +610,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val localLogger = log val localLoggerTree = reify { localLogger } q""" - $localLoggerTree.debug(${e.toString} + ": " + (if($nullTerm) "null" else $primitiveTerm)) + $localLoggerTree.debug( + ${e.toString} + ": " + (if ($nullTerm) "null" else $primitiveTerm.toString)) """ :: Nil } else { Nil @@ -581,7 +622,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = { dataType match { - case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)" + case StringType => q"$inputRow($ordinal).asInstanceOf[org.apache.spark.sql.types.UTF8String]" + case dt: DataType if isNativeType(dt) => q"$inputRow.${accessorForType(dt)}($ordinal)" case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]" } } @@ -592,7 +634,9 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin ordinal: Int, value: TermName) = { dataType match { - case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" + case StringType => q"$destinationRow.update($ordinal, $value)" + case dt: DataType if isNativeType(dt) => + q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" case _ => q"$destinationRow.update($ordinal, $value)" } } @@ -615,15 +659,15 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin case DoubleType => "Double" case FloatType => "Float" case BooleanType => "Boolean" - case StringType => "String" + case StringType => "org.apache.spark.sql.types.UTF8String" } protected def defaultPrimitive(dt: DataType) = dt match { case BooleanType => ru.Literal(Constant(false)) case FloatType => ru.Literal(Constant(-1.0.toFloat)) - case StringType => ru.Literal(Constant("")) + case StringType => q"""org.apache.spark.sql.types.UTF8String("")""" case ShortType => ru.Literal(Constant(-1.toShort)) - case LongType => ru.Literal(Constant(1L)) + case LongType => ru.Literal(Constant(-1L)) case ByteType => ru.Literal(Constant(-1.toByte)) case DoubleType => ru.Literal(Constant(-1.toDouble)) case DecimalType() => q"org.apache.spark.sql.types.Decimal(-1)" @@ -632,7 +676,18 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } protected def termForType(dt: DataType) = dt match { - case n: NativeType => n.tag + case n: AtomicType => n.tag case _ => typeTag[Any] } + + /** + * List of data types that have special accessors and setters in [[Row]]. + */ + protected val nativeTypes = + Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) + + /** + * Returns true if the data type has a special accessor and setter in [[Row]]. + */ + protected def isNativeType(dt: DataType) = nativeTypes.contains(dt) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index a419fd7ecb39..840260703ab7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -30,7 +30,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu val mutableRowName = newTermName("mutableRow") protected def canonicalize(in: Seq[Expression]): Seq[Expression] = - in.map(ExpressionCanonicalizer(_)) + in.map(ExpressionCanonicalizer.execute) protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = in.map(BindReferences.bindReference(_, inputSchema)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 0db29eb404bd..b129c0d898bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{StringType, NumericType} +import org.apache.spark.sql.types.{BinaryType, StringType, NumericType} /** * Generates bytecode for an [[Ordering]] of [[Row Rows]] for a given set of @@ -30,7 +30,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit import scala.reflect.runtime.universe._ protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = - in.map(ExpressionCanonicalizer(_).asInstanceOf[SortOrder]) + in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder]) protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] = in.map(BindReferences.bindReference(_, inputSchema)) @@ -43,6 +43,18 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit val evalB = expressionEvaluator(order.child) val compare = order.child.dataType match { + case BinaryType => + q""" + val x = ${if (order.direction == Ascending) evalA.primitiveTerm else evalB.primitiveTerm} + val y = ${if (order.direction != Ascending) evalB.primitiveTerm else evalA.primitiveTerm} + var i = 0 + while (i < x.length && i < y.length) { + val res = x(i).compareTo(y(i)) + if (res != 0) return res + i = i+1 + } + return x.length - y.length + """ case _: NumericType => q""" val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 2a0935c790cf..40e163024360 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -26,7 +26,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { import scala.reflect.runtime.{universe => ru} import scala.reflect.runtime.universe._ - protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer(in) + protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in) protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression = BindReferences.bindReference(in, inputSchema) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 69397a73a888..584f938445c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -31,7 +31,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { import scala.reflect.runtime.universe._ protected def canonicalize(in: Seq[Expression]): Seq[Expression] = - in.map(ExpressionCanonicalizer(_)) + in.map(ExpressionCanonicalizer.execute) protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = in.map(BindReferences.bindReference(_, inputSchema)) @@ -109,38 +109,56 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { q"override def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }" } - val specificAccessorFunctions = NativeType.all.map { dataType => + val specificAccessorFunctions = nativeTypes.map { dataType => val ifStatements = expressions.zipWithIndex.flatMap { - case (e, i) if e.dataType == dataType => + // getString() is not used by expressions + case (e, i) if e.dataType == dataType && dataType != StringType => val elementName = newTermName(s"c$i") // TODO: The string of ifs gets pretty inefficient as the row grows in size. // TODO: Optional null checks? q"if(i == $i) return $elementName" :: Nil case _ => Nil } - - q""" - override def ${accessorForType(dataType)}(i: Int):${termForType(dataType)} = { - ..$ifStatements; - $accessorFailure - }""" + dataType match { + // Row() need this interface to compile + case StringType => + q""" + override def getString(i: Int): String = { + $accessorFailure + }""" + case other => + q""" + override def ${accessorForType(dataType)}(i: Int): ${termForType(dataType)} = { + ..$ifStatements; + $accessorFailure + }""" + } } - val specificMutatorFunctions = NativeType.all.map { dataType => + val specificMutatorFunctions = nativeTypes.map { dataType => val ifStatements = expressions.zipWithIndex.flatMap { - case (e, i) if e.dataType == dataType => + // setString() is not used by expressions + case (e, i) if e.dataType == dataType && dataType != StringType => val elementName = newTermName(s"c$i") // TODO: The string of ifs gets pretty inefficient as the row grows in size. // TODO: Optional null checks? q"if(i == $i) { nullBits($i) = false; $elementName = value; return }" :: Nil case _ => Nil } - - q""" - override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}): Unit = { - ..$ifStatements; - $accessorFailure - }""" + dataType match { + case StringType => + // MutableRow() need this interface to compile + q""" + override def setString(i: Int, value: String) { + $accessorFailure + }""" + case other => + q""" + override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}) { + ..$ifStatements; + $accessorFailure + }""" + } } val hashValues = expressions.zipWithIndex.map { case (e,i) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala index 80c7dfd376c9..528e38a50a74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.rules -import org.apache.spark.sql.catalyst.util +import org.apache.spark.util.Utils /** * A collection of generators that build custom bytecode at runtime for performing the evaluation @@ -52,7 +52,7 @@ package object codegen { @DeveloperApi object DumpByteCode { import scala.sys.process._ - val dumpDirectory = util.getTempFilePath("sparkSqlByteCode") + val dumpDirectory = Utils.createTempDir() dumpDirectory.mkdir() def apply(obj: Any): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 1bc34f71441f..fc1f69655963 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.Map +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.types._ /** @@ -27,12 +29,12 @@ import org.apache.spark.sql.types._ case class GetItem(child: Expression, ordinal: Expression) extends Expression { type EvaluatedType = Any - val children = child :: ordinal :: Nil + val children: Seq[Expression] = child :: ordinal :: Nil /** `Null` is returned for invalid ordinals. */ - override def nullable = true - override def foldable = child.foldable && ordinal.foldable + override def nullable: Boolean = true + override def foldable: Boolean = child.foldable && ordinal.foldable - def dataType = child.dataType match { + override def dataType: DataType = child.dataType match { case ArrayType(dt, _) => dt case MapType(_, vt, _) => vt } @@ -40,7 +42,7 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { childrenResolved && (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) - override def toString = s"$child[$ordinal]" + override def toString: String = s"$child[$ordinal]" override def eval(input: Row): Any = { val value = child.eval(input) @@ -70,42 +72,83 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { } } -/** - * Returns the value of fields in the Struct `child`. - */ -case class GetField(child: Expression, fieldName: String) extends UnaryExpression { - type EvaluatedType = Any - def dataType = field.dataType - override def nullable = child.nullable || field.nullable - override def foldable = child.foldable +trait GetField extends UnaryExpression { + self: Product => - protected def structType = child.dataType match { - case s: StructType => s - case otherType => sys.error(s"GetField is not valid on fields of type $otherType") - } + type EvaluatedType = Any + override def foldable: Boolean = child.foldable + override def toString: String = s"$child.${field.name}" - lazy val field = - structType.fields - .find(_.name == fieldName) - .getOrElse(sys.error(s"No such field $fieldName in ${child.dataType}")) + def field: StructField +} - lazy val ordinal = structType.fields.indexOf(field) +object GetField { + /** + * Returns the resolved `GetField`, and report error if no desired field or over one + * desired fields are found. + */ + def apply( + expr: Expression, + fieldName: String, + resolver: Resolver): GetField = { + def findField(fields: Array[StructField]): Int = { + val checkField = (f: StructField) => resolver(f.name, fieldName) + val ordinal = fields.indexWhere(checkField) + if (ordinal == -1) { + throw new AnalysisException( + s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}") + } else if (fields.indexWhere(checkField, ordinal + 1) != -1) { + throw new AnalysisException( + s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}") + } else { + ordinal + } + } + expr.dataType match { + case StructType(fields) => + val ordinal = findField(fields) + StructGetField(expr, fields(ordinal), ordinal) + case ArrayType(StructType(fields), containsNull) => + val ordinal = findField(fields) + ArrayGetField(expr, fields(ordinal), ordinal, containsNull) + case otherType => + throw new AnalysisException(s"GetField is not valid on fields of type $otherType") + } + } +} - override lazy val resolved = childrenResolved && fieldResolved +/** + * Returns the value of fields in the Struct `child`. + */ +case class StructGetField(child: Expression, field: StructField, ordinal: Int) extends GetField { - /** Returns true only if the fieldName is found in the child struct. */ - private def fieldResolved = child.dataType match { - case StructType(fields) => fields.map(_.name).contains(fieldName) - case _ => false - } + override def dataType: DataType = field.dataType + override def nullable: Boolean = child.nullable || field.nullable override def eval(input: Row): Any = { val baseValue = child.eval(input).asInstanceOf[Row] if (baseValue == null) null else baseValue(ordinal) } +} - override def toString = s"$child.$fieldName" +/** + * Returns the array of value of fields in the Array of Struct `child`. + */ +case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, containsNull: Boolean) + extends GetField { + + override def dataType: DataType = ArrayType(field.dataType, containsNull) + override def nullable: Boolean = child.nullable + + override def eval(input: Row): Any = { + val baseValue = child.eval(input).asInstanceOf[Seq[Row]] + if (baseValue == null) null else { + baseValue.map { row => + if (row == null) null else row(ordinal) + } + } + } } /** @@ -114,7 +157,7 @@ case class GetField(child: Expression, fieldName: String) extends UnaryExpressio case class CreateArray(children: Seq[Expression]) extends Expression { override type EvaluatedType = Any - override def foldable = !children.exists(!_.foldable) + override def foldable: Boolean = children.forall(_.foldable) lazy val childTypes = children.map(_.dataType).distinct @@ -134,5 +177,32 @@ case class CreateArray(children: Seq[Expression]) extends Expression { children.map(_.eval(input)) } - override def toString = s"Array(${children.mkString(",")})" + override def toString: String = s"Array(${children.mkString(",")})" +} + +/** + * Returns a Row containing the evaluation of all children expressions. + * TODO: [[CreateStruct]] does not support codegen. + */ +case class CreateStruct(children: Seq[NamedExpression]) extends Expression { + override type EvaluatedType = Row + + override def foldable: Boolean = children.forall(_.foldable) + + override lazy val resolved: Boolean = childrenResolved + + override lazy val dataType: StructType = { + assert(resolved, + s"CreateStruct contains unresolvable children: ${children.filterNot(_.resolved)}.") + val fields = children.map { child => + StructField(child.name, child.dataType, child.nullable, child.metadata) + } + StructType(fields) + } + + override def nullable: Boolean = false + + override def eval(input: Row): EvaluatedType = { + Row(children.map(_.eval(input)): _*) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index 83d8c1d42bca..adb94df7d1c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -24,9 +24,9 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { override type EvaluatedType = Any override def dataType: DataType = LongType - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"UnscaledValue($child)" + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"UnscaledValue($child)" override def eval(input: Row): Any = { val childResult = child.eval(input) @@ -43,9 +43,9 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un override type EvaluatedType = Decimal override def dataType: DataType = DecimalType(precision, scale) - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"MakeDecimal($child,$precision,$scale)" + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"MakeDecimal($child,$precision,$scale)" override def eval(input: Row): Decimal = { val childResult = child.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 43b6482c0171..9a6cb048af5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.Map -import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees} import org.apache.spark.sql.types._ /** @@ -42,64 +42,57 @@ abstract class Generator extends Expression { override type EvaluatedType = TraversableOnce[Row] - override lazy val dataType = - ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))) + // TODO ideally we should return the type of ArrayType(StructType), + // however, we don't keep the output field names in the Generator. + override def dataType: DataType = throw new UnsupportedOperationException - override def nullable = false + override def nullable: Boolean = false /** - * Should be overridden by specific generators. Called only once for each instance to ensure - * that rule application does not change the output schema of a generator. + * The output element data types in structure of Seq[(DataType, Nullable)] + * TODO we probably need to add more information like metadata etc. */ - protected def makeOutput(): Seq[Attribute] - - private var _output: Seq[Attribute] = null - - def output: Seq[Attribute] = { - if (_output == null) { - _output = makeOutput() - } - _output - } + def elementTypes: Seq[(DataType, Boolean)] /** Should be implemented by child classes to perform specific Generators. */ override def eval(input: Row): TraversableOnce[Row] +} + +/** + * A generator that produces its output using the provided lambda function. + */ +case class UserDefinedGenerator( + elementTypes: Seq[(DataType, Boolean)], + function: Row => TraversableOnce[Row], + children: Seq[Expression]) + extends Generator { - /** Overridden `makeCopy` also copies the attributes that are produced by this generator. */ - override def makeCopy(newArgs: Array[AnyRef]): this.type = { - val copy = super.makeCopy(newArgs) - copy._output = _output - copy + override def eval(input: Row): TraversableOnce[Row] = { + // TODO(davies): improve this + // Convert the objects into Scala Type before calling function, we need schema to support UDT + val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) + val inputRow = new InterpretedProjection(children) + function(CatalystTypeConverters.convertToScala(inputRow(input), inputSchema).asInstanceOf[Row]) } + + override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})" } /** * Given an input array produces a sequence of rows for each value in the array. */ -case class Explode(attributeNames: Seq[String], child: Expression) +case class Explode(child: Expression) extends Generator with trees.UnaryNode[Expression] { override lazy val resolved = child.resolved && (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) - private lazy val elementTypes = child.dataType match { + override def elementTypes: Seq[(DataType, Boolean)] = child.dataType match { case ArrayType(et, containsNull) => (et, containsNull) :: Nil case MapType(kt, vt, valueContainsNull) => (kt, false) :: (vt, valueContainsNull) :: Nil } - // TODO: Move this pattern into Generator. - protected def makeOutput() = - if (attributeNames.size == elementTypes.size) { - attributeNames.zip(elementTypes).map { - case (n, (t, nullable)) => AttributeReference(n, t, nullable)() - } - } else { - elementTypes.zipWithIndex.map { - case ((t, nullable), i) => AttributeReference(s"c_$i", t, nullable)() - } - } - override def eval(input: Row): TraversableOnce[Row] = { child.dataType match { case ArrayType(_, _) => @@ -111,5 +104,5 @@ case class Explode(attributeNames: Seq[String], child: Expression) } } - override def toString() = s"explode($child)" + override def toString: String = s"explode($child)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 5b389aad7a85..18cba4cc4670 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.types._ object Literal { @@ -29,15 +30,21 @@ object Literal { case f: Float => Literal(f, FloatType) case b: Byte => Literal(b, ByteType) case s: Short => Literal(s, ShortType) - case s: String => Literal(s, StringType) + case s: String => Literal(UTF8String(s), StringType) case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) case d: Decimal => Literal(d, DecimalType.Unlimited) case t: Timestamp => Literal(t, TimestampType) - case d: Date => Literal(d, DateType) + case d: Date => Literal(DateUtils.fromJavaDate(d), DateType) case a: Array[Byte] => Literal(a, BinaryType) case null => Literal(null, NullType) + case _ => + throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) + } + + def create(v: Any, dataType: DataType): Literal = { + Literal(CatalystTypeConverters.convertToCatalyst(v), dataType) } } @@ -60,16 +67,18 @@ object IntegerLiteral { } } -case class Literal(value: Any, dataType: DataType) extends LeafExpression { - - override def foldable = true - def nullable = value == null +/** + * In order to do type checking, use Literal.create() instead of constructor + */ +case class Literal protected (value: Any, dataType: DataType) extends LeafExpression { + override def foldable: Boolean = true + override def nullable: Boolean = value == null - override def toString = if (value != null) value.toString else "null" + override def toString: String = if (value != null) value.toString else "null" type EvaluatedType = Any - override def eval(input: Row):Any = value + override def eval(input: Row): Any = value } // TODO: Specialize @@ -77,9 +86,9 @@ case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean extends LeafExpression { type EvaluatedType = Any - def update(expression: Expression, input: Row) = { + def update(expression: Expression, input: Row): Unit = { value = expression.eval(input) } - override def eval(input: Row) = value + override def eval(input: Row): Any = value } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 3035d934ff9f..afcb2ce8b9cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.trees.LeafNode import org.apache.spark.sql.types._ object NamedExpression { private val curId = new java.util.concurrent.atomic.AtomicLong() - def newExprId = ExprId(curId.getAndIncrement()) + def newExprId: ExprId = ExprId(curId.getAndIncrement()) def unapply(expr: NamedExpression): Option[(String, DataType)] = Some(expr.name, expr.dataType) } @@ -40,6 +41,24 @@ abstract class NamedExpression extends Expression { def name: String def exprId: ExprId + + /** + * Returns a dot separated fully qualified name for this attribute. Given that there can be + * multiple qualifiers, it is possible that there are other possible way to refer to this + * attribute. + */ + def qualifiedName: String = (qualifiers.headOption.toSeq :+ name).mkString(".") + + /** + * All possible qualifiers for the expression. + * + * For now, since we do not allow using original table name to qualify a column name once the + * table is aliased, this can only be: + * + * 1. Empty Seq: when an attribute doesn't have a qualifier, + * e.g. top level attributes aliased in the SELECT clause, or column from a LocalRelation. + * 2. Single element: either the table name or the alias name of the table. + */ def qualifiers: Seq[String] def toAttribute: Attribute @@ -61,13 +80,13 @@ abstract class NamedExpression extends Expression { abstract class Attribute extends NamedExpression { self: Product => - override def references = AttributeSet(this) + override def references: AttributeSet = AttributeSet(this) def withNullability(newNullability: Boolean): Attribute def withQualifiers(newQualifiers: Seq[String]): Attribute def withName(newName: String): Attribute - def toAttribute = this + def toAttribute: Attribute = this def newInstance(): Attribute } @@ -75,31 +94,41 @@ abstract class Attribute extends NamedExpression { /** * Used to assign a new name to a computation. * For example the SQL expression "1 + 1 AS a" could be represented as follows: - * Alias(Add(Literal(1), Literal(1), "a")() + * Alias(Add(Literal(1), Literal(1)), "a")() + * + * Note that exprId and qualifiers are in a separate parameter list because + * we only pattern match on child and name. * * @param child the computation being performed * @param name the name to be associated with the result of computing [[child]]. * @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this * alias. Auto-assigned if left blank. + * @param explicitMetadata Explicit metadata associated with this alias that overwrites child's. */ -case class Alias(child: Expression, name: String) - (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil) +case class Alias(child: Expression, name: String)( + val exprId: ExprId = NamedExpression.newExprId, + val qualifiers: Seq[String] = Nil, + val explicitMetadata: Option[Metadata] = None) extends NamedExpression with trees.UnaryNode[Expression] { override type EvaluatedType = Any + // Alias(Generator, xx) need to be transformed into Generate(generator, ...) + override lazy val resolved = childrenResolved && !child.isInstanceOf[Generator] - override def eval(input: Row) = child.eval(input) + override def eval(input: Row): Any = child.eval(input) - override def dataType = child.dataType - override def nullable = child.nullable + override def dataType: DataType = child.dataType + override def nullable: Boolean = child.nullable override def metadata: Metadata = { - child match { - case named: NamedExpression => named.metadata - case _ => Metadata.empty + explicitMetadata.getOrElse { + child match { + case named: NamedExpression => named.metadata + case _ => Metadata.empty + } } } - override def toAttribute = { + override def toAttribute: Attribute = { if (resolved) { AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifiers) } else { @@ -109,7 +138,16 @@ case class Alias(child: Expression, name: String) override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix" - override protected final def otherCopyArgs = exprId :: qualifiers :: Nil + override protected final def otherCopyArgs: Seq[AnyRef] = { + exprId :: qualifiers :: explicitMetadata :: Nil + } + + override def equals(other: Any): Boolean = other match { + case a: Alias => + name == a.name && exprId == a.exprId && child == a.child && qualifiers == a.qualifiers && + explicitMetadata == a.explicitMetadata + case _ => false + } } /** @@ -133,7 +171,7 @@ case class AttributeReference( val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil) extends Attribute with trees.LeafNode[Expression] { - override def equals(other: Any) = other match { + override def equals(other: Any): Boolean = other match { case ar: AttributeReference => name == ar.name && exprId == ar.exprId && dataType == ar.dataType case _ => false } @@ -147,7 +185,7 @@ case class AttributeReference( h } - override def newInstance() = + override def newInstance(): AttributeReference = AttributeReference(name, dataType, nullable, metadata)(qualifiers = qualifiers) /** @@ -172,7 +210,7 @@ case class AttributeReference( /** * Returns a copy of this [[AttributeReference]] with new qualifiers. */ - override def withQualifiers(newQualifiers: Seq[String]) = { + override def withQualifiers(newQualifiers: Seq[String]): AttributeReference = { if (newQualifiers.toSet == qualifiers.toSet) { this } else { @@ -187,7 +225,29 @@ case class AttributeReference( override def toString: String = s"$name#${exprId.id}$typeSuffix" } +/** + * A place holder used when printing expressions without debugging information such as the + * expression id or the unresolved indicator. + */ +case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[Expression] { + type EvaluatedType = Any + + override def toString: String = name + + override def withNullability(newNullability: Boolean): Attribute = + throw new UnsupportedOperationException + override def newInstance(): Attribute = throw new UnsupportedOperationException + override def withQualifiers(newQualifiers: Seq[String]): Attribute = + throw new UnsupportedOperationException + override def withName(newName: String): Attribute = throw new UnsupportedOperationException + override def qualifiers: Seq[String] = throw new UnsupportedOperationException + override def exprId: ExprId = throw new UnsupportedOperationException + override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException + override def nullable: Boolean = throw new UnsupportedOperationException + override def dataType: DataType = NullType +} + object VirtualColumn { - val groupingIdName = "grouping__id" - def newGroupingId = AttributeReference(groupingIdName, IntegerType, false)() + val groupingIdName: String = "grouping__id" + def newGroupingId: AttributeReference = AttributeReference(groupingIdName, IntegerType, false)() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 08b982bc671e..f9161cf34f0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -19,22 +19,23 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.types.DataType case class Coalesce(children: Seq[Expression]) extends Expression { type EvaluatedType = Any /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ - def nullable = !children.exists(!_.nullable) + override def nullable: Boolean = !children.exists(!_.nullable) // Coalesce is foldable if all children are foldable. - override def foldable = !children.exists(!_.foldable) + override def foldable: Boolean = !children.exists(!_.foldable) // Only resolved if all the children are of the same type. override lazy val resolved = childrenResolved && (children.map(_.dataType).distinct.size == 1) - override def toString = s"Coalesce(${children.mkString(",")})" + override def toString: String = s"Coalesce(${children.mkString(",")})" - def dataType = if (resolved) { + override def dataType: DataType = if (resolved) { children.head.dataType } else { val childTypes = children.map(c => s"$c: ${c.dataType}").mkString(", ") @@ -54,22 +55,45 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { - override def foldable = child.foldable - def nullable = false + override def foldable: Boolean = child.foldable + override def nullable: Boolean = false override def eval(input: Row): Any = { child.eval(input) == null } - override def toString = s"IS NULL $child" + override def toString: String = s"IS NULL $child" } case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { - override def foldable = child.foldable - def nullable = false - override def toString = s"IS NOT NULL $child" + override def foldable: Boolean = child.foldable + override def nullable: Boolean = false + override def toString: String = s"IS NOT NULL $child" override def eval(input: Row): Any = { child.eval(input) != null } } + +/** + * A predicate that is evaluated to be true if there are at least `n` non-null values. + */ +case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate { + override def nullable: Boolean = false + override def foldable: Boolean = false + override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})" + + private[this] val childrenArray = children.toArray + + override def eval(input: Row): Boolean = { + var numNonNulls = 0 + var i = 0 + while (i < childrenArray.length && numNonNulls < n) { + if (childrenArray(i).eval(input) != null) { + numNonNulls += 1 + } + i += 1 + } + numNonNulls >= n + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index c84cc95520a1..9cb00cb2732f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -18,14 +18,15 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.types.{DataType, BinaryType, BooleanType, AtomicType} object InterpretedPredicate { - def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = - apply(BindReferences.bindReference(expression, inputSchema)) + def create(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = + create(BindReferences.bindReference(expression, inputSchema)) - def apply(expression: Expression): (Row => Boolean) = { + def create(expression: Expression): (Row => Boolean) = { (r: Row) => expression.eval(r).asInstanceOf[Boolean] } } @@ -33,7 +34,7 @@ object InterpretedPredicate { trait Predicate extends Expression { self: Product => - def dataType = BooleanType + override def dataType: DataType = BooleanType type EvaluatedType = Any } @@ -71,13 +72,13 @@ trait PredicateHelper { abstract class BinaryPredicate extends BinaryExpression with Predicate { self: Product => - def nullable = left.nullable || right.nullable + override def nullable: Boolean = left.nullable || right.nullable } case class Not(child: Expression) extends UnaryExpression with Predicate { - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"NOT $child" + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"NOT $child" override def eval(input: Row): Any = { child.eval(input) match { @@ -91,10 +92,10 @@ case class Not(child: Expression) extends UnaryExpression with Predicate { * Evaluates to `true` if `list` contains `value`. */ case class In(value: Expression, list: Seq[Expression]) extends Predicate { - def children = value +: list + override def children: Seq[Expression] = value +: list - def nullable = true // TODO: Figure out correct nullability semantics of IN. - override def toString = s"$value IN ${list.mkString("(", ",", ")")}" + override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. + override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}" override def eval(input: Row): Any = { val evaluatedValue = value.eval(input) @@ -109,10 +110,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { case class InSet(value: Expression, hset: Set[Any]) extends Predicate { - def children = value :: Nil + override def children: Seq[Expression] = value :: Nil - def nullable = true // TODO: Figure out correct nullability semantics of IN. - override def toString = s"$value INSET ${hset.mkString("(", ",", ")")}" + override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. + override def toString: String = s"$value INSET ${hset.mkString("(", ",", ")")}" override def eval(input: Row): Any = { hset.contains(value.eval(input)) @@ -120,7 +121,7 @@ case class InSet(value: Expression, hset: Set[Any]) } case class And(left: Expression, right: Expression) extends BinaryPredicate { - def symbol = "&&" + override def symbol: String = "&&" override def eval(input: Row): Any = { val l = left.eval(input) @@ -142,7 +143,7 @@ case class And(left: Expression, right: Expression) extends BinaryPredicate { } case class Or(left: Expression, right: Expression) extends BinaryPredicate { - def symbol = "||" + override def symbol: String = "||" override def eval(input: Row): Any = { val l = left.eval(input) @@ -168,21 +169,26 @@ abstract class BinaryComparison extends BinaryPredicate { } case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { - def symbol = "=" + override def symbol: String = "=" + override def eval(input: Row): Any = { val l = left.eval(input) if (l == null) { null } else { val r = right.eval(input) - if (r == null) null else l == r + if (r == null) null + else if (left.dataType != BinaryType) l == r + else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } } } case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { - def symbol = "<=>" - override def nullable = false + override def symbol: String = "<=>" + + override def nullable: Boolean = false + override def eval(input: Row): Any = { val l = left.eval(input) val r = right.eval(input) @@ -197,33 +203,129 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } case class LessThan(left: Expression, right: Expression) extends BinaryComparison { - def symbol = "<" - override def eval(input: Row): Any = c2(input, left, right, _.lt(_, _)) + override def symbol: String = "<" + + lazy val ordering: Ordering[Any] = { + if (left.dataType != right.dataType) { + throw new TreeNodeException(this, + s"Types do not match ${left.dataType} != ${right.dataType}") + } + left.dataType match { + case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] + case other => sys.error(s"Type $other does not support ordered operations") + } + } + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + ordering.lt(evalE1, evalE2) + } + } + } } case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { - def symbol = "<=" - override def eval(input: Row): Any = c2(input, left, right, _.lteq(_, _)) + override def symbol: String = "<=" + + lazy val ordering: Ordering[Any] = { + if (left.dataType != right.dataType) { + throw new TreeNodeException(this, + s"Types do not match ${left.dataType} != ${right.dataType}") + } + left.dataType match { + case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] + case other => sys.error(s"Type $other does not support ordered operations") + } + } + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + ordering.lteq(evalE1, evalE2) + } + } + } } case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { - def symbol = ">" - override def eval(input: Row): Any = c2(input, left, right, _.gt(_, _)) + override def symbol: String = ">" + + lazy val ordering: Ordering[Any] = { + if (left.dataType != right.dataType) { + throw new TreeNodeException(this, + s"Types do not match ${left.dataType} != ${right.dataType}") + } + left.dataType match { + case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] + case other => sys.error(s"Type $other does not support ordered operations") + } + } + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if(evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + ordering.gt(evalE1, evalE2) + } + } + } } case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { - def symbol = ">=" - override def eval(input: Row): Any = c2(input, left, right, _.gteq(_, _)) + override def symbol: String = ">=" + + lazy val ordering: Ordering[Any] = { + if (left.dataType != right.dataType) { + throw new TreeNodeException(this, + s"Types do not match ${left.dataType} != ${right.dataType}") + } + left.dataType match { + case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] + case other => sys.error(s"Type $other does not support ordered operations") + } + } + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + ordering.gteq(evalE1, evalE2) + } + } + } } case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) - extends Expression { + extends Expression { - def children = predicate :: trueValue :: falseValue :: Nil - override def nullable = trueValue.nullable || falseValue.nullable + override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil + override def nullable: Boolean = trueValue.nullable || falseValue.nullable override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType - def dataType = { + override def dataType: DataType = { if (!resolved) { throw new UnresolvedException( this, @@ -242,7 +344,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } } - override def toString = s"if ($predicate) $trueValue else $falseValue" + override def toString: String = s"if ($predicate) $trueValue else $falseValue" } // scalastyle:off @@ -262,9 +364,10 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi // scalastyle:on case class CaseWhen(branches: Seq[Expression]) extends Expression { type EvaluatedType = Any - def children = branches - def dataType = { + override def children: Seq[Expression] = branches + + override def dataType: DataType = { if (!resolved) { throw new UnresolvedException(this, "cannot resolve due to differing types in some branches") } @@ -279,12 +382,12 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression { @transient private[this] lazy val elseValue = if (branches.length % 2 == 0) None else Option(branches.last) - override def nullable = { + override def nullable: Boolean = { // If no value is nullable and no elseValue is provided, the whole statement defaults to null. values.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) } - override lazy val resolved = { + override lazy val resolved: Boolean = { if (!childrenResolved) { false } else { @@ -315,7 +418,7 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression { res } - override def toString = { + override def toString: String = { "CASE" + branches.sliding(2, 2).map { case Seq(cond, value) => s" WHEN $cond THEN $value" case Seq(elseValue) => s" ELSE $elseValue" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 8df150e2f855..5fd892c42e69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.types.NativeType - +import org.apache.spark.sql.types.{UTF8String, DataType, StructType, AtomicType} /** * An extended interface to [[Row]] that allows the values for each column to be updated. Setting @@ -37,6 +36,7 @@ trait MutableRow extends Row { def setByte(ordinal: Int, value: Byte) def setFloat(ordinal: Int, value: Float) def setString(ordinal: Int, value: String) + // TODO(davies): add setDate() and setDecimal() } /** @@ -44,8 +44,8 @@ trait MutableRow extends Row { */ object EmptyRow extends Row { override def apply(i: Int): Any = throw new UnsupportedOperationException - override def toSeq = Seq.empty - override def length = 0 + override def toSeq: Seq[Any] = Seq.empty + override def length: Int = 0 override def isNullAt(i: Int): Boolean = throw new UnsupportedOperationException override def getInt(i: Int): Int = throw new UnsupportedOperationException override def getLong(i: Int): Long = throw new UnsupportedOperationException @@ -56,7 +56,7 @@ object EmptyRow extends Row { override def getByte(i: Int): Byte = throw new UnsupportedOperationException override def getString(i: Int): String = throw new UnsupportedOperationException override def getAs[T](i: Int): T = throw new UnsupportedOperationException - def copy() = this + override def copy(): Row = this } /** @@ -66,17 +66,17 @@ object EmptyRow extends Row { */ class GenericRow(protected[sql] val values: Array[Any]) extends Row { /** No-arg constructor for serialization. */ - def this() = this(null) + protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) - override def toSeq = values.toSeq + override def toSeq: Seq[Any] = values.toSeq - override def length = values.length + override def length: Int = values.length - override def apply(i: Int) = values(i) + override def apply(i: Int): Any = values(i) - override def isNullAt(i: Int) = values(i) == null + override def isNullAt(i: Int): Boolean = values(i) == null override def getInt(i: Int): Int = { if (values(i) == null) sys.error("Failed to check null bit for primitive int value.") @@ -114,10 +114,15 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { } override def getString(i: Int): String = { - if (values(i) == null) sys.error("Failed to check null bit for primitive String value.") - values(i).asInstanceOf[String] + values(i) match { + case null => null + case s: String => s + case utf8: UTF8String => utf8.toString + } } + // TODO(davies): add getDate and getDecimal + // Custom hashCode function that matches the efficient code generated version. override def hashCode: Int = { var result: Int = 37 @@ -147,12 +152,42 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { result } - def copy() = this + override def equals(o: Any): Boolean = o match { + case other: Row => + if (values.length != other.length) { + return false + } + + var i = 0 + while (i < values.length) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (apply(i) != other.apply(i)) { + return false + } + i += 1 + } + true + + case _ => false + } + + override def copy(): Row = this +} + +class GenericRowWithSchema(values: Array[Any], override val schema: StructType) + extends GenericRow(values) { + + /** No-arg constructor for serialization. */ + protected def this() = this(null, null) + + override def fieldIndex(name: String): Int = schema.fieldIndex(name) } class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { /** No-arg constructor for serialization. */ - def this() = this(null) + protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) @@ -162,15 +197,14 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value } override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value } override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value } - override def setString(ordinal: Int, value: String): Unit = { values(ordinal) = value } - + override def setString(ordinal: Int, value: String) { values(ordinal) = UTF8String(value)} override def setNullAt(i: Int): Unit = { values(i) = null } override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value } override def update(ordinal: Int, value: Any): Unit = { values(ordinal) = value } - override def copy() = new GenericRow(values.clone()) + override def copy(): Row = new GenericRow(values.clone()) } @@ -193,10 +227,11 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { return if (order.direction == Ascending) 1 else -1 } else { val comparison = order.dataType match { - case n: NativeType if order.direction == Ascending => + case n: AtomicType if order.direction == Ascending => n.ordering.asInstanceOf[Ordering[Any]].compare(left, right) - case n: NativeType if order.direction == Descending => + case n: AtomicType if order.direction == Descending => n.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) + case other => sys.error(s"Type $other does not support ordered operations") } if (comparison != 0) return comparison } @@ -205,3 +240,10 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { return 0 } } + +object RowOrdering { + def forSchema(dataTypes: Seq[DataType]): RowOrdering = + new RowOrdering(dataTypes.zipWithIndex.map { + case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) + }) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 3a5bdca1f07c..4c4418227820 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -20,23 +20,48 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet +/** The data type for expressions returning an OpenHashSet as the result. */ +private[sql] class OpenHashSetUDT( + val elementType: DataType) extends UserDefinedType[OpenHashSet[Any]] { + + override def sqlType: DataType = ArrayType(elementType) + + /** Since we are using OpenHashSet internally, usually it will not be called. */ + override def serialize(obj: Any): Seq[Any] = { + obj.asInstanceOf[OpenHashSet[Any]].iterator.toSeq + } + + /** Since we are using OpenHashSet internally, usually it will not be called. */ + override def deserialize(datum: Any): OpenHashSet[Any] = { + val iterator = datum.asInstanceOf[Seq[Any]].iterator + val set = new OpenHashSet[Any] + while(iterator.hasNext) { + set.add(iterator.next()) + } + + set + } + + override def userClass: Class[OpenHashSet[Any]] = classOf[OpenHashSet[Any]] + + private[spark] override def asNullable: OpenHashSetUDT = this +} + /** * Creates a new set of the specified type */ case class NewSet(elementType: DataType) extends LeafExpression { type EvaluatedType = Any - def nullable = false + override def nullable: Boolean = false - // We are currently only using these Expressions internally for aggregation. However, if we ever - // expose these to users we'll want to create a proper type instead of hijacking ArrayType. - def dataType = ArrayType(elementType) + override def dataType: OpenHashSetUDT = new OpenHashSetUDT(elementType) - def eval(input: Row): Any = { + override def eval(input: Row): Any = { new OpenHashSet[Any]() } - override def toString = s"new Set($dataType)" + override def toString: String = s"new Set($dataType)" } /** @@ -46,12 +71,13 @@ case class NewSet(elementType: DataType) extends LeafExpression { case class AddItemToSet(item: Expression, set: Expression) extends Expression { type EvaluatedType = Any - def children = item :: set :: Nil + override def children: Seq[Expression] = item :: set :: Nil + + override def nullable: Boolean = set.nullable - def nullable = set.nullable + override def dataType: OpenHashSetUDT = set.dataType.asInstanceOf[OpenHashSetUDT] - def dataType = set.dataType - def eval(input: Row): Any = { + override def eval(input: Row): Any = { val itemEval = item.eval(input) val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]] @@ -67,7 +93,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { } } - override def toString = s"$set += $item" + override def toString: String = s"$set += $item" } /** @@ -77,13 +103,13 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { case class CombineSets(left: Expression, right: Expression) extends BinaryExpression { type EvaluatedType = Any - def nullable = left.nullable || right.nullable + override def nullable: Boolean = left.nullable || right.nullable - def dataType = left.dataType + override def dataType: OpenHashSetUDT = left.dataType.asInstanceOf[OpenHashSetUDT] - def symbol = "++=" + override def symbol: String = "++=" - def eval(input: Row): Any = { + override def eval(input: Row): Any = { val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]] if(leftEval != null) { val rightEval = right.eval(input).asInstanceOf[OpenHashSet[Any]] @@ -109,16 +135,16 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres case class CountSet(child: Expression) extends UnaryExpression { type EvaluatedType = Any - def nullable = child.nullable + override def nullable: Boolean = child.nullable - def dataType = LongType + override def dataType: DataType = LongType - def eval(input: Row): Any = { + override def eval(input: Row): Any = { val childEval = child.eval(input).asInstanceOf[OpenHashSet[Any]] if (childEval != null) { childEval.size.toLong } } - override def toString = s"$child.count()" + override def toString: String = s"$child.count()" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index f85ee0a9bb6d..d597bf7ce756 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -19,11 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import java.util.regex.Pattern -import scala.collection.IndexedSeqOptimized - - import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType, StringType} +import org.apache.spark.sql.types._ trait StringRegexExpression { self: BinaryExpression => @@ -33,8 +30,8 @@ trait StringRegexExpression { def escape(v: String): String def matches(regex: Pattern, str: String): Boolean - def nullable: Boolean = left.nullable || right.nullable - def dataType: DataType = BooleanType + override def nullable: Boolean = left.nullable || right.nullable + override def dataType: DataType = BooleanType // try cache the pattern for Literal private lazy val cache: Pattern = right match { @@ -60,49 +57,28 @@ trait StringRegexExpression { if(r == null) { null } else { - val regex = pattern(r.asInstanceOf[String]) + val regex = pattern(r.asInstanceOf[UTF8String].toString) if(regex == null) { null } else { - matches(regex, l.asInstanceOf[String]) + matches(regex, l.asInstanceOf[UTF8String].toString) } } } } } -trait CaseConversionExpression { - self: UnaryExpression => - - type EvaluatedType = Any - - def convert(v: String): String - - override def foldable: Boolean = child.foldable - def nullable: Boolean = child.nullable - def dataType: DataType = StringType - - override def eval(input: Row): Any = { - val evaluated = child.eval(input) - if (evaluated == null) { - null - } else { - convert(evaluated.toString) - } - } -} - /** * Simple RegEx pattern matching function */ case class Like(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - def symbol = "LIKE" + override def symbol: String = "LIKE" // replace the _ with .{1} exactly match 1 time of any character // replace the % with .*, match 0 or more times with any character - override def escape(v: String) = + override def escape(v: String): String = if (!v.isEmpty) { "(?s)" + (' ' +: v.init).zip(v).flatMap { case (prev, '\\') => "" @@ -129,19 +105,40 @@ case class Like(left: Expression, right: Expression) case class RLike(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - def symbol = "RLIKE" + override def symbol: String = "RLIKE" override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) } +trait CaseConversionExpression { + self: UnaryExpression => + + type EvaluatedType = Any + + def convert(v: UTF8String): UTF8String + + override def foldable: Boolean = child.foldable + def nullable: Boolean = child.nullable + def dataType: DataType = StringType + + override def eval(input: Row): Any = { + val evaluated = child.eval(input) + if (evaluated == null) { + null + } else { + convert(evaluated.asInstanceOf[UTF8String]) + } + } +} + /** * A function that converts the characters of a string to uppercase. */ case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: String): String = v.toUpperCase() + override def convert(v: UTF8String): UTF8String = v.toUpperCase - override def toString() = s"Upper($child)" + override def toString: String = s"Upper($child)" } /** @@ -149,59 +146,59 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE */ case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: String): String = v.toLowerCase() + override def convert(v: UTF8String): UTF8String = v.toLowerCase - override def toString() = s"Lower($child)" + override def toString: String = s"Lower($child)" } /** A base trait for functions that compare two strings, returning a boolean. */ trait StringComparison { - self: BinaryExpression => + self: BinaryPredicate => - type EvaluatedType = Any + override type EvaluatedType = Any - def nullable: Boolean = left.nullable || right.nullable - override def dataType: DataType = BooleanType + override def nullable: Boolean = left.nullable || right.nullable - def compare(l: String, r: String): Boolean + def compare(l: UTF8String, r: UTF8String): Boolean override def eval(input: Row): Any = { - val leftEval = left.eval(input).asInstanceOf[String] + val leftEval = left.eval(input) if(leftEval == null) { null } else { - val rightEval = right.eval(input).asInstanceOf[String] - if (rightEval == null) null else compare(leftEval, rightEval) + val rightEval = right.eval(input) + if (rightEval == null) null + else compare(leftEval.asInstanceOf[UTF8String], rightEval.asInstanceOf[UTF8String]) } } - def symbol: String = nodeName + override def symbol: String = nodeName - override def toString() = s"$nodeName($left, $right)" + override def toString: String = s"$nodeName($left, $right)" } /** * A function that returns true if the string `left` contains the string `right`. */ case class Contains(left: Expression, right: Expression) - extends BinaryExpression with StringComparison { - override def compare(l: String, r: String) = l.contains(r) + extends BinaryPredicate with StringComparison { + override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) } /** * A function that returns true if the string `left` starts with the string `right`. */ case class StartsWith(left: Expression, right: Expression) - extends BinaryExpression with StringComparison { - def compare(l: String, r: String) = l.startsWith(r) + extends BinaryPredicate with StringComparison { + override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) } /** * A function that returns true if the string `left` ends with the string `right`. */ case class EndsWith(left: Expression, right: Expression) - extends BinaryExpression with StringComparison { - def compare(l: String, r: String) = l.endsWith(r) + extends BinaryPredicate with StringComparison { + override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) } /** @@ -212,22 +209,20 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends type EvaluatedType = Any - override def foldable = str.foldable && pos.foldable && len.foldable + override def foldable: Boolean = str.foldable && pos.foldable && len.foldable - def nullable: Boolean = str.nullable || pos.nullable || len.nullable - def dataType: DataType = { + override def nullable: Boolean = str.nullable || pos.nullable || len.nullable + override def dataType: DataType = { if (!resolved) { throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved") } if (str.dataType == BinaryType) str.dataType else StringType } - override def children = str :: pos :: len :: Nil + override def children: Seq[Expression] = str :: pos :: len :: Nil @inline - def slice[T, C <: Any](str: C, startPos: Int, sliceLen: Int) - (implicit ev: (C=>IndexedSeqOptimized[T,_])): Any = { - val len = str.length + def slicePos(startPos: Int, sliceLen: Int, length: () => Int): (Int, Int) = { // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and // negative indices for start positions. If a start index i is greater than 0, it // refers to element i-1 in the sequence. If a start index i is less than 0, it refers @@ -236,7 +231,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends val start = startPos match { case pos if pos > 0 => pos - 1 - case neg if neg < 0 => len + neg + case neg if neg < 0 => length() + neg case _ => 0 } @@ -245,12 +240,11 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends case x => start + x } - str.slice(start, end) + (start, end) } override def eval(input: Row): Any = { val string = str.eval(input) - val po = pos.eval(input) val ln = len.eval(input) @@ -258,16 +252,20 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends null } else { val start = po.asInstanceOf[Int] - val length = ln.asInstanceOf[Int] - + val length = ln.asInstanceOf[Int] string match { - case ba: Array[Byte] => slice(ba, start, length) - case other => slice(other.toString, start, length) + case ba: Array[Byte] => + val (st, end) = slicePos(start, length, () => ba.length) + ba.slice(st, end) + case s: UTF8String => + val (st, end) = slicePos(start, length, () => s.length) + s.slice(st, end) } } } - override def toString = len match { + override def toString: String = len match { + // TODO: This is broken because max is not an integer value. case max if max == Integer.MAX_VALUE => s"SUBSTR($str, $pos)" case _ => s"SUBSTR($str, $pos, $len)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 376a9f36568a..2d03fbfb0d31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.FullOuter @@ -32,6 +33,9 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] object DefaultOptimizer extends Optimizer { val batches = + // SubQueries are only needed for analysis and can be removed before execution. + Batch("Remove SubQueries", FixedPoint(100), + EliminateSubQueries) :: Batch("Combine Limits", FixedPoint(100), CombineLimits) :: Batch("ConstantFolding", FixedPoint(100), @@ -50,7 +54,10 @@ object DefaultOptimizer extends Optimizer { CombineFilters, PushPredicateThroughProject, PushPredicateThroughJoin, - ColumnPruning) :: Nil + PushPredicateThroughGenerate, + ColumnPruning) :: + Batch("LocalRelation", FixedPoint(100), + ConvertToLocalRelation) :: Nil } /** @@ -116,6 +123,15 @@ object ColumnPruning extends Rule[LogicalPlan] { case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = Project(a.references.toSeq, child)) + case p @ Project(projectList, a @ Aggregate(groupingExpressions, aggregateExpressions, child)) + if (a.outputSet -- p.references).nonEmpty => + Project( + projectList, + Aggregate( + groupingExpressions, + aggregateExpressions.filter(e => p.references.contains(e)), + child)) + // Eliminate unneeded attributes from either side of a Join. case Project(projectList, Join(left, right, joinType, condition)) => // Collect the list of all references required either above or to evaluate the condition. @@ -125,7 +141,7 @@ object ColumnPruning extends Rule[LogicalPlan] { condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) /** Applies a projection only when the child is producing unnecessary attributes */ - def pruneJoinChild(c: LogicalPlan) = prunedChild(c, allReferences) + def pruneJoinChild(c: LogicalPlan): LogicalPlan = prunedChild(c, allReferences) Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition)) @@ -182,14 +198,19 @@ object LikeSimplification extends Rule[LogicalPlan] { val equalTo = "([^_%]*)".r def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case Like(l, Literal(startsWith(pattern), StringType)) if !pattern.endsWith("\\") => - StartsWith(l, Literal(pattern)) - case Like(l, Literal(endsWith(pattern), StringType)) => - EndsWith(l, Literal(pattern)) - case Like(l, Literal(contains(pattern), StringType)) if !pattern.endsWith("\\") => - Contains(l, Literal(pattern)) - case Like(l, Literal(equalTo(pattern), StringType)) => - EqualTo(l, Literal(pattern)) + case Like(l, Literal(utf, StringType)) => + utf.toString match { + case startsWith(pattern) if !pattern.endsWith("\\") => + StartsWith(l, Literal(pattern)) + case endsWith(pattern) => + EndsWith(l, Literal(pattern)) + case contains(pattern) if !pattern.endsWith("\\") => + Contains(l, Literal(pattern)) + case equalTo(pattern) => + EqualTo(l, Literal(pattern)) + case _ => + Like(l, Literal.create(utf, StringType)) + } } } @@ -202,11 +223,12 @@ object NullPropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType) - case e @ IsNull(c) if !c.nullable => Literal(false, BooleanType) - case e @ IsNotNull(c) if !c.nullable => Literal(true, BooleanType) - case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType) - case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType) - case e @ GetField(Literal(null, _), _) => Literal(null, e.dataType) + case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) + case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) + case e @ GetItem(Literal(null, _), _) => Literal.create(null, e.dataType) + case e @ GetItem(_, Literal(null, _)) => Literal.create(null, e.dataType) + case e @ StructGetField(Literal(null, _), _, _) => Literal.create(null, e.dataType) + case e @ ArrayGetField(Literal(null, _), _, _, _) => Literal.create(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) case e @ Count(expr) if !expr.nullable => Count(Literal(1)) @@ -218,36 +240,36 @@ object NullPropagation extends Rule[LogicalPlan] { case _ => true } if (newChildren.length == 0) { - Literal(null, e.dataType) + Literal.create(null, e.dataType) } else if (newChildren.length == 1) { newChildren(0) } else { Coalesce(newChildren) } - case e @ Substring(Literal(null, _), _, _) => Literal(null, e.dataType) - case e @ Substring(_, Literal(null, _), _) => Literal(null, e.dataType) - case e @ Substring(_, _, Literal(null, _)) => Literal(null, e.dataType) + case e @ Substring(Literal(null, _), _, _) => Literal.create(null, e.dataType) + case e @ Substring(_, Literal(null, _), _) => Literal.create(null, e.dataType) + case e @ Substring(_, _, Literal(null, _)) => Literal.create(null, e.dataType) // Put exceptional cases above if any case e: BinaryArithmetic => e.children match { - case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) + case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) + case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e } case e: BinaryComparison => e.children match { - case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) + case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) + case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e } case e: StringRegexExpression => e.children match { - case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) + case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) + case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e } case e: StringComparison => e.children match { - case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) + case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) + case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e } } @@ -267,13 +289,13 @@ object ConstantFolding extends Rule[LogicalPlan] { case l: Literal => l // Fold expressions that are foldable. - case e if e.foldable => Literal(e.eval(null), e.dataType) + case e if e.foldable => Literal.create(e.eval(null), e.dataType) // Fold "literal in (item1, item2, ..., literal, ...)" into true directly. case In(Literal(v, _), list) if list.exists { case Literal(candidate, _) if candidate == v => true case _ => false - } => Literal(true, BooleanType) + } => Literal.create(true, BooleanType) } } } @@ -453,6 +475,30 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] { } } +/** + * Push [[Filter]] operators through [[Generate]] operators. Parts of the predicate that reference + * attributes generated in [[Generate]] will remain above, and the rest should be pushed beneath. + */ +object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper { + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case filter @ Filter(condition, g: Generate) => + // Predicates that reference attributes produced by the `Generate` operator cannot + // be pushed below the operator. + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { + conjunct => conjunct.references subsetOf g.child.outputSet + } + if (pushDown.nonEmpty) { + val pushDownPredicate = pushDown.reduce(And) + val withPushdown = Generate(g.generator, join = g.join, outer = g.outer, + g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child)) + stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown) + } else { + filter + } + } +} + /** * Pushes down [[Filter]] operators where the `condition` can be * evaluated using only the attributes of the left or right side of a join. Other @@ -606,7 +652,21 @@ object DecimalAggregates extends Rule[LogicalPlan] { case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => Cast( - Divide(Average(UnscaledValue(e)), Literal(math.pow(10.0, scale), DoubleType)), + Divide(Average(UnscaledValue(e)), Literal.create(math.pow(10.0, scale), DoubleType)), DecimalType(prec + 4, scale + 4)) } } + +/** + * Converts local operations (i.e. ones that don't require data exchange) on LocalRelation to + * another LocalRelation. + * + * This is relatively simple as it currently handles only a single case: Project. + */ +object ConvertToLocalRelation extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Project(projectList, LocalRelation(output, data)) => + val projection = new InterpretedProjection(projectList, output) + LocalRelation(projectList.map(_.toAttribute), data.map(projection)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 310d127506d6..9c8c643f7d17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -91,16 +91,18 @@ object PhysicalOperation extends PredicateHelper { (None, Nil, other, Map.empty) } - def collectAliases(fields: Seq[Expression]) = fields.collect { + def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = fields.collect { case a @ Alias(child, _) => a.toAttribute.asInstanceOf[Attribute] -> child }.toMap - def substitute(aliases: Map[Attribute, Expression])(expr: Expression) = expr.transform { - case a @ Alias(ref: AttributeReference, name) => - aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a) + def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = { + expr.transform { + case a @ Alias(ref: AttributeReference, name) => + aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a) - case a: AttributeReference => - aliases.get(a).map(Alias(_, a.name)(a.exprId, a.qualifiers)).getOrElse(a) + case a: AttributeReference => + aliases.get(a).map(Alias(_, a.name)(a.exprId, a.qualifiers)).getOrElse(a) + } } } @@ -141,10 +143,11 @@ object PartialAggregation { // We need to pass all grouping expressions though so the grouping can happen a second // time. However some of them might be unnamed so we alias them allowing them to be // referenced in the second aggregation. - val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map { - case n: NamedExpression => (n, n) - case other => (other, Alias(other, "PartialGroup")()) - }.toMap + val namedGroupingExpressions: Map[Expression, NamedExpression] = + groupingExpressions.filter(!_.isInstanceOf[Literal]).map { + case n: NamedExpression => (n, n) + case other => (other, Alias(other, "PartialGroup")()) + }.toMap // Replace aggregations with a new expression that computes the result from the already // computed partial evaluations and grouping values. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 619f42859cbb..7967189cacb2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression} +import org.apache.spark.sql.catalyst.expressions.{VirtualColumn, Attribute, AttributeSet, Expression} import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} @@ -47,8 +47,12 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy * Attributes that are referenced by expressions but not provided by this nodes children. * Subclasses should override this method if they produce attributes internally as it is used by * assertions designed to prevent the construction of invalid plans. + * + * Note that virtual columns should be excluded. Currently, we only support the grouping ID + * virtual column. */ - def missingInput: AttributeSet = references -- inputSet + def missingInput: AttributeSet = + (references -- inputSet).filter(_.name != VirtualColumn.groupingIdName) /** * Runs [[transform]] with `rule` on all expressions present in this query operator. @@ -67,7 +71,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = { var changed = false - @inline def transformExpressionDown(e: Expression) = { + @inline def transformExpressionDown(e: Expression): Expression = { val newE = e.transformDown(rule) if (newE.fastEquals(e)) { e @@ -81,6 +85,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy case e: Expression => transformExpressionDown(e) case Some(e: Expression) => Some(transformExpressionDown(e)) case m: Map[_,_] => m + case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map { case e: Expression => transformExpressionDown(e) case other => other @@ -99,7 +104,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy def transformExpressionsUp(rule: PartialFunction[Expression, Expression]): this.type = { var changed = false - @inline def transformExpressionUp(e: Expression) = { + @inline def transformExpressionUp(e: Expression): Expression = { val newE = e.transformUp(rule) if (newE.fastEquals(e)) { e @@ -113,6 +118,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy case e: Expression => transformExpressionUp(e) case Some(e: Expression) => Some(transformExpressionUp(e)) case m: Map[_,_] => m + case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map { case e: Expression => transformExpressionUp(e) case other => other @@ -144,7 +150,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy }.toSeq } - def schema: StructType = StructType.fromAttributes(output) + lazy val schema: StructType = StructType.fromAttributes(output) /** Returns the output schema in the tree format. */ def schemaString: String = schema.treeString @@ -152,7 +158,12 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy /** Prints out the schema in the tree format */ def printSchema(): Unit = println(schemaString) + /** + * A prefix string used when printing the plan. + * + * We use "!" to indicate an invalid plan, and "'" to indicate an unresolved plan. + */ protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else "" - override def simpleString = statePrefix + super.simpleString + override def simpleString: String = statePrefix + super.simpleString } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index 613f4bb09daf..5dc0539caec2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -17,9 +17,24 @@ package org.apache.spark.sql.catalyst.plans +object JoinType { + def apply(typ: String): JoinType = typ.toLowerCase.replace("_", "") match { + case "inner" => Inner + case "outer" | "full" | "fullouter" => FullOuter + case "leftouter" | "left" => LeftOuter + case "rightouter" | "right" => RightOuter + case "leftsemi" => LeftSemi + } +} + sealed abstract class JoinType + case object Inner extends JoinType + case object LeftOuter extends JoinType + case object RightOuter extends JoinType + case object FullOuter extends JoinType + case object LeftSemi extends JoinType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala similarity index 60% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index 19769986ef58..e3e070f0ff30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -17,27 +17,35 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, analysis} import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.types.{StructType, StructField} object LocalRelation { - def apply(output: Attribute*) = - new LocalRelation(output) + def apply(output: Attribute*): LocalRelation = new LocalRelation(output) + + def apply(output1: StructField, output: StructField*): LocalRelation = { + new LocalRelation(StructType(output1 +: output).toAttributes) + } + + def fromProduct(output: Seq[Attribute], data: Seq[Product]): LocalRelation = { + val schema = StructType.fromAttributes(output) + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + LocalRelation(output, data.map(converter(_).asInstanceOf[Row])) + } } -case class LocalRelation(output: Seq[Attribute], data: Seq[Product] = Nil) +case class LocalRelation(output: Seq[Attribute], data: Seq[Row] = Nil) extends LeafNode with analysis.MultiInstanceRelation { - // TODO: Validate schema compliance. - def loadData(newData: Seq[Product]) = new LocalRelation(output, data ++ newData) - /** * Returns an identical copy of this relation with new exprIds for all attributes. Different * attributes are required when a relation is going to be included multiple times in the same * query. */ - override final def newInstance: this.type = { - LocalRelation(output.map(_.newInstance), data).asInstanceOf[this.type] + override final def newInstance(): this.type = { + LocalRelation(output.map(_.newInstance()), data).asInstanceOf[this.type] } override protected def stringArgs = Iterator(output) @@ -47,4 +55,7 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[Product] = Nil) otherOutput.map(_.dataType) == output.map(_.dataType) && otherData == data case _ => false } + + override lazy val statistics = + Statistics(sizeInBytes = output.map(_.dataType.defaultSize).sum * data.length) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 65ae066e4b4b..57e640b273ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -18,38 +18,29 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.analysis.Resolver -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, EliminateSubQueries, Resolver} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.types.StructType import org.apache.spark.sql.catalyst.trees -/** - * Estimates of various statistics. The default estimation logic simply lazily multiplies the - * corresponding statistic produced by the children. To override this behavior, override - * `statistics` and assign it an overriden version of `Statistics`. - * - * '''NOTE''': concrete and/or overriden versions of statistics fields should pay attention to the - * performance of the implementations. The reason is that estimations might get triggered in - * performance-critical processes, such as query plan planning. - * - * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it - * defaults to the product of children's `sizeInBytes`. - */ -private[sql] case class Statistics(sizeInBytes: BigInt) abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { self: Product => + /** + * Computes [[Statistics]] for this plan. The default implementation assumes the output + * cardinality is the product of of all child plan's cardinality, i.e. applies in the case + * of cartesian joins. + * + * [[LeafNode]]s must override this. + */ def statistics: Statistics = { if (children.size == 0) { throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.") } - - Statistics( - sizeInBytes = children.map(_.statistics).map(_.sizeInBytes).product) + Statistics(sizeInBytes = children.map(_.statistics.sizeInBytes).product) } /** @@ -81,17 +72,21 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * differences like attribute naming and or expression id differences. Logical operators that * can do better should override this function. */ - def sameResult(plan: LogicalPlan): Boolean = { - plan.getClass == this.getClass && - plan.children.size == children.size && { - logDebug(s"[${cleanArgs.mkString(", ")}] == [${plan.cleanArgs.mkString(", ")}]") - cleanArgs == plan.cleanArgs + def sameResult(plan: LogicalPlan): Boolean = { // 该方法永远返回false不会影响结果正确性,但会影响性能 + val cleanLeft = EliminateSubQueries(this) + val cleanRight = EliminateSubQueries(plan) + + cleanLeft.getClass == cleanRight.getClass && + cleanLeft.children.size == cleanRight.children.size && { + logDebug( + s"[${cleanRight.cleanArgs.mkString(", ")}] == [${cleanLeft.cleanArgs.mkString(", ")}]") + cleanRight.cleanArgs == cleanLeft.cleanArgs // todo:为什么判断 cleanArgs? } && - (plan.children, children).zipped.forall(_ sameResult _) + (cleanLeft.children, cleanRight.children).zipped.forall(_ sameResult _) } /** Args that have cleaned such that differences in expression id should not affect equality */ - protected lazy val cleanArgs: Seq[Any] = { + protected lazy val cleanArgs: Seq[Any] = { // todo: debug 这个方法 val input = children.flatMap(_.output) productIterator.map { // Children are checked using sameResult above. @@ -114,57 +109,112 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * nodes of this LogicalPlan. The attribute is expressed as * as string in the following form: `[scope].AttributeName.[nested].[fields]...`. */ - def resolveChildren(name: String, resolver: Resolver): Option[NamedExpression] = - resolve(name, children.flatMap(_.output), resolver) + def resolveChildren( + nameParts: Seq[String], + resolver: Resolver, + throwErrors: Boolean = false): Option[NamedExpression] = + resolve(nameParts, children.flatMap(_.output), resolver, throwErrors) /** * Optionally resolves the given string to a [[NamedExpression]] based on the output of this * LogicalPlan. The attribute is expressed as string in the following form: * `[scope].AttributeName.[nested].[fields]...`. */ - def resolve(name: String, resolver: Resolver): Option[NamedExpression] = - resolve(name, output, resolver) + def resolve( + nameParts: Seq[String], + resolver: Resolver, + throwErrors: Boolean = false): Option[NamedExpression] = + resolve(nameParts, output, resolver, throwErrors) + + /** + * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. + * + * This assumes `name` has multiple parts, where the 1st part is a qualifier + * (i.e. table name, alias, or subquery alias). + * See the comment above `candidates` variable in resolve() for semantics the returned data. + */ + private def resolveAsTableColumn( + nameParts: Seq[String], + resolver: Resolver, + attribute: Attribute): Option[(Attribute, List[String])] = { + assert(nameParts.length > 1) + if (attribute.qualifiers.exists(resolver(_, nameParts.head))) { + // At least one qualifier matches. See if remaining parts match. + val remainingParts = nameParts.tail + resolveAsColumn(remainingParts, resolver, attribute) + } else { + None + } + } + + /** + * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. + * + * Different from resolveAsTableColumn, this assumes `name` does NOT start with a qualifier. + * See the comment above `candidates` variable in resolve() for semantics the returned data. + */ + private def resolveAsColumn( + nameParts: Seq[String], + resolver: Resolver, + attribute: Attribute): Option[(Attribute, List[String])] = { + if (resolver(attribute.name, nameParts.head)) { + Option((attribute.withName(nameParts.head), nameParts.tail.toList)) + } else { + None + } + } /** Performs attribute resolution given a name and a sequence of possible attributes. */ protected def resolve( - name: String, + nameParts: Seq[String], input: Seq[Attribute], - resolver: Resolver): Option[NamedExpression] = { - - val parts = name.split("\\.") - - // Collect all attributes that are output by this nodes children where either the first part - // matches the name or where the first part matches the scope and the second part matches the - // name. Return these matches along with any remaining parts, which represent dotted access to - // struct fields. - val options = input.flatMap { option => - // If the first part of the desired name matches a qualifier for this possible match, drop it. - val remainingParts = - if (option.qualifiers.find(resolver(_, parts.head)).nonEmpty && parts.size > 1) { - parts.drop(1) - } else { - parts + resolver: Resolver, + throwErrors: Boolean): Option[NamedExpression] = { + + // A sequence of possible candidate matches. + // Each candidate is a tuple. The first element is a resolved attribute, followed by a list + // of parts that are to be resolved. + // For example, consider an example where "a" is the table name, "b" is the column name, + // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b", + // and the second element will be List("c"). + var candidates: Seq[(Attribute, List[String])] = { + // If the name has 2 or more parts, try to resolve it as `table.column` first. + if (nameParts.length > 1) { + input.flatMap { option => + resolveAsTableColumn(nameParts, resolver, option) } - - if (resolver(option.name, remainingParts.head)) { - // Preserve the case of the user's attribute reference. - (option.withName(remainingParts.head), remainingParts.tail.toList) :: Nil } else { - Nil + Seq.empty + } + } + + // If none of attributes match `table.column` pattern, we try to resolve it as a column. + if (candidates.isEmpty) { + candidates = input.flatMap { candidate => + resolveAsColumn(nameParts, resolver, candidate) } } - options.distinct match { + def name = UnresolvedAttribute(nameParts).name + + candidates.distinct match { // One match, no nested fields, use it. case Seq((a, Nil)) => Some(a) // One match, but we also need to extract the requested nested field. case Seq((a, nestedFields)) => - val aliased = - Alias( - resolveNesting(nestedFields, a, resolver), - nestedFields.last)() // Preserve the case of the user's field access. - Some(aliased) + try { + // The foldLeft adds GetFields for every remaining parts of the identifier, + // and aliases it with the last part of the identifier. + // For example, consider "a.b.c", where "a" is resolved to an existing attribute. + // Then this will add GetField("c", GetField("b", a)), and alias + // the final expression as "c". + val fieldExprs = nestedFields.foldLeft(a: Expression)(GetField(_, _, resolver)) + val aliasName = nestedFields.last + Some(Alias(fieldExprs, aliasName)()) + } catch { + case a: AnalysisException if !throwErrors => None + } // No matches. case Seq() => @@ -173,33 +223,9 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // More than one match. case ambiguousReferences => - throw new TreeNodeException( - this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}") - } - } - - /** - * Given a list of successive nested field accesses, and a based expression, attempt to resolve - * the actual field lookups on this expression. - */ - private def resolveNesting( - nestedFields: List[String], - expression: Expression, - resolver: Resolver): Expression = { - - (nestedFields, expression.dataType) match { - case (Nil, _) => expression - case (requestedField :: rest, StructType(fields)) => - val actualField = fields.filter(f => resolver(f.name, requestedField)) - if (actualField.length == 0) { - sys.error( - s"No such struct field $requestedField in ${fields.map(_.name).mkString(", ")}") - } else if (actualField.length == 1) { - resolveNesting(rest, GetField(expression, actualField(0).name), resolver) - } else { - sys.error(s"Ambiguous reference to fields ${actualField.mkString(", ")}") - } - case (_, dt) => sys.error(s"Can't access nested field in type $dt") + val referenceNames = ambiguousReferences.map(_._1).mkString(", ") + throw new AnalysisException( + s"Reference '$name' is ambiguous, could be: $referenceNames.") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala index 4460c86ed902..ccf5291219ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Attribute, Expression} /** * Transforms the input by forking and running the specified script. @@ -25,9 +25,19 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} * @param input the set of expression that should be passed to the script. * @param script the command that should be executed. * @param output the attributes that are produced by the script. + * @param ioschema the input and output schema applied in the execution of the script. */ case class ScriptTransformation( input: Seq[Expression], script: String, output: Seq[Attribute], - child: LogicalPlan) extends UnaryNode + child: LogicalPlan, + ioschema: ScriptInputOutputSchema) extends UnaryNode { + override def references: AttributeSet = AttributeSet(input.flatMap(_.references)) +} + +/** + * A placeholder for implementation specific input and output properties when passing data + * to a script. For example, in Hive this would specify which SerDes to use. + */ +trait ScriptInputOutputSchema diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala new file mode 100644 index 000000000000..9ac4c3a2a56c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +/** + * Estimates of various statistics. The default estimation logic simply lazily multiplies the + * corresponding statistic produced by the children. To override this behavior, override + * `statistics` and assign it an overridden version of `Statistics`. + * + * '''NOTE''': concrete and/or overridden versions of statistics fields should pay attention to the + * performance of the implementations. The reason is that estimations might get triggered in + * performance-critical processes, such as query plan planning. + * + * Note that we are using a BigInt here since it is easy to overflow a 64-bit integer in + * cardinality estimation (e.g. cartesian joins). + * + * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it + * defaults to the product of children's `sizeInBytes`. + */ +private[sql] case class Statistics(sizeInBytes: BigInt) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 9628e93274a1..bbc94a7ab339 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -22,7 +22,17 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { - def output = projectList.map(_.toAttribute) + override def output: Seq[Attribute] = projectList.map(_.toAttribute) + + override lazy val resolved: Boolean = { + val containsAggregatesOrGenerators = projectList.exists ( _.collect { + case agg: AggregateExpression => agg + case generator: Generator => generator + }.nonEmpty + ) + + !expressions.exists(!_.resolved) && childrenResolved && !containsAggregatesOrGenerators + } } /** @@ -30,47 +40,61 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional * programming with one important additional feature, which allows the input rows to be joined with * their output. + * @param generator the generator expression * @param join when true, each output row is implicitly joined with the input tuple that produced * it. * @param outer when true, each input row will be output at least once, even if the output of the * given `generator` is empty. `outer` has no effect when `join` is false. - * @param alias when set, this string is applied to the schema of the output of the transformation - * as a qualifier. + * @param qualifier Qualifier for the attributes of generator(UDTF) + * @param generatorOutput The output schema of the Generator. + * @param child Children logical plan node */ case class Generate( generator: Generator, join: Boolean, outer: Boolean, - alias: Option[String], + qualifier: Option[String], + generatorOutput: Seq[Attribute], child: LogicalPlan) extends UnaryNode { - protected def generatorOutput: Seq[Attribute] = { - val output = alias - .map(a => generator.output.map(_.withQualifiers(a :: Nil))) - .getOrElse(generator.output) - if (join && outer) { - output.map(_.withNullability(true)) - } else { - output - } + override lazy val resolved: Boolean = { + generator.resolved && + childrenResolved && + generator.elementTypes.length == generatorOutput.length && + !generatorOutput.exists(!_.resolved) } - override def output = - if (join) child.output ++ generatorOutput else generatorOutput + // we don't want the gOutput to be taken as part of the expressions + // as that will cause exceptions like unresolved attributes etc. + override def expressions: Seq[Expression] = generator :: Nil + + def output: Seq[Attribute] = { + val qualified = qualifier.map(q => + // prepend the new qualifier to the existed one + generatorOutput.map(a => a.withQualifiers(q +: a.qualifiers)) + ).getOrElse(generatorOutput) + + if (join) child.output ++ qualified else qualified + } } case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output } case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { // TODO: These aren't really the same attributes as nullability etc might change. - override def output = left.output + override def output: Seq[Attribute] = left.output - override lazy val resolved = + override lazy val resolved: Boolean = childrenResolved && - !left.output.zip(right.output).exists { case (l,r) => l.dataType != r.dataType } + left.output.zip(right.output).forall { case (l,r) => l.dataType == r.dataType } + + override def statistics: Statistics = { + val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes + Statistics(sizeInBytes = sizeInBytes) + } } case class Join( @@ -79,7 +103,7 @@ case class Join( joinType: JoinType, condition: Option[Expression]) extends BinaryNode { - override def output = { + override def output: Seq[Attribute] = { joinType match { case LeftSemi => left.output @@ -93,24 +117,34 @@ case class Join( left.output ++ right.output } } + + private def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + + // Joins are only resolved if they don't introduce ambiguious expression ids. + override lazy val resolved: Boolean = { + childrenResolved && !expressions.exists(!_.resolved) && selfJoinResolved + } } case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - def output = left.output + override def output: Seq[Attribute] = left.output } case class InsertIntoTable( table: LogicalPlan, partition: Map[String, Option[String]], child: LogicalPlan, - overwrite: Boolean) + overwrite: Boolean, + ifNotExists: Boolean) extends LogicalPlan { - override def children = child :: Nil - override def output = child.output + override def children: Seq[LogicalPlan] = child :: Nil + override def output: Seq[Attribute] = child.output - override lazy val resolved = childrenResolved && child.output.zip(table.output).forall { - case (childAttr, tableAttr) => childAttr.dataType == tableAttr.dataType + assert(overwrite || !ifNotExists) + override lazy val resolved: Boolean = childrenResolved && child.output.zip(table.output).forall { + case (childAttr, tableAttr) => + DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType) } } @@ -120,14 +154,26 @@ case class CreateTableAsSelect[T]( child: LogicalPlan, allowExisting: Boolean, desc: Option[T] = None) extends UnaryNode { - override def output = Seq.empty[Attribute] - override lazy val resolved = databaseName != None && childrenResolved + override def output: Seq[Attribute] = Seq.empty[Attribute] + override lazy val resolved: Boolean = databaseName != None && childrenResolved +} + +/** + * A container for holding named common table expressions (CTEs) and a query plan. + * This operator will be removed during analysis and the relations will be substituted into child. + * @param child The final query of this CTE. + * @param cteRelations Queries that this CTE defined, + * key is the alias of the CTE definition, + * value is the CTE definition. + */ +case class With(child: LogicalPlan, cteRelations: Map[String, Subquery]) extends UnaryNode { + override def output: Seq[Attribute] = child.output } case class WriteToFile( path: String, child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output } /** @@ -140,7 +186,7 @@ case class Sort( order: Seq[SortOrder], global: Boolean, child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output } case class Aggregate( @@ -149,7 +195,7 @@ case class Aggregate( child: LogicalPlan) extends UnaryNode { - override def output = aggregateExpressions.map(_.toAttribute) + override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) } /** @@ -163,7 +209,12 @@ case class Aggregate( case class Expand( projections: Seq[GroupExpression], output: Seq[Attribute], - child: LogicalPlan) extends UnaryNode + child: LogicalPlan) extends UnaryNode { + override def statistics: Statistics = { + val sizeInBytes = child.statistics.sizeInBytes * projections.length + Statistics(sizeInBytes = sizeInBytes) + } +} trait GroupingAnalytics extends UnaryNode { self: Product => @@ -171,7 +222,7 @@ trait GroupingAnalytics extends UnaryNode { def groupByExprs: Seq[Expression] def aggregations: Seq[NamedExpression] - override def output = aggregations.map(_.toAttribute) + override def output: Seq[Attribute] = aggregations.map(_.toAttribute) } /** @@ -236,7 +287,7 @@ case class Rollup( gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output override lazy val statistics: Statistics = { val limit = limitExpr.eval(null).asInstanceOf[Int] @@ -246,23 +297,35 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { } case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { - override def output = child.output.map(_.withQualifiers(alias :: Nil)) + override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil)) } case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output } case class Distinct(child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output } -case object NoRelation extends LeafNode { - override def output = Nil +/** + * A relation with one row. This is used in "SELECT ..." without a from clause. + */ +case object OneRowRelation extends LeafNode { + override def output: Seq[Attribute] = Nil + + /** + * Computes [[Statistics]] for this plan. The default implementation assumes the output + * cardinality is the product of of all child plan's cardinality, i.e. applies in the case + * of cartesian joins. + * + * [[LeafNode]]s must override this. + */ + override def statistics: Statistics = Statistics(sizeInBytes = 1) } case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - override def output = left.output + override def output: Seq[Attribute] = left.output } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala index 72b0c5c8e7a2..e737418d9c3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, SortOrder} /** * Performs a physical redistribution of the data. Used when the consumer of the query @@ -26,14 +26,11 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder} abstract class RedistributeData extends UnaryNode { self: Product => - def output = child.output + override def output: Seq[Attribute] = child.output } case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan) - extends RedistributeData { -} + extends RedistributeData case class Repartition(partitionExpressions: Seq[Expression], child: LogicalPlan) - extends RedistributeData { -} - + extends RedistributeData diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 3c3d7a311906..fb4217a44807 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Expression, Row, SortOrder} -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{DataType, IntegerType} /** * Specifies how tuples that share common expressions will be distributed when a query is executed @@ -72,7 +72,7 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { "a single partition.") // TODO: This is not really valid... - def clustering = ordering.map(_.child).toSet + def clustering: Set[Expression] = ordering.map(_.child).toSet } sealed trait Partitioning { @@ -94,6 +94,9 @@ sealed trait Partitioning { * only compatible if the `numPartitions` of them is the same. */ def compatibleWith(other: Partitioning): Boolean + + /** Returns the expressions that are used to key the partitioning. */ + def keyExpressions: Seq[Expression] } case class UnknownPartitioning(numPartitions: Int) extends Partitioning { @@ -106,6 +109,8 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning { case UnknownPartitioning(_) => true case _ => false } + + override def keyExpressions: Seq[Expression] = Nil } case object SinglePartition extends Partitioning { @@ -113,10 +118,12 @@ case object SinglePartition extends Partitioning { override def satisfies(required: Distribution): Boolean = true - override def compatibleWith(other: Partitioning) = other match { + override def compatibleWith(other: Partitioning): Boolean = other match { case SinglePartition => true case _ => false } + + override def keyExpressions: Seq[Expression] = Nil } case object BroadcastPartitioning extends Partitioning { @@ -124,10 +131,12 @@ case object BroadcastPartitioning extends Partitioning { override def satisfies(required: Distribution): Boolean = true - override def compatibleWith(other: Partitioning) = other match { + override def compatibleWith(other: Partitioning): Boolean = other match { case SinglePartition => true case _ => false } + + override def keyExpressions: Seq[Expression] = Nil } /** @@ -139,9 +148,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) extends Expression with Partitioning { - override def children = expressions - override def nullable = false - override def dataType = IntegerType + override def children: Seq[Expression] = expressions + override def nullable: Boolean = false + override def dataType: DataType = IntegerType private[this] lazy val clusteringSet = expressions.toSet @@ -152,12 +161,14 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } - override def compatibleWith(other: Partitioning) = other match { + override def compatibleWith(other: Partitioning): Boolean = other match { case BroadcastPartitioning => true case h: HashPartitioning if h == this => true case _ => false } + override def keyExpressions: Seq[Expression] = expressions + override def eval(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } @@ -178,9 +189,9 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) extends Expression with Partitioning { - override def children = ordering - override def nullable = false - override def dataType = IntegerType + override def children: Seq[SortOrder] = ordering + override def nullable: Boolean = false + override def dataType: DataType = IntegerType private[this] lazy val clusteringSet = ordering.map(_.child).toSet @@ -194,12 +205,14 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case _ => false } - override def compatibleWith(other: Partitioning) = other match { + override def compatibleWith(other: Partitioning): Boolean = other match { case BroadcastPartitioning => true case r: RangePartitioning if r == this => true case _ => false } + override def keyExpressions: Seq[Expression] = ordering.map(_.child) + override def eval(input: Row): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index c441f0bf24d8..3f9858b0c4a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -45,7 +45,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { * Executes the batches of rules defined by the subclass. The batches are executed serially * using the defined execution strategy. Within each batch, rules are also executed serially. */ - def apply(plan: TreeType): TreeType = { + def execute(plan: TreeType): TreeType = { var curPlan = plan batches.foreach { batch => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 2013ae4f7bd1..97502ed3afe7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -18,19 +18,53 @@ package org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.types.DataType /** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */ private class MutableInt(var i: Int) +case class Origin( + line: Option[Int] = None, + startPosition: Option[Int] = None) + +/** + * Provides a location for TreeNodes to ask about the context of their origin. For example, which + * line of code is currently being parsed. + */ +object CurrentOrigin { + private val value = new ThreadLocal[Origin]() { + override def initialValue: Origin = Origin() + } + + def get: Origin = value.get() + def set(o: Origin): Unit = value.set(o) + + def reset(): Unit = value.set(Origin()) + + def setPosition(line: Int, start: Int): Unit = { + value.set( + value.get.copy(line = Some(line), startPosition = Some(start))) + } + + def withOrigin[A](o: Origin)(f: => A): A = { + set(o) + val ret = try f finally { reset() } + reset() + ret + } +} + abstract class TreeNode[BaseType <: TreeNode[BaseType]] { self: BaseType with Product => + val origin: Origin = CurrentOrigin.get + /** Returns a Seq of the children of this node */ def children: Seq[BaseType] /** * Faster version of equality which short-circuits when two treeNodes are the same instance. - * We don't just override Object.Equals, as doing so prevents the scala compiler from from + * We don't just override Object.equals, as doing so prevents the scala compiler from * generating case class `equals` methods */ def fastEquals(other: TreeNode[_]): Boolean = { @@ -46,6 +80,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { children.foreach(_.foreach(f)) } + /** + * Runs the given function recursively on [[children]] then on this node. + * @param f the function to be applied to each node in the tree. + */ + def foreachUp(f: BaseType => Unit): Unit = { + children.foreach(_.foreachUp(f)) + f(this) + } + /** * Returns a Seq containing the result of applying the given function to each * node in this tree in a preorder traversal. @@ -141,7 +184,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { * @param rule the function used to transform this nodes children */ def transformDown(rule: PartialFunction[BaseType, BaseType]): BaseType = { - val afterRule = rule.applyOrElse(this, identity[BaseType]) + val afterRule = CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(this, identity[BaseType]) + } + // Check if unchanged and then possibly return old copy to avoid gc churn. if (this fastEquals afterRule) { transformChildrenDown(rule) @@ -175,6 +221,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { Some(arg) } case m: Map[_,_] => m + case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => val newChild = arg.asInstanceOf[BaseType].transformDown(rule) @@ -201,9 +248,13 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = { val afterRuleOnChildren = transformChildrenUp(rule); if (this fastEquals afterRuleOnChildren) { - rule.applyOrElse(this, identity[BaseType]) + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(this, identity[BaseType]) + } } else { - rule.applyOrElse(afterRuleOnChildren, identity[BaseType]) + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(afterRuleOnChildren, identity[BaseType]) + } } } @@ -227,6 +278,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { Some(arg) } case m: Map[_,_] => m + case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => val newChild = arg.asInstanceOf[BaseType].transformUp(rule) @@ -258,29 +310,42 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { * @param newArgs the new product arguments. */ def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") { + val defaultCtor = + getClass.getConstructors + .find(_.getParameterTypes.size != 0) + .headOption + .getOrElse(sys.error(s"No valid constructor for $nodeName")) + try { - // Skip no-arg constructors that are just there for kryo. - val defaultCtor = getClass.getConstructors.find(_.getParameterTypes.size != 0).head - if (otherCopyArgs.isEmpty) { - defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type] - } else { - defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[this.type] + CurrentOrigin.withOrigin(origin) { + // Skip no-arg constructors that are just there for kryo. + if (otherCopyArgs.isEmpty) { + defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type] + } else { + defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[this.type] + } } } catch { case e: java.lang.IllegalArgumentException => throw new TreeNodeException( - this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName? " - + s"Exception message: ${e.getMessage}.") + this, + s""" + |Failed to copy node. + |Is otherCopyArgs specified correctly for $nodeName. + |Exception message: ${e.getMessage} + |ctor: $defaultCtor? + |args: ${newArgs.mkString(", ")} + """.stripMargin) } } /** Returns the name of this type of TreeNode. Defaults to the class name. */ - def nodeName = getClass.getSimpleName + def nodeName: String = getClass.getSimpleName /** * The arguments that should be included in the arg string. Defaults to the `productIterator`. */ - protected def stringArgs = productIterator + protected def stringArgs: Iterator[Any] = productIterator /** Returns a string representing the arguments to this node, minus any children */ def argString: String = productIterator.flatMap { @@ -292,18 +357,18 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { }.mkString(", ") /** String representation of this node without any children */ - def simpleString = s"$nodeName $argString" + def simpleString: String = s"$nodeName $argString".trim override def toString: String = treeString /** Returns a string representation of the nodes in this tree */ - def treeString = generateTreeString(0, new StringBuilder).toString + def treeString: String = generateTreeString(0, new StringBuilder).toString /** * Returns a string representation of the nodes in this tree, where each operator is numbered. * The numbers can be used with [[trees.TreeNode.apply apply]] to easily access specific subtrees. */ - def numberedTreeString = + def numberedTreeString: String = treeString.split("\n").zipWithIndex.map { case (line, i) => f"$i%02d $line" }.mkString("\n") /** @@ -355,14 +420,14 @@ trait BinaryNode[BaseType <: TreeNode[BaseType]] { def left: BaseType def right: BaseType - def children = Seq(left, right) + def children: Seq[BaseType] = Seq(left, right) } /** * A [[TreeNode]] with no children. */ trait LeafNode[BaseType <: TreeNode[BaseType]] { - def children = Nil + def children: Seq[BaseType] = Nil } /** @@ -370,6 +435,5 @@ trait LeafNode[BaseType <: TreeNode[BaseType]] { */ trait UnaryNode[BaseType <: TreeNode[BaseType]] { def child: BaseType - def children = child :: Nil + def children: Seq[BaseType] = child :: Nil } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala index 79a8e06d4b4d..ea6aa1850db4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala @@ -41,11 +41,11 @@ package object trees extends Logging { * A [[TreeNode]] companion for reference equality for Hash based Collection. */ class TreeNodeRef(val obj: TreeNode[_]) { - override def equals(o: Any) = o match { + override def equals(o: Any): Boolean = o match { case that: TreeNodeRef => that.obj.eq(obj) case _ => false } - override def hashCode = if (obj == null) 0 else obj.hashCode + override def hashCode: Int = if (obj == null) 0 else obj.hashCode } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index d8da45ae70c4..c86214a2aa94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -19,22 +19,11 @@ package org.apache.spark.sql.catalyst import java.io.{PrintWriter, ByteArrayOutputStream, FileInputStream, File} -import org.apache.spark.util.{Utils => SparkUtils} +import org.apache.spark.util.Utils package object util { - /** - * Returns a path to a temporary file that probably does not exist. - * Note, there is always the race condition that someone created this - * file since the last time we checked. Thus, this shouldn't be used - * for anything security conscious. - */ - def getTempFilePath(prefix: String, suffix: String = ""): File = { - val tempFile = File.createTempFile(prefix, suffix) - tempFile.delete() - tempFile - } - def fileToString(file: File, encoding: String = "UTF-8") = { + def fileToString(file: File, encoding: String = "UTF-8"): String = { val inStream = new FileInputStream(file) val outStream = new ByteArrayOutputStream try { @@ -56,7 +45,7 @@ package object util { def resourceToString( resource:String, encoding: String = "UTF-8", - classLoader: ClassLoader = SparkUtils.getSparkClassLoader) = { + classLoader: ClassLoader = Utils.getSparkClassLoader): String = { val inStream = classLoader.getResourceAsStream(resource) val outStream = new ByteArrayOutputStream try { @@ -104,7 +93,7 @@ package object util { new String(out.toByteArray) } - def stringOrNull(a: AnyRef) = if (a == null) null else a.toString + def stringOrNull(a: AnyRef): String = if (a == null) null else a.toString def benchmark[A](f: => A): A = { val startTime = System.nanoTime() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala new file mode 100644 index 000000000000..b116163facca --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.json4s.JsonDSL._ + +import org.apache.spark.annotation.DeveloperApi + + +object ArrayType { + /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ + def apply(elementType: DataType): ArrayType = ArrayType(elementType, containsNull = true) +} + + +/** + * :: DeveloperApi :: + * The data type for collections of multiple values. + * Internally these are represented as columns that contain a ``scala.collection.Seq``. + * + * Please use [[DataTypes.createArrayType()]] to create a specific instance. + * + * An [[ArrayType]] object comprises two fields, `elementType: [[DataType]]` and + * `containsNull: Boolean`. The field of `elementType` is used to specify the type of + * array elements. The field of `containsNull` is used to specify if the array has `null` values. + * + * @param elementType The data type of values. + * @param containsNull Indicates if values have `null` values + * + * @group dataType + */ +@DeveloperApi +case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { + + /** No-arg constructor for kryo. */ + protected def this() = this(null, false) + + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + builder.append( + s"$prefix-- element: ${elementType.typeName} (containsNull = $containsNull)\n") + DataType.buildFormattedString(elementType, s"$prefix |", builder) + } + + override private[sql] def jsonValue = + ("type" -> typeName) ~ + ("elementType" -> elementType.jsonValue) ~ + ("containsNull" -> containsNull) + + /** + * The default size of a value of the ArrayType is 100 * the default size of the element type. + * (We assume that there are 100 elements). + */ + override def defaultSize: Int = 100 * elementType.defaultSize + + override def simpleString: String = s"array<${elementType.simpleString}>" + + private[spark] override def asNullable: ArrayType = + ArrayType(elementType.asNullable, containsNull = true) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala new file mode 100644 index 000000000000..a581a9e9468e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.Ordering +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + + +/** + * :: DeveloperApi :: + * The data type representing `Array[Byte]` values. + * Please use the singleton [[DataTypes.BinaryType]]. + * + * @group dataType + */ +@DeveloperApi +class BinaryType private() extends AtomicType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "BinaryType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + + private[sql] type InternalType = Array[Byte] + + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + + private[sql] val ordering = new Ordering[InternalType] { + def compare(x: Array[Byte], y: Array[Byte]): Int = { + for (i <- 0 until x.length; if i < y.length) { + val res = x(i).compareTo(y(i)) + if (res != 0) return res + } + x.length - y.length + } + } + + /** + * The default size of a value of the BinaryType is 4096 bytes. + */ + override def defaultSize: Int = 4096 + + private[spark] override def asNullable: BinaryType = this +} + + +case object BinaryType extends BinaryType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala new file mode 100644 index 000000000000..a7f228cefa57 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.Ordering +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + + +/** + * :: DeveloperApi :: + * The data type representing `Boolean` values. Please use the singleton [[DataTypes.BooleanType]]. + * + *@group dataType + */ +@DeveloperApi +class BooleanType private() extends AtomicType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "BooleanType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Boolean + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val ordering = implicitly[Ordering[InternalType]] + + /** + * The default size of a value of the BooleanType is 1 byte. + */ + override def defaultSize: Int = 1 + + private[spark] override def asNullable: BooleanType = this +} + + +case object BooleanType extends BooleanType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala new file mode 100644 index 000000000000..4d8685796ec7 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.{Ordering, Integral, Numeric} +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + + +/** + * :: DeveloperApi :: + * The data type representing `Byte` values. Please use the singleton [[DataTypes.ByteType]]. + * + * @group dataType + */ +@DeveloperApi +class ByteType private() extends IntegralType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "ByteType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Byte + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val numeric = implicitly[Numeric[Byte]] + private[sql] val integral = implicitly[Integral[Byte]] + private[sql] val ordering = implicitly[Ordering[InternalType]] + + /** + * The default size of a value of the ByteType is 1 byte. + */ + override def defaultSize: Int = 1 + + override def simpleString: String = "tinyint" + + private[spark] override def asNullable: ByteType = this +} + +case object ByteType extends ByteType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala new file mode 100644 index 000000000000..0992a7c311ee --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -0,0 +1,385 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.{TypeTag, runtimeMirror} +import scala.util.parsing.combinator.RegexParsers + +import org.json4s._ +import org.json4s.JsonAST.JValue +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.util.Utils + + +/** + * :: DeveloperApi :: + * The base type of all Spark SQL data types. + * + * @group dataType + */ +@DeveloperApi +abstract class DataType { + /** + * Enables matching against DataType for expressions: + * {{{ + * case Cast(child @ BinaryType(), StringType) => + * ... + * }}} + */ + private[sql] def unapply(a: Expression): Boolean = a match { + case e: Expression if e.dataType == this => true + case _ => false + } + + /** + * The default size of a value of this data type, used internally for size estimation. + */ + def defaultSize: Int + + /** Name of the type used in JSON serialization. */ + def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase + + private[sql] def jsonValue: JValue = typeName + + /** The compact JSON representation of this data type. */ + def json: String = compact(render(jsonValue)) + + /** The pretty (i.e. indented) JSON representation of this data type. */ + def prettyJson: String = pretty(render(jsonValue)) + + /** Readable string representation for the type. */ + def simpleString: String = typeName + + /** + * Check if `this` and `other` are the same data type when ignoring nullability + * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). + */ + private[spark] def sameType(other: DataType): Boolean = + DataType.equalsIgnoreNullability(this, other) + + /** + * Returns the same data type but set all nullability fields are true + * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). + */ + private[spark] def asNullable: DataType +} + + +/** + * An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps. + */ +protected[sql] abstract class AtomicType extends DataType { + private[sql] type InternalType + @transient private[sql] val tag: TypeTag[InternalType] + private[sql] val ordering: Ordering[InternalType] + + @transient private[sql] val classTag = ScalaReflectionLock.synchronized { + val mirror = runtimeMirror(Utils.getSparkClassLoader) + ClassTag[InternalType](mirror.runtimeClass(tag.tpe)) + } +} + + +/** + * :: DeveloperApi :: + * Numeric data types. + * + * @group dataType + */ +abstract class NumericType extends AtomicType { + // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for + // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a + // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets + // desugared by the compiler into an argument to the objects constructor. This means there is no + // longer an no argument constructor and thus the JVM cannot serialize the object anymore. + private[sql] val numeric: Numeric[InternalType] +} + + +private[sql] object NumericType { + /** + * Enables matching against NumericType for expressions: + * {{{ + * case Cast(child @ NumericType(), StringType) => + * ... + * }}} + */ + def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] +} + + +private[sql] object IntegralType { + /** + * Enables matching against IntegralType for expressions: + * {{{ + * case Cast(child @ IntegralType(), StringType) => + * ... + * }}} + */ + def unapply(a: Expression): Boolean = a match { + case e: Expression if e.dataType.isInstanceOf[IntegralType] => true + case _ => false + } +} + + +private[sql] abstract class IntegralType extends NumericType { + private[sql] val integral: Integral[InternalType] +} + + +private[sql] object FractionalType { + /** + * Enables matching against FractionalType for expressions: + * {{{ + * case Cast(child @ FractionalType(), StringType) => + * ... + * }}} + */ + def unapply(a: Expression): Boolean = a match { + case e: Expression if e.dataType.isInstanceOf[FractionalType] => true + case _ => false + } +} + + +private[sql] abstract class FractionalType extends NumericType { + private[sql] val fractional: Fractional[InternalType] + private[sql] val asIntegral: Integral[InternalType] +} + + +object DataType { + + def fromJson(json: String): DataType = parseDataType(parse(json)) + + @deprecated("Use DataType.fromJson instead", "1.2.0") + def fromCaseClassString(string: String): DataType = CaseClassStringParser(string) + + private val nonDecimalNameToType = { + Seq(NullType, DateType, TimestampType, BinaryType, + IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) + .map(t => t.typeName -> t).toMap + } + + /** Given the string representation of a type, return its DataType */ + private def nameToType(name: String): DataType = { + val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r + name match { + case "decimal" => DecimalType.Unlimited + case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) + case other => nonDecimalNameToType(other) + } + } + + private object JSortedObject { + def unapplySeq(value: JValue): Option[List[(String, JValue)]] = value match { + case JObject(seq) => Some(seq.toList.sortBy(_._1)) + case _ => None + } + } + + // NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side. + private def parseDataType(json: JValue): DataType = json match { + case JString(name) => + nameToType(name) + + case JSortedObject( + ("containsNull", JBool(n)), + ("elementType", t: JValue), + ("type", JString("array"))) => + ArrayType(parseDataType(t), n) + + case JSortedObject( + ("keyType", k: JValue), + ("type", JString("map")), + ("valueContainsNull", JBool(n)), + ("valueType", v: JValue)) => + MapType(parseDataType(k), parseDataType(v), n) + + case JSortedObject( + ("fields", JArray(fields)), + ("type", JString("struct"))) => + StructType(fields.map(parseStructField)) + + case JSortedObject( + ("class", JString(udtClass)), + ("pyClass", _), + ("sqlType", _), + ("type", JString("udt"))) => + Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] + } + + private def parseStructField(json: JValue): StructField = json match { + case JSortedObject( + ("metadata", metadata: JObject), + ("name", JString(name)), + ("nullable", JBool(nullable)), + ("type", dataType: JValue)) => + StructField(name, parseDataType(dataType), nullable, Metadata.fromJObject(metadata)) + // Support reading schema when 'metadata' is missing. + case JSortedObject( + ("name", JString(name)), + ("nullable", JBool(nullable)), + ("type", dataType: JValue)) => + StructField(name, parseDataType(dataType), nullable) + } + + private object CaseClassStringParser extends RegexParsers { + protected lazy val primitiveType: Parser[DataType] = + ( "StringType" ^^^ StringType + | "FloatType" ^^^ FloatType + | "IntegerType" ^^^ IntegerType + | "ByteType" ^^^ ByteType + | "ShortType" ^^^ ShortType + | "DoubleType" ^^^ DoubleType + | "LongType" ^^^ LongType + | "BinaryType" ^^^ BinaryType + | "BooleanType" ^^^ BooleanType + | "DateType" ^^^ DateType + | "DecimalType()" ^^^ DecimalType.Unlimited + | fixedDecimalType + | "TimestampType" ^^^ TimestampType + ) + + protected lazy val fixedDecimalType: Parser[DataType] = + ("DecimalType(" ~> "[0-9]+".r) ~ ("," ~> "[0-9]+".r <~ ")") ^^ { + case precision ~ scale => DecimalType(precision.toInt, scale.toInt) + } + + protected lazy val arrayType: Parser[DataType] = + "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { + case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) + } + + protected lazy val mapType: Parser[DataType] = + "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { + case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) + } + + protected lazy val structField: Parser[StructField] = + ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { + case name ~ tpe ~ nullable => + StructField(name, tpe, nullable = nullable) + } + + protected lazy val boolVal: Parser[Boolean] = + ( "true" ^^^ true + | "false" ^^^ false + ) + + protected lazy val structType: Parser[DataType] = + "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { + case fields => StructType(fields) + } + + protected lazy val dataType: Parser[DataType] = + ( arrayType + | mapType + | structType + | primitiveType + ) + + /** + * Parses a string representation of a DataType. + * + * TODO: Generate parser as pickler... + */ + def apply(asString: String): DataType = parseAll(dataType, asString) match { + case Success(result, _) => result + case failure: NoSuccess => + throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure") + } + } + + protected[types] def buildFormattedString( + dataType: DataType, + prefix: String, + builder: StringBuilder): Unit = { + dataType match { + case array: ArrayType => + array.buildFormattedString(prefix, builder) + case struct: StructType => + struct.buildFormattedString(prefix, builder) + case map: MapType => + map.buildFormattedString(prefix, builder) + case _ => + } + } + + /** + * Compares two types, ignoring nullability of ArrayType, MapType, StructType. + */ + private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { + (left, right) match { + case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => + equalsIgnoreNullability(leftElementType, rightElementType) + case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType, rightValueType, _)) => + equalsIgnoreNullability(leftKeyType, rightKeyType) && + equalsIgnoreNullability(leftValueType, rightValueType) + case (StructType(leftFields), StructType(rightFields)) => + leftFields.length == rightFields.length && + leftFields.zip(rightFields).forall { case (l, r) => + l.name == r.name && equalsIgnoreNullability(l.dataType, r.dataType) + } + case (l, r) => l == r + } + } + + /** + * Compares two types, ignoring compatible nullability of ArrayType, MapType, StructType. + * + * Compatible nullability is defined as follows: + * - If `from` and `to` are ArrayTypes, `from` has a compatible nullability with `to` + * if and only if `to.containsNull` is true, or both of `from.containsNull` and + * `to.containsNull` are false. + * - If `from` and `to` are MapTypes, `from` has a compatible nullability with `to` + * if and only if `to.valueContainsNull` is true, or both of `from.valueContainsNull` and + * `to.valueContainsNull` are false. + * - If `from` and `to` are StructTypes, `from` has a compatible nullability with `to` + * if and only if for all every pair of fields, `to.nullable` is true, or both + * of `fromField.nullable` and `toField.nullable` are false. + */ + private[sql] def equalsIgnoreCompatibleNullability(from: DataType, to: DataType): Boolean = { + (from, to) match { + case (ArrayType(fromElement, fn), ArrayType(toElement, tn)) => + (tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement) + + case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => + (tn || !fn) && + equalsIgnoreCompatibleNullability(fromKey, toKey) && + equalsIgnoreCompatibleNullability(fromValue, toValue) + + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields).forall { case (fromField, toField) => + fromField.name == toField.name && + (toField.nullable || !fromField.nullable) && + equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType) + } + + case (fromDataType, toDataType) => fromDataType == toDataType + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala new file mode 100644 index 000000000000..a4d4d12f61c3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.language.implicitConversions +import scala.util.matching.Regex +import scala.util.parsing.combinator.syntactical.StandardTokenParsers + +import org.apache.spark.sql.catalyst.SqlLexical + +/** + * This is a data type parser that can be used to parse string representations of data types + * provided in SQL queries. This parser is mixed in with DDLParser and SqlParser. + */ +private[sql] trait DataTypeParser extends StandardTokenParsers { // 这个就是 token parser了,其内部先用 lexical parser + + // This is used to create a parser from a regex. We are using regexes for data type strings + // since these strings can be also used as column names or field names. + import lexical.Identifier + implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch( + s"identifier matching regex ${regex}", + { case Identifier(str) if regex.unapplySeq(str).isDefined => str } + ) + + /** + * todo: (?i) 是什么意思? + * r 返回的是一个 regex,隐式转换为 一个 parser + */ + protected lazy val primitiveType: Parser[DataType] = + "(?i)string".r ^^^ StringType | + "(?i)float".r ^^^ FloatType | + "(?i)(?:int|integer)".r ^^^ IntegerType | + "(?i)tinyint".r ^^^ ByteType | + "(?i)smallint".r ^^^ ShortType | + "(?i)double".r ^^^ DoubleType | + "(?i)bigint".r ^^^ LongType | + "(?i)binary".r ^^^ BinaryType | + "(?i)boolean".r ^^^ BooleanType | + fixedDecimalType | + "(?i)decimal".r ^^^ DecimalType.Unlimited | + "(?i)date".r ^^^ DateType | + "(?i)timestamp".r ^^^ TimestampType | + varchar + + protected lazy val fixedDecimalType: Parser[DataType] = + ("(?i)decimal".r ~> "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ { + case precision ~ scale => + DecimalType(precision.toInt, scale.toInt) + } + + protected lazy val varchar: Parser[DataType] = + "(?i)varchar".r ~> "(" ~> (numericLit <~ ")") ^^^ StringType + + protected lazy val arrayType: Parser[DataType] = + "(?i)array".r ~> "<" ~> dataType <~ ">" ^^ { + case tpe => ArrayType(tpe) + } + + protected lazy val mapType: Parser[DataType] = + "(?i)map".r ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ { + case t1 ~ _ ~ t2 => MapType(t1, t2) + } + + protected lazy val structField: Parser[StructField] = + ident ~ ":" ~ dataType ^^ { + case name ~ _ ~ tpe => StructField(name, tpe, nullable = true) + } + + protected lazy val structType: Parser[DataType] = + ("(?i)struct".r ~> "<" ~> repsep(structField, ",") <~ ">" ^^ { + case fields => new StructType(fields.toArray) + }) | + ("(?i)struct".r ~ "<>" ^^^ StructType(Nil)) + + protected lazy val dataType: Parser[DataType] = + arrayType | + mapType | + structType | + primitiveType + + def toDataType(dataTypeString: String): DataType = synchronized { + phrase(dataType)(new lexical.Scanner(dataTypeString)) match { + case Success(result, _) => result + case failure: NoSuccess => throw new DataTypeException(failMessage(dataTypeString)) + } + } + + private def failMessage(dataTypeString: String): String = { + s"Unsupported dataType: $dataTypeString. If you have a struct and a field name of it has " + + "any special characters, please use backticks (`) to quote that field name, e.g. `x+y`. " + + "Please note that backtick itself is not supported in a field name." + } +} + +private[sql] object DataTypeParser { + lazy val dataTypeParser = new DataTypeParser { + override val lexical = new SqlLexical + } + // 改为parse 方法 + def apply(dataTypeString: String): DataType = dataTypeParser.toDataType(dataTypeString) +} + +/** The exception thrown from the [[DataTypeParser]]. */ +private[sql] class DataTypeException(message: String) extends Exception(message) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala new file mode 100644 index 000000000000..03f0644bc784 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.Ordering +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + + +/** + * :: DeveloperApi :: + * The data type representing `java.sql.Date` values. + * Please use the singleton [[DataTypes.DateType]]. + * + * @group dataType + */ +@DeveloperApi +class DateType private() extends AtomicType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "DateType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Int + + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + + private[sql] val ordering = implicitly[Ordering[InternalType]] + + /** + * The default size of a value of the DateType is 4 bytes. + */ + override def defaultSize: Int = 4 + + private[spark] override def asNullable: DateType = this +} + + +case object DateType extends DateType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala similarity index 58% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala index 21f478c80c94..d36a49159b87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala @@ -17,12 +17,48 @@ package org.apache.spark.sql.types +import java.sql.Date import java.text.SimpleDateFormat +import java.util.{Calendar, TimeZone} -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.expressions.Cast +/** + * helper function to convert between Int value of days since 1970-01-01 and java.sql.Date + */ +object DateUtils { + private val MILLIS_PER_DAY = 86400000 + + // Java TimeZone has no mention of thread safety. Use thread local instance to be safe. + private val LOCAL_TIMEZONE = new ThreadLocal[TimeZone] { + override protected def initialValue: TimeZone = { + Calendar.getInstance.getTimeZone + } + } + + private def javaDateToDays(d: Date): Int = { + millisToDays(d.getTime) + } + + // we should use the exact day as Int, for example, (year, month, day) -> day + def millisToDays(millisLocal: Long): Int = { + ((millisLocal + LOCAL_TIMEZONE.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt + } -protected[sql] object DataTypeConversions { + private def toMillisSinceEpoch(days: Int): Long = { + val millisUtc = days.toLong * MILLIS_PER_DAY + millisUtc - LOCAL_TIMEZONE.get().getOffset(millisUtc) + } + + def fromJavaDate(date: java.sql.Date): Int = { + javaDateToDays(date) + } + + def toJavaDate(daysSinceEpoch: Int): java.sql.Date = { + new java.sql.Date(toMillisSinceEpoch(daysSinceEpoch)) + } + + def toString(days: Int): String = Cast.threadLocalDateFormat.get.format(toJavaDate(days)) def stringToTime(s: String): java.util.Date = { if (!s.contains('T')) { @@ -51,11 +87,4 @@ protected[sql] object DataTypeConversions { ISO8601GMT.parse(s) } } - - /** Converts Java objects to catalyst rows / types */ - def convertJavaToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { - case (obj, udt: UserDefinedType[_]) => ScalaReflection.convertToCatalyst(obj, udt) // Scala type - case (d: java.math.BigDecimal, _) => Decimal(d) - case (other, _) => other - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 21cc6cea4bf5..994c5202c15d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -246,7 +246,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } - override def equals(other: Any) = other match { + override def equals(other: Any): Boolean = other match { case d: Decimal => compare(d) == 0 case _ => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala new file mode 100644 index 000000000000..0f8cecd28f7d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.sql.catalyst.expressions.Expression + + +/** Precision parameters for a Decimal */ +case class PrecisionInfo(precision: Int, scale: Int) + + +/** + * :: DeveloperApi :: + * The data type representing `java.math.BigDecimal` values. + * A Decimal that might have fixed precision and scale, or unlimited values for these. + * + * Please use [[DataTypes.createDecimalType()]] to create a specific instance. + * + * @group dataType + */ +@DeveloperApi +case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType { + + /** No-arg constructor for kryo. */ + protected def this() = this(null) + + private[sql] type InternalType = Decimal + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val numeric = Decimal.DecimalIsFractional + private[sql] val fractional = Decimal.DecimalIsFractional + private[sql] val ordering = Decimal.DecimalIsFractional + private[sql] val asIntegral = Decimal.DecimalAsIfIntegral + + def precision: Int = precisionInfo.map(_.precision).getOrElse(-1) + + def scale: Int = precisionInfo.map(_.scale).getOrElse(-1) + + override def typeName: String = precisionInfo match { + case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)" + case None => "decimal" + } + + override def toString: String = precisionInfo match { + case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)" + case None => "DecimalType()" + } + + /** + * The default size of a value of the DecimalType is 4096 bytes. + */ + override def defaultSize: Int = 4096 + + override def simpleString: String = precisionInfo match { + case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)" + case None => "decimal(10,0)" + } + + private[spark] override def asNullable: DecimalType = this +} + + +/** Extra factory methods and pattern matchers for Decimals */ +object DecimalType { + val Unlimited: DecimalType = DecimalType(None) + + object Fixed { + def unapply(t: DecimalType): Option[(Int, Int)] = + t.precisionInfo.map(p => (p.precision, p.scale)) + } + + object Expression { + def unapply(e: Expression): Option[(Int, Int)] = e.dataType match { + case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale)) + case _ => None + } + } + + def apply(): DecimalType = Unlimited + + def apply(precision: Int, scale: Int): DecimalType = + DecimalType(Some(PrecisionInfo(precision, scale))) + + def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType] + + def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType] + + def isFixed(dataType: DataType): Boolean = dataType match { + case DecimalType.Fixed(_, _) => true + case _ => false + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala new file mode 100644 index 000000000000..66766623213c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.{Ordering, Fractional, Numeric} +import scala.math.Numeric.DoubleAsIfIntegral +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + +/** + * :: DeveloperApi :: + * The data type representing `Double` values. Please use the singleton [[DataTypes.DoubleType]]. + * + * @group dataType + */ +@DeveloperApi +class DoubleType private() extends FractionalType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "DoubleType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Double + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val numeric = implicitly[Numeric[Double]] + private[sql] val fractional = implicitly[Fractional[Double]] + private[sql] val ordering = implicitly[Ordering[InternalType]] + private[sql] val asIntegral = DoubleAsIfIntegral + + /** + * The default size of a value of the DoubleType is 8 bytes. + */ + override def defaultSize: Int = 8 + + private[spark] override def asNullable: DoubleType = this +} + +case object DoubleType extends DoubleType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala new file mode 100644 index 000000000000..1d5a2f4f6f86 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.Numeric.FloatAsIfIntegral +import scala.math.{Ordering, Fractional, Numeric} +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + +/** + * :: DeveloperApi :: + * The data type representing `Float` values. Please use the singleton [[DataTypes.FloatType]]. + * + * @group dataType + */ +@DeveloperApi +class FloatType private() extends FractionalType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "FloatType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Float + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val numeric = implicitly[Numeric[Float]] + private[sql] val fractional = implicitly[Fractional[Float]] + private[sql] val ordering = implicitly[Ordering[InternalType]] + private[sql] val asIntegral = FloatAsIfIntegral + + /** + * The default size of a value of the FloatType is 4 bytes. + */ + override def defaultSize: Int = 4 + + private[spark] override def asNullable: FloatType = this +} + +case object FloatType extends FloatType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala new file mode 100644 index 000000000000..74e464c08287 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.{Ordering, Integral, Numeric} +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + + +/** + * :: DeveloperApi :: + * The data type representing `Int` values. Please use the singleton [[DataTypes.IntegerType]]. + * + * @group dataType + */ +@DeveloperApi +class IntegerType private() extends IntegralType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "IntegerType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Int + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val numeric = implicitly[Numeric[Int]] + private[sql] val integral = implicitly[Integral[Int]] + private[sql] val ordering = implicitly[Ordering[InternalType]] + + /** + * The default size of a value of the IntegerType is 4 bytes. + */ + override def defaultSize: Int = 4 + + override def simpleString: String = "int" + + private[spark] override def asNullable: IntegerType = this +} + +case object IntegerType extends IntegerType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala new file mode 100644 index 000000000000..390675782e5f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.{Ordering, Integral, Numeric} +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + +/** + * :: DeveloperApi :: + * The data type representing `Long` values. Please use the singleton [[DataTypes.LongType]]. + * + * @group dataType + */ +@DeveloperApi +class LongType private() extends IntegralType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "LongType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Long + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val numeric = implicitly[Numeric[Long]] + private[sql] val integral = implicitly[Integral[Long]] + private[sql] val ordering = implicitly[Ordering[InternalType]] + + /** + * The default size of a value of the LongType is 8 bytes. + */ + override def defaultSize: Int = 8 + + override def simpleString: String = "bigint" + + private[spark] override def asNullable: LongType = this +} + + +case object LongType extends LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala new file mode 100644 index 000000000000..cfdf49307441 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.json4s.JsonAST.JValue +import org.json4s.JsonDSL._ + + +/** + * :: DeveloperApi :: + * The data type for Maps. Keys in a map are not allowed to have `null` values. + * + * Please use [[DataTypes.createMapType()]] to create a specific instance. + * + * @param keyType The data type of map keys. + * @param valueType The data type of map values. + * @param valueContainsNull Indicates if map values have `null` values. + * + * @group dataType + */ +case class MapType( + keyType: DataType, + valueType: DataType, + valueContainsNull: Boolean) extends DataType { + + /** No-arg constructor for kryo. */ + def this() = this(null, null, false) + + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + builder.append(s"$prefix-- key: ${keyType.typeName}\n") + builder.append(s"$prefix-- value: ${valueType.typeName} " + + s"(valueContainsNull = $valueContainsNull)\n") + DataType.buildFormattedString(keyType, s"$prefix |", builder) + DataType.buildFormattedString(valueType, s"$prefix |", builder) + } + + override private[sql] def jsonValue: JValue = + ("type" -> typeName) ~ + ("keyType" -> keyType.jsonValue) ~ + ("valueType" -> valueType.jsonValue) ~ + ("valueContainsNull" -> valueContainsNull) + + /** + * The default size of a value of the MapType is + * 100 * (the default size of the key type + the default size of the value type). + * (We assume that there are 100 elements). + */ + override def defaultSize: Int = 100 * (keyType.defaultSize + valueType.defaultSize) + + override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>" + + private[spark] override def asNullable: MapType = + MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true) +} + + +object MapType { + /** + * Construct a [[MapType]] object with the given key type and value type. + * The `valueContainsNull` is true. + */ + def apply(keyType: DataType, valueType: DataType): MapType = + MapType(keyType: DataType, valueType: DataType, valueContainsNull = true) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala old mode 100755 new mode 100644 index e50e9761431f..6ee24ee0c191 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -41,6 +41,9 @@ import org.apache.spark.annotation.DeveloperApi sealed class Metadata private[types] (private[types] val map: Map[String, Any]) extends Serializable { + /** No-arg constructor for kryo. */ + protected def this() = this(null) + /** Tests whether this Metadata contains a binding for a key. */ def contains(key: String): Boolean = map.contains(key) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala new file mode 100644 index 000000000000..b64b07431fa9 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.annotation.DeveloperApi + + +/** + * :: DeveloperApi :: + * The data type representing `NULL` values. Please use the singleton [[DataTypes.NullType]]. + * + * @group dataType + */ +@DeveloperApi +class NullType private() extends DataType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "NullType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + override def defaultSize: Int = 1 + + private[spark] override def asNullable: NullType = this +} + +case object NullType extends NullType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala new file mode 100644 index 000000000000..73e9ec780b0a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.{Ordering, Integral, Numeric} +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + +/** + * :: DeveloperApi :: + * The data type representing `Short` values. Please use the singleton [[DataTypes.ShortType]]. + * + * @group dataType + */ +@DeveloperApi +class ShortType private() extends IntegralType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "ShortType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Short + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val numeric = implicitly[Numeric[Short]] + private[sql] val integral = implicitly[Integral[Short]] + private[sql] val ordering = implicitly[Ordering[InternalType]] + + /** + * The default size of a value of the ShortType is 2 bytes. + */ + override def defaultSize: Int = 2 + + override def simpleString: String = "smallint" + + private[spark] override def asNullable: ShortType = this +} + +case object ShortType extends ShortType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala new file mode 100644 index 000000000000..134ab0af4e0d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.Ordering +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + +/** + * :: DeveloperApi :: + * The data type representing `String` values. Please use the singleton [[DataTypes.StringType]]. + * + * @group dataType + */ +@DeveloperApi +class StringType private() extends AtomicType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "StringType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = UTF8String + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val ordering = implicitly[Ordering[InternalType]] + + /** + * The default size of a value of the StringType is 4096 bytes. + */ + override def defaultSize: Int = 4096 + + private[spark] override def asNullable: StringType = this +} + +case object StringType extends StringType + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala new file mode 100644 index 000000000000..83570a5eaee6 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.json4s.JsonAST.JValue +import org.json4s.JsonDSL._ + +/** + * A field inside a StructType. + * @param name The name of this field. + * @param dataType The data type of this field. + * @param nullable Indicates if values of this field can be `null` values. + * @param metadata The metadata of this field. The metadata should be preserved during + * transformation if the content of the column is not modified, e.g, in selection. + */ +case class StructField( + name: String, + dataType: DataType, + nullable: Boolean = true, + metadata: Metadata = Metadata.empty) { + + /** No-arg constructor for kryo. */ + protected def this() = this(null, null) + + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n") + DataType.buildFormattedString(dataType, s"$prefix |", builder) + } + + // override the default toString to be compatible with legacy parquet files. + override def toString: String = s"StructField($name,$dataType,$nullable)" + + private[sql] def jsonValue: JValue = { + ("name" -> name) ~ + ("type" -> dataType.jsonValue) ~ + ("nullable" -> nullable) ~ + ("metadata" -> metadata.jsonValue) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala new file mode 100644 index 000000000000..d80ffca18ec9 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.collection.mutable.ArrayBuffer +import scala.math.max + +import org.json4s.JsonDSL._ + +import org.apache.spark.SparkException +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} + + +/** + * :: DeveloperApi :: + * A [[StructType]] object can be constructed by + * {{{ + * StructType(fields: Seq[StructField]) + * }}} + * For a [[StructType]] object, one or multiple [[StructField]]s can be extracted by names. + * If multiple [[StructField]]s are extracted, a [[StructType]] object will be returned. + * If a provided name does not have a matching field, it will be ignored. For the case + * of extracting a single StructField, a `null` will be returned. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val struct = + * StructType( + * StructField("a", IntegerType, true) :: + * StructField("b", LongType, false) :: + * StructField("c", BooleanType, false) :: Nil) + * + * // Extract a single StructField. + * val singleField = struct("b") + * // singleField: StructField = StructField(b,LongType,false) + * + * // This struct does not have a field called "d". null will be returned. + * val nonExisting = struct("d") + * // nonExisting: StructField = null + * + * // Extract multiple StructFields. Field names are provided in a set. + * // A StructType object will be returned. + * val twoFields = struct(Set("b", "c")) + * // twoFields: StructType = + * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) + * + * // Any names without matching fields will be ignored. + * // For the case shown below, "d" will be ignored and + * // it is treated as struct(Set("b", "c")). + * val ignoreNonExisting = struct(Set("b", "c", "d")) + * // ignoreNonExisting: StructType = + * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) + * }}} + * + * A [[org.apache.spark.sql.Row]] object is used as a value of the StructType. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val innerStruct = + * StructType( + * StructField("f1", IntegerType, true) :: + * StructField("f2", LongType, false) :: + * StructField("f3", BooleanType, false) :: Nil) + * + * val struct = StructType( + * StructField("a", innerStruct, true) :: Nil) + * + * // Create a Row with the schema defined by struct + * val row = Row(Row(1, 2, true)) + * // row: Row = [[1,2,true]] + * }}} + * + * @group dataType + */ +@DeveloperApi +case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { + + /** No-arg constructor for kryo. */ + protected def this() = this(null) + + /** Returns all field names in an array. */ + def fieldNames: Array[String] = fields.map(_.name) + + private lazy val fieldNamesSet: Set[String] = fieldNames.toSet + private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap + private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap + + /** + * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not + * have a name matching the given name, `null` will be returned. + */ + def apply(name: String): StructField = { + nameToField.getOrElse(name, + throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) + } + + /** + * Returns a [[StructType]] containing [[StructField]]s of the given names, preserving the + * original order of fields. Those names which do not have matching fields will be ignored. + */ + def apply(names: Set[String]): StructType = { + val nonExistFields = names -- fieldNamesSet + if (nonExistFields.nonEmpty) { + throw new IllegalArgumentException( + s"Field ${nonExistFields.mkString(",")} does not exist.") + } + // Preserve the original order of fields. + StructType(fields.filter(f => names.contains(f.name))) + } + + /** + * Returns index of a given field + */ + def fieldIndex(name: String): Int = { + nameToIndex.getOrElse(name, + throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) + } + + protected[sql] def toAttributes: Seq[AttributeReference] = + map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) + + def treeString: String = { + val builder = new StringBuilder + builder.append("root\n") + val prefix = " |" + fields.foreach(field => field.buildFormattedString(prefix, builder)) + + builder.toString() + } + + def printTreeString(): Unit = println(treeString) + + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + fields.foreach(field => field.buildFormattedString(prefix, builder)) + } + + override private[sql] def jsonValue = + ("type" -> typeName) ~ + ("fields" -> map(_.jsonValue)) + + override def apply(fieldIndex: Int): StructField = fields(fieldIndex) + + override def length: Int = fields.length + + override def iterator: Iterator[StructField] = fields.iterator + + /** + * The default size of a value of the StructType is the total default sizes of all field types. + */ + override def defaultSize: Int = fields.map(_.dataType.defaultSize).sum + + override def simpleString: String = { + val fieldTypes = fields.map(field => s"${field.name}:${field.dataType.simpleString}") + s"struct<${fieldTypes.mkString(",")}>" + } + + /** + * Merges with another schema (`StructType`). For a struct field A from `this` and a struct field + * B from `that`, + * + * 1. If A and B have the same name and data type, they are merged to a field C with the same name + * and data type. C is nullable if and only if either A or B is nullable. + * 2. If A doesn't exist in `that`, it's included in the result schema. + * 3. If B doesn't exist in `this`, it's also included in the result schema. + * 4. Otherwise, `this` and `that` are considered as conflicting schemas and an exception would be + * thrown. + */ + private[sql] def merge(that: StructType): StructType = + StructType.merge(this, that).asInstanceOf[StructType] + + private[spark] override def asNullable: StructType = { + val newFields = fields.map { + case StructField(name, dataType, nullable, metadata) => + StructField(name, dataType.asNullable, nullable = true, metadata) + } + + StructType(newFields) + } +} + + +object StructType { + + def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) + + def apply(fields: java.util.List[StructField]): StructType = { + StructType(fields.toArray.asInstanceOf[Array[StructField]]) + } + + protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = + StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) + + private[sql] def merge(left: DataType, right: DataType): DataType = + (left, right) match { + case (ArrayType(leftElementType, leftContainsNull), + ArrayType(rightElementType, rightContainsNull)) => + ArrayType( + merge(leftElementType, rightElementType), + leftContainsNull || rightContainsNull) + + case (MapType(leftKeyType, leftValueType, leftContainsNull), + MapType(rightKeyType, rightValueType, rightContainsNull)) => + MapType( + merge(leftKeyType, rightKeyType), + merge(leftValueType, rightValueType), + leftContainsNull || rightContainsNull) + + case (StructType(leftFields), StructType(rightFields)) => + val newFields = ArrayBuffer.empty[StructField] + + leftFields.foreach { + case leftField @ StructField(leftName, leftType, leftNullable, _) => + rightFields + .find(_.name == leftName) + .map { case rightField @ StructField(_, rightType, rightNullable, _) => + leftField.copy( + dataType = merge(leftType, rightType), + nullable = leftNullable || rightNullable) + } + .orElse(Some(leftField)) + .foreach(newFields += _) + } + + rightFields + .filterNot(f => leftFields.map(_.name).contains(f.name)) + .foreach(newFields += _) + + StructType(newFields) + + case (DecimalType.Fixed(leftPrecision, leftScale), + DecimalType.Fixed(rightPrecision, rightScale)) => + DecimalType( + max(leftScale, rightScale) + max(leftPrecision - leftScale, rightPrecision - rightScale), + max(leftScale, rightScale)) + + case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_]) + if leftUdt.userClass == rightUdt.userClass => leftUdt + + case (leftType, rightType) if leftType == rightType => + leftType + + case _ => + throw new SparkException(s"Failed to merge incompatible data types $left and $right") + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala new file mode 100644 index 000000000000..aebabfc47592 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import java.sql.Timestamp + +import scala.math.Ordering +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + + +/** + * :: DeveloperApi :: + * The data type representing `java.sql.Timestamp` values. + * Please use the singleton [[DataTypes.TimestampType]]. + * + * @group dataType + */ +@DeveloperApi +class TimestampType private() extends AtomicType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "TimestampType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Timestamp + + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + + private[sql] val ordering = new Ordering[InternalType] { + def compare(x: Timestamp, y: Timestamp): Int = x.compareTo(y) + } + + /** + * The default size of a value of the TimestampType is 12 bytes. + */ + override def defaultSize: Int = 12 + + private[spark] override def asNullable: TimestampType = this +} + +case object TimestampType extends TimestampType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala new file mode 100644 index 000000000000..fc02ba6c9c43 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala @@ -0,0 +1,214 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.types + +import java.util.Arrays + +/** + * A UTF-8 String, as internal representation of StringType in SparkSQL + * + * A String encoded in UTF-8 as an Array[Byte], which can be used for comparison, + * search, see http://en.wikipedia.org/wiki/UTF-8 for details. + * + * Note: This is not designed for general use cases, should not be used outside SQL. + */ + +final class UTF8String extends Ordered[UTF8String] with Serializable { + + private[this] var bytes: Array[Byte] = _ + + /** + * Update the UTF8String with String. + */ + def set(str: String): UTF8String = { + bytes = str.getBytes("utf-8") + this + } + + /** + * Update the UTF8String with Array[Byte], which should be encoded in UTF-8 + */ + def set(bytes: Array[Byte]): UTF8String = { + this.bytes = bytes + this + } + + /** + * Return the number of bytes for a code point with the first byte as `b` + * @param b The first byte of a code point + */ + @inline + private[this] def numOfBytes(b: Byte): Int = { + val offset = (b & 0xFF) - 192 + if (offset >= 0) UTF8String.bytesOfCodePointInUTF8(offset) else 1 + } + + /** + * Return the number of code points in it. + * + * This is only used by Substring() when `start` is negative. + */ + def length(): Int = { + var len = 0 + var i: Int = 0 + while (i < bytes.length) { + i += numOfBytes(bytes(i)) + len += 1 + } + len + } + + def getBytes: Array[Byte] = { + bytes + } + + /** + * Return a substring of this, + * @param start the position of first code point + * @param until the position after last code point + */ + def slice(start: Int, until: Int): UTF8String = { + if (until <= start || start >= bytes.length || bytes == null) { + new UTF8String + } + + var c = 0 + var i: Int = 0 + while (c < start && i < bytes.length) { + i += numOfBytes(bytes(i)) + c += 1 + } + var j = i + while (c < until && j < bytes.length) { + j += numOfBytes(bytes(j)) + c += 1 + } + UTF8String(Arrays.copyOfRange(bytes, i, j)) + } + + def contains(sub: UTF8String): Boolean = { + val b = sub.getBytes + if (b.length == 0) { + return true + } + var i: Int = 0 + while (i <= bytes.length - b.length) { + // In worst case, it's O(N*K), but should works fine with SQL + if (bytes(i) == b(0) && Arrays.equals(Arrays.copyOfRange(bytes, i, i + b.length), b)) { + return true + } + i += 1 + } + false + } + + def startsWith(prefix: UTF8String): Boolean = { + val b = prefix.getBytes + if (b.length > bytes.length) { + return false + } + Arrays.equals(Arrays.copyOfRange(bytes, 0, b.length), b) + } + + def endsWith(suffix: UTF8String): Boolean = { + val b = suffix.getBytes + if (b.length > bytes.length) { + return false + } + Arrays.equals(Arrays.copyOfRange(bytes, bytes.length - b.length, bytes.length), b) + } + + def toUpperCase(): UTF8String = { + // upper case depends on locale, fallback to String. + UTF8String(toString().toUpperCase) + } + + def toLowerCase(): UTF8String = { + // lower case depends on locale, fallback to String. + UTF8String(toString().toLowerCase) + } + + override def toString(): String = { + new String(bytes, "utf-8") + } + + override def clone(): UTF8String = new UTF8String().set(this.bytes) + + override def compare(other: UTF8String): Int = { + var i: Int = 0 + val b = other.getBytes + while (i < bytes.length && i < b.length) { + val res = bytes(i).compareTo(b(i)) + if (res != 0) return res + i += 1 + } + bytes.length - b.length + } + + override def compareTo(other: UTF8String): Int = { + compare(other) + } + + override def equals(other: Any): Boolean = other match { + case s: UTF8String => + Arrays.equals(bytes, s.getBytes) + case s: String => + // This is only used for Catalyst unit tests + // fail fast + bytes.length >= s.length && length() == s.length && toString() == s + case _ => + false + } + + override def hashCode(): Int = { + Arrays.hashCode(bytes) + } +} + +object UTF8String { + // number of tailing bytes in a UTF8 sequence for a code point + // see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1 + private[types] val bytesOfCodePointInUTF8: Array[Int] = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, + 5, 5, 5, 5, + 6, 6, 6, 6) + + /** + * Create a UTF-8 String from String + */ + def apply(s: String): UTF8String = { + if (s != null) { + new UTF8String().set(s) + } else{ + null + } + } + + /** + * Create a UTF-8 String from Array[Byte], which should be encoded in UTF-8 + */ + def apply(bytes: Array[Byte]): UTF8String = { + if (bytes != null) { + new UTF8String().set(bytes) + } else { + null + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala new file mode 100644 index 000000000000..6b20505c6009 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.json4s.JsonAST.JValue +import org.json4s.JsonDSL._ + +import org.apache.spark.annotation.DeveloperApi + +/** + * ::DeveloperApi:: + * The data type for User Defined Types (UDTs). + * + * This interface allows a user to make their own classes more interoperable with SparkSQL; + * e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create + * a `DataFrame` which has class X in the schema. + * + * For SparkSQL to recognize UDTs, the UDT must be annotated with + * [[SQLUserDefinedType]]. + * + * The conversion via `serialize` occurs when instantiating a `DataFrame` from another RDD. + * The conversion via `deserialize` occurs when reading from a `DataFrame`. + */ +@DeveloperApi +abstract class UserDefinedType[UserType] extends DataType with Serializable { + + /** Underlying storage type for this UDT */ + def sqlType: DataType + + /** Paired Python UDT class, if exists. */ + def pyUDT: String = null + + /** + * Convert the user type to a SQL datum + * + * TODO: Can we make this take obj: UserType? The issue is in + * CatalystTypeConverters.convertToCatalyst, where we need to convert Any to UserType. + */ + def serialize(obj: Any): Any + + /** Convert a SQL datum to the user type */ + def deserialize(datum: Any): UserType + + override private[sql] def jsonValue: JValue = { + ("type" -> "udt") ~ + ("class" -> this.getClass.getName) ~ + ("pyClass" -> pyUDT) ~ + ("sqlType" -> sqlType.jsonValue) + } + + /** + * Class object for the UserType + */ + def userClass: java.lang.Class[UserType] + + /** + * The default size of a value of the UserDefinedType is 4096 bytes. + */ + override def defaultSize: Int = 4096 + + /** + * For UDT, asNullable will not change the nullability of its internal sqlType and just returns + * itself. + */ + private[spark] override def asNullable: UserDefinedType[UserType] = this +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala deleted file mode 100644 index 9f30f40a173e..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ /dev/null @@ -1,977 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.types - -import java.sql.{Date, Timestamp} - -import scala.math.Numeric.{FloatAsIfIntegral, DoubleAsIfIntegral} -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag} -import scala.util.parsing.combinator.RegexParsers - -import org.json4s._ -import org.json4s.JsonAST.JValue -import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods._ - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.ScalaReflectionLock -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} -import org.apache.spark.util.Utils - - -object DataType { - def fromJson(json: String): DataType = parseDataType(parse(json)) - - private object JSortedObject { - def unapplySeq(value: JValue): Option[List[(String, JValue)]] = value match { - case JObject(seq) => Some(seq.toList.sortBy(_._1)) - case _ => None - } - } - - // NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side. - private def parseDataType(json: JValue): DataType = json match { - case JString(name) => - PrimitiveType.nameToType(name) - - case JSortedObject( - ("containsNull", JBool(n)), - ("elementType", t: JValue), - ("type", JString("array"))) => - ArrayType(parseDataType(t), n) - - case JSortedObject( - ("keyType", k: JValue), - ("type", JString("map")), - ("valueContainsNull", JBool(n)), - ("valueType", v: JValue)) => - MapType(parseDataType(k), parseDataType(v), n) - - case JSortedObject( - ("fields", JArray(fields)), - ("type", JString("struct"))) => - StructType(fields.map(parseStructField)) - - case JSortedObject( - ("class", JString(udtClass)), - ("pyClass", _), - ("sqlType", _), - ("type", JString("udt"))) => - Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] - } - - private def parseStructField(json: JValue): StructField = json match { - case JSortedObject( - ("metadata", metadata: JObject), - ("name", JString(name)), - ("nullable", JBool(nullable)), - ("type", dataType: JValue)) => - StructField(name, parseDataType(dataType), nullable, Metadata.fromJObject(metadata)) - // Support reading schema when 'metadata' is missing. - case JSortedObject( - ("name", JString(name)), - ("nullable", JBool(nullable)), - ("type", dataType: JValue)) => - StructField(name, parseDataType(dataType), nullable) - } - - @deprecated("Use DataType.fromJson instead", "1.2.0") - def fromCaseClassString(string: String): DataType = CaseClassStringParser(string) - - private object CaseClassStringParser extends RegexParsers { - protected lazy val primitiveType: Parser[DataType] = - ( "StringType" ^^^ StringType - | "FloatType" ^^^ FloatType - | "IntegerType" ^^^ IntegerType - | "ByteType" ^^^ ByteType - | "ShortType" ^^^ ShortType - | "DoubleType" ^^^ DoubleType - | "LongType" ^^^ LongType - | "BinaryType" ^^^ BinaryType - | "BooleanType" ^^^ BooleanType - | "DateType" ^^^ DateType - | "DecimalType()" ^^^ DecimalType.Unlimited - | fixedDecimalType - | "TimestampType" ^^^ TimestampType - ) - - protected lazy val fixedDecimalType: Parser[DataType] = - ("DecimalType(" ~> "[0-9]+".r) ~ ("," ~> "[0-9]+".r <~ ")") ^^ { - case precision ~ scale => DecimalType(precision.toInt, scale.toInt) - } - - protected lazy val arrayType: Parser[DataType] = - "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { - case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) - } - - protected lazy val mapType: Parser[DataType] = - "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { - case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) - } - - protected lazy val structField: Parser[StructField] = - ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { - case name ~ tpe ~ nullable => - StructField(name, tpe, nullable = nullable) - } - - protected lazy val boolVal: Parser[Boolean] = - ( "true" ^^^ true - | "false" ^^^ false - ) - - protected lazy val structType: Parser[DataType] = - "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { - case fields => StructType(fields) - } - - protected lazy val dataType: Parser[DataType] = - ( arrayType - | mapType - | structType - | primitiveType - ) - - /** - * Parses a string representation of a DataType. - * - * TODO: Generate parser as pickler... - */ - def apply(asString: String): DataType = parseAll(dataType, asString) match { - case Success(result, _) => result - case failure: NoSuccess => - throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure") - } - - } - - protected[types] def buildFormattedString( - dataType: DataType, - prefix: String, - builder: StringBuilder): Unit = { - dataType match { - case array: ArrayType => - array.buildFormattedString(prefix, builder) - case struct: StructType => - struct.buildFormattedString(prefix, builder) - case map: MapType => - map.buildFormattedString(prefix, builder) - case _ => - } - } - - /** - * Compares two types, ignoring nullability of ArrayType, MapType, StructType. - */ - private[sql] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { - (left, right) match { - case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => - equalsIgnoreNullability(leftElementType, rightElementType) - case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType, rightValueType, _)) => - equalsIgnoreNullability(leftKeyType, rightKeyType) && - equalsIgnoreNullability(leftValueType, rightValueType) - case (StructType(leftFields), StructType(rightFields)) => - leftFields.size == rightFields.size && - leftFields.zip(rightFields) - .forall{ - case (left, right) => - left.name == right.name && equalsIgnoreNullability(left.dataType, right.dataType) - } - case (left, right) => left == right - } - } -} - - -/** - * :: DeveloperApi :: - * - * The base type of all Spark SQL data types. - * - * @group dataType - */ -@DeveloperApi -abstract class DataType { - /** Matches any expression that evaluates to this DataType */ - def unapply(a: Expression): Boolean = a match { - case e: Expression if e.dataType == this => true - case _ => false - } - - /** The default size of a value of this data type. */ - def defaultSize: Int - - def isPrimitive: Boolean = false - - def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase - - private[sql] def jsonValue: JValue = typeName - - def json: String = compact(render(jsonValue)) - - def prettyJson: String = pretty(render(jsonValue)) -} - - -/** - * :: DeveloperApi :: - * - * The data type representing `NULL` values. Please use the singleton [[DataTypes.NullType]]. - * - * @group dataType - */ -@DeveloperApi -case object NullType extends DataType { - override def defaultSize: Int = 1 -} - - -protected[sql] object NativeType { - val all = Seq( - IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) - - def unapply(dt: DataType): Boolean = all.contains(dt) -} - - -protected[sql] trait PrimitiveType extends DataType { - override def isPrimitive = true -} - - -protected[sql] object PrimitiveType { - private val nonDecimals = Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all - private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> t).toMap - - /** Given the string representation of a type, return its DataType */ - private[sql] def nameToType(name: String): DataType = { - val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r - name match { - case "decimal" => DecimalType.Unlimited - case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) - case other => nonDecimalNameToType(other) - } - } -} - -protected[sql] abstract class NativeType extends DataType { - private[sql] type JvmType - @transient private[sql] val tag: TypeTag[JvmType] - private[sql] val ordering: Ordering[JvmType] - - @transient private[sql] val classTag = ScalaReflectionLock.synchronized { - val mirror = runtimeMirror(Utils.getSparkClassLoader) - ClassTag[JvmType](mirror.runtimeClass(tag.tpe)) - } -} - - -/** - * :: DeveloperApi :: - * - * The data type representing `String` values. Please use the singleton [[DataTypes.StringType]]. - * - * @group dataType - */ -@DeveloperApi -case object StringType extends NativeType with PrimitiveType { - private[sql] type JvmType = String - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - private[sql] val ordering = implicitly[Ordering[JvmType]] - - /** - * The default size of a value of the StringType is 4096 bytes. - */ - override def defaultSize: Int = 4096 -} - - -/** - * :: DeveloperApi :: - * - * The data type representing `Array[Byte]` values. - * Please use the singleton [[DataTypes.BinaryType]]. - * - * @group dataType - */ -@DeveloperApi -case object BinaryType extends NativeType with PrimitiveType { - private[sql] type JvmType = Array[Byte] - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - private[sql] val ordering = new Ordering[JvmType] { - def compare(x: Array[Byte], y: Array[Byte]): Int = { - for (i <- 0 until x.length; if i < y.length) { - val res = x(i).compareTo(y(i)) - if (res != 0) return res - } - x.length - y.length - } - } - - /** - * The default size of a value of the BinaryType is 4096 bytes. - */ - override def defaultSize: Int = 4096 -} - - -/** - * :: DeveloperApi :: - * - * The data type representing `Boolean` values. Please use the singleton [[DataTypes.BooleanType]]. - * - *@group dataType - */ -@DeveloperApi -case object BooleanType extends NativeType with PrimitiveType { - private[sql] type JvmType = Boolean - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - private[sql] val ordering = implicitly[Ordering[JvmType]] - - /** - * The default size of a value of the BooleanType is 1 byte. - */ - override def defaultSize: Int = 1 -} - - -/** - * :: DeveloperApi :: - * - * The data type representing `java.sql.Timestamp` values. - * Please use the singleton [[DataTypes.TimestampType]]. - * - * @group dataType - */ -@DeveloperApi -case object TimestampType extends NativeType { - private[sql] type JvmType = Timestamp - - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - - private[sql] val ordering = new Ordering[JvmType] { - def compare(x: Timestamp, y: Timestamp) = x.compareTo(y) - } - - /** - * The default size of a value of the TimestampType is 8 bytes. - */ - override def defaultSize: Int = 8 -} - - -/** - * :: DeveloperApi :: - * - * The data type representing `java.sql.Date` values. - * Please use the singleton [[DataTypes.DateType]]. - * - * @group dataType - */ -@DeveloperApi -case object DateType extends NativeType { - private[sql] type JvmType = Date - - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - - private[sql] val ordering = new Ordering[JvmType] { - def compare(x: Date, y: Date) = x.compareTo(y) - } - - /** - * The default size of a value of the DateType is 8 bytes. - */ - override def defaultSize: Int = 8 -} - - -protected[sql] abstract class NumericType extends NativeType with PrimitiveType { - // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for - // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a - // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets - // desugared by the compiler into an argument to the objects constructor. This means there is no - // longer an no argument constructor and thus the JVM cannot serialize the object anymore. - private[sql] val numeric: Numeric[JvmType] -} - - -protected[sql] object NumericType { - def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] -} - - -/** Matcher for any expressions that evaluate to [[IntegralType]]s */ -protected[sql] object IntegralType { - def unapply(a: Expression): Boolean = a match { - case e: Expression if e.dataType.isInstanceOf[IntegralType] => true - case _ => false - } -} - - -protected[sql] sealed abstract class IntegralType extends NumericType { - private[sql] val integral: Integral[JvmType] -} - - -/** - * :: DeveloperApi :: - * - * The data type representing `Long` values. Please use the singleton [[DataTypes.LongType]]. - * - * @group dataType - */ -@DeveloperApi -case object LongType extends IntegralType { - private[sql] type JvmType = Long - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - private[sql] val numeric = implicitly[Numeric[Long]] - private[sql] val integral = implicitly[Integral[Long]] - private[sql] val ordering = implicitly[Ordering[JvmType]] - - /** - * The default size of a value of the LongType is 8 bytes. - */ - override def defaultSize: Int = 8 -} - - -/** - * :: DeveloperApi :: - * - * The data type representing `Int` values. Please use the singleton [[DataTypes.IntegerType]]. - * - * @group dataType - */ -@DeveloperApi -case object IntegerType extends IntegralType { - private[sql] type JvmType = Int - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - private[sql] val numeric = implicitly[Numeric[Int]] - private[sql] val integral = implicitly[Integral[Int]] - private[sql] val ordering = implicitly[Ordering[JvmType]] - - /** - * The default size of a value of the IntegerType is 4 bytes. - */ - override def defaultSize: Int = 4 -} - - -/** - * :: DeveloperApi :: - * - * The data type representing `Short` values. Please use the singleton [[DataTypes.ShortType]]. - * - * @group dataType - */ -@DeveloperApi -case object ShortType extends IntegralType { - private[sql] type JvmType = Short - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - private[sql] val numeric = implicitly[Numeric[Short]] - private[sql] val integral = implicitly[Integral[Short]] - private[sql] val ordering = implicitly[Ordering[JvmType]] - - /** - * The default size of a value of the ShortType is 2 bytes. - */ - override def defaultSize: Int = 2 -} - - -/** - * :: DeveloperApi :: - * - * The data type representing `Byte` values. Please use the singleton [[DataTypes.ByteType]]. - * - * @group dataType - */ -@DeveloperApi -case object ByteType extends IntegralType { - private[sql] type JvmType = Byte - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - private[sql] val numeric = implicitly[Numeric[Byte]] - private[sql] val integral = implicitly[Integral[Byte]] - private[sql] val ordering = implicitly[Ordering[JvmType]] - - /** - * The default size of a value of the ByteType is 1 byte. - */ - override def defaultSize: Int = 1 -} - - -/** Matcher for any expressions that evaluate to [[FractionalType]]s */ -protected[sql] object FractionalType { - def unapply(a: Expression): Boolean = a match { - case e: Expression if e.dataType.isInstanceOf[FractionalType] => true - case _ => false - } -} - - -protected[sql] sealed abstract class FractionalType extends NumericType { - private[sql] val fractional: Fractional[JvmType] - private[sql] val asIntegral: Integral[JvmType] -} - - -/** Precision parameters for a Decimal */ -case class PrecisionInfo(precision: Int, scale: Int) - - -/** - * :: DeveloperApi :: - * - * The data type representing `java.math.BigDecimal` values. - * A Decimal that might have fixed precision and scale, or unlimited values for these. - * - * Please use [[DataTypes.createDecimalType()]] to create a specific instance. - * - * @group dataType - */ -@DeveloperApi -case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType { - private[sql] type JvmType = Decimal - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - private[sql] val numeric = Decimal.DecimalIsFractional - private[sql] val fractional = Decimal.DecimalIsFractional - private[sql] val ordering = Decimal.DecimalIsFractional - private[sql] val asIntegral = Decimal.DecimalAsIfIntegral - - def precision: Int = precisionInfo.map(_.precision).getOrElse(-1) - - def scale: Int = precisionInfo.map(_.scale).getOrElse(-1) - - override def typeName: String = precisionInfo match { - case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)" - case None => "decimal" - } - - override def toString: String = precisionInfo match { - case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)" - case None => "DecimalType()" - } - - /** - * The default size of a value of the DecimalType is 4096 bytes. - */ - override def defaultSize: Int = 4096 -} - - -/** Extra factory methods and pattern matchers for Decimals */ -object DecimalType { - val Unlimited: DecimalType = DecimalType(None) - - object Fixed { - def unapply(t: DecimalType): Option[(Int, Int)] = - t.precisionInfo.map(p => (p.precision, p.scale)) - } - - object Expression { - def unapply(e: Expression): Option[(Int, Int)] = e.dataType match { - case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale)) - case _ => None - } - } - - def apply(): DecimalType = Unlimited - - def apply(precision: Int, scale: Int): DecimalType = - DecimalType(Some(PrecisionInfo(precision, scale))) - - def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType] - - def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType] - - def isFixed(dataType: DataType): Boolean = dataType match { - case DecimalType.Fixed(_, _) => true - case _ => false - } -} - - -/** - * :: DeveloperApi :: - * - * The data type representing `Double` values. Please use the singleton [[DataTypes.DoubleType]]. - * - * @group dataType - */ -@DeveloperApi -case object DoubleType extends FractionalType { - private[sql] type JvmType = Double - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - private[sql] val numeric = implicitly[Numeric[Double]] - private[sql] val fractional = implicitly[Fractional[Double]] - private[sql] val ordering = implicitly[Ordering[JvmType]] - private[sql] val asIntegral = DoubleAsIfIntegral - - /** - * The default size of a value of the DoubleType is 8 bytes. - */ - override def defaultSize: Int = 8 -} - - -/** - * :: DeveloperApi :: - * - * The data type representing `Float` values. Please use the singleton [[DataTypes.FloatType]]. - * - * @group dataType - */ -@DeveloperApi -case object FloatType extends FractionalType { - private[sql] type JvmType = Float - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - private[sql] val numeric = implicitly[Numeric[Float]] - private[sql] val fractional = implicitly[Fractional[Float]] - private[sql] val ordering = implicitly[Ordering[JvmType]] - private[sql] val asIntegral = FloatAsIfIntegral - - /** - * The default size of a value of the FloatType is 4 bytes. - */ - override def defaultSize: Int = 4 -} - - -object ArrayType { - /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ - def apply(elementType: DataType): ArrayType = ArrayType(elementType, true) -} - - -/** - * :: DeveloperApi :: - * - * The data type for collections of multiple values. - * Internally these are represented as columns that contain a ``scala.collection.Seq``. - * - * Please use [[DataTypes.createArrayType()]] to create a specific instance. - * - * An [[ArrayType]] object comprises two fields, `elementType: [[DataType]]` and - * `containsNull: Boolean`. The field of `elementType` is used to specify the type of - * array elements. The field of `containsNull` is used to specify if the array has `null` values. - * - * @param elementType The data type of values. - * @param containsNull Indicates if values have `null` values - * - * @group dataType - */ -@DeveloperApi -case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { - private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { - builder.append( - s"$prefix-- element: ${elementType.typeName} (containsNull = $containsNull)\n") - DataType.buildFormattedString(elementType, s"$prefix |", builder) - } - - override private[sql] def jsonValue = - ("type" -> typeName) ~ - ("elementType" -> elementType.jsonValue) ~ - ("containsNull" -> containsNull) - - /** - * The default size of a value of the ArrayType is 100 * the default size of the element type. - * (We assume that there are 100 elements). - */ - override def defaultSize: Int = 100 * elementType.defaultSize -} - - -/** - * A field inside a StructType. - * - * @param name The name of this field. - * @param dataType The data type of this field. - * @param nullable Indicates if values of this field can be `null` values. - * @param metadata The metadata of this field. The metadata should be preserved during - * transformation if the content of the column is not modified, e.g, in selection. - */ -case class StructField( - name: String, - dataType: DataType, - nullable: Boolean = true, - metadata: Metadata = Metadata.empty) { - - private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { - builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n") - DataType.buildFormattedString(dataType, s"$prefix |", builder) - } - - // override the default toString to be compatible with legacy parquet files. - override def toString: String = s"StructField($name,$dataType,$nullable)" - - private[sql] def jsonValue: JValue = { - ("name" -> name) ~ - ("type" -> dataType.jsonValue) ~ - ("nullable" -> nullable) ~ - ("metadata" -> metadata.jsonValue) - } -} - - -object StructType { - protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = - StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) - - def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) - - def apply(fields: java.util.List[StructField]): StructType = { - StructType(fields.toArray.asInstanceOf[Array[StructField]]) - } -} - - -/** - * :: DeveloperApi :: - * - * A [[StructType]] object can be constructed by - * {{{ - * StructType(fields: Seq[StructField]) - * }}} - * For a [[StructType]] object, one or multiple [[StructField]]s can be extracted by names. - * If multiple [[StructField]]s are extracted, a [[StructType]] object will be returned. - * If a provided name does not have a matching field, it will be ignored. For the case - * of extracting a single StructField, a `null` will be returned. - * Example: - * {{{ - * import org.apache.spark.sql._ - * - * val struct = - * StructType( - * StructField("a", IntegerType, true) :: - * StructField("b", LongType, false) :: - * StructField("c", BooleanType, false) :: Nil) - * - * // Extract a single StructField. - * val singleField = struct("b") - * // singleField: StructField = StructField(b,LongType,false) - * - * // This struct does not have a field called "d". null will be returned. - * val nonExisting = struct("d") - * // nonExisting: StructField = null - * - * // Extract multiple StructFields. Field names are provided in a set. - * // A StructType object will be returned. - * val twoFields = struct(Set("b", "c")) - * // twoFields: StructType = - * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) - * - * // Any names without matching fields will be ignored. - * // For the case shown below, "d" will be ignored and - * // it is treated as struct(Set("b", "c")). - * val ignoreNonExisting = struct(Set("b", "c", "d")) - * // ignoreNonExisting: StructType = - * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) - * }}} - * - * A [[org.apache.spark.sql.Row]] object is used as a value of the StructType. - * Example: - * {{{ - * import org.apache.spark.sql._ - * - * val innerStruct = - * StructType( - * StructField("f1", IntegerType, true) :: - * StructField("f2", LongType, false) :: - * StructField("f3", BooleanType, false) :: Nil) - * - * val struct = StructType( - * StructField("a", innerStruct, true) :: Nil) - * - * // Create a Row with the schema defined by struct - * val row = Row(Row(1, 2, true)) - * // row: Row = [[1,2,true]] - * }}} - * - * @group dataType - */ -@DeveloperApi -case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { - - /** Returns all field names in an array. */ - def fieldNames: Array[String] = fields.map(_.name) - - private lazy val fieldNamesSet: Set[String] = fieldNames.toSet - private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap - - /** - * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not - * have a name matching the given name, `null` will be returned. - */ - def apply(name: String): StructField = { - nameToField.getOrElse(name, throw new IllegalArgumentException(s"Field $name does not exist.")) - } - - /** - * Returns a [[StructType]] containing [[StructField]]s of the given names, preserving the - * original order of fields. Those names which do not have matching fields will be ignored. - */ - def apply(names: Set[String]): StructType = { - val nonExistFields = names -- fieldNamesSet - if (nonExistFields.nonEmpty) { - throw new IllegalArgumentException( - s"Field ${nonExistFields.mkString(",")} does not exist.") - } - // Preserve the original order of fields. - StructType(fields.filter(f => names.contains(f.name))) - } - - protected[sql] def toAttributes: Seq[AttributeReference] = - map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) - - def treeString: String = { - val builder = new StringBuilder - builder.append("root\n") - val prefix = " |" - fields.foreach(field => field.buildFormattedString(prefix, builder)) - - builder.toString() - } - - def printTreeString(): Unit = println(treeString) - - private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { - fields.foreach(field => field.buildFormattedString(prefix, builder)) - } - - override private[sql] def jsonValue = - ("type" -> typeName) ~ - ("fields" -> map(_.jsonValue)) - - override def apply(fieldIndex: Int): StructField = fields(fieldIndex) - - override def length: Int = fields.length - - override def iterator: Iterator[StructField] = fields.iterator - - /** - * The default size of a value of the StructType is the total default sizes of all field types. - */ - override def defaultSize: Int = fields.map(_.dataType.defaultSize).sum -} - - -object MapType { - /** - * Construct a [[MapType]] object with the given key type and value type. - * The `valueContainsNull` is true. - */ - def apply(keyType: DataType, valueType: DataType): MapType = - MapType(keyType: DataType, valueType: DataType, valueContainsNull = true) -} - - -/** - * :: DeveloperApi :: - * - * The data type for Maps. Keys in a map are not allowed to have `null` values. - * - * Please use [[DataTypes.createMapType()]] to create a specific instance. - * - * @param keyType The data type of map keys. - * @param valueType The data type of map values. - * @param valueContainsNull Indicates if map values have `null` values. - * - * @group dataType - */ -case class MapType( - keyType: DataType, - valueType: DataType, - valueContainsNull: Boolean) extends DataType { - private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { - builder.append(s"$prefix-- key: ${keyType.typeName}\n") - builder.append(s"$prefix-- value: ${valueType.typeName} " + - s"(valueContainsNull = $valueContainsNull)\n") - DataType.buildFormattedString(keyType, s"$prefix |", builder) - DataType.buildFormattedString(valueType, s"$prefix |", builder) - } - - override private[sql] def jsonValue: JValue = - ("type" -> typeName) ~ - ("keyType" -> keyType.jsonValue) ~ - ("valueType" -> valueType.jsonValue) ~ - ("valueContainsNull" -> valueContainsNull) - - /** - * The default size of a value of the MapType is - * 100 * (the default size of the key type + the default size of the value type). - * (We assume that there are 100 elements). - */ - override def defaultSize: Int = 100 * (keyType.defaultSize + valueType.defaultSize) -} - - -/** - * ::DeveloperApi:: - * The data type for User Defined Types (UDTs). - * - * This interface allows a user to make their own classes more interoperable with SparkSQL; - * e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create - * a SchemaRDD which has class X in the schema. - * - * For SparkSQL to recognize UDTs, the UDT must be annotated with - * [[SQLUserDefinedType]]. - * - * The conversion via `serialize` occurs when instantiating a `SchemaRDD` from another RDD. - * The conversion via `deserialize` occurs when reading from a `SchemaRDD`. - */ -@DeveloperApi -abstract class UserDefinedType[UserType] extends DataType with Serializable { - - /** Underlying storage type for this UDT */ - def sqlType: DataType - - /** Paired Python UDT class, if exists. */ - def pyUDT: String = null - - /** - * Convert the user type to a SQL datum - * - * TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst, - * where we need to convert Any to UserType. - */ - def serialize(obj: Any): Any - - /** Convert a SQL datum to the user type */ - def deserialize(datum: Any): UserType - - override private[sql] def jsonValue: JValue = { - ("type" -> "udt") ~ - ("class" -> this.getClass.getName) ~ - ("pyClass" -> pyUDT) ~ - ("sqlType" -> sqlType.jsonValue) - } - - /** - * Class object for the UserType - */ - def userClass: java.lang.Class[UserType] - - /** - * The default size of a value of the UserDefinedType is 4096 bytes. - */ - override def defaultSize: Int = 4096 -} diff --git a/sql/catalyst/src/test/resources/log4j.properties b/sql/catalyst/src/test/resources/log4j.properties index 287c8e356350..eb3b1999eb99 100644 --- a/sql/catalyst/src/test/resources/log4j.properties +++ b/sql/catalyst/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala new file mode 100644 index 000000000000..bbb9739e9cc7 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSchema} +import org.apache.spark.sql.types._ +import org.scalatest.{Matchers, FunSpec} + +class RowTest extends FunSpec with Matchers { + + val schema = StructType( + StructField("col1", StringType) :: + StructField("col2", StringType) :: + StructField("col3", IntegerType) :: Nil) + val values = Array("value1", "value2", 1) + + val sampleRow: Row = new GenericRowWithSchema(values, schema) + val noSchemaRow: Row = new GenericRow(values) + + describe("Row (without schema)") { + it("throws an exception when accessing by fieldName") { + intercept[UnsupportedOperationException] { + noSchemaRow.fieldIndex("col1") + } + intercept[UnsupportedOperationException] { + noSchemaRow.getAs("col1") + } + } + } + + describe("Row (with schema)") { + it("fieldIndex(name) returns field index") { + sampleRow.fieldIndex("col1") shouldBe 0 + sampleRow.fieldIndex("col3") shouldBe 2 + } + + it("getAs[T] retrieves a value by fieldname") { + sampleRow.getAs[String]("col1") shouldBe "value1" + sampleRow.getAs[Int]("col3") shouldBe 1 + } + + it("Accessing non existent field throws an exception") { + intercept[IllegalArgumentException] { + sampleRow.getAs[String]("non_existent") + } + } + + it("getValuesMap() retrieves values of multiple fields as a Map(field -> value)") { + val expected = Map( + "col1" -> "value1", + "col2" -> "value2" + ) + sampleRow.getValuesMap(List("col1", "col2")) shouldBe expected + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index 46b2250aab23..ea82cd2622de 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -30,7 +30,7 @@ class DistributionSuite extends FunSuite { inputPartitioning: Partitioning, requiredDistribution: Distribution, satisfied: Boolean) { - if (inputPartitioning.satisfies(requiredDistribution) != satisfied) + if (inputPartitioning.satisfies(requiredDistribution) != satisfied) { fail( s""" |== Input Partitioning == @@ -40,6 +40,7 @@ class DistributionSuite extends FunSuite { |== Does input partitioning satisfy required distribution? == |Expected $satisfied got ${inputPartitioning.satisfies(requiredDistribution)} """.stripMargin) + } } test("HashPartitioning is the output partitioning") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 5138942a55da..bbc0b661a0c0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -60,10 +60,13 @@ case class OptionalData( case class ComplexData( arrayField: Seq[Int], + arrayField1: Array[Int], + arrayField2: List[Int], arrayFieldContainsNull: Seq[java.lang.Integer], mapField: Map[Int, Long], mapFieldValueContainsNull: Map[Int, java.lang.Long], - structField: PrimitiveData) + structField: PrimitiveData, + nestedArrayField: Array[Array[Int]]) case class GenericData[A]( genericField: A) @@ -131,6 +134,14 @@ class ScalaReflectionSuite extends FunSuite { "arrayField", ArrayType(IntegerType, containsNull = false), nullable = true), + StructField( + "arrayField1", + ArrayType(IntegerType, containsNull = false), + nullable = true), + StructField( + "arrayField2", + ArrayType(IntegerType, containsNull = false), + nullable = true), StructField( "arrayFieldContainsNull", ArrayType(IntegerType, containsNull = true), @@ -153,7 +164,10 @@ class ScalaReflectionSuite extends FunSuite { StructField("shortField", ShortType, nullable = false), StructField("byteField", ByteType, nullable = false), StructField("booleanField", BooleanType, nullable = false))), - nullable = true))), + nullable = true), + StructField( + "nestedArrayField", + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)))), nullable = true)) } @@ -246,7 +260,7 @@ class ScalaReflectionSuite extends FunSuite { val data = PrimitiveData(1, 1, 1, 1, 1, 1, true) val convertedData = Row(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) val dataType = schemaFor[PrimitiveData].dataType - assert(convertToCatalyst(data, dataType) === convertedData) + assert(CatalystTypeConverters.convertToCatalyst(data, dataType) === convertedData) } test("convert Option[Product] to catalyst") { @@ -256,7 +270,7 @@ class ScalaReflectionSuite extends FunSuite { val dataType = schemaFor[OptionalData].dataType val convertedData = Row(2, 2.toLong, 2.toDouble, 2.toFloat, 2.toShort, 2.toByte, true, Row(1, 1, 1, 1, 1, 1, true)) - assert(convertToCatalyst(data, dataType) === convertedData) + assert(CatalystTypeConverters.convertToCatalyst(data, dataType) === convertedData) } test("infer schema from case class with multiple constructors") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala index 1a0a0e6154ad..a652c7056099 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala @@ -49,13 +49,14 @@ class SqlParserSuite extends FunSuite { test("test long keyword") { val parser = new SuperLongKeywordTestParser - assert(TestCommand("NotRealCommand") === parser("ThisIsASuperLongKeyWordTest NotRealCommand")) + assert(TestCommand("NotRealCommand") === + parser.parse("ThisIsASuperLongKeyWordTest NotRealCommand")) } test("test case insensitive") { val parser = new CaseInsensitiveTestParser - assert(TestCommand("NotRealCommand") === parser("EXECUTE NotRealCommand")) - assert(TestCommand("NotRealCommand") === parser("execute NotRealCommand")) - assert(TestCommand("NotRealCommand") === parser("exEcute NotRealCommand")) + assert(TestCommand("NotRealCommand") === parser.parse("EXECUTE NotRealCommand")) + assert(TestCommand("NotRealCommand") === parser.parse("execute NotRealCommand")) + assert(TestCommand("NotRealCommand") === parser.parse("exEcute NotRealCommand")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 3aea337460d4..971e1ff5ec2b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.analysis import org.scalatest.{BeforeAndAfter, FunSuite} -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -30,10 +30,22 @@ import org.apache.spark.sql.catalyst.dsl.plans._ class AnalysisSuite extends FunSuite with BeforeAndAfter { val caseSensitiveCatalog = new SimpleCatalog(true) val caseInsensitiveCatalog = new SimpleCatalog(false) - val caseSensitiveAnalyze = - new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitive = true) - val caseInsensitiveAnalyze = - new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false) + + val caseSensitiveAnalyzer = + new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitive = true) { + override val extendedResolutionRules = EliminateSubQueries :: Nil + } + val caseInsensitiveAnalyzer = + new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false) { + override val extendedResolutionRules = EliminateSubQueries :: Nil + } + + + def caseSensitiveAnalyze(plan: LogicalPlan): Unit = + caseSensitiveAnalyzer.checkAnalysis(caseSensitiveAnalyzer.execute(plan)) + + def caseInsensitiveAnalyze(plan: LogicalPlan): Unit = + caseInsensitiveAnalyzer.checkAnalysis(caseInsensitiveAnalyzer.execute(plan)) val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) val testRelation2 = LocalRelation( @@ -43,6 +55,21 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { AttributeReference("d", DecimalType.Unlimited)(), AttributeReference("e", ShortType)()) + val nestedRelation = LocalRelation( + AttributeReference("top", StructType( + StructField("duplicateField", StringType) :: + StructField("duplicateField", StringType) :: + StructField("differentCase", StringType) :: + StructField("differentcase", StringType) :: Nil + ))()) + + val nestedRelation2 = LocalRelation( + AttributeReference("top", StructType( + StructField("aField", StringType) :: + StructField("bField", StringType) :: + StructField("cField", StringType) :: Nil + ))()) + before { caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) @@ -51,37 +78,50 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { test("union project *") { val plan = (1 to 100) .map(_ => testRelation) - .fold[LogicalPlan](testRelation)((a,b) => a.select(Star(None)).select('a).unionAll(b.select(Star(None)))) + .fold[LogicalPlan](testRelation) { (a, b) => + a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None))) + } + + assert(caseInsensitiveAnalyzer.execute(plan).resolved) + } + + test("check project's resolved") { + assert(Project(testRelation.output, testRelation).resolved) + + assert(!Project(Seq(UnresolvedAttribute("a")), testRelation).resolved) + + val explode = Explode(AttributeReference("a", IntegerType, nullable = true)()) + assert(!Project(Seq(Alias(explode, "explode")()), testRelation).resolved) - assert(caseInsensitiveAnalyze(plan).resolved) + assert(!Project(Seq(Alias(Count(Literal(1)), "count")()), testRelation).resolved) } test("analyze project") { assert( - caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("a")), testRelation)) === + caseSensitiveAnalyzer.execute(Project(Seq(UnresolvedAttribute("a")), testRelation)) === Project(testRelation.output, testRelation)) assert( - caseSensitiveAnalyze( + caseSensitiveAnalyzer.execute( Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === Project(testRelation.output, testRelation)) - val e = intercept[TreeNodeException[_]] { + val e = intercept[AnalysisException] { caseSensitiveAnalyze( Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) } - assert(e.getMessage().toLowerCase.contains("unresolved")) + assert(e.getMessage().toLowerCase.contains("cannot resolve")) assert( - caseInsensitiveAnalyze( + caseInsensitiveAnalyzer.execute( Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === Project(testRelation.output, testRelation)) assert( - caseInsensitiveAnalyze( + caseInsensitiveAnalyzer.execute( Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === Project(testRelation.output, testRelation)) @@ -94,36 +134,83 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { assert(e.getMessage == "Table Not Found: tAbLe") assert( - caseSensitiveAnalyze(UnresolvedRelation(Seq("TaBlE"), None)) === - testRelation) + caseSensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) assert( - caseInsensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None)) === - testRelation) + caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("tAbLe"), None)) === testRelation) assert( - caseInsensitiveAnalyze(UnresolvedRelation(Seq("TaBlE"), None)) === - testRelation) + caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) } - test("throw errors for unresolved attributes during analysis") { - val e = intercept[TreeNodeException[_]] { - caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("abcd")), testRelation)) + def errorTest( + name: String, + plan: LogicalPlan, + errorMessages: Seq[String], + caseSensitive: Boolean = true): Unit = { + test(name) { + val error = intercept[AnalysisException] { + if(caseSensitive) { + caseSensitiveAnalyze(plan) + } else { + caseInsensitiveAnalyze(plan) + } + } + + errorMessages.foreach(m => assert(error.getMessage contains m)) } - assert(e.getMessage().toLowerCase.contains("unresolved attribute")) } - test("throw errors for unresolved plans during analysis") { - case class UnresolvedTestPlan() extends LeafNode { - override lazy val resolved = false - override def output = Nil - } - val e = intercept[TreeNodeException[_]] { - caseSensitiveAnalyze(UnresolvedTestPlan()) - } - assert(e.getMessage().toLowerCase.contains("unresolved plan")) + errorTest( + "unresolved attributes", + testRelation.select('abcd), + "cannot resolve" :: "abcd" :: Nil) + + errorTest( + "bad casts", + testRelation.select(Literal(1).cast(BinaryType).as('badCast)), + "invalid cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil) + + errorTest( + "non-boolean filters", + testRelation.where(Literal(1)), + "filter" :: "'1'" :: "not a boolean" :: Literal(1).dataType.simpleString :: Nil) + + errorTest( + "missing group by", + testRelation2.groupBy('a)('b), + "'b'" :: "group by" :: Nil + ) + + errorTest( + "ambiguous field", + nestedRelation.select($"top.duplicateField"), + "Ambiguous reference to fields" :: "duplicateField" :: Nil, + caseSensitive = false) + + errorTest( + "ambiguous field due to case insensitivity", + nestedRelation.select($"top.differentCase"), + "Ambiguous reference to fields" :: "differentCase" :: "differentcase" :: Nil, + caseSensitive = false) + + errorTest( + "missing field", + nestedRelation2.select($"top.c"), + "No such struct field" :: "aField" :: "bField" :: "cField" :: Nil, + caseSensitive = false) + + case class UnresolvedTestPlan() extends LeafNode { + override lazy val resolved = false + override def output: Seq[Attribute] = Nil } + errorTest( + "catch all unresolved plan", + UnresolvedTestPlan(), + "unresolved" :: Nil) + + test("divide should be casted into fractional types") { val testRelation2 = LocalRelation( AttributeReference("a", StringType)(), @@ -132,22 +219,37 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { AttributeReference("d", DecimalType.Unlimited)(), AttributeReference("e", ShortType)()) - val expr0 = 'a / 2 - val expr1 = 'a / 'b - val expr2 = 'a / 'c - val expr3 = 'a / 'd - val expr4 = 'e / 'e - val plan = caseInsensitiveAnalyze(Project( - Alias(expr0, s"Analyzer($expr0)")() :: - Alias(expr1, s"Analyzer($expr1)")() :: - Alias(expr2, s"Analyzer($expr2)")() :: - Alias(expr3, s"Analyzer($expr3)")() :: - Alias(expr4, s"Analyzer($expr4)")() :: Nil, testRelation2)) + val plan = caseInsensitiveAnalyzer.execute( + testRelation2.select( + 'a / Literal(2) as 'div1, + 'a / 'b as 'div2, + 'a / 'c as 'div3, + 'a / 'd as 'div4, + 'e / 'e as 'div5)) val pl = plan.asInstanceOf[Project].projectList + assert(pl(0).dataType == DoubleType) assert(pl(1).dataType == DoubleType) assert(pl(2).dataType == DoubleType) assert(pl(3).dataType == DecimalType.Unlimited) assert(pl(4).dataType == DoubleType) } + + test("SPARK-6452 regression test") { + // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) + val plan = + Aggregate( + Nil, + Alias(Sum(AttributeReference("a", StringType)(exprId = ExprId(1))), "b")() :: Nil, + LocalRelation( + AttributeReference("a", StringType)(exprId = ExprId(2)))) + + assert(plan.resolved) + + val message = intercept[AnalysisException] { + caseSensitiveAnalyze(plan) + }.getMessage + + assert(message.contains("resolved attribute(s) a#1 missing from a#2")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index bc2ec754d586..36b03d1c65e2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} +import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation} import org.apache.spark.sql.types._ import org.scalatest.{BeforeAndAfter, FunSuite} @@ -31,7 +31,8 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { AttributeReference("d1", DecimalType(2, 1))(), AttributeReference("d2", DecimalType(5, 2))(), AttributeReference("u", DecimalType.Unlimited)(), - AttributeReference("f", FloatType)() + AttributeReference("f", FloatType)(), + AttributeReference("b", DoubleType)() ) val i: Expression = UnresolvedAttribute("i") @@ -39,6 +40,7 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { val d2: Expression = UnresolvedAttribute("d2") val u: Expression = UnresolvedAttribute("u") val f: Expression = UnresolvedAttribute("f") + val b: Expression = UnresolvedAttribute("b") before { catalog.registerTable(Seq("table"), relation) @@ -46,18 +48,29 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { private def checkType(expression: Expression, expectedType: DataType): Unit = { val plan = Project(Seq(Alias(expression, "c")()), relation) - assert(analyzer(plan).schema.fields(0).dataType === expectedType) + assert(analyzer.execute(plan).schema.fields(0).dataType === expectedType) } private def checkComparison(expression: Expression, expectedType: DataType): Unit = { val plan = Project(Alias(expression, "c")() :: Nil, relation) - val comparison = analyzer(plan).collect { + val comparison = analyzer.execute(plan).collect { case Project(Alias(e: BinaryComparison, _) :: Nil, _) => e }.head assert(comparison.left.dataType === expectedType) assert(comparison.right.dataType === expectedType) } + private def checkUnion(left: Expression, right: Expression, expectedType: DataType): Unit = { + val plan = + Union(Project(Seq(Alias(left, "l")()), relation), + Project(Seq(Alias(right, "r")()), relation)) + val (l, r) = analyzer.execute(plan).collect { + case Union(left, right) => (left.output.head, right.output.head) + }.head + assert(l.dataType === expectedType) + assert(r.dataType === expectedType) + } + test("basic operations") { checkType(Add(d1, d2), DecimalType(6, 2)) checkType(Subtract(d1, d2), DecimalType(6, 2)) @@ -82,6 +95,19 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { checkComparison(GreaterThan(d2, d2), DecimalType(5, 2)) } + test("decimal precision for union") { + checkUnion(d1, i, DecimalType(11, 1)) + checkUnion(i, d2, DecimalType(12, 2)) + checkUnion(d1, d2, DecimalType(5, 2)) + checkUnion(d2, d1, DecimalType(5, 2)) + checkUnion(d1, f, DecimalType(8, 7)) + checkUnion(f, d2, DecimalType(10, 7)) + checkUnion(d1, b, DecimalType(16, 15)) + checkUnion(b, d2, DecimalType(18, 15)) + checkUnion(d1, u, DecimalType.Unlimited) + checkUnion(u, d2, DecimalType.Unlimited) + } + test("bringing in primitive types") { checkType(Add(d1, i), DecimalType(12, 1)) checkType(Add(d1, f), DoubleType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index f5a502b43f80..fcd745f43cfb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.types._ -class HiveTypeCoercionSuite extends FunSuite { +class HiveTypeCoercionSuite extends PlanTest { test("tightest common bound for types") { def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) { @@ -96,7 +96,9 @@ class HiveTypeCoercionSuite extends FunSuite { widenTest(StringType, TimestampType, None) // ComplexType - widenTest(NullType, MapType(IntegerType, StringType, false), Some(MapType(IntegerType, StringType, false))) + widenTest(NullType, + MapType(IntegerType, StringType, false), + Some(MapType(IntegerType, StringType, false))) widenTest(NullType, StructType(Seq()), Some(StructType(Seq()))) widenTest(StringType, MapType(IntegerType, StringType, true), None) widenTest(ArrayType(IntegerType), StructType(Seq()), None) @@ -106,12 +108,43 @@ class HiveTypeCoercionSuite extends FunSuite { val booleanCasts = new HiveTypeCoercion { }.BooleanCasts def ruleTest(initial: Expression, transformed: Expression) { val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) - assert(booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)) == + comparePlans( + booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)), Project(Seq(Alias(transformed, "a")()), testRelation)) } // Remove superflous boolean -> boolean casts. ruleTest(Cast(Literal(true), BooleanType), Literal(true)) // Stringify boolean when casting to string. - ruleTest(Cast(Literal(false), StringType), If(Literal(false), Literal("true"), Literal("false"))) + ruleTest( + Cast(Literal(false), StringType), + If(Literal(false), Literal("true"), Literal("false"))) + } + + test("coalesce casts") { + val fac = new HiveTypeCoercion { }.FunctionArgumentConversion + def ruleTest(initial: Expression, transformed: Expression) { + val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) + comparePlans( + fac(Project(Seq(Alias(initial, "a")()), testRelation)), + Project(Seq(Alias(transformed, "a")()), testRelation)) + } + ruleTest( + Coalesce(Literal(1.0) + :: Literal(1) + :: Literal.create(1.0, FloatType) + :: Nil), + Coalesce(Cast(Literal(1.0), DoubleType) + :: Cast(Literal(1), DoubleType) + :: Cast(Literal.create(1.0, FloatType), DoubleType) + :: Nil)) + ruleTest( + Coalesce(Literal(1L) + :: Literal(1) + :: Literal(new java.math.BigDecimal("1000000000000000000000")) + :: Nil), + Coalesce(Cast(Literal(1L), DecimalType()) + :: Cast(Literal(1), DecimalType()) + :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType()) + :: Nil)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala new file mode 100644 index 000000000000..f2f3a84d1938 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.scalatest.FunSuite + +import org.apache.spark.sql.types.IntegerType + +class AttributeSetSuite extends FunSuite { + + val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1)) + val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1)) + val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3)) + val aSet = AttributeSet(aLower :: Nil) + + val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2)) + val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2)) + val bSet = AttributeSet(bUpper :: Nil) + + val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil) + + test("sanity check") { + assert(aUpper != aLower) + assert(bUpper != bLower) + } + + test("checks by id not name") { + assert(aSet.contains(aUpper) === true) + assert(aSet.contains(aLower) === true) + assert(aSet.contains(fakeA) === false) + + assert(aSet.contains(bUpper) === false) + assert(aSet.contains(bLower) === false) + } + + test("++ preserves AttributeSet") { + assert((aSet ++ bSet).contains(aUpper) === true) + assert((aSet ++ bSet).contains(aLower) === true) + } + + test("extracts all references references") { + val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil) + assert(addSet.contains(aUpper)) + assert(addSet.contains(aLower)) + assert(addSet.contains(bUpper)) + assert(addSet.contains(bLower)) + } + + test("dedups attributes") { + assert(AttributeSet(aUpper :: aLower :: Nil).size === 1) + } + + test("subset") { + assert(aSet.subsetOf(aAndBSet) === true) + assert(aAndBSet.subsetOf(aSet) === false) + } + + test("equality") { + assert(aSet != aAndBSet) + assert(aAndBSet != aSet) + assert(aSet != bSet) + assert(bSet != aSet) + + assert(aSet == aSet) + assert(aSet == AttributeSet(aUpper :: Nil)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 37e64adeea85..76298f03c94a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -25,11 +25,44 @@ import org.scalactic.TripleEqualsSupport.Spread import org.scalatest.FunSuite import org.scalatest.Matchers._ +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ -class ExpressionEvaluationSuite extends FunSuite { +class ExpressionEvaluationBaseSuite extends FunSuite { + + def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { + expression.eval(inputRow) + } + + def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { + val actual = try evaluate(expression, inputRow) catch { + case e: Exception => fail(s"Exception evaluating $expression", e) + } + if(actual != expected) { + val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + + def checkDoubleEvaluation( + expression: Expression, + expected: Spread[Double], + inputRow: Row = EmptyRow): Unit = { + val actual = try evaluate(expression, inputRow) catch { + case e: Exception => fail(s"Exception evaluating $expression", e) + } + actual.asInstanceOf[Double] shouldBe expected + } +} + +class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { + + def create_row(values: Any*): Row = { + new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray) + } test("literals") { checkEvaluation(Literal(1), 1) @@ -54,10 +87,13 @@ class ExpressionEvaluationSuite extends FunSuite { assert(BitwiseNot(1.toByte).eval(EmptyRow).isInstanceOf[Byte]) } + // scalastyle:off /** * Checks for three-valued-logic. Based on: * http://en.wikipedia.org/wiki/Null_(SQL)#Comparisons_with_NULL_and_the_three-valued_logic_.283VL.29 - * I.e. in flat cpo "False -> Unknown -> True", OR is lowest upper bound, AND is greatest lower bound. + * I.e. in flat cpo "False -> Unknown -> True", + * OR is lowest upper bound, + * AND is greatest lower bound. * p q p OR q p AND q p = q * True True True True True * True False True False False @@ -74,7 +110,7 @@ class ExpressionEvaluationSuite extends FunSuite { * False True * Unknown Unknown */ - + // scalastyle:on val notTrueTable = (true, false) :: (false, true) :: @@ -83,7 +119,7 @@ class ExpressionEvaluationSuite extends FunSuite { test("3VL Not") { notTrueTable.foreach { case (v, answer) => - checkEvaluation(!Literal(v, BooleanType), answer) + checkEvaluation(!Literal.create(v, BooleanType), answer) } } @@ -127,38 +163,19 @@ class ExpressionEvaluationSuite extends FunSuite { test(s"3VL $name") { truthTable.foreach { case (l,r,answer) => - val expr = op(Literal(l, BooleanType), Literal(r, BooleanType)) + val expr = op(Literal.create(l, BooleanType), Literal.create(r, BooleanType)) checkEvaluation(expr, answer) } } } - def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { - expression.eval(inputRow) - } - - def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { - case e: Exception => fail(s"Exception evaluating $expression", e) - } - if(actual != expected) { - val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") - } - } - - def checkDoubleEvaluation(expression: Expression, expected: Spread[Double], inputRow: Row = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { - case e: Exception => fail(s"Exception evaluating $expression", e) - } - actual.asInstanceOf[Double] shouldBe expected - } - test("IN") { checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) - checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true) + checkEvaluation( + In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), + true) } test("Divide") { @@ -168,12 +185,13 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Divide(Literal(1), Literal(0)), null) checkEvaluation(Divide(Literal(1.0), Literal(0.0)), null) checkEvaluation(Divide(Literal(0.0), Literal(0.0)), null) - checkEvaluation(Divide(Literal(0), Literal(null, IntegerType)), null) - checkEvaluation(Divide(Literal(1), Literal(null, IntegerType)), null) - checkEvaluation(Divide(Literal(null, IntegerType), Literal(0)), null) - checkEvaluation(Divide(Literal(null, DoubleType), Literal(0.0)), null) - checkEvaluation(Divide(Literal(null, IntegerType), Literal(1)), null) - checkEvaluation(Divide(Literal(null, IntegerType), Literal(null, IntegerType)), null) + checkEvaluation(Divide(Literal(0), Literal.create(null, IntegerType)), null) + checkEvaluation(Divide(Literal(1), Literal.create(null, IntegerType)), null) + checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(0)), null) + checkEvaluation(Divide(Literal.create(null, DoubleType), Literal(0.0)), null) + checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(Divide(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), + null) } test("Remainder") { @@ -183,12 +201,13 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Remainder(Literal(1), Literal(0)), null) checkEvaluation(Remainder(Literal(1.0), Literal(0.0)), null) checkEvaluation(Remainder(Literal(0.0), Literal(0.0)), null) - checkEvaluation(Remainder(Literal(0), Literal(null, IntegerType)), null) - checkEvaluation(Remainder(Literal(1), Literal(null, IntegerType)), null) - checkEvaluation(Remainder(Literal(null, IntegerType), Literal(0)), null) - checkEvaluation(Remainder(Literal(null, DoubleType), Literal(0.0)), null) - checkEvaluation(Remainder(Literal(null, IntegerType), Literal(1)), null) - checkEvaluation(Remainder(Literal(null, IntegerType), Literal(null, IntegerType)), null) + checkEvaluation(Remainder(Literal(0), Literal.create(null, IntegerType)), null) + checkEvaluation(Remainder(Literal(1), Literal.create(null, IntegerType)), null) + checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(0)), null) + checkEvaluation(Remainder(Literal.create(null, DoubleType), Literal(0.0)), null) + checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), + null) } test("INSET") { @@ -215,14 +234,24 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(MaxOf(1L, 2L), 2L) checkEvaluation(MaxOf(2L, 1L), 2L) - checkEvaluation(MaxOf(Literal(null, IntegerType), 2), 2) - checkEvaluation(MaxOf(2, Literal(null, IntegerType)), 2) + checkEvaluation(MaxOf(Literal.create(null, IntegerType), 2), 2) + checkEvaluation(MaxOf(2, Literal.create(null, IntegerType)), 2) + } + + test("MinOf") { + checkEvaluation(MinOf(1, 2), 1) + checkEvaluation(MinOf(2, 1), 1) + checkEvaluation(MinOf(1L, 2L), 1L) + checkEvaluation(MinOf(2L, 1L), 1L) + + checkEvaluation(MinOf(Literal.create(null, IntegerType), 1), 1) + checkEvaluation(MinOf(1, Literal.create(null, IntegerType)), 1) } test("LIKE literal Regular Expression") { - checkEvaluation(Literal(null, StringType).like("a"), null) - checkEvaluation(Literal("a", StringType).like(Literal(null, StringType)), null) - checkEvaluation(Literal(null, StringType).like(Literal(null, StringType)), null) + checkEvaluation(Literal.create(null, StringType).like("a"), null) + checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null) + checkEvaluation(Literal.create(null, StringType).like(Literal.create(null, StringType)), null) checkEvaluation("abdef" like "abdef", true) checkEvaluation("a_%b" like "a\\__b", true) checkEvaluation("addb" like "a_%b", true) @@ -241,29 +270,29 @@ class ExpressionEvaluationSuite extends FunSuite { test("LIKE Non-literal Regular Expression") { val regEx = 'a.string.at(0) - checkEvaluation("abcd" like regEx, null, new GenericRow(Array[Any](null))) - checkEvaluation("abdef" like regEx, true, new GenericRow(Array[Any]("abdef"))) - checkEvaluation("a_%b" like regEx, true, new GenericRow(Array[Any]("a\\__b"))) - checkEvaluation("addb" like regEx, true, new GenericRow(Array[Any]("a_%b"))) - checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("a\\__b"))) - checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("a%\\%b"))) - checkEvaluation("a_%b" like regEx, true, new GenericRow(Array[Any]("a%\\%b"))) - checkEvaluation("addb" like regEx, true, new GenericRow(Array[Any]("a%"))) - checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("**"))) - checkEvaluation("abc" like regEx, true, new GenericRow(Array[Any]("a%"))) - checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("b%"))) - checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("bc%"))) - checkEvaluation("a\nb" like regEx, true, new GenericRow(Array[Any]("a_b"))) - checkEvaluation("ab" like regEx, true, new GenericRow(Array[Any]("a%b"))) - checkEvaluation("a\nb" like regEx, true, new GenericRow(Array[Any]("a%b"))) - - checkEvaluation(Literal(null, StringType) like regEx, null, new GenericRow(Array[Any]("bc%"))) + checkEvaluation("abcd" like regEx, null, create_row(null)) + checkEvaluation("abdef" like regEx, true, create_row("abdef")) + checkEvaluation("a_%b" like regEx, true, create_row("a\\__b")) + checkEvaluation("addb" like regEx, true, create_row("a_%b")) + checkEvaluation("addb" like regEx, false, create_row("a\\__b")) + checkEvaluation("addb" like regEx, false, create_row("a%\\%b")) + checkEvaluation("a_%b" like regEx, true, create_row("a%\\%b")) + checkEvaluation("addb" like regEx, true, create_row("a%")) + checkEvaluation("addb" like regEx, false, create_row("**")) + checkEvaluation("abc" like regEx, true, create_row("a%")) + checkEvaluation("abc" like regEx, false, create_row("b%")) + checkEvaluation("abc" like regEx, false, create_row("bc%")) + checkEvaluation("a\nb" like regEx, true, create_row("a_b")) + checkEvaluation("ab" like regEx, true, create_row("a%b")) + checkEvaluation("a\nb" like regEx, true, create_row("a%b")) + + checkEvaluation(Literal.create(null, StringType) like regEx, null, create_row("bc%")) } test("RLIKE literal Regular Expression") { - checkEvaluation(Literal(null, StringType) rlike "abdef", null) - checkEvaluation("abdef" rlike Literal(null, StringType), null) - checkEvaluation(Literal(null, StringType) rlike Literal(null, StringType), null) + checkEvaluation(Literal.create(null, StringType) rlike "abdef", null) + checkEvaluation("abdef" rlike Literal.create(null, StringType), null) + checkEvaluation(Literal.create(null, StringType) rlike Literal.create(null, StringType), null) checkEvaluation("abdef" rlike "abdef", true) checkEvaluation("abbbbc" rlike "a.*c", true) @@ -288,14 +317,14 @@ class ExpressionEvaluationSuite extends FunSuite { test("RLIKE Non-literal Regular Expression") { val regEx = 'a.string.at(0) - checkEvaluation("abdef" rlike regEx, true, new GenericRow(Array[Any]("abdef"))) - checkEvaluation("abbbbc" rlike regEx, true, new GenericRow(Array[Any]("a.*c"))) - checkEvaluation("fofo" rlike regEx, true, new GenericRow(Array[Any]("^fo"))) - checkEvaluation("fo\no" rlike regEx, true, new GenericRow(Array[Any]("^fo\no$"))) - checkEvaluation("Bn" rlike regEx, true, new GenericRow(Array[Any]("^Ba*n"))) + checkEvaluation("abdef" rlike regEx, true, create_row("abdef")) + checkEvaluation("abbbbc" rlike regEx, true, create_row("a.*c")) + checkEvaluation("fofo" rlike regEx, true, create_row("^fo")) + checkEvaluation("fo\no" rlike regEx, true, create_row("^fo\no$")) + checkEvaluation("Bn" rlike regEx, true, create_row("^Ba*n")) intercept[java.util.regex.PatternSyntaxException] { - evaluate("abbbbc" rlike regEx, new GenericRow(Array[Any]("**"))) + evaluate("abbbbc" rlike regEx, create_row("**")) } } @@ -303,6 +332,7 @@ class ExpressionEvaluationSuite extends FunSuite { val sd = "1970-01-01" val d = Date.valueOf(sd) + val zts = sd + " 00:00:00" val sts = sd + " 00:00:02" val nts = sts + ".1" val ts = Timestamp.valueOf(nts) @@ -319,14 +349,14 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) checkEvaluation(Cast(Literal(sd) cast DateType, StringType), sd) - checkEvaluation(Cast(Literal(d) cast StringType, DateType), d) + checkEvaluation(Cast(Literal(d) cast StringType, DateType), 0) checkEvaluation(Cast(Literal(nts) cast TimestampType, StringType), nts) checkEvaluation(Cast(Literal(ts) cast StringType, TimestampType), ts) // all convert to string type to check checkEvaluation( Cast(Cast(Literal(nts) cast TimestampType, DateType), StringType), sd) checkEvaluation( - Cast(Cast(Literal(ts) cast DateType, TimestampType), StringType), sts) + Cast(Cast(Literal(ts) cast DateType, TimestampType), StringType), zts) checkEvaluation(Cast("abdef" cast BinaryType, StringType), "abdef") @@ -373,12 +403,12 @@ class ExpressionEvaluationSuite extends FunSuite { assert(("abcdef" cast DoubleType).nullable === true) assert(("abcdef" cast FloatType).nullable === true) - checkEvaluation(Cast(Literal(null, IntegerType), ShortType), null) + checkEvaluation(Cast(Literal.create(null, IntegerType), ShortType), null) } test("date") { - val d1 = Date.valueOf("1970-01-01") - val d2 = Date.valueOf("1970-01-02") + val d1 = DateUtils.fromJavaDate(Date.valueOf("1970-01-01")) + val d2 = DateUtils.fromJavaDate(Date.valueOf("1970-01-02")) checkEvaluation(Literal(d1) < Literal(d2), true) } @@ -459,22 +489,21 @@ class ExpressionEvaluationSuite extends FunSuite { test("date casting") { val d = Date.valueOf("1970-01-01") - checkEvaluation(Cast(d, ShortType), null) - checkEvaluation(Cast(d, IntegerType), null) - checkEvaluation(Cast(d, LongType), null) - checkEvaluation(Cast(d, FloatType), null) - checkEvaluation(Cast(d, DoubleType), null) - checkEvaluation(Cast(d, DecimalType.Unlimited), null) - checkEvaluation(Cast(d, DecimalType(10, 2)), null) - checkEvaluation(Cast(d, StringType), "1970-01-01") - checkEvaluation(Cast(Cast(d, TimestampType), StringType), "1970-01-01 00:00:00") + checkEvaluation(Cast(Literal(d), ShortType), null) + checkEvaluation(Cast(Literal(d), IntegerType), null) + checkEvaluation(Cast(Literal(d), LongType), null) + checkEvaluation(Cast(Literal(d), FloatType), null) + checkEvaluation(Cast(Literal(d), DoubleType), null) + checkEvaluation(Cast(Literal(d), DecimalType.Unlimited), null) + checkEvaluation(Cast(Literal(d), DecimalType(10, 2)), null) + checkEvaluation(Cast(Literal(d), StringType), "1970-01-01") + checkEvaluation(Cast(Cast(Literal(d), TimestampType), StringType), "1970-01-01 00:00:00") } test("timestamp casting") { val millis = 15 * 1000 + 2 val seconds = millis * 1000 + 2 val ts = new Timestamp(millis) - val ts1 = new Timestamp(15 * 1000) // a timestamp without the milliseconds part val tss = new Timestamp(seconds) checkEvaluation(Cast(ts, ShortType), 15) checkEvaluation(Cast(ts, IntegerType), 15) @@ -500,8 +529,10 @@ class ExpressionEvaluationSuite extends FunSuite { } test("array casting") { - val array = Literal(Seq("123", "abc", "", null), ArrayType(StringType, containsNull = true)) - val array_notNull = Literal(Seq("123", "abc", ""), ArrayType(StringType, containsNull = false)) + val array = Literal.create(Seq("123", "abc", "", null), + ArrayType(StringType, containsNull = true)) + val array_notNull = Literal.create(Seq("123", "abc", ""), + ArrayType(StringType, containsNull = false)) { val cast = Cast(array, ArrayType(IntegerType, containsNull = true)) @@ -549,10 +580,10 @@ class ExpressionEvaluationSuite extends FunSuite { } test("map casting") { - val map = Literal( + val map = Literal.create( Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null), MapType(StringType, StringType, valueContainsNull = true)) - val map_notNull = Literal( + val map_notNull = Literal.create( Map("a" -> "123", "b" -> "abc", "c" -> ""), MapType(StringType, StringType, valueContainsNull = false)) @@ -610,14 +641,14 @@ class ExpressionEvaluationSuite extends FunSuite { } test("struct casting") { - val struct = Literal( + val struct = Literal.create( Row("123", "abc", "", null), StructType(Seq( StructField("a", StringType, nullable = true), StructField("b", StringType, nullable = true), StructField("c", StringType, nullable = true), StructField("d", StringType, nullable = true)))) - val struct_notNull = Literal( + val struct_notNull = Literal.create( Row("123", "abc", ""), StructType(Seq( StructField("a", StringType, nullable = false), @@ -705,7 +736,7 @@ class ExpressionEvaluationSuite extends FunSuite { } test("complex casting") { - val complex = Literal( + val complex = Literal.create( Row( Seq("123", "abc", ""), Map("a" -> "123", "b" -> "abc", "c" -> ""), @@ -736,7 +767,7 @@ class ExpressionEvaluationSuite extends FunSuite { } test("null checking") { - val row = new GenericRow(Array[Any]("^Ba*n", null, true, null)) + val row = create_row("^Ba*n", null, true, null) val c1 = 'a.string.at(0) val c2 = 'a.string.at(1) val c3 = 'a.boolean.at(2) @@ -748,34 +779,35 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(c2.isNull, true, row) checkEvaluation(c2.isNotNull, false, row) - checkEvaluation(Literal(1, ShortType).isNull, false) - checkEvaluation(Literal(1, ShortType).isNotNull, true) + checkEvaluation(Literal.create(1, ShortType).isNull, false) + checkEvaluation(Literal.create(1, ShortType).isNotNull, true) - checkEvaluation(Literal(null, ShortType).isNull, true) - checkEvaluation(Literal(null, ShortType).isNotNull, false) + checkEvaluation(Literal.create(null, ShortType).isNull, true) + checkEvaluation(Literal.create(null, ShortType).isNotNull, false) checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row) - checkEvaluation(Coalesce(Literal(null, StringType) :: Nil), null, row) - checkEvaluation(Coalesce(Literal(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row) + checkEvaluation(Coalesce(Literal.create(null, StringType) :: Nil), null, row) + checkEvaluation(Coalesce(Literal.create(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row) - checkEvaluation(If(c3, Literal("a", StringType), Literal("b", StringType)), "a", row) + checkEvaluation( + If(c3, Literal.create("a", StringType), Literal.create("b", StringType)), "a", row) checkEvaluation(If(c3, c1, c2), "^Ba*n", row) checkEvaluation(If(c4, c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal(null, BooleanType), c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal(true, BooleanType), c1, c2), "^Ba*n", row) - checkEvaluation(If(Literal(false, BooleanType), c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal(false, BooleanType), - Literal("a", StringType), Literal("b", StringType)), "b", row) + checkEvaluation(If(Literal.create(null, BooleanType), c2, c1), "^Ba*n", row) + checkEvaluation(If(Literal.create(true, BooleanType), c1, c2), "^Ba*n", row) + checkEvaluation(If(Literal.create(false, BooleanType), c2, c1), "^Ba*n", row) + checkEvaluation(If(Literal.create(false, BooleanType), + Literal.create("a", StringType), Literal.create("b", StringType)), "b", row) checkEvaluation(c1 in (c1, c2), true, row) checkEvaluation( - Literal("^Ba*n", StringType) in (Literal("^Ba*n", StringType)), true, row) + Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType)), true, row) checkEvaluation( - Literal("^Ba*n", StringType) in (Literal("^Ba*n", StringType), c2), true, row) + Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType), c2), true, row) } test("case when") { - val row = new GenericRow(Array[Any](null, false, true, "a", "b", "c")) + val row = create_row(null, false, true, "a", "b", "c") val c1 = 'a.boolean.at(0) val c2 = 'a.boolean.at(1) val c3 = 'a.boolean.at(2) @@ -786,9 +818,9 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(CaseWhen(Seq(c1, c4, c6)), "c", row) checkEvaluation(CaseWhen(Seq(c2, c4, c6)), "c", row) checkEvaluation(CaseWhen(Seq(c3, c4, c6)), "a", row) - checkEvaluation(CaseWhen(Seq(Literal(null, BooleanType), c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(Literal(false, BooleanType), c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(Literal(true, BooleanType), c4, c6)), "a", row) + checkEvaluation(CaseWhen(Seq(Literal.create(null, BooleanType), c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(Literal.create(false, BooleanType), c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(Literal.create(true, BooleanType), c4, c6)), "a", row) checkEvaluation(CaseWhen(Seq(c3, c4, c2, c5, c6)), "a", row) checkEvaluation(CaseWhen(Seq(c2, c4, c3, c5, c6)), "b", row) @@ -818,13 +850,13 @@ class ExpressionEvaluationSuite extends FunSuite { } test("complex type") { - val row = new GenericRow(Array[Any]( - "^Ba*n", // 0 - null.asInstanceOf[String], // 1 - new GenericRow(Array[Any]("aa", "bb")), // 2 - Map("aa"->"bb"), // 3 - Seq("aa", "bb") // 4 - )) + val row = create_row( + "^Ba*n", // 0 + null.asInstanceOf[UTF8String], // 1 + create_row("aa", "bb"), // 2 + Map("aa"->"bb"), // 3 + Seq("aa", "bb") // 4 + ) val typeS = StructType( StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil @@ -834,52 +866,68 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(GetItem(BoundReference(3, typeMap, true), Literal("aa")), "bb", row) - checkEvaluation(GetItem(Literal(null, typeMap), Literal("aa")), null, row) - checkEvaluation(GetItem(Literal(null, typeMap), Literal(null, StringType)), null, row) + checkEvaluation(GetItem(Literal.create(null, typeMap), Literal("aa")), null, row) + checkEvaluation( + GetItem(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row) checkEvaluation(GetItem(BoundReference(3, typeMap, true), - Literal(null, StringType)), null, row) + Literal.create(null, StringType)), null, row) checkEvaluation(GetItem(BoundReference(4, typeArray, true), Literal(1)), "bb", row) - checkEvaluation(GetItem(Literal(null, typeArray), Literal(1)), null, row) - checkEvaluation(GetItem(Literal(null, typeArray), Literal(null, IntegerType)), null, row) + checkEvaluation(GetItem(Literal.create(null, typeArray), Literal(1)), null, row) + checkEvaluation( + GetItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row) checkEvaluation(GetItem(BoundReference(4, typeArray, true), - Literal(null, IntegerType)), null, row) + Literal.create(null, IntegerType)), null, row) + + def quickBuildGetField(expr: Expression, fieldName: String): StructGetField = { + expr.dataType match { + case StructType(fields) => + val field = fields.find(_.name == fieldName).get + StructGetField(expr, field, fields.indexOf(field)) + } + } + + def quickResolve(u: UnresolvedGetField): StructGetField = { + quickBuildGetField(u.child, u.fieldName) + } - checkEvaluation(GetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row) - checkEvaluation(GetField(Literal(null, typeS), "a"), null, row) + checkEvaluation(quickBuildGetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row) + checkEvaluation(quickBuildGetField(Literal.create(null, typeS), "a"), null, row) val typeS_notNullable = StructType( StructField("a", StringType, nullable = false) :: StructField("b", StringType, nullable = false) :: Nil ) - assert(GetField(BoundReference(2,typeS, nullable = true), "a").nullable === true) - assert(GetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false) + assert(quickBuildGetField(BoundReference(2,typeS, nullable = true), "a").nullable === true) + assert(quickBuildGetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable + === false) - assert(GetField(Literal(null, typeS), "a").nullable === true) - assert(GetField(Literal(null, typeS_notNullable), "a").nullable === true) + assert(quickBuildGetField(Literal.create(null, typeS), "a").nullable === true) + assert(quickBuildGetField(Literal.create(null, typeS_notNullable), "a").nullable === true) checkEvaluation('c.map(typeMap).at(3).getItem("aa"), "bb", row) checkEvaluation('c.array(typeArray.elementType).at(4).getItem(1), "bb", row) - checkEvaluation('c.struct(typeS).at(2).getField("a"), "aa", row) + checkEvaluation(quickResolve('c.struct(typeS).at(2).getField("a")), "aa", row) } test("arithmetic") { - val row = new GenericRow(Array[Any](1, 2, 3, null)) + val row = create_row(1, 2, 3, null) val c1 = 'a.int.at(0) val c2 = 'a.int.at(1) val c3 = 'a.int.at(2) val c4 = 'a.int.at(3) checkEvaluation(UnaryMinus(c1), -1, row) - checkEvaluation(UnaryMinus(Literal(100, IntegerType)), -100) + checkEvaluation(UnaryMinus(Literal.create(100, IntegerType)), -100) checkEvaluation(Add(c1, c4), null, row) checkEvaluation(Add(c1, c2), 3, row) - checkEvaluation(Add(c1, Literal(null, IntegerType)), null, row) - checkEvaluation(Add(Literal(null, IntegerType), c2), null, row) - checkEvaluation(Add(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(Add(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation(Add(Literal.create(null, IntegerType), c2), null, row) + checkEvaluation( + Add(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(-c1, -1, row) checkEvaluation(c1 + c2, 3, row) @@ -890,19 +938,20 @@ class ExpressionEvaluationSuite extends FunSuite { } test("fractional arithmetic") { - val row = new GenericRow(Array[Any](1.1, 2.0, 3.1, null)) + val row = create_row(1.1, 2.0, 3.1, null) val c1 = 'a.double.at(0) val c2 = 'a.double.at(1) val c3 = 'a.double.at(2) val c4 = 'a.double.at(3) checkEvaluation(UnaryMinus(c1), -1.1, row) - checkEvaluation(UnaryMinus(Literal(100.0, DoubleType)), -100.0) + checkEvaluation(UnaryMinus(Literal.create(100.0, DoubleType)), -100.0) checkEvaluation(Add(c1, c4), null, row) checkEvaluation(Add(c1, c2), 3.1, row) - checkEvaluation(Add(c1, Literal(null, DoubleType)), null, row) - checkEvaluation(Add(Literal(null, DoubleType), c2), null, row) - checkEvaluation(Add(Literal(null, DoubleType), Literal(null, DoubleType)), null, row) + checkEvaluation(Add(c1, Literal.create(null, DoubleType)), null, row) + checkEvaluation(Add(Literal.create(null, DoubleType), c2), null, row) + checkEvaluation( + Add(Literal.create(null, DoubleType), Literal.create(null, DoubleType)), null, row) checkEvaluation(-c1, -1.1, row) checkEvaluation(c1 + c2, 3.1, row) @@ -913,7 +962,7 @@ class ExpressionEvaluationSuite extends FunSuite { } test("BinaryComparison") { - val row = new GenericRow(Array[Any](1, 2, 3, null, 3, null)) + val row = create_row(1, 2, 3, null, 3, null) val c1 = 'a.int.at(0) val c2 = 'a.int.at(1) val c3 = 'a.int.at(2) @@ -923,9 +972,10 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(LessThan(c1, c4), null, row) checkEvaluation(LessThan(c1, c2), true, row) - checkEvaluation(LessThan(c1, Literal(null, IntegerType)), null, row) - checkEvaluation(LessThan(Literal(null, IntegerType), c2), null, row) - checkEvaluation(LessThan(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(LessThan(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation(LessThan(Literal.create(null, IntegerType), c2), null, row) + checkEvaluation( + LessThan(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(c1 < c2, true, row) checkEvaluation(c1 <= c2, true, row) @@ -937,85 +987,115 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(c1 <=> c4, false, row) checkEvaluation(c4 <=> c6, true, row) checkEvaluation(c3 <=> c5, true, row) - checkEvaluation(Literal(true) <=> Literal(null, BooleanType), false, row) - checkEvaluation(Literal(null, BooleanType) <=> Literal(true), false, row) + checkEvaluation(Literal(true) <=> Literal.create(null, BooleanType), false, row) + checkEvaluation(Literal.create(null, BooleanType) <=> Literal(true), false, row) } test("StringComparison") { - val row = new GenericRow(Array[Any]("abc", null)) + val row = create_row("abc", null) val c1 = 'a.string.at(0) val c2 = 'a.string.at(1) checkEvaluation(c1 contains "b", true, row) checkEvaluation(c1 contains "x", false, row) checkEvaluation(c2 contains "b", null, row) - checkEvaluation(c1 contains Literal(null, StringType), null, row) + checkEvaluation(c1 contains Literal.create(null, StringType), null, row) checkEvaluation(c1 startsWith "a", true, row) checkEvaluation(c1 startsWith "b", false, row) checkEvaluation(c2 startsWith "a", null, row) - checkEvaluation(c1 startsWith Literal(null, StringType), null, row) + checkEvaluation(c1 startsWith Literal.create(null, StringType), null, row) checkEvaluation(c1 endsWith "c", true, row) checkEvaluation(c1 endsWith "b", false, row) checkEvaluation(c2 endsWith "b", null, row) - checkEvaluation(c1 endsWith Literal(null, StringType), null, row) + checkEvaluation(c1 endsWith Literal.create(null, StringType), null, row) } test("Substring") { - val row = new GenericRow(Array[Any]("example", "example".toArray.map(_.toByte))) + val row = create_row("example", "example".toArray.map(_.toByte)) val s = 'a.string.at(0) // substring from zero position with less-than-full length - checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(2, IntegerType)), "ex", row) - checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(2, IntegerType)), "ex", row) + checkEvaluation( + Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)), "ex", row) + checkEvaluation( + Substring(s, Literal.create(1, IntegerType), Literal.create(2, IntegerType)), "ex", row) // substring from zero position with full length - checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(7, IntegerType)), "example", row) - checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(7, IntegerType)), "example", row) + checkEvaluation( + Substring(s, Literal.create(0, IntegerType), Literal.create(7, IntegerType)), "example", row) + checkEvaluation( + Substring(s, Literal.create(1, IntegerType), Literal.create(7, IntegerType)), "example", row) // substring from zero position with greater-than-full length - checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(100, IntegerType)), "example", row) - checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(100, IntegerType)), "example", row) + checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(100, IntegerType)), + "example", row) + checkEvaluation(Substring(s, Literal.create(1, IntegerType), Literal.create(100, IntegerType)), + "example", row) // substring from nonzero position with less-than-full length - checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(2, IntegerType)), "xa", row) + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(2, IntegerType)), + "xa", row) // substring from nonzero position with full length - checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(6, IntegerType)), "xample", row) + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(6, IntegerType)), + "xample", row) // substring from nonzero position with greater-than-full length - checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(100, IntegerType)), "xample", row) + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(100, IntegerType)), + "xample", row) // zero-length substring (within string bounds) - checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(0, IntegerType)), "", row) + checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(0, IntegerType)), + "", row) // zero-length substring (beyond string bounds) - checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(4, IntegerType)), "", row) + checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), + "", row) // substring(null, _, _) -> null - checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(4, IntegerType)), null, new GenericRow(Array[Any](null))) + checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), + null, create_row(null)) // substring(_, null, _) -> null - checkEvaluation(Substring(s, Literal(null, IntegerType), Literal(4, IntegerType)), null, row) + checkEvaluation(Substring(s, Literal.create(null, IntegerType), Literal.create(4, IntegerType)), + null, row) // substring(_, _, null) -> null - checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation( + Substring(s, Literal.create(100, IntegerType), Literal.create(null, IntegerType)), + null, + row) // 2-arg substring from zero position - checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "example", row) - checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "example", row) + checkEvaluation( + Substring(s, Literal.create(0, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), + "example", + row) + checkEvaluation( + Substring(s, Literal.create(1, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), + "example", + row) // 2-arg substring from nonzero position - checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "xample", row) + checkEvaluation( + Substring(s, Literal.create(2, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), + "xample", + row) val s_notNull = 'a.string.notNull.at(0) - assert(Substring(s, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === true) - assert(Substring(s_notNull, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === false) - assert(Substring(s_notNull, Literal(null, IntegerType), Literal(2, IntegerType)).nullable === true) - assert(Substring(s_notNull, Literal(0, IntegerType), Literal(null, IntegerType)).nullable === true) + assert(Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable + === true) + assert( + Substring(s_notNull, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable + === false) + assert(Substring(s_notNull, + Literal.create(null, IntegerType), Literal.create(2, IntegerType)).nullable === true) + assert(Substring(s_notNull, + Literal.create(0, IntegerType), Literal.create(null, IntegerType)).nullable === true) checkEvaluation(s.substr(0, 2), "ex", row) checkEvaluation(s.substr(0), "example", row) @@ -1026,20 +1106,20 @@ class ExpressionEvaluationSuite extends FunSuite { test("SQRT") { val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24)) val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble)) - val rowSequence = inputSequence.map(l => new GenericRow(Array[Any](l.toDouble))) + val rowSequence = inputSequence.map(l => create_row(l.toDouble)) val d = 'a.double.at(0) for ((row, expected) <- rowSequence zip expectedResults) { checkEvaluation(Sqrt(d), expected, row) } - checkEvaluation(Sqrt(Literal(null, DoubleType)), null, new GenericRow(Array[Any](null))) + checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) checkEvaluation(Sqrt(-1), null, EmptyRow) checkEvaluation(Sqrt(-1.5), null, EmptyRow) } test("Bitwise operations") { - val row = new GenericRow(Array[Any](1, 2, 3, null)) + val row = create_row(1, 2, 3, null) val c1 = 'a.int.at(0) val c2 = 'a.int.at(1) val c3 = 'a.int.at(2) @@ -1047,22 +1127,25 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(BitwiseAnd(c1, c4), null, row) checkEvaluation(BitwiseAnd(c1, c2), 0, row) - checkEvaluation(BitwiseAnd(c1, Literal(null, IntegerType)), null, row) - checkEvaluation(BitwiseAnd(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(BitwiseAnd(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation( + BitwiseAnd(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(BitwiseOr(c1, c4), null, row) checkEvaluation(BitwiseOr(c1, c2), 3, row) - checkEvaluation(BitwiseOr(c1, Literal(null, IntegerType)), null, row) - checkEvaluation(BitwiseOr(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(BitwiseOr(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation( + BitwiseOr(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(BitwiseXor(c1, c4), null, row) checkEvaluation(BitwiseXor(c1, c2), 3, row) - checkEvaluation(BitwiseXor(c1, Literal(null, IntegerType)), null, row) - checkEvaluation(BitwiseXor(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(BitwiseXor(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation( + BitwiseXor(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(BitwiseNot(c4), null, row) checkEvaluation(BitwiseNot(c1), -2, row) - checkEvaluation(BitwiseNot(Literal(null, IntegerType)), null, row) + checkEvaluation(BitwiseNot(Literal.create(null, IntegerType)), null, row) checkEvaluation(c1 & c2, 0, row) checkEvaluation(c1 | c2, 3, row) @@ -1070,3 +1153,14 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(~c1, -2, row) } } + +// TODO: Make the tests work with codegen. +class ExpressionEvaluationWithoutCodeGenSuite extends ExpressionEvaluationBaseSuite { + + test("CreateStruct") { + val row = Row(1, 2, 3) + val c1 = 'a.int.at(0).as("a") + val c3 = 'c.int.at(2).as("c") + checkEvaluation(CreateStruct(Seq(c1, c3)), Row(1, 3), row) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala index ef3114fd4dba..b5ebe4b38e33 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala @@ -29,7 +29,7 @@ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { expected: Any, inputRow: Row = EmptyRow): Unit = { val plan = try { - GenerateMutableProjection(Alias(expression, s"Optimized($expression)")() :: Nil)() + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() } catch { case e: Throwable => val evaluated = GenerateProjection.expressionEvaluator(expression) @@ -56,10 +56,10 @@ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { val futures = (1 to 20).map { _ => future { - GeneratePredicate(EqualTo(Literal(1), Literal(1))) - GenerateProjection(EqualTo(Literal(1), Literal(1)) :: Nil) - GenerateMutableProjection(EqualTo(Literal(1), Literal(1)) :: Nil) - GenerateOrdering(Add(Literal(1), Literal(1)).asc :: Nil) + GeneratePredicate.generate(EqualTo(Literal(1), Literal(1))) + GenerateProjection.generate(EqualTo(Literal(1), Literal(1)) :: Nil) + GenerateMutableProjection.generate(EqualTo(Literal(1), Literal(1)) :: Nil) + GenerateOrdering.generate(Add(Literal(1), Literal(1)).asc :: Nil) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala index 275ea2627ebc..97af2e0fd050 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.codegen._ /** @@ -25,13 +25,13 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ */ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { override def checkEvaluation( - expression: Expression, - expected: Any, - inputRow: Row = EmptyRow): Unit = { + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { lazy val evaluated = GenerateProjection.expressionEvaluator(expression) val plan = try { - GenerateProjection(Alias(expression, s"Optimized($expression)")() :: Nil) + GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) } catch { case e: Throwable => fail( @@ -43,7 +43,7 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { } val actual = plan(inputRow) - val expectedRow = new GenericRow(Array[Any](expected)) + val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected))) if (actual.hashCode() != expectedRow.hashCode()) { fail( s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 264a0eff37d3..6255578d7fa5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest @@ -30,7 +30,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("AnalysisNodes", Once, - EliminateAnalysisOperators) :: + EliminateSubQueries) :: Batch("Constant Folding", FixedPoint(50), NullPropagation, ConstantFolding, @@ -61,7 +61,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { def checkCondition(input: Expression, expected: Expression): Unit = { val plan = testRelation.where(input).analyze - val actual = Optimize(plan).expressions.head + val actual = Optimize.execute(plan).expressions.head compareConditions(actual, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index e2ae0d25db1a..ad52283cfc95 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -42,9 +42,9 @@ class CombiningLimitsSuite extends PlanTest { testRelation .select('a) .limit(10) - .limit(5) + .limit(5) // wf: optimize the case for not adjacent limit - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select('a) @@ -61,7 +61,7 @@ class CombiningLimitsSuite extends PlanTest { .limit(7) .limit(5) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select('a) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 9fdf3efa02bb..14b28e840261 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators +import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, EliminateSubQueries} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.plans.PlanTest @@ -33,7 +33,7 @@ class ConstantFoldingSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("AnalysisNodes", Once, - EliminateAnalysisOperators) :: + EliminateSubQueries) :: Batch("ConstantFolding", Once, ConstantFolding, BooleanSimplification) :: Nil @@ -47,7 +47,7 @@ class ConstantFoldingSuite extends PlanTest { .subquery('y) .select('a) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select('a.attr) @@ -74,7 +74,7 @@ class ConstantFoldingSuite extends PlanTest { Literal(2) * Literal(3) - Literal(6) / (Literal(4) - Literal(2)) )(Literal(9) / Literal(3) as Symbol("9/3")) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation @@ -99,7 +99,7 @@ class ConstantFoldingSuite extends PlanTest { Literal(2) * 'a + Literal(4) as Symbol("c3"), 'a * (Literal(3) + Literal(4)) as Symbol("c4")) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation @@ -127,7 +127,7 @@ class ConstantFoldingSuite extends PlanTest { (Literal(1) === Literal(1) || 'b > 1) && (Literal(1) === Literal(2) || 'b < 10))) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation @@ -144,7 +144,7 @@ class ConstantFoldingSuite extends PlanTest { Cast(Literal("2"), IntegerType) + Literal(3) + 'a as Symbol("c1"), Coalesce(Seq(Cast(Literal("abc"), IntegerType), Literal(3))) as Symbol("c2")) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation @@ -163,7 +163,7 @@ class ConstantFoldingSuite extends PlanTest { Rand + Literal(1) as Symbol("c1"), Sum('a) as Symbol("c2")) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation @@ -176,42 +176,41 @@ class ConstantFoldingSuite extends PlanTest { } test("Constant folding test: expressions have null literals") { - val originalQuery = - testRelation - .select( - IsNull(Literal(null)) as 'c1, - IsNotNull(Literal(null)) as 'c2, + val originalQuery = testRelation.select( + IsNull(Literal(null)) as 'c1, + IsNotNull(Literal(null)) as 'c2, - GetItem(Literal(null, ArrayType(IntegerType)), 1) as 'c3, - GetItem(Literal(Seq(1), ArrayType(IntegerType)), Literal(null, IntegerType)) as 'c4, - GetField( - Literal(null, StructType(Seq(StructField("a", IntegerType, true)))), - "a") as 'c5, + GetItem(Literal.create(null, ArrayType(IntegerType)), 1) as 'c3, + GetItem( + Literal.create(Seq(1), ArrayType(IntegerType)), Literal.create(null, IntegerType)) as 'c4, + UnresolvedGetField( + Literal.create(null, StructType(Seq(StructField("a", IntegerType, true)))), + "a") as 'c5, - UnaryMinus(Literal(null, IntegerType)) as 'c6, - Cast(Literal(null), IntegerType) as 'c7, - Not(Literal(null, BooleanType)) as 'c8, + UnaryMinus(Literal.create(null, IntegerType)) as 'c6, + Cast(Literal(null), IntegerType) as 'c7, + Not(Literal.create(null, BooleanType)) as 'c8, - Add(Literal(null, IntegerType), 1) as 'c9, - Add(1, Literal(null, IntegerType)) as 'c10, + Add(Literal.create(null, IntegerType), 1) as 'c9, + Add(1, Literal.create(null, IntegerType)) as 'c10, - EqualTo(Literal(null, IntegerType), 1) as 'c11, - EqualTo(1, Literal(null, IntegerType)) as 'c12, + EqualTo(Literal.create(null, IntegerType), 1) as 'c11, + EqualTo(1, Literal.create(null, IntegerType)) as 'c12, - Like(Literal(null, StringType), "abc") as 'c13, - Like("abc", Literal(null, StringType)) as 'c14, + Like(Literal.create(null, StringType), "abc") as 'c13, + Like("abc", Literal.create(null, StringType)) as 'c14, - Upper(Literal(null, StringType)) as 'c15, + Upper(Literal.create(null, StringType)) as 'c15, - Substring(Literal(null, StringType), 0, 1) as 'c16, - Substring("abc", Literal(null, IntegerType), 1) as 'c17, - Substring("abc", 0, Literal(null, IntegerType)) as 'c18, + Substring(Literal.create(null, StringType), 0, 1) as 'c16, + Substring("abc", Literal.create(null, IntegerType), 1) as 'c17, + Substring("abc", 0, Literal.create(null, IntegerType)) as 'c18, - Contains(Literal(null, StringType), "abc") as 'c19, - Contains("abc", Literal(null, StringType)) as 'c20 - ) + Contains(Literal.create(null, StringType), "abc") as 'c19, + Contains("abc", Literal.create(null, StringType)) as 'c20 + ) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation @@ -219,31 +218,31 @@ class ConstantFoldingSuite extends PlanTest { Literal(true) as 'c1, Literal(false) as 'c2, - Literal(null, IntegerType) as 'c3, - Literal(null, IntegerType) as 'c4, - Literal(null, IntegerType) as 'c5, + Literal.create(null, IntegerType) as 'c3, + Literal.create(null, IntegerType) as 'c4, + Literal.create(null, IntegerType) as 'c5, - Literal(null, IntegerType) as 'c6, - Literal(null, IntegerType) as 'c7, - Literal(null, BooleanType) as 'c8, + Literal.create(null, IntegerType) as 'c6, + Literal.create(null, IntegerType) as 'c7, + Literal.create(null, BooleanType) as 'c8, - Literal(null, IntegerType) as 'c9, - Literal(null, IntegerType) as 'c10, + Literal.create(null, IntegerType) as 'c9, + Literal.create(null, IntegerType) as 'c10, - Literal(null, BooleanType) as 'c11, - Literal(null, BooleanType) as 'c12, + Literal.create(null, BooleanType) as 'c11, + Literal.create(null, BooleanType) as 'c12, - Literal(null, BooleanType) as 'c13, - Literal(null, BooleanType) as 'c14, + Literal.create(null, BooleanType) as 'c13, + Literal.create(null, BooleanType) as 'c14, - Literal(null, StringType) as 'c15, + Literal.create(null, StringType) as 'c15, - Literal(null, StringType) as 'c16, - Literal(null, StringType) as 'c17, - Literal(null, StringType) as 'c18, + Literal.create(null, StringType) as 'c16, + Literal.create(null, StringType) as 'c17, + Literal.create(null, StringType) as 'c18, - Literal(null, BooleanType) as 'c19, - Literal(null, BooleanType) as 'c20 + Literal.create(null, BooleanType) as 'c19, + Literal.create(null, BooleanType) as 'c20 ).analyze comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala new file mode 100644 index 000000000000..6841bd9890c9 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + + +class ConvertToLocalRelationSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("LocalRelation", FixedPoint(100), + ConvertToLocalRelation) :: Nil + } + + test("Project on LocalRelation should be turned into a single LocalRelation") { + val testRelation = LocalRelation( + LocalRelation('a.int, 'b.int).output, + Row(1, 2) :: + Row(4, 5) :: Nil) + + val correctAnswer = LocalRelation( + LocalRelation('a1.int, 'b1.int).output, + Row(1, 3) :: + Row(4, 6) :: Nil) + + val projectOnLocal = testRelation.select( + UnresolvedAttribute("a").as("a1"), + (UnresolvedAttribute("b") + 1).as("b1")) + + val optimized = Optimize.execute(projectOnLocal.analyze) + + comparePlans(optimized, correctAnswer) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala index ae99a3f9ba28..a4a3a66b8b22 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala @@ -29,8 +29,8 @@ class ExpressionOptimizationSuite extends ExpressionEvaluationSuite { expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { - val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, NoRelation) - val optimizedPlan = DefaultOptimizer(plan) + val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) + val optimizedPlan = DefaultOptimizer.execute(plan) super.checkEvaluation(optimizedPlan.expressions.head, expected, inputRow) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index ebb123c1f909..aa9708b164ef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -18,23 +18,27 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.expressions.{Count, Explode} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{PlanTest, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.IntegerType class FilterPushdownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", Once, - EliminateAnalysisOperators) :: + EliminateSubQueries) :: Batch("Filter Pushdown", Once, CombineFilters, PushPredicateThroughProject, - PushPredicateThroughJoin) :: Nil + PushPredicateThroughJoin, + PushPredicateThroughGenerate, + ColumnPruning) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) @@ -46,7 +50,7 @@ class FilterPushdownSuite extends PlanTest { .subquery('y) .select('a) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select('a.attr) @@ -55,6 +59,38 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("column pruning for group") { + val originalQuery = + testRelation + .groupBy('a)('a, Count('b)) + .select('a) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .groupBy('a)('a) + .select('a).analyze + + comparePlans(optimized, correctAnswer) + } + + test("column pruning for group with alias") { + val originalQuery = + testRelation + .groupBy('a)('a as 'c, Count('b)) + .select('c) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .groupBy('a)('a as 'c) + .select('c).analyze + + comparePlans(optimized, correctAnswer) + } + // After this line is unimplemented. test("simple push down") { val originalQuery = @@ -62,7 +98,7 @@ class FilterPushdownSuite extends PlanTest { .select('a) .where('a === 1) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where('a === 1) @@ -79,7 +115,7 @@ class FilterPushdownSuite extends PlanTest { .where('e === 1) .analyze - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where('a + 'b === 1) @@ -95,7 +131,7 @@ class FilterPushdownSuite extends PlanTest { .where('a === 1) .where('a === 2) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where('a === 1 && 'a === 2) @@ -116,7 +152,7 @@ class FilterPushdownSuite extends PlanTest { .where("y.b".attr === 2) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 1) val right = testRelation.where('b === 2) val correctAnswer = @@ -134,7 +170,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 1) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 1) val right = testRelation val correctAnswer = @@ -152,7 +188,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 1 && "y.b".attr === 2) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 1) val right = testRelation.where('b === 2) val correctAnswer = @@ -170,7 +206,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 1 && "y.b".attr === 2) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 1) val correctAnswer = left.join(y, LeftOuter).where("y.b".attr === 2).analyze @@ -187,7 +223,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 1 && "y.b".attr === 2) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val right = testRelation.where('b === 2).subquery('d) val correctAnswer = x.join(right, RightOuter).where("x.b".attr === 1).analyze @@ -204,7 +240,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 2 && "y.b".attr === 2) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 2).subquery('d) val correctAnswer = left.join(y, LeftOuter, Some("d.b".attr === 1)).where("y.b".attr === 2).analyze @@ -221,7 +257,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 2 && "y.b".attr === 2) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val right = testRelation.where('b === 2).subquery('d) val correctAnswer = x.join(right, RightOuter, Some("d.b".attr === 1)).where("x.b".attr === 2).analyze @@ -238,7 +274,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 2 && "y.b".attr === 2) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 2).subquery('l) val right = testRelation.where('b === 1).subquery('r) val correctAnswer = @@ -256,7 +292,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 2 && "y.b".attr === 2) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val right = testRelation.where('b === 2).subquery('r) val correctAnswer = x.join(right, RightOuter, Some("r.b".attr === 1)).where("x.b".attr === 2).analyze @@ -273,7 +309,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 2).subquery('l) val right = testRelation.where('b === 1).subquery('r) val correctAnswer = @@ -291,7 +327,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.subquery('l) val right = testRelation.where('b === 2).subquery('r) val correctAnswer = @@ -310,7 +346,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 2).subquery('l) val right = testRelation.where('b === 1).subquery('r) val correctAnswer = @@ -329,7 +365,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('a === 3).subquery('l) val right = testRelation.where('b === 2).subquery('r) val correctAnswer = @@ -346,9 +382,9 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = { x.join(y, condition = Some("x.b".attr === "y.b".attr)) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) - comparePlans(analysis.EliminateAnalysisOperators(originalQuery.analyze), optimized) + comparePlans(analysis.EliminateSubQueries(originalQuery.analyze), optimized) } test("joins: conjunctive predicates") { @@ -360,14 +396,14 @@ class FilterPushdownSuite extends PlanTest { .where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1) && ("y.a".attr === 1)) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('a === 1).subquery('x) val right = testRelation.where('a === 1).subquery('y) val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr)) .analyze - comparePlans(optimized, analysis.EliminateAnalysisOperators(correctAnswer)) + comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer)) } test("joins: conjunctive predicates #2") { @@ -379,14 +415,14 @@ class FilterPushdownSuite extends PlanTest { .where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1)) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('a === 1).subquery('x) val right = testRelation.subquery('y) val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr)) .analyze - comparePlans(optimized, analysis.EliminateAnalysisOperators(correctAnswer)) + comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer)) } test("joins: conjunctive predicates #3") { @@ -396,10 +432,11 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = { z.join(x.join(y)) - .where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1) && ("z.a".attr >= 3) && ("z.a".attr === "x.b".attr)) + .where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1) && + ("z.a".attr >= 3) && ("z.a".attr === "x.b".attr)) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val lleft = testRelation.where('a >= 3).subquery('z) val left = testRelation.where('a === 1).subquery('x) val right = testRelation.subquery('y) @@ -409,6 +446,64 @@ class FilterPushdownSuite extends PlanTest { condition = Some("z.a".attr === "x.b".attr)) .analyze - comparePlans(optimized, analysis.EliminateAnalysisOperators(correctAnswer)) + comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer)) + } + + val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType)) + + test("generate: predicate referenced no generated column") { + val originalQuery = { + testRelationWithArrayType + .generate(Explode('c_arr), true, false, Some("arr")) + .where(('b >= 5) && ('a > 6)) + } + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = { + testRelationWithArrayType + .where(('b >= 5) && ('a > 6)) + .generate(Explode('c_arr), true, false, Some("arr")).analyze + } + + comparePlans(optimized, correctAnswer) + } + + test("generate: part of conjuncts referenced generated column") { + val generator = Explode('c_arr) + val originalQuery = { + testRelationWithArrayType + .generate(generator, true, false, Some("arr")) + .where(('b >= 5) && ('c > 6)) + } + val optimized = Optimize.execute(originalQuery.analyze) + val referenceResult = { + testRelationWithArrayType + .where('b >= 5) + .generate(generator, true, false, Some("arr")) + .where('c > 6).analyze + } + + // Since newly generated columns get different ids every time being analyzed + // e.g. comparePlans(originalQuery.analyze, originalQuery.analyze) fails. + // So we check operators manually here. + // Filter("c" > 6) + assertResult(classOf[Filter])(optimized.getClass) + assertResult(1)(optimized.asInstanceOf[Filter].condition.references.size) + assertResult("c"){ + optimized.asInstanceOf[Filter].condition.references.toSeq(0).name + } + + // the rest part + comparePlans(optimized.children(0), referenceResult.children(0)) + } + + test("generate: all conjuncts referenced generated column") { + val originalQuery = { + testRelationWithArrayType + .generate(Explode('c_arr), true, false, Some("arr")) + .where(('c > 6) || ('b > 5)).analyze + } + val optimized = Optimize.execute(originalQuery) + + comparePlans(optimized, originalQuery) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala index b10577c8001e..b3df487c84dc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala @@ -41,7 +41,7 @@ class LikeSimplificationSuite extends PlanTest { testRelation .where(('a like "abc%") || ('a like "abc\\%")) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where(StartsWith('a, "abc") || ('a like "abc\\%")) .analyze @@ -54,7 +54,7 @@ class LikeSimplificationSuite extends PlanTest { testRelation .where('a like "%xyz") - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where(EndsWith('a, "xyz")) .analyze @@ -67,7 +67,7 @@ class LikeSimplificationSuite extends PlanTest { testRelation .where(('a like "%mn%") || ('a like "%mn\\%")) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where(Contains('a, "mn") || ('a like "%mn\\%")) .analyze @@ -80,7 +80,7 @@ class LikeSimplificationSuite extends PlanTest { testRelation .where(('a like "") || ('a like "abc")) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where(('a === "") || ('a === "abc")) .analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index da912ab38217..3eb399e68e70 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet -import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.plans.PlanTest @@ -34,7 +34,7 @@ class OptimizeInSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("AnalysisNodes", Once, - EliminateAnalysisOperators) :: + EliminateSubQueries) :: Batch("ConstantFolding", Once, ConstantFolding, BooleanSimplification, @@ -49,10 +49,10 @@ class OptimizeInSuite extends PlanTest { .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2)))) .analyze - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(InSet(UnresolvedAttribute("a"), HashSet[Any]()+1+2)) + .where(InSet(UnresolvedAttribute("a"), HashSet[Any]() + 1 + 2)) .analyze comparePlans(optimized, correctAnswer) @@ -64,7 +64,7 @@ class OptimizeInSuite extends PlanTest { .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b")))) .analyze - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b")))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala index 22992fb6f50d..6b1e53cd42b2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala @@ -41,7 +41,7 @@ class SimplifyCaseConversionExpressionsSuite extends PlanTest { testRelation .select(Upper(Upper('a)) as 'u) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select(Upper('a) as 'u) @@ -55,7 +55,7 @@ class SimplifyCaseConversionExpressionsSuite extends PlanTest { testRelation .select(Upper(Lower('a)) as 'u) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select(Upper('a) as 'u) @@ -69,7 +69,7 @@ class SimplifyCaseConversionExpressionsSuite extends PlanTest { testRelation .select(Lower(Upper('a)) as 'l) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select(Lower('a) as 'l) .analyze @@ -82,7 +82,7 @@ class SimplifyCaseConversionExpressionsSuite extends PlanTest { testRelation .select(Lower(Lower('a)) as 'l) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select(Lower('a) as 'l) .analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala index dfef87bd9133..a3ad200800b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.{PlanTest, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -29,7 +28,7 @@ class UnionPushdownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", Once, - EliminateAnalysisOperators) :: + EliminateSubQueries) :: Batch("Union Pushdown", Once, UnionPushdown) :: Nil } @@ -41,7 +40,7 @@ class UnionPushdownSuite extends PlanTest { test("union: filter to each side") { val query = testUnion.where('a === 1) - val optimized = Optimize(query.analyze) + val optimized = Optimize.execute(query.analyze) val correctAnswer = Union(testRelation.where('a === 1), testRelation2.where('d === 1)).analyze @@ -52,7 +51,7 @@ class UnionPushdownSuite extends PlanTest { test("union: project to each side") { val query = testUnion.select('b) - val optimized = Optimize(query.analyze) + val optimized = Optimize.execute(query.analyze) val correctAnswer = Union(testRelation.select('b), testRelation2.select('e)).analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index c4a1f899d8a1..e7cafcc96de8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.plans import org.scalatest.FunSuite -import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Filter, LogicalPlan} import org.apache.spark.sql.catalyst.util._ /** @@ -33,11 +33,11 @@ class PlanTest extends FunSuite { * we must normalize them to check if two different queries are identical. */ protected def normalizeExprIds(plan: LogicalPlan) = { - val list = plan.flatMap(_.expressions.flatMap(_.references).map(_.exprId.id)) - val minId = if (list.isEmpty) 0 else list.min plan transformAllExpressions { case a: AttributeReference => - AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(a.exprId.id - minId)) + AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) + case a: Alias => + Alias(a.child, a.name)(exprId = ExprId(0)) } } @@ -45,11 +45,17 @@ class PlanTest extends FunSuite { protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { val normalized1 = normalizeExprIds(plan1) val normalized2 = normalizeExprIds(plan2) - if (normalized1 != normalized2) + if (normalized1 != normalized2) { fail( s""" |== FAIL: Plans do not match === |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} - """.stripMargin) + """.stripMargin) + } + } + + /** Fails the test if the two expressions do not match */ + protected def compareExpressions(e1: Expression, e2: Expression): Unit = { + comparePlans(Filter(e1, OneRowRelation), Filter(e2, OneRowRelation)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala index 11e6831b2476..1273921f6394 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala @@ -32,7 +32,7 @@ class SameResultSuite extends FunSuite { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int) - def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true) = { + def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true): Unit = { val aAnalyzed = a.analyze val bAnalyzed = b.analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala index 4b2d45584045..2a641c63f87b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala @@ -34,7 +34,7 @@ class RuleExecutorSuite extends FunSuite { val batches = Batch("once", Once, DecrementLiterals) :: Nil } - assert(ApplyOnce(Literal(10)) === Literal(9)) + assert(ApplyOnce.execute(Literal(10)) === Literal(9)) } test("to fixed point") { @@ -42,7 +42,7 @@ class RuleExecutorSuite extends FunSuite { val batches = Batch("fixedPoint", FixedPoint(100), DecrementLiterals) :: Nil } - assert(ToFixedPoint(Literal(10)) === Literal(0)) + assert(ToFixedPoint.execute(Literal(10)) === Literal(0)) } test("to maxIterations") { @@ -50,6 +50,6 @@ class RuleExecutorSuite extends FunSuite { val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil } - assert(ToFixedPoint(Literal(100)) === Literal(90)) + assert(ToFixedPoint.execute(Literal(100)) === Literal(90)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index cdb843f95970..6b393327cc97 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -25,12 +25,12 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{StringType, NullType} case class Dummy(optKey: Option[Expression]) extends Expression { - def children = optKey.toSeq - def nullable = true - def dataType = NullType + def children: Seq[Expression] = optKey.toSeq + def nullable: Boolean = true + def dataType: NullType = NullType override lazy val resolved = true type EvaluatedType = Any - def eval(input: Row) = null.asInstanceOf[Any] + def eval(input: Row): Any = null.asInstanceOf[Any] } class TreeNodeSuite extends FunSuite { @@ -90,7 +90,7 @@ class TreeNodeSuite extends FunSuite { } test("transform works on nodes with Option children") { - val dummy1 = Dummy(Some(Literal("1", StringType))) + val dummy1 = Dummy(Some(Literal.create("1", StringType))) val dummy2 = Dummy(None) val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) } @@ -104,4 +104,30 @@ class TreeNodeSuite extends FunSuite { assert(actual === Dummy(None)) } + test("preserves origin") { + CurrentOrigin.setPosition(1,1) + val add = Add(Literal(1), Literal(1)) + CurrentOrigin.reset() + + val transformed = add transform { + case Literal(1, _) => Literal(2) + } + + assert(transformed.origin.line.isDefined) + assert(transformed.origin.startPosition.isDefined) + } + + test("foreach up") { + val actual = new ArrayBuffer[String]() + val expected = Seq("1", "2", "3", "4", "-", "*", "+") + val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) + expression foreachUp { + case b: BinaryExpression => actual.append(b.symbol); + case l: Literal => actual.append(l.toString); + } + + assert(expected === actual) + } + + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala old mode 100755 new mode 100644 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala new file mode 100644 index 000000000000..169125264a80 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala @@ -0,0 +1,118 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.types + +import org.scalatest.FunSuite + +class DataTypeParserSuite extends FunSuite { + + def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = { + test(s"parse ${dataTypeString.replace("\n", "")}") { + assert(DataTypeParser(dataTypeString) === expectedDataType) + } + } + + def unsupported(dataTypeString: String): Unit = { + test(s"$dataTypeString is not supported") { + intercept[DataTypeException](DataTypeParser(dataTypeString)) + } + } + + checkDataType("int", IntegerType) + checkDataType("integer", IntegerType) + checkDataType("BooLean", BooleanType) + checkDataType("tinYint", ByteType) + checkDataType("smallINT", ShortType) + checkDataType("INT", IntegerType) + checkDataType("INTEGER", IntegerType) + checkDataType("bigint", LongType) + checkDataType("float", FloatType) + checkDataType("dOUBle", DoubleType) + checkDataType("decimal(10, 5)", DecimalType(10, 5)) + checkDataType("decimal", DecimalType.Unlimited) + checkDataType("DATE", DateType) + checkDataType("timestamp", TimestampType) + checkDataType("string", StringType) + checkDataType("varchAr(20)", StringType) + checkDataType("BINARY", BinaryType) + + checkDataType("array", ArrayType(DoubleType, true)) + checkDataType("Array>", ArrayType(MapType(IntegerType, ByteType, true), true)) + checkDataType( + "array>", + ArrayType(StructType(StructField("tinYint", ByteType, true) :: Nil), true) + ) + checkDataType("MAP", MapType(IntegerType, StringType, true)) + checkDataType("MAp>", MapType(IntegerType, ArrayType(DoubleType), true)) + checkDataType( + "MAP>", + MapType(IntegerType, StructType(StructField("varchar", StringType, true) :: Nil), true) + ) + + checkDataType( + "struct", + StructType( + StructField("intType", IntegerType, true) :: + StructField("ts", TimestampType, true) :: Nil) + ) + // It is fine to use the data type string as the column name. + checkDataType( + "Struct", + StructType( + StructField("int", IntegerType, true) :: + StructField("timestamp", TimestampType, true) :: Nil) + ) + checkDataType( + """ + |struct< + | struct:struct, + | MAP:Map, + | arrAy:Array> + """.stripMargin, + StructType( + StructField("struct", + StructType( + StructField("deciMal", DecimalType.Unlimited, true) :: + StructField("anotherDecimal", DecimalType(5, 2), true) :: Nil), true) :: + StructField("MAP", MapType(TimestampType, StringType), true) :: + StructField("arrAy", ArrayType(DoubleType, true), true) :: Nil) + ) + // A column name can be a reserved word in our DDL parser and SqlParser. + checkDataType( + "Struct", + StructType( + StructField("TABLE", StringType, true) :: + StructField("CASE", BooleanType, true) :: Nil) + ) + // Use backticks to quote column names having special characters. + checkDataType( + "struct<`x+y`:int, `!@#$%^&*()`:string, `1_2.345<>:\"`:varchar(20)>", + StructType( + StructField("x+y", IntegerType, true) :: + StructField("!@#$%^&*()", StringType, true) :: + StructField("1_2.345<>:\"", StringType, true) :: Nil) + ) + // Empty struct. + checkDataType("strUCt<>", StructType(Nil)) + + unsupported("it is not a data type") + unsupported("struct") + unsupported("struct") + unsupported("struct<`x``y` int>") +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index c147be9f6b1a..d797510f3668 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -56,6 +56,19 @@ class DataTypeSuite extends FunSuite { } } + test("extract field index from a StructType") { + val struct = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + + assert(struct.fieldIndex("a") === 0) + assert(struct.fieldIndex("b") === 1) + + intercept[IllegalArgumentException] { + struct.fieldIndex("non_existent") + } + } + def checkDataTypeJsonRepr(dataType: DataType): Unit = { test(s"JSON - $dataType") { assert(DataType.fromJson(dataType.json) === dataType) @@ -106,8 +119,8 @@ class DataTypeSuite extends FunSuite { checkDefaultSize(DoubleType, 8) checkDefaultSize(DecimalType(10, 5), 4096) checkDefaultSize(DecimalType.Unlimited, 4096) - checkDefaultSize(DateType, 8) - checkDefaultSize(TimestampType, 8) + checkDefaultSize(DateType, 4) + checkDefaultSize(TimestampType,12) checkDefaultSize(StringType, 4096) checkDefaultSize(BinaryType, 4096) checkDefaultSize(ArrayType(DoubleType, true), 800) @@ -115,4 +128,87 @@ class DataTypeSuite extends FunSuite { checkDefaultSize(MapType(IntegerType, StringType, true), 410000) checkDefaultSize(MapType(IntegerType, ArrayType(DoubleType), false), 80400) checkDefaultSize(structType, 812) + + def checkEqualsIgnoreCompatibleNullability( + from: DataType, + to: DataType, + expected: Boolean): Unit = { + val testName = + s"equalsIgnoreCompatibleNullability: (from: ${from}, to: ${to})" + test(testName) { + assert(DataType.equalsIgnoreCompatibleNullability(from, to) === expected) + } + } + + checkEqualsIgnoreCompatibleNullability( + from = ArrayType(DoubleType, containsNull = true), + to = ArrayType(DoubleType, containsNull = true), + expected = true) + checkEqualsIgnoreCompatibleNullability( + from = ArrayType(DoubleType, containsNull = false), + to = ArrayType(DoubleType, containsNull = false), + expected = true) + checkEqualsIgnoreCompatibleNullability( + from = ArrayType(DoubleType, containsNull = false), + to = ArrayType(DoubleType, containsNull = true), + expected = true) + checkEqualsIgnoreCompatibleNullability( + from = ArrayType(DoubleType, containsNull = true), + to = ArrayType(DoubleType, containsNull = false), + expected = false) + checkEqualsIgnoreCompatibleNullability( + from = ArrayType(DoubleType, containsNull = false), + to = ArrayType(StringType, containsNull = false), + expected = false) + + checkEqualsIgnoreCompatibleNullability( + from = MapType(StringType, DoubleType, valueContainsNull = true), + to = MapType(StringType, DoubleType, valueContainsNull = true), + expected = true) + checkEqualsIgnoreCompatibleNullability( + from = MapType(StringType, DoubleType, valueContainsNull = false), + to = MapType(StringType, DoubleType, valueContainsNull = false), + expected = true) + checkEqualsIgnoreCompatibleNullability( + from = MapType(StringType, DoubleType, valueContainsNull = false), + to = MapType(StringType, DoubleType, valueContainsNull = true), + expected = true) + checkEqualsIgnoreCompatibleNullability( + from = MapType(StringType, DoubleType, valueContainsNull = true), + to = MapType(StringType, DoubleType, valueContainsNull = false), + expected = false) + checkEqualsIgnoreCompatibleNullability( + from = MapType(StringType, ArrayType(IntegerType, true), valueContainsNull = true), + to = MapType(StringType, ArrayType(IntegerType, false), valueContainsNull = true), + expected = false) + checkEqualsIgnoreCompatibleNullability( + from = MapType(StringType, ArrayType(IntegerType, false), valueContainsNull = true), + to = MapType(StringType, ArrayType(IntegerType, true), valueContainsNull = true), + expected = true) + + + checkEqualsIgnoreCompatibleNullability( + from = StructType(StructField("a", StringType, nullable = true) :: Nil), + to = StructType(StructField("a", StringType, nullable = true) :: Nil), + expected = true) + checkEqualsIgnoreCompatibleNullability( + from = StructType(StructField("a", StringType, nullable = false) :: Nil), + to = StructType(StructField("a", StringType, nullable = false) :: Nil), + expected = true) + checkEqualsIgnoreCompatibleNullability( + from = StructType(StructField("a", StringType, nullable = false) :: Nil), + to = StructType(StructField("a", StringType, nullable = true) :: Nil), + expected = true) + checkEqualsIgnoreCompatibleNullability( + from = StructType(StructField("a", StringType, nullable = true) :: Nil), + to = StructType(StructField("a", StringType, nullable = false) :: Nil), + expected = false) + checkEqualsIgnoreCompatibleNullability( + from = StructType( + StructField("a", StringType, nullable = false) :: + StructField("b", StringType, nullable = true) :: Nil), + to = StructType( + StructField("a", StringType, nullable = false) :: + StructField("b", StringType, nullable = false) :: Nil), + expected = false) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala new file mode 100644 index 000000000000..a22aa6f244c4 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala @@ -0,0 +1,70 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.types + +import org.scalatest.FunSuite + +// scalastyle:off +class UTF8StringSuite extends FunSuite { + test("basic") { + def check(str: String, len: Int) { + + assert(UTF8String(str).length == len) + assert(UTF8String(str.getBytes("utf8")).length() == len) + + assert(UTF8String(str) == str) + assert(UTF8String(str.getBytes("utf8")) == str) + assert(UTF8String(str).toString == str) + assert(UTF8String(str.getBytes("utf8")).toString == str) + assert(UTF8String(str.getBytes("utf8")) == UTF8String(str)) + + assert(UTF8String(str).hashCode() == UTF8String(str.getBytes("utf8")).hashCode()) + } + + check("hello", 5) + check("世 界", 3) + } + + test("contains") { + assert(UTF8String("hello").contains(UTF8String("ello"))) + assert(!UTF8String("hello").contains(UTF8String("vello"))) + assert(UTF8String("大千世界").contains(UTF8String("千世"))) + assert(!UTF8String("大千世界").contains(UTF8String("世千"))) + } + + test("prefix") { + assert(UTF8String("hello").startsWith(UTF8String("hell"))) + assert(!UTF8String("hello").startsWith(UTF8String("ell"))) + assert(UTF8String("大千世界").startsWith(UTF8String("大千"))) + assert(!UTF8String("大千世界").startsWith(UTF8String("千"))) + } + + test("suffix") { + assert(UTF8String("hello").endsWith(UTF8String("ello"))) + assert(!UTF8String("hello").endsWith(UTF8String("ellov"))) + assert(UTF8String("大千世界").endsWith(UTF8String("世界"))) + assert(!UTF8String("大千世界").endsWith(UTF8String("世"))) + } + + test("slice") { + assert(UTF8String("hello").slice(1, 3) == UTF8String("el")) + assert(UTF8String("大千世界").slice(0, 1) == UTF8String("大")) + assert(UTF8String("大千世界").slice(1, 3) == UTF8String("千世")) + assert(UTF8String("大千世界").slice(3, 5) == UTF8String("界")) + } +} diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 3e9ef07df9db..e3a6b1fe7243 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -21,8 +21,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT ../../pom.xml @@ -66,6 +66,11 @@ jackson-databind 2.3.0 + + org.jodd + jodd-core + ${jodd.version} + junit junit @@ -76,9 +81,35 @@ scalacheck_${scala.binary.version} test + + com.h2database + h2 + 1.4.183 + test + + + mysql + mysql-connector-java + 5.1.34 + test + + + org.postgresql + postgresql + 9.3-1102-jdbc41 + test + target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + + ../../python + + pyspark/sql/*.py + + + diff --git a/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java b/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java new file mode 100644 index 000000000000..a40be526d0d1 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql; + +/** + * SaveMode is used to specify the expected behavior of saving a DataFrame to a data source. + */ +public enum SaveMode { + /** + * Append mode means that when saving a DataFrame to a data source, if data/table already exists, + * contents of the DataFrame are expected to be appended to existing data. + */ + Append, + /** + * Overwrite mode means that when saving a DataFrame to a data source, + * if data/table already exists, existing data is expected to be overwritten by the contents of + * the DataFrame. + */ + Overwrite, + /** + * ErrorIfExists mode means that when saving a DataFrame to a data source, if data already exists, + * an exception is expected to be thrown. + */ + ErrorIfExists, + /** + * Ignore mode means that when saving a DataFrame to a data source, if data already exists, + * the save operation is expected to not save the contents of the DataFrame and to not + * change the existing data. + */ + Ignore +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala index e715d9434a2a..18584c2dcf79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.util.concurrent.locks.ReentrantReadWriteLock +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.columnar.InMemoryRelation import org.apache.spark.storage.StorageLevel @@ -32,9 +33,10 @@ private case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryR * results when subsequent queries are executed. Data is cached using byte buffers stored in an * InMemoryRelation. This relation is automatically substituted query plans that return the * `sameResult` as the originally cached query. + * + * Internal to Spark SQL. */ -private[sql] trait CacheManager { - self: SQLContext => +private[sql] class CacheManager(sqlContext: SQLContext) extends Logging { @transient private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData] @@ -43,13 +45,13 @@ private[sql] trait CacheManager { private val cacheLock = new ReentrantReadWriteLock /** Returns true if the table is currently cached in-memory. */ - def isCached(tableName: String): Boolean = lookupCachedData(table(tableName)).nonEmpty + def isCached(tableName: String): Boolean = lookupCachedData(sqlContext.table(tableName)).nonEmpty /** Caches the specified table in-memory. */ - def cacheTable(tableName: String): Unit = cacheQuery(table(tableName), Some(tableName)) + def cacheTable(tableName: String): Unit = cacheQuery(sqlContext.table(tableName), Some(tableName)) /** Removes the specified table from the in-memory cache. */ - def uncacheTable(tableName: String): Unit = uncacheQuery(table(tableName)) + def uncacheTable(tableName: String): Unit = uncacheQuery(sqlContext.table(tableName)) /** Acquires a read lock on the cache for the duration of `f`. */ private def readLock[A](f: => A): A = { @@ -69,18 +71,24 @@ private[sql] trait CacheManager { } } + /** Clears all cached tables. */ private[sql] def clearCache(): Unit = writeLock { cachedData.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist()) cachedData.clear() } + /** Checks if the cache is empty. */ + private[sql] def isEmpty: Boolean = readLock { + cachedData.isEmpty + } + /** * Caches the data produced by the logical representation of the given schema rdd. Unlike * `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because recomputing * the in-memory columnar representation of the underlying table is expensive. */ private[sql] def cacheQuery( - query: SchemaRDD, + query: DataFrame, tableName: Option[String] = None, storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { val planToCache = query.queryExecution.analyzed @@ -91,26 +99,26 @@ private[sql] trait CacheManager { CachedData( planToCache, InMemoryRelation( - conf.useCompression, - conf.columnBatchSize, + sqlContext.conf.useCompression, + sqlContext.conf.columnBatchSize, storageLevel, query.queryExecution.executedPlan, tableName)) } } - /** Removes the data for the given SchemaRDD from the cache */ - private[sql] def uncacheQuery(query: SchemaRDD, blocking: Boolean = true): Unit = writeLock { + /** Removes the data for the given [[DataFrame]] from the cache */ + private[sql] def uncacheQuery(query: DataFrame, blocking: Boolean = true): Unit = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) require(dataIndex >= 0, s"Table $query is not cached.") - cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking) + cachedData(dataIndex).cachedRepresentation.uncache(blocking) cachedData.remove(dataIndex) } - /** Tries to remove the data for the given SchemaRDD from the cache if it's cached */ + /** Tries to remove the data for the given [[DataFrame]] from the cache if it's cached */ private[sql] def tryUncacheQuery( - query: SchemaRDD, + query: DataFrame, blocking: Boolean = true): Boolean = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) @@ -122,8 +130,8 @@ private[sql] trait CacheManager { found } - /** Optionally returns cached data for the given SchemaRDD */ - private[sql] def lookupCachedData(query: SchemaRDD): Option[CachedData] = readLock { + /** Optionally returns cached data for the given [[DataFrame]] */ + private[sql] def lookupCachedData(query: DataFrame): Option[CachedData] = readLock { lookupCachedData(query.queryExecution.analyzed) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala new file mode 100644 index 000000000000..edb229c059e6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -0,0 +1,755 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import scala.language.implicitConversions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.Logging +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar, UnresolvedGetField} +import org.apache.spark.sql.types._ + + +private[sql] object Column { + + def apply(colName: String): Column = new Column(colName) + + def apply(expr: Expression): Column = new Column(expr) + + def unapply(col: Column): Option[Expression] = Some(col.expr) +} + + +/** + * :: Experimental :: + * A column in a [[DataFrame]]. + * + * @groupname java_expr_ops Java-specific expression operators. + * @groupname expr_ops Expression operators. + * @groupname df_ops DataFrame functions. + * @groupname Ungrouped Support functions for DataFrames. + */ +@Experimental +class Column(protected[sql] val expr: Expression) extends Logging { + + def this(name: String) = this(name match { + case "*" => UnresolvedStar(None) + case _ if name.endsWith(".*") => UnresolvedStar(Some(name.substring(0, name.length - 2))) + case _ => UnresolvedAttribute(name) + }) + + /** Creates a column based on the given expression. */ + implicit private def exprToColumn(newExpr: Expression): Column = new Column(newExpr) + + override def toString: String = expr.prettyString + + override def equals(that: Any): Boolean = that match { + case that: Column => that.expr.equals(this.expr) + case _ => false + } + + override def hashCode: Int = this.expr.hashCode + + /** + * Unary minus, i.e. negate the expression. + * {{{ + * // Scala: select the amount column and negates all values. + * df.select( -df("amount") ) + * + * // Java: + * import static org.apache.spark.sql.functions.*; + * df.select( negate(col("amount") ); + * }}} + * + * @group expr_ops + */ + def unary_- : Column = UnaryMinus(expr) + + /** + * Inversion of boolean expression, i.e. NOT. + * {{ + * // Scala: select rows that are not active (isActive === false) + * df.filter( !df("isActive") ) + * + * // Java: + * import static org.apache.spark.sql.functions.*; + * df.filter( not(df.col("isActive")) ); + * }} + * + * @group expr_ops + */ + def unary_! : Column = Not(expr) + + /** + * Equality test. + * {{{ + * // Scala: + * df.filter( df("colA") === df("colB") ) + * + * // Java + * import static org.apache.spark.sql.functions.*; + * df.filter( col("colA").equalTo(col("colB")) ); + * }}} + * + * @group expr_ops + */ + def === (other: Any): Column = { + val right = lit(other).expr + if (this.expr == right) { + logWarning( + s"Constructing trivially true equals predicate, '${this.expr} = $right'. " + + "Perhaps you need to use aliases.") + } + EqualTo(expr, right) + } + + /** + * Equality test. + * {{{ + * // Scala: + * df.filter( df("colA") === df("colB") ) + * + * // Java + * import static org.apache.spark.sql.functions.*; + * df.filter( col("colA").equalTo(col("colB")) ); + * }}} + * + * @group expr_ops + */ + def equalTo(other: Any): Column = this === other + + /** + * Inequality test. + * {{{ + * // Scala: + * df.select( df("colA") !== df("colB") ) + * df.select( !(df("colA") === df("colB")) ) + * + * // Java: + * import static org.apache.spark.sql.functions.*; + * df.filter( col("colA").notEqual(col("colB")) ); + * }}} + * + * @group expr_ops + */ + def !== (other: Any): Column = Not(EqualTo(expr, lit(other).expr)) + + /** + * Inequality test. + * {{{ + * // Scala: + * df.select( df("colA") !== df("colB") ) + * df.select( !(df("colA") === df("colB")) ) + * + * // Java: + * import static org.apache.spark.sql.functions.*; + * df.filter( col("colA").notEqual(col("colB")) ); + * }}} + * + * @group java_expr_ops + */ + def notEqual(other: Any): Column = Not(EqualTo(expr, lit(other).expr)) + + /** + * Greater than. + * {{{ + * // Scala: The following selects people older than 21. + * people.select( people("age") > 21 ) + * + * // Java: + * import static org.apache.spark.sql.functions.*; + * people.select( people("age").gt(21) ); + * }}} + * + * @group expr_ops + */ + def > (other: Any): Column = GreaterThan(expr, lit(other).expr) + + /** + * Greater than. + * {{{ + * // Scala: The following selects people older than 21. + * people.select( people("age") > lit(21) ) + * + * // Java: + * import static org.apache.spark.sql.functions.*; + * people.select( people("age").gt(21) ); + * }}} + * + * @group java_expr_ops + */ + def gt(other: Any): Column = this > other + + /** + * Less than. + * {{{ + * // Scala: The following selects people younger than 21. + * people.select( people("age") < 21 ) + * + * // Java: + * people.select( people("age").lt(21) ); + * }}} + * + * @group expr_ops + */ + def < (other: Any): Column = LessThan(expr, lit(other).expr) + + /** + * Less than. + * {{{ + * // Scala: The following selects people younger than 21. + * people.select( people("age") < 21 ) + * + * // Java: + * people.select( people("age").lt(21) ); + * }}} + * + * @group java_expr_ops + */ + def lt(other: Any): Column = this < other + + /** + * Less than or equal to. + * {{{ + * // Scala: The following selects people age 21 or younger than 21. + * people.select( people("age") <= 21 ) + * + * // Java: + * people.select( people("age").leq(21) ); + * }}} + * + * @group expr_ops + */ + def <= (other: Any): Column = LessThanOrEqual(expr, lit(other).expr) + + /** + * Less than or equal to. + * {{{ + * // Scala: The following selects people age 21 or younger than 21. + * people.select( people("age") <= 21 ) + * + * // Java: + * people.select( people("age").leq(21) ); + * }}} + * + * @group java_expr_ops + */ + def leq(other: Any): Column = this <= other + + /** + * Greater than or equal to an expression. + * {{{ + * // Scala: The following selects people age 21 or older than 21. + * people.select( people("age") >= 21 ) + * + * // Java: + * people.select( people("age").geq(21) ) + * }}} + * + * @group expr_ops + */ + def >= (other: Any): Column = GreaterThanOrEqual(expr, lit(other).expr) + + /** + * Greater than or equal to an expression. + * {{{ + * // Scala: The following selects people age 21 or older than 21. + * people.select( people("age") >= 21 ) + * + * // Java: + * people.select( people("age").geq(21) ) + * }}} + * + * @group java_expr_ops + */ + def geq(other: Any): Column = this >= other + + /** + * Equality test that is safe for null values. + * + * @group expr_ops + */ + def <=> (other: Any): Column = EqualNullSafe(expr, lit(other).expr) + + /** + * Equality test that is safe for null values. + * + * @group java_expr_ops + */ + def eqNullSafe(other: Any): Column = this <=> other + + /** + * True if the current expression is null. + * + * @group expr_ops + */ + def isNull: Column = IsNull(expr) + + /** + * True if the current expression is NOT null. + * + * @group expr_ops + */ + def isNotNull: Column = IsNotNull(expr) + + /** + * Boolean OR. + * {{{ + * // Scala: The following selects people that are in school or employed. + * people.filter( people("inSchool") || people("isEmployed") ) + * + * // Java: + * people.filter( people("inSchool").or(people("isEmployed")) ); + * }}} + * + * @group expr_ops + */ + def || (other: Any): Column = Or(expr, lit(other).expr) + + /** + * Boolean OR. + * {{{ + * // Scala: The following selects people that are in school or employed. + * people.filter( people("inSchool") || people("isEmployed") ) + * + * // Java: + * people.filter( people("inSchool").or(people("isEmployed")) ); + * }}} + * + * @group java_expr_ops + */ + def or(other: Column): Column = this || other + + /** + * Boolean AND. + * {{{ + * // Scala: The following selects people that are in school and employed at the same time. + * people.select( people("inSchool") && people("isEmployed") ) + * + * // Java: + * people.select( people("inSchool").and(people("isEmployed")) ); + * }}} + * + * @group expr_ops + */ + def && (other: Any): Column = And(expr, lit(other).expr) + + /** + * Boolean AND. + * {{{ + * // Scala: The following selects people that are in school and employed at the same time. + * people.select( people("inSchool") && people("isEmployed") ) + * + * // Java: + * people.select( people("inSchool").and(people("isEmployed")) ); + * }}} + * + * @group java_expr_ops + */ + def and(other: Column): Column = this && other + + /** + * Sum of this expression and another expression. + * {{{ + * // Scala: The following selects the sum of a person's height and weight. + * people.select( people("height") + people("weight") ) + * + * // Java: + * people.select( people("height").plus(people("weight")) ); + * }}} + * + * @group expr_ops + */ + def + (other: Any): Column = Add(expr, lit(other).expr) + + /** + * Sum of this expression and another expression. + * {{{ + * // Scala: The following selects the sum of a person's height and weight. + * people.select( people("height") + people("weight") ) + * + * // Java: + * people.select( people("height").plus(people("weight")) ); + * }}} + * + * @group java_expr_ops + */ + def plus(other: Any): Column = this + other + + /** + * Subtraction. Subtract the other expression from this expression. + * {{{ + * // Scala: The following selects the difference between people's height and their weight. + * people.select( people("height") - people("weight") ) + * + * // Java: + * people.select( people("height").minus(people("weight")) ); + * }}} + * + * @group expr_ops + */ + def - (other: Any): Column = Subtract(expr, lit(other).expr) + + /** + * Subtraction. Subtract the other expression from this expression. + * {{{ + * // Scala: The following selects the difference between people's height and their weight. + * people.select( people("height") - people("weight") ) + * + * // Java: + * people.select( people("height").minus(people("weight")) ); + * }}} + * + * @group java_expr_ops + */ + def minus(other: Any): Column = this - other + + /** + * Multiplication of this expression and another expression. + * {{{ + * // Scala: The following multiplies a person's height by their weight. + * people.select( people("height") * people("weight") ) + * + * // Java: + * people.select( people("height").multiply(people("weight")) ); + * }}} + * + * @group expr_ops + */ + def * (other: Any): Column = Multiply(expr, lit(other).expr) + + /** + * Multiplication of this expression and another expression. + * {{{ + * // Scala: The following multiplies a person's height by their weight. + * people.select( people("height") * people("weight") ) + * + * // Java: + * people.select( people("height").multiply(people("weight")) ); + * }}} + * + * @group java_expr_ops + */ + def multiply(other: Any): Column = this * other + + /** + * Division this expression by another expression. + * {{{ + * // Scala: The following divides a person's height by their weight. + * people.select( people("height") / people("weight") ) + * + * // Java: + * people.select( people("height").divide(people("weight")) ); + * }}} + * + * @group expr_ops + */ + def / (other: Any): Column = Divide(expr, lit(other).expr) + + /** + * Division this expression by another expression. + * {{{ + * // Scala: The following divides a person's height by their weight. + * people.select( people("height") / people("weight") ) + * + * // Java: + * people.select( people("height").divide(people("weight")) ); + * }}} + * + * @group java_expr_ops + */ + def divide(other: Any): Column = this / other + + /** + * Modulo (a.k.a. remainder) expression. + * + * @group expr_ops + */ + def % (other: Any): Column = Remainder(expr, lit(other).expr) + + /** + * Modulo (a.k.a. remainder) expression. + * + * @group java_expr_ops + */ + def mod(other: Any): Column = this % other + + /** + * A boolean expression that is evaluated to true if the value of this expression is contained + * by the evaluated values of the arguments. + * + * @group expr_ops + */ + @scala.annotation.varargs + def in(list: Column*): Column = In(expr, list.map(_.expr)) + + /** + * SQL like expression. + * + * @group expr_ops + */ + def like(literal: String): Column = Like(expr, lit(literal).expr) + + /** + * SQL RLIKE expression (LIKE with Regex). + * + * @group expr_ops + */ + def rlike(literal: String): Column = RLike(expr, lit(literal).expr) + + /** + * An expression that gets an item at position `ordinal` out of an array, + * or gets a value by key `key` in a [[MapType]]. + * + * @group expr_ops + */ + def getItem(key: Any): Column = GetItem(expr, Literal(key)) + + /** + * An expression that gets a field by name in a [[StructType]]. + * + * @group expr_ops + */ + def getField(fieldName: String): Column = UnresolvedGetField(expr, fieldName) + + /** + * An expression that returns a substring. + * @param startPos expression for the starting position. + * @param len expression for the length of the substring. + * + * @group expr_ops + */ + def substr(startPos: Column, len: Column): Column = Substring(expr, startPos.expr, len.expr) + + /** + * An expression that returns a substring. + * @param startPos starting position. + * @param len length of the substring. + * + * @group expr_ops + */ + def substr(startPos: Int, len: Int): Column = Substring(expr, lit(startPos).expr, lit(len).expr) + + /** + * Contains the other element. + * + * @group expr_ops + */ + def contains(other: Any): Column = Contains(expr, lit(other).expr) + + /** + * String starts with. + * + * @group expr_ops + */ + def startsWith(other: Column): Column = StartsWith(expr, lit(other).expr) + + /** + * String starts with another string literal. + * + * @group expr_ops + */ + def startsWith(literal: String): Column = this.startsWith(lit(literal)) + + /** + * String ends with. + * + * @group expr_ops + */ + def endsWith(other: Column): Column = EndsWith(expr, lit(other).expr) + + /** + * String ends with another string literal. + * + * @group expr_ops + */ + def endsWith(literal: String): Column = this.endsWith(lit(literal)) + + /** + * Gives the column an alias. + * {{{ + * // Renames colA to colB in select output. + * df.select($"colA".as("colB")) + * }}} + * + * @group expr_ops + */ + def as(alias: String): Column = Alias(expr, alias)() + + /** + * Gives the column an alias. + * {{{ + * // Renames colA to colB in select output. + * df.select($"colA".as('colB)) + * }}} + * + * @group expr_ops + */ + def as(alias: Symbol): Column = Alias(expr, alias.name)() + + /** + * Gives the column an alias with metadata. + * {{{ + * val metadata: Metadata = ... + * df.select($"colA".as("colB", metadata)) + * }}} + * + * @group expr_ops + */ + def as(alias: String, metadata: Metadata): Column = { + Alias(expr, alias)(explicitMetadata = Some(metadata)) + } + + /** + * Casts the column to a different data type. + * {{{ + * // Casts colA to IntegerType. + * import org.apache.spark.sql.types.IntegerType + * df.select(df("colA").cast(IntegerType)) + * + * // equivalent to + * df.select(df("colA").cast("int")) + * }}} + * + * @group expr_ops + */ + def cast(to: DataType): Column = expr match { + // Lift alias out of cast so we can support col.as("name").cast(IntegerType) + case Alias(childExpr, name) => Alias(Cast(childExpr, to), name)() + case _ => Cast(expr, to) + } + + /** + * Casts the column to a different data type, using the canonical string representation + * of the type. The supported types are: `string`, `boolean`, `byte`, `short`, `int`, `long`, + * `float`, `double`, `decimal`, `date`, `timestamp`. + * {{{ + * // Casts colA to integer. + * df.select(df("colA").cast("int")) + * }}} + * + * @group expr_ops + */ + def cast(to: String): Column = cast(DataTypeParser(to)) + + /** + * Returns an ordering used in sorting. + * {{{ + * // Scala: sort a DataFrame by age column in descending order. + * df.sort(df("age").desc) + * + * // Java + * df.sort(df.col("age").desc()); + * }}} + * + * @group expr_ops + */ + def desc: Column = SortOrder(expr, Descending) + + /** + * Returns an ordering used in sorting. + * {{{ + * // Scala: sort a DataFrame by age column in ascending order. + * df.sort(df("age").asc) + * + * // Java + * df.sort(df.col("age").asc()); + * }}} + * + * @group expr_ops + */ + def asc: Column = SortOrder(expr, Ascending) + + /** + * Prints the expression to the console for debugging purpose. + * + * @group df_ops + */ + def explain(extended: Boolean): Unit = { + if (extended) { + println(expr) + } else { + println(expr.prettyString) + } + } +} + + +/** + * :: Experimental :: + * A convenient class used for constructing schema. + */ +@Experimental +class ColumnName(name: String) extends Column(name) { + + /** Creates a new AttributeReference of type boolean */ + def boolean: StructField = StructField(name, BooleanType) + + /** Creates a new AttributeReference of type byte */ + def byte: StructField = StructField(name, ByteType) + + /** Creates a new AttributeReference of type short */ + def short: StructField = StructField(name, ShortType) + + /** Creates a new AttributeReference of type int */ + def int: StructField = StructField(name, IntegerType) + + /** Creates a new AttributeReference of type long */ + def long: StructField = StructField(name, LongType) + + /** Creates a new AttributeReference of type float */ + def float: StructField = StructField(name, FloatType) + + /** Creates a new AttributeReference of type double */ + def double: StructField = StructField(name, DoubleType) + + /** Creates a new AttributeReference of type string */ + def string: StructField = StructField(name, StringType) + + /** Creates a new AttributeReference of type date */ + def date: StructField = StructField(name, DateType) + + /** Creates a new AttributeReference of type decimal */ + def decimal: StructField = StructField(name, DecimalType.Unlimited) + + /** Creates a new AttributeReference of type decimal */ + def decimal(precision: Int, scale: Int): StructField = + StructField(name, DecimalType(precision, scale)) + + /** Creates a new AttributeReference of type timestamp */ + def timestamp: StructField = StructField(name, TimestampType) + + /** Creates a new AttributeReference of type binary */ + def binary: StructField = StructField(name, BinaryType) + + /** Creates a new AttributeReference of type array */ + def array(dataType: DataType): StructField = StructField(name, ArrayType(dataType)) + + /** Creates a new AttributeReference of type map */ + def map(keyType: DataType, valueType: DataType): StructField = + map(MapType(keyType, valueType)) + + def map(mapType: MapType): StructField = StructField(name, mapType) + + /** Creates a new AttributeReference of type struct */ + def struct(fields: StructField*): StructField = struct(StructType(fields)) + + def struct(structType: StructType): StructField = StructField(name, structType) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala new file mode 100644 index 000000000000..ca6ae482eb2a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -0,0 +1,1391 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import java.io.CharArrayWriter +import java.sql.DriverManager + +import scala.collection.JavaConversions._ +import scala.language.implicitConversions +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag +import scala.util.control.NonFatal + +import com.fasterxml.jackson.core.JsonFactory + +import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.python.SerDeUtil +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, ResolvedStar} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} +import org.apache.spark.sql.jdbc.JDBCWriteDetails +import org.apache.spark.sql.json.JsonRDD +import org.apache.spark.sql.types._ +import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect} +import org.apache.spark.util.Utils + + +private[sql] object DataFrame { + def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = { + new DataFrame(sqlContext, logicalPlan) + } +} + + +/** + * :: Experimental :: + * A distributed collection of data organized into named columns. + * + * A [[DataFrame]] is equivalent to a relational table in Spark SQL. There are multiple ways + * to create a [[DataFrame]]: + * {{{ + * // Create a DataFrame from Parquet files + * val people = sqlContext.parquetFile("...") + * + * // Create a DataFrame from data sources + * val df = sqlContext.load("...", "json") + * }}} + * + * Once created, it can be manipulated using the various domain-specific-language (DSL) functions + * defined in: [[DataFrame]] (this class), [[Column]], and [[functions]]. + * + * To select a column from the data frame, use `apply` method in Scala and `col` in Java. + * {{{ + * val ageCol = people("age") // in Scala + * Column ageCol = people.col("age") // in Java + * }}} + * + * Note that the [[Column]] type can also be manipulated through its various functions. + * {{{ + * // The following creates a new column that increases everybody's age by 10. + * people("age") + 10 // in Scala + * people.col("age").plus(10); // in Java + * }}} + * + * A more concrete example in Scala: + * {{{ + * // To create DataFrame using SQLContext + * val people = sqlContext.parquetFile("...") + * val department = sqlContext.parquetFile("...") + * + * people.filter("age > 30") + * .join(department, people("deptId") === department("id")) + * .groupBy(department("name"), "gender") + * .agg(avg(people("salary")), max(people("age"))) + * }}} + * + * and in Java: + * {{{ + * // To create DataFrame using SQLContext + * DataFrame people = sqlContext.parquetFile("..."); + * DataFrame department = sqlContext.parquetFile("..."); + * + * people.filter("age".gt(30)) + * .join(department, people.col("deptId").equalTo(department("id"))) + * .groupBy(department.col("name"), "gender") + * .agg(avg(people.col("salary")), max(people.col("age"))); + * }}} + * + * @groupname basic Basic DataFrame functions + * @groupname dfops Language Integrated Queries + * @groupname rdd RDD Operations + * @groupname output Output Operations + * @groupname action Actions + */ +// TODO: Improve documentation. +@Experimental +class DataFrame private[sql]( + @transient val sqlContext: SQLContext, + @DeveloperApi @transient val queryExecution: SQLContext#QueryExecution) + extends RDDApi[Row] with Serializable { + + /** + * A constructor that automatically analyzes the logical plan. + * + * This reports error eagerly as the [[DataFrame]] is constructed, unless + * [[SQLConf.dataFrameEagerAnalysis]] is turned off. + */ + def this(sqlContext: SQLContext, logicalPlan: LogicalPlan) = { + this(sqlContext, { + val qe = sqlContext.executePlan(logicalPlan) + if (sqlContext.conf.dataFrameEagerAnalysis) { + qe.assertAnalyzed() // This should force analysis and throw errors if there are any + } + qe + }) + } + + @transient protected[sql] val logicalPlan: LogicalPlan = queryExecution.logical match { + // For various commands (like DDL) and queries with side effects, we force query optimization to + // happen right away to let these side effects take place eagerly. + case _: Command | + _: InsertIntoTable | + _: CreateTableAsSelect[_] | + _: CreateTableUsingAsSelect | + _: WriteToFile => + LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) + case _ => + queryExecution.analyzed + } + + /** + * An implicit conversion function internal to this class for us to avoid doing + * "new DataFrame(...)" everywhere. + */ + @inline private implicit def logicalPlanToDataFrame(logicalPlan: LogicalPlan): DataFrame = { + new DataFrame(sqlContext, logicalPlan) + } + + protected[sql] def resolve(colName: String): NamedExpression = { + queryExecution.analyzed.resolve(colName.split("\\."), sqlContext.analyzer.resolver).getOrElse { + throw new AnalysisException( + s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""") + } + } + + protected[sql] def numericColumns: Seq[Expression] = { + schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => + queryExecution.analyzed.resolve(n.name.split("\\."), sqlContext.analyzer.resolver).get + } + } + + /** + * Internal API for Python + * @param numRows Number of rows to show + */ + private[sql] def showString(numRows: Int): String = { + val data = take(numRows) + val numCols = schema.fieldNames.length + + // For cells that are beyond 20 characters, replace it with the first 17 and "..." + val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row => + row.toSeq.map { cell => + val str = if (cell == null) "null" else cell.toString + if (str.length > 20) str.substring(0, 17) + "..." else str + }: Seq[String] + } + + // Compute the width of each column + val colWidths = Array.fill(numCols)(0) + for (row <- rows) { + for ((cell, i) <- row.zipWithIndex) { + colWidths(i) = math.max(colWidths(i), cell.length) + } + } + + // Pad the cells + rows.map { row => + row.zipWithIndex.map { case (cell, i) => + String.format(s"%-${colWidths(i)}s", cell) + }.mkString(" ") + }.mkString("\n") + } + + override def toString: String = { + try { + schema.map(f => s"${f.name}: ${f.dataType.simpleString}").mkString("[", ", ", "]") + } catch { + case NonFatal(e) => + s"Invalid tree; ${e.getMessage}:\n$queryExecution" + } + } + + /** Left here for backward compatibility. */ + @deprecated("1.3.0", "use toDF") + def toSchemaRDD: DataFrame = this + + /** + * Returns the object itself. + * @group basic + */ + // This is declared with parentheses to prevent the Scala compiler from treating + // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. + def toDF(): DataFrame = this + + /** + * Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion + * from a RDD of tuples into a [[DataFrame]] with meaningful names. For example: + * {{{ + * val rdd: RDD[(Int, String)] = ... + * rdd.toDF() // this implicit conversion creates a DataFrame with column name _1 and _2 + * rdd.toDF("id", "name") // this creates a DataFrame with column name "id" and "name" + * }}} + * @group basic + */ + @scala.annotation.varargs + def toDF(colNames: String*): DataFrame = { + require(schema.size == colNames.size, + "The number of columns doesn't match.\n" + + s"Old column names (${schema.size}): " + schema.fields.map(_.name).mkString(", ") + "\n" + + s"New column names (${colNames.size}): " + colNames.mkString(", ")) + + val newCols = logicalPlan.output.zip(colNames).map { case (oldAttribute, newName) => + Column(oldAttribute).as(newName) + } + select(newCols :_*) + } + + /** + * Returns the schema of this [[DataFrame]]. + * @group basic + */ + def schema: StructType = queryExecution.analyzed.schema + + /** + * Returns all column names and their data types as an array. + * @group basic + */ + def dtypes: Array[(String, String)] = schema.fields.map { field => + (field.name, field.dataType.toString) + } + + /** + * Returns all column names as an array. + * @group basic + */ + def columns: Array[String] = schema.fields.map(_.name) + + /** + * Prints the schema to the console in a nice tree format. + * @group basic + */ + def printSchema(): Unit = println(schema.treeString) + + /** + * Prints the plans (logical and physical) to the console for debugging purposes. + * @group basic + */ + def explain(extended: Boolean): Unit = { + ExplainCommand( + queryExecution.logical, + extended = extended).queryExecution.executedPlan.executeCollect().map { + r => println(r.getString(0)) + } + } + + /** + * Only prints the physical plan to the console for debugging purposes. + * @group basic + */ + def explain(): Unit = explain(extended = false) + + /** + * Returns true if the `collect` and `take` methods can be run locally + * (without any Spark executors). + * @group basic + */ + def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation] + + /** + * Displays the [[DataFrame]] in a tabular form. For example: + * {{{ + * year month AVG('Adj Close) MAX('Adj Close) + * 1980 12 0.503218 0.595103 + * 1981 01 0.523289 0.570307 + * 1982 02 0.436504 0.475256 + * 1983 03 0.410516 0.442194 + * 1984 04 0.450090 0.483521 + * }}} + * @param numRows Number of rows to show + * + * @group action + */ + def show(numRows: Int): Unit = println(showString(numRows)) + + /** + * Displays the top 20 rows of [[DataFrame]] in a tabular form. + * @group action + */ + def show(): Unit = show(20) + + /** + * Returns a [[DataFrameNaFunctions]] for working with missing data. + * {{{ + * // Dropping rows containing any null values. + * df.na.drop() + * }}} + * + * @group dfops + */ + def na: DataFrameNaFunctions = new DataFrameNaFunctions(this) + + /** + * Cartesian join with another [[DataFrame]]. + * + * Note that cartesian joins are very expensive without an extra filter that can be pushed down. + * + * @param right Right side of the join operation. + * @group dfops + */ + def join(right: DataFrame): DataFrame = { + Join(logicalPlan, right.logicalPlan, joinType = Inner, None) + } + + /** + * Inner equi-join with another [[DataFrame]] using the given column. + * + * Different from other join functions, the join column will only appear once in the output, + * i.e. similar to SQL's `JOIN USING` syntax. + * + * {{{ + * // Joining df1 and df2 using the column "user_id" + * df1.join(df2, "user_id") + * }}} + * + * Note that if you perform a self-join using this function without aliasing the input + * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since + * there is no way to disambiguate which side of the join you would like to reference. + * + * @param right Right side of the join operation. + * @param usingColumn Name of the column to join on. This column must exist on both sides. + * @group dfops + */ + def join(right: DataFrame, usingColumn: String): DataFrame = { + // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right + // by creating a new instance for one of the branch. + val joined = sqlContext.executePlan( + Join(logicalPlan, right.logicalPlan, joinType = Inner, None)).analyzed.asInstanceOf[Join] + + // Project only one of the join column. + val joinedCol = joined.right.resolve(usingColumn) + Project( + joined.output.filterNot(_ == joinedCol), + Join( + joined.left, + joined.right, + joinType = Inner, + Some(EqualTo(joined.left.resolve(usingColumn), joined.right.resolve(usingColumn)))) + ) + } + + /** + * Inner join with another [[DataFrame]], using the given join expression. + * + * {{{ + * // The following two are equivalent: + * df1.join(df2, $"df1Key" === $"df2Key") + * df1.join(df2).where($"df1Key" === $"df2Key") + * }}} + * @group dfops + */ + def join(right: DataFrame, joinExprs: Column): DataFrame = { + Join(logicalPlan, right.logicalPlan, joinType = Inner, Some(joinExprs.expr)) + } + + /** + * Join with another [[DataFrame]], using the given join expression. The following performs + * a full outer join between `df1` and `df2`. + * + * {{{ + * // Scala: + * import org.apache.spark.sql.functions._ + * df1.join(df2, $"df1Key" === $"df2Key", "outer") + * + * // Java: + * import static org.apache.spark.sql.functions.*; + * df1.join(df2, col("df1Key").equalTo(col("df2Key")), "outer"); + * }}} + * + * @param right Right side of the join. + * @param joinExprs Join expression. + * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. + * @group dfops + */ + def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = { + Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr)) + } + + /** + * Returns a new [[DataFrame]] sorted by the specified column, all in ascending order. + * {{{ + * // The following 3 are equivalent + * df.sort("sortcol") + * df.sort($"sortcol") + * df.sort($"sortcol".asc) + * }}} + * @group dfops + */ + @scala.annotation.varargs + def sort(sortCol: String, sortCols: String*): DataFrame = { + sort((sortCol +: sortCols).map(apply) :_*) + } + + /** + * Returns a new [[DataFrame]] sorted by the given expressions. For example: + * {{{ + * df.sort($"col1", $"col2".desc) + * }}} + * @group dfops + */ + @scala.annotation.varargs + def sort(sortExprs: Column*): DataFrame = { + val sortOrder: Seq[SortOrder] = sortExprs.map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + Sort(sortOrder, global = true, logicalPlan) + } + + /** + * Returns a new [[DataFrame]] sorted by the given expressions. + * This is an alias of the `sort` function. + * @group dfops + */ + @scala.annotation.varargs + def orderBy(sortCol: String, sortCols: String*): DataFrame = sort(sortCol, sortCols :_*) + + /** + * Returns a new [[DataFrame]] sorted by the given expressions. + * This is an alias of the `sort` function. + * @group dfops + */ + @scala.annotation.varargs + def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs :_*) + + /** + * Selects column based on the column name and return it as a [[Column]]. + * @group dfops + */ + def apply(colName: String): Column = col(colName) + + /** + * Selects column based on the column name and return it as a [[Column]]. + * @group dfops + */ + def col(colName: String): Column = colName match { + case "*" => + Column(ResolvedStar(schema.fieldNames.map(resolve))) + case _ => + val expr = resolve(colName) + Column(expr) + } + + /** + * Returns a new [[DataFrame]] with an alias set. + * @group dfops + */ + def as(alias: String): DataFrame = Subquery(alias, logicalPlan) + + /** + * (Scala-specific) Returns a new [[DataFrame]] with an alias set. + * @group dfops + */ + def as(alias: Symbol): DataFrame = as(alias.name) + + /** + * Selects a set of expressions. + * {{{ + * df.select($"colA", $"colB" + 1) + * }}} + * @group dfops + */ + @scala.annotation.varargs + def select(cols: Column*): DataFrame = { + val namedExpressions = cols.map { + case Column(expr: NamedExpression) => expr + case Column(expr: Expression) => Alias(expr, expr.prettyString)() + } + Project(namedExpressions.toSeq, logicalPlan) + } + + /** + * Selects a set of columns. This is a variant of `select` that can only select + * existing columns using column names (i.e. cannot construct expressions). + * + * {{{ + * // The following two are equivalent: + * df.select("colA", "colB") + * df.select($"colA", $"colB") + * }}} + * @group dfops + */ + @scala.annotation.varargs + def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)) :_*) + + /** + * Selects a set of SQL expressions. This is a variant of `select` that accepts + * SQL expressions. + * + * {{{ + * df.selectExpr("colA", "colB as newName", "abs(colC)") + * }}} + * @group dfops + */ + @scala.annotation.varargs + def selectExpr(exprs: String*): DataFrame = { + select(exprs.map { expr => + Column(new SqlParser().parseExpression(expr)) + }: _*) + } + + /** + * Filters rows using the given condition. + * {{{ + * // The following are equivalent: + * peopleDf.filter($"age" > 15) + * peopleDf.where($"age" > 15) + * peopleDf($"age" > 15) + * }}} + * @group dfops + */ + def filter(condition: Column): DataFrame = Filter(condition.expr, logicalPlan) + + /** + * Filters rows using the given SQL expression. + * {{{ + * peopleDf.filter("age > 15") + * }}} + * @group dfops + */ + def filter(conditionExpr: String): DataFrame = { + filter(Column(new SqlParser().parseExpression(conditionExpr))) + } + + /** + * Filters rows using the given condition. This is an alias for `filter`. + * {{{ + * // The following are equivalent: + * peopleDf.filter($"age" > 15) + * peopleDf.where($"age" > 15) + * peopleDf($"age" > 15) + * }}} + * @group dfops + */ + def where(condition: Column): DataFrame = filter(condition) + + /** + * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. + * See [[GroupedData]] for all the available aggregate functions. + * + * {{{ + * // Compute the average for all numeric columns grouped by department. + * df.groupBy($"department").avg() + * + * // Compute the max age and average salary, grouped by department and gender. + * df.groupBy($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * @group dfops + */ + @scala.annotation.varargs + def groupBy(cols: Column*): GroupedData = new GroupedData(this, cols.map(_.expr)) + + /** + * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. + * See [[GroupedData]] for all the available aggregate functions. + * + * This is a variant of groupBy that can only group by existing columns using column names + * (i.e. cannot construct expressions). + * + * {{{ + * // Compute the average for all numeric columns grouped by department. + * df.groupBy("department").avg() + * + * // Compute the max age and average salary, grouped by department and gender. + * df.groupBy($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * @group dfops + */ + @scala.annotation.varargs + def groupBy(col1: String, cols: String*): GroupedData = { + val colNames: Seq[String] = col1 +: cols + new GroupedData(this, colNames.map(colName => resolve(colName))) + } + + /** + * (Scala-specific) Compute aggregates by specifying a map from column name to + * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns. + * + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * df.groupBy("department").agg( + * "age" -> "max", + * "expense" -> "sum" + * ) + * }}} + * @group dfops + */ + def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { + groupBy().agg(aggExpr, aggExprs :_*) + } + + /** + * (Scala-specific) Aggregates on the entire [[DataFrame]] without groups. + * {{ + * // df.agg(...) is a shorthand for df.groupBy().agg(...) + * df.agg(Map("age" -> "max", "salary" -> "avg")) + * df.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) + * }} + * @group dfops + */ + def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs) + + /** + * (Java-specific) Aggregates on the entire [[DataFrame]] without groups. + * {{ + * // df.agg(...) is a shorthand for df.groupBy().agg(...) + * df.agg(Map("age" -> "max", "salary" -> "avg")) + * df.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) + * }} + * @group dfops + */ + def agg(exprs: java.util.Map[String, String]): DataFrame = groupBy().agg(exprs) + + /** + * Aggregates on the entire [[DataFrame]] without groups. + * {{ + * // df.agg(...) is a shorthand for df.groupBy().agg(...) + * df.agg(max($"age"), avg($"salary")) + * df.groupBy().agg(max($"age"), avg($"salary")) + * }} + * @group dfops + */ + @scala.annotation.varargs + def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs :_*) + + /** + * Returns a new [[DataFrame]] by taking the first `n` rows. The difference between this function + * and `head` is that `head` returns an array while `limit` returns a new [[DataFrame]]. + * @group dfops + */ + def limit(n: Int): DataFrame = Limit(Literal(n), logicalPlan) + + /** + * Returns a new [[DataFrame]] containing union of rows in this frame and another frame. + * This is equivalent to `UNION ALL` in SQL. + * @group dfops + */ + def unionAll(other: DataFrame): DataFrame = Union(logicalPlan, other.logicalPlan) + + /** + * Returns a new [[DataFrame]] containing rows only in both this frame and another frame. + * This is equivalent to `INTERSECT` in SQL. + * @group dfops + */ + def intersect(other: DataFrame): DataFrame = Intersect(logicalPlan, other.logicalPlan) + + /** + * Returns a new [[DataFrame]] containing rows in this frame but not in another frame. + * This is equivalent to `EXCEPT` in SQL. + * @group dfops + */ + def except(other: DataFrame): DataFrame = Except(logicalPlan, other.logicalPlan) + + /** + * Returns a new [[DataFrame]] by sampling a fraction of rows. + * + * @param withReplacement Sample with replacement or not. + * @param fraction Fraction of rows to generate. + * @param seed Seed for sampling. + * @group dfops + */ + def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = { + Sample(fraction, withReplacement, seed, logicalPlan) + } + + /** + * Returns a new [[DataFrame]] by sampling a fraction of rows, using a random seed. + * + * @param withReplacement Sample with replacement or not. + * @param fraction Fraction of rows to generate. + * @group dfops + */ + def sample(withReplacement: Boolean, fraction: Double): DataFrame = { + sample(withReplacement, fraction, Utils.random.nextLong) + } + + /** + * (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more + * rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of + * the input row are implicitly joined with each row that is output by the function. + * + * The following example uses this function to count the number of books which contain + * a given word: + * + * {{{ + * case class Book(title: String, words: String) + * val df: RDD[Book] + * + * case class Word(word: String) + * val allWords = df.explode('words) { + * case Row(words: String) => words.split(" ").map(Word(_)) + * } + * + * val bookCountPerWord = allWords.groupBy("word").agg(countDistinct("title")) + * }}} + * @group dfops + */ + def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = { + val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] + + val elementTypes = schema.toAttributes.map { attr => (attr.dataType, attr.nullable) } + val names = schema.toAttributes.map(_.name) + + val rowFunction = + f.andThen(_.map(CatalystTypeConverters.convertToCatalyst(_, schema).asInstanceOf[Row])) + val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr)) + + Generate(generator, join = true, outer = false, + qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan) + } + + /** + * (Scala-specific) Returns a new [[DataFrame]] where a single column has been expanded to zero + * or more rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. All + * columns of the input row are implicitly joined with each value that is output by the function. + * + * {{{ + * df.explode("words", "word")(words: String => words.split(" ")) + * }}} + * @group dfops + */ + def explode[A, B : TypeTag](inputColumn: String, outputColumn: String)(f: A => TraversableOnce[B]) + : DataFrame = { + val dataType = ScalaReflection.schemaFor[B].dataType + val attributes = AttributeReference(outputColumn, dataType)() :: Nil + // TODO handle the metadata? + val elementTypes = attributes.map { attr => (attr.dataType, attr.nullable) } + val names = attributes.map(_.name) + + def rowFunction(row: Row): TraversableOnce[Row] = { + f(row(0).asInstanceOf[A]).map(o => Row(CatalystTypeConverters.convertToCatalyst(o, dataType))) + } + val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil) + + Generate(generator, join = true, outer = false, + qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan) + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Returns a new [[DataFrame]] by adding a column. + * @group dfops + */ + def withColumn(colName: String, col: Column): DataFrame = { + val resolver = sqlContext.analyzer.resolver + val replaced = schema.exists(f => resolver(f.name, colName)) + if (replaced) { + val colNames = schema.map { field => + val name = field.name + if (resolver(name, colName)) col.as(colName) else Column(name) + } + select(colNames :_*) + } else { + select(Column("*"), col.as(colName)) + } + } + + /** + * Returns a new [[DataFrame]] with a column renamed. + * @group dfops + */ + def withColumnRenamed(existingName: String, newName: String): DataFrame = { + val resolver = sqlContext.analyzer.resolver + val colNames = schema.map { field => + val name = field.name + if (resolver(name, existingName)) Column(name).as(newName) else Column(name) + } + select(colNames :_*) + } + + /** + * Computes statistics for numeric columns, including count, mean, stddev, min, and max. + * If no columns are given, this function computes statistics for all numerical columns. + * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[DataFrame]]. If you want to + * programmatically compute summary statistics, use the `agg` function instead. + * + * {{{ + * df.describe("age", "height").show() + * + * // output: + * // summary age height + * // count 10.0 10.0 + * // mean 53.3 178.05 + * // stddev 11.6 15.7 + * // min 18.0 163.0 + * // max 92.0 192.0 + * }}} + * + * @group action + */ + @scala.annotation.varargs + def describe(cols: String*): DataFrame = { + + // TODO: Add stddev as an expression, and remove it from here. + def stddevExpr(expr: Expression): Expression = + Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr)))) + + // The list of summary statistics to compute, in the form of expressions. + val statistics = List[(String, Expression => Expression)]( + "count" -> Count, + "mean" -> Average, + "stddev" -> stddevExpr, + "min" -> Min, + "max" -> Max) + + val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList + + val ret: Seq[Row] = if (outputCols.nonEmpty) { + val aggExprs = statistics.flatMap { case (_, colToAgg) => + outputCols.map(c => Column(colToAgg(Column(c).expr)).as(c)) + } + + val row = agg(aggExprs.head, aggExprs.tail: _*).head().toSeq + + // Pivot the data so each summary is one row + row.grouped(outputCols.size).toSeq.zip(statistics).map { + case (aggregation, (statistic, _)) => Row(statistic :: aggregation.toList: _*) + } + } else { + // If there are no output columns, just output a single column that contains the stats. + statistics.map { case (name, _) => Row(name) } + } + + // The first column is string type, and the rest are double type. + val schema = StructType( + StructField("summary", StringType) :: outputCols.map(StructField(_, DoubleType))).toAttributes + LocalRelation(schema, ret) + } + + /** + * Returns the first `n` rows. + * @group action + */ + def head(n: Int): Array[Row] = limit(n).collect() + + /** + * Returns the first row. + * @group action + */ + def head(): Row = head(1).head + + /** + * Returns the first row. Alias for head(). + * @group action + */ + override def first(): Row = head() + + /** + * Returns a new RDD by applying a function to all rows of this DataFrame. + * @group rdd + */ + override def map[R: ClassTag](f: Row => R): RDD[R] = rdd.map(f) + + /** + * Returns a new RDD by first applying a function to all rows of this [[DataFrame]], + * and then flattening the results. + * @group rdd + */ + override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f) + + /** + * Returns a new RDD by applying a function to each partition of this DataFrame. + * @group rdd + */ + override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = { + rdd.mapPartitions(f) + } + + /** + * Applies a function `f` to all rows. + * @group rdd + */ + override def foreach(f: Row => Unit): Unit = rdd.foreach(f) + + /** + * Applies a function f to each partition of this [[DataFrame]]. + * @group rdd + */ + override def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f) + + /** + * Returns the first `n` rows in the [[DataFrame]]. + * @group action + */ + override def take(n: Int): Array[Row] = head(n) + + /** + * Returns an array that contains all of [[Row]]s in this [[DataFrame]]. + * @group action + */ + override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect() + + /** + * Returns a Java list that contains all of [[Row]]s in this [[DataFrame]]. + * @group action + */ + override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() :_*) + + /** + * Returns the number of rows in the [[DataFrame]]. + * @group action + */ + override def count(): Long = groupBy().count().collect().head.getLong(0) + + /** + * Returns a new [[DataFrame]] that has exactly `numPartitions` partitions. + * @group rdd + */ + override def repartition(numPartitions: Int): DataFrame = { + sqlContext.createDataFrame( + queryExecution.toRdd.map(_.copy()).repartition(numPartitions), + schema, needsConversion = false) + } + + /** + * Returns a new [[DataFrame]] that has exactly `numPartitions` partitions. + * Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g. + * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of + * the 100 new partitions will claim 10 of the current partitions. + * @group rdd + */ + override def coalesce(numPartitions: Int): DataFrame = { + sqlContext.createDataFrame( + queryExecution.toRdd.coalesce(numPartitions), + schema, + needsConversion = false) + } + + /** + * Returns a new [[DataFrame]] that contains only the unique rows from this [[DataFrame]]. + * @group dfops + */ + override def distinct: DataFrame = Distinct(logicalPlan) + + /** + * @group basic + */ + override def persist(): this.type = { + sqlContext.cacheManager.cacheQuery(this) + this + } + + /** + * @group basic + */ + override def cache(): this.type = persist() + + /** + * @group basic + */ + override def persist(newLevel: StorageLevel): this.type = { + sqlContext.cacheManager.cacheQuery(this, None, newLevel) + this + } + + /** + * @group basic + */ + override def unpersist(blocking: Boolean): this.type = { + sqlContext.cacheManager.tryUncacheQuery(this, blocking) + this + } + + /** + * @group basic + */ + override def unpersist(): this.type = unpersist(blocking = false) + + ///////////////////////////////////////////////////////////////////////////// + // I/O + ///////////////////////////////////////////////////////////////////////////// + + /** + * Represents the content of the [[DataFrame]] as an [[RDD]] of [[Row]]s. Note that the RDD is + * memoized. Once called, it won't change even if you change any query planning related Spark SQL + * configurations (e.g. `spark.sql.shuffle.partitions`). + * @group rdd + */ + lazy val rdd: RDD[Row] = { + // use a local variable to make sure the map closure doesn't capture the whole DataFrame + val schema = this.schema + queryExecution.executedPlan.execute().mapPartitions { rows => + val converter = CatalystTypeConverters.createToScalaConverter(schema) + rows.map(converter(_).asInstanceOf[Row]) + } + } + + /** + * Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s. + * @group rdd + */ + def toJavaRDD: JavaRDD[Row] = rdd.toJavaRDD() + + /** + * Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s. + * @group rdd + */ + def javaRDD: JavaRDD[Row] = toJavaRDD + + /** + * Registers this [[DataFrame]] as a temporary table using the given name. The lifetime of this + * temporary table is tied to the [[SQLContext]] that was used to create this DataFrame. + * + * @group basic + */ + def registerTempTable(tableName: String): Unit = { + sqlContext.registerDataFrameAsTable(this, tableName) + } + + /** + * Saves the contents of this [[DataFrame]] as a parquet file, preserving the schema. + * Files that are written out using this method can be read back in as a [[DataFrame]] + * using the `parquetFile` function in [[SQLContext]]. + * @group output + */ + def saveAsParquetFile(path: String): Unit = { + if (sqlContext.conf.parquetUseDataSourceApi) { + save("org.apache.spark.sql.parquet", SaveMode.ErrorIfExists, Map("path" -> path)) + } else { + sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd + } + } + + /** + * :: Experimental :: + * Creates a table from the the contents of this DataFrame. + * It will use the default data source configured by spark.sql.sources.default. + * This will fail if the table already exists. + * + * Note that this currently only works with DataFrames that are created from a HiveContext as + * there is no notion of a persisted catalog in a standard SQL context. Instead you can write + * an RDD out to a parquet file, and then register that file as a table. This "table" can then + * be the target of an `insertInto`. + * @group output + */ + @Experimental + def saveAsTable(tableName: String): Unit = { + saveAsTable(tableName, SaveMode.ErrorIfExists) + } + + /** + * :: Experimental :: + * Creates a table from the the contents of this DataFrame, using the default data source + * configured by spark.sql.sources.default and [[SaveMode.ErrorIfExists]] as the save mode. + * + * Note that this currently only works with DataFrames that are created from a HiveContext as + * there is no notion of a persisted catalog in a standard SQL context. Instead you can write + * an RDD out to a parquet file, and then register that file as a table. This "table" can then + * be the target of an `insertInto`. + * @group output + */ + @Experimental + def saveAsTable(tableName: String, mode: SaveMode): Unit = { + if (sqlContext.catalog.tableExists(Seq(tableName)) && mode == SaveMode.Append) { + // If table already exists and the save mode is Append, + // we will just call insertInto to append the contents of this DataFrame. + insertInto(tableName, overwrite = false) + } else { + val dataSourceName = sqlContext.conf.defaultDataSourceName + saveAsTable(tableName, dataSourceName, mode) + } + } + + /** + * :: Experimental :: + * Creates a table at the given path from the the contents of this DataFrame + * based on a given data source and a set of options, + * using [[SaveMode.ErrorIfExists]] as the save mode. + * + * Note that this currently only works with DataFrames that are created from a HiveContext as + * there is no notion of a persisted catalog in a standard SQL context. Instead you can write + * an RDD out to a parquet file, and then register that file as a table. This "table" can then + * be the target of an `insertInto`. + * @group output + */ + @Experimental + def saveAsTable(tableName: String, source: String): Unit = { + saveAsTable(tableName, source, SaveMode.ErrorIfExists) + } + + /** + * :: Experimental :: + * Creates a table at the given path from the the contents of this DataFrame + * based on a given data source, [[SaveMode]] specified by mode, and a set of options. + * + * Note that this currently only works with DataFrames that are created from a HiveContext as + * there is no notion of a persisted catalog in a standard SQL context. Instead you can write + * an RDD out to a parquet file, and then register that file as a table. This "table" can then + * be the target of an `insertInto`. + * @group output + */ + @Experimental + def saveAsTable(tableName: String, source: String, mode: SaveMode): Unit = { + saveAsTable(tableName, source, mode, Map.empty[String, String]) + } + + /** + * :: Experimental :: + * Creates a table at the given path from the the contents of this DataFrame + * based on a given data source, [[SaveMode]] specified by mode, and a set of options. + * + * Note that this currently only works with DataFrames that are created from a HiveContext as + * there is no notion of a persisted catalog in a standard SQL context. Instead you can write + * an RDD out to a parquet file, and then register that file as a table. This "table" can then + * be the target of an `insertInto`. + * @group output + */ + @Experimental + def saveAsTable( + tableName: String, + source: String, + mode: SaveMode, + options: java.util.Map[String, String]): Unit = { + saveAsTable(tableName, source, mode, options.toMap) + } + + /** + * :: Experimental :: + * (Scala-specific) + * Creates a table from the the contents of this DataFrame based on a given data source, + * [[SaveMode]] specified by mode, and a set of options. + * + * Note that this currently only works with DataFrames that are created from a HiveContext as + * there is no notion of a persisted catalog in a standard SQL context. Instead you can write + * an RDD out to a parquet file, and then register that file as a table. This "table" can then + * be the target of an `insertInto`. + * @group output + */ + @Experimental + def saveAsTable( + tableName: String, + source: String, + mode: SaveMode, + options: Map[String, String]): Unit = { + val cmd = + CreateTableUsingAsSelect( + tableName, + source, + temporary = false, + mode, + options, + logicalPlan) + + sqlContext.executePlan(cmd).toRdd + } + + /** + * :: Experimental :: + * Saves the contents of this DataFrame to the given path, + * using the default data source configured by spark.sql.sources.default and + * [[SaveMode.ErrorIfExists]] as the save mode. + * @group output + */ + @Experimental + def save(path: String): Unit = { + save(path, SaveMode.ErrorIfExists) + } + + /** + * :: Experimental :: + * Saves the contents of this DataFrame to the given path and [[SaveMode]] specified by mode, + * using the default data source configured by spark.sql.sources.default. + * @group output + */ + @Experimental + def save(path: String, mode: SaveMode): Unit = { + val dataSourceName = sqlContext.conf.defaultDataSourceName + save(path, dataSourceName, mode) + } + + /** + * :: Experimental :: + * Saves the contents of this DataFrame to the given path based on the given data source, + * using [[SaveMode.ErrorIfExists]] as the save mode. + * @group output + */ + @Experimental + def save(path: String, source: String): Unit = { + save(source, SaveMode.ErrorIfExists, Map("path" -> path)) + } + + /** + * :: Experimental :: + * Saves the contents of this DataFrame to the given path based on the given data source and + * [[SaveMode]] specified by mode. + * @group output + */ + @Experimental + def save(path: String, source: String, mode: SaveMode): Unit = { + save(source, mode, Map("path" -> path)) + } + + /** + * :: Experimental :: + * Saves the contents of this DataFrame based on the given data source, + * [[SaveMode]] specified by mode, and a set of options. + * @group output + */ + @Experimental + def save( + source: String, + mode: SaveMode, + options: java.util.Map[String, String]): Unit = { + save(source, mode, options.toMap) + } + + /** + * :: Experimental :: + * (Scala-specific) + * Saves the contents of this DataFrame based on the given data source, + * [[SaveMode]] specified by mode, and a set of options + * @group output + */ + @Experimental + def save( + source: String, + mode: SaveMode, + options: Map[String, String]): Unit = { + ResolvedDataSource(sqlContext, source, mode, options, this) + } + + /** + * :: Experimental :: + * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. + * @group output + */ + @Experimental + def insertInto(tableName: String, overwrite: Boolean): Unit = { + sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)), + Map.empty, logicalPlan, overwrite, ifNotExists = false)).toRdd + } + + /** + * :: Experimental :: + * Adds the rows from this RDD to the specified table. + * Throws an exception if the table already exists. + * @group output + */ + @Experimental + def insertInto(tableName: String): Unit = insertInto(tableName, overwrite = false) + + /** + * Returns the content of the [[DataFrame]] as a RDD of JSON strings. + * @group rdd + */ + def toJSON: RDD[String] = { + val rowSchema = this.schema + this.mapPartitions { iter => + val writer = new CharArrayWriter() + // create the Generator without separator inserted between 2 records + val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) + + new Iterator[String] { + override def hasNext: Boolean = iter.hasNext + override def next(): String = { + JsonRDD.rowToJSON(rowSchema, gen)(iter.next()) + gen.flush() + + val json = writer.toString + if (hasNext) { + writer.reset() + } else { + gen.close() + } + + json + } + } + } + } + + //////////////////////////////////////////////////////////////////////////// + // JDBC Write Support + //////////////////////////////////////////////////////////////////////////// + + /** + * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. + * This will run a `CREATE TABLE` and a bunch of `INSERT INTO` statements. + * If you pass `true` for `allowExisting`, it will drop any table with the + * given name; if you pass `false`, it will throw if the table already + * exists. + * @group output + */ + def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = { + val conn = DriverManager.getConnection(url) + try { + if (allowExisting) { + val sql = s"DROP TABLE IF EXISTS $table" + conn.prepareStatement(sql).executeUpdate() + } + val schema = JDBCWriteDetails.schemaString(this, url) + val sql = s"CREATE TABLE $table ($schema)" + conn.prepareStatement(sql).executeUpdate() + } finally { + conn.close() + } + JDBCWriteDetails.saveTable(this, url, table) + } + + /** + * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. + * Assumes the table already exists and has a compatible schema. If you + * pass `true` for `overwrite`, it will `TRUNCATE` the table before + * performing the `INSERT`s. + * + * The table must already exist on the database. It must have a schema + * that is compatible with the schema of this RDD; inserting the rows of + * the RDD in order via the simple statement + * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail. + * @group output + */ + def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = { + if (overwrite) { + val conn = DriverManager.getConnection(url) + try { + val sql = s"TRUNCATE TABLE $table" + conn.prepareStatement(sql).executeUpdate() + } finally { + conn.close() + } + } + JDBCWriteDetails.saveTable(this, url, table) + } + + //////////////////////////////////////////////////////////////////////////// + // for Python API + //////////////////////////////////////////////////////////////////////////// + + /** + * Converts a JavaRDD to a PythonRDD. + */ + protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { + val fieldTypes = schema.fields.map(_.dataType) + val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() + SerDeUtil.javaToPython(jrdd) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala new file mode 100644 index 000000000000..a3187fe3230f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala @@ -0,0 +1,30 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +/** + * A container for a [[DataFrame]], used for implicit conversions. + */ +private[sql] case class DataFrameHolder(df: DataFrame) { + + // This is declared with parentheses to prevent the Scala compiler from treating + // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. + def toDF(): DataFrame = df + + def toDF(colNames: String*): DataFrame = df.toDF(colNames :_*) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala new file mode 100644 index 000000000000..481ed4924857 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -0,0 +1,375 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import java.{lang => jl} + +import scala.collection.JavaConversions._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + + +/** + * :: Experimental :: + * Functionality for working with missing data in [[DataFrame]]s. + */ +@Experimental +final class DataFrameNaFunctions private[sql](df: DataFrame) { + + /** + * Returns a new [[DataFrame]] that drops rows containing any null values. + */ + def drop(): DataFrame = drop("any", df.columns) + + /** + * Returns a new [[DataFrame]] that drops rows containing null values. + * + * If `how` is "any", then drop rows containing any null values. + * If `how` is "all", then drop rows only if every column is null for that row. + */ + def drop(how: String): DataFrame = drop(how, df.columns) + + /** + * Returns a new [[DataFrame]] that drops rows containing any null values + * in the specified columns. + */ + def drop(cols: Array[String]): DataFrame = drop(cols.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame ]] that drops rows containing any null values + * in the specified columns. + */ + def drop(cols: Seq[String]): DataFrame = drop(cols.size, cols) + + /** + * Returns a new [[DataFrame]] that drops rows containing null values + * in the specified columns. + * + * If `how` is "any", then drop rows containing any null values in the specified columns. + * If `how` is "all", then drop rows only if every specified column is null for that row. + */ + def drop(how: String, cols: Array[String]): DataFrame = drop(how, cols.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame]] that drops rows containing null values + * in the specified columns. + * + * If `how` is "any", then drop rows containing any null values in the specified columns. + * If `how` is "all", then drop rows only if every specified column is null for that row. + */ + def drop(how: String, cols: Seq[String]): DataFrame = { + how.toLowerCase match { + case "any" => drop(cols.size, cols) + case "all" => drop(1, cols) + case _ => throw new IllegalArgumentException(s"how ($how) must be 'any' or 'all'") + } + } + + /** + * Returns a new [[DataFrame]] that drops rows containing less than `minNonNulls` non-null values. + */ + def drop(minNonNulls: Int): DataFrame = drop(minNonNulls, df.columns) + + /** + * Returns a new [[DataFrame]] that drops rows containing less than `minNonNulls` non-null + * values in the specified columns. + */ + def drop(minNonNulls: Int, cols: Array[String]): DataFrame = drop(minNonNulls, cols.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame]] that drops rows containing less than + * `minNonNulls` non-null values in the specified columns. + */ + def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = { + // Filtering condition -- only keep the row if it has at least `minNonNulls` non-null values. + val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name))) + df.filter(Column(predicate)) + } + + /** + * Returns a new [[DataFrame]] that replaces null values in numeric columns with `value`. + */ + def fill(value: Double): DataFrame = fill(value, df.columns) + + /** + * Returns a new [[DataFrame ]] that replaces null values in string columns with `value`. + */ + def fill(value: String): DataFrame = fill(value, df.columns) + + /** + * Returns a new [[DataFrame]] that replaces null values in specified numeric columns. + * If a specified column is not a numeric column, it is ignored. + */ + def fill(value: Double, cols: Array[String]): DataFrame = fill(value, cols.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame]] that replaces null values in specified + * numeric columns. If a specified column is not a numeric column, it is ignored. + */ + def fill(value: Double, cols: Seq[String]): DataFrame = { + val columnEquals = df.sqlContext.analyzer.resolver + val projections = df.schema.fields.map { f => + // Only fill if the column is part of the cols list. + if (f.dataType.isInstanceOf[NumericType] && cols.exists(col => columnEquals(f.name, col))) { + fillCol[Double](f, value) + } else { + df.col(f.name) + } + } + df.select(projections : _*) + } + + /** + * Returns a new [[DataFrame]] that replaces null values in specified string columns. + * If a specified column is not a string column, it is ignored. + */ + def fill(value: String, cols: Array[String]): DataFrame = fill(value, cols.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame]] that replaces null values in + * specified string columns. If a specified column is not a string column, it is ignored. + */ + def fill(value: String, cols: Seq[String]): DataFrame = { + val columnEquals = df.sqlContext.analyzer.resolver + val projections = df.schema.fields.map { f => + // Only fill if the column is part of the cols list. + if (f.dataType.isInstanceOf[StringType] && cols.exists(col => columnEquals(f.name, col))) { + fillCol[String](f, value) + } else { + df.col(f.name) + } + } + df.select(projections : _*) + } + + /** + * Returns a new [[DataFrame]] that replaces null values. + * + * The key of the map is the column name, and the value of the map is the replacement value. + * The value must be of the following type: `Integer`, `Long`, `Float`, `Double`, `String`. + * + * For example, the following replaces null values in column "A" with string "unknown", and + * null values in column "B" with numeric value 1.0. + * {{{ + * import com.google.common.collect.ImmutableMap; + * df.na.fill(ImmutableMap.of("A", "unknown", "B", 1.0)); + * }}} + */ + def fill(valueMap: java.util.Map[String, Any]): DataFrame = fill0(valueMap.toSeq) + + /** + * (Scala-specific) Returns a new [[DataFrame]] that replaces null values. + * + * The key of the map is the column name, and the value of the map is the replacement value. + * The value must be of the following type: `Int`, `Long`, `Float`, `Double`, `String`. + * + * For example, the following replaces null values in column "A" with string "unknown", and + * null values in column "B" with numeric value 1.0. + * {{{ + * df.na.fill(Map( + * "A" -> "unknown", + * "B" -> 1.0 + * )) + * }}} + */ + def fill(valueMap: Map[String, Any]): DataFrame = fill0(valueMap.toSeq) + + /** + * Replaces values matching keys in `replacement` map with the corresponding values. + * Key and value of `replacement` map must have the same type, and can only be doubles or strings. + * If `col` is "*", then the replacement is applied on all string columns or numeric columns. + * + * {{{ + * import com.google.common.collect.ImmutableMap; + * + * // Replaces all occurrences of 1.0 with 2.0 in column "height". + * df.replace("height", ImmutableMap.of(1.0, 2.0)); + * + * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name". + * df.replace("name", ImmutableMap.of("UNKNOWN", "unnamed")); + * + * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns. + * df.replace("*", ImmutableMap.of("UNKNOWN", "unnamed")); + * }}} + * + * @param col name of the column to apply the value replacement + * @param replacement value replacement map, as explained above + */ + def replace[T](col: String, replacement: java.util.Map[T, T]): DataFrame = { + replace[T](col, replacement.toMap : Map[T, T]) + } + + /** + * Replaces values matching keys in `replacement` map with the corresponding values. + * Key and value of `replacement` map must have the same type, and can only be doubles or strings. + * + * {{{ + * import com.google.common.collect.ImmutableMap; + * + * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". + * df.replace(new String[] {"height", "weight"}, ImmutableMap.of(1.0, 2.0)); + * + * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname". + * df.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed")); + * }}} + * + * @param cols list of columns to apply the value replacement + * @param replacement value replacement map, as explained above + */ + def replace[T](cols: Array[String], replacement: java.util.Map[T, T]): DataFrame = { + replace(cols.toSeq, replacement.toMap) + } + + /** + * (Scala-specific) Replaces values matching keys in `replacement` map. + * Key and value of `replacement` map must have the same type, and can only be doubles or strings. + * If `col` is "*", then the replacement is applied on all string columns or numeric columns. + * + * {{{ + * // Replaces all occurrences of 1.0 with 2.0 in column "height". + * df.replace("height", Map(1.0 -> 2.0)) + * + * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name". + * df.replace("name", Map("UNKNOWN" -> "unnamed") + * + * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns. + * df.replace("*", Map("UNKNOWN" -> "unnamed") + * }}} + * + * @param col name of the column to apply the value replacement + * @param replacement value replacement map, as explained above + */ + def replace[T](col: String, replacement: Map[T, T]): DataFrame = { + if (col == "*") { + replace0(df.columns, replacement) + } else { + replace0(Seq(col), replacement) + } + } + + /** + * (Scala-specific) Replaces values matching keys in `replacement` map. + * Key and value of `replacement` map must have the same type, and can only be doubles or strings. + * + * {{{ + * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". + * df.replace("height" :: "weight" :: Nil, Map(1.0 -> 2.0)); + * + * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname". + * df.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed"); + * }}} + * + * @param cols list of columns to apply the value replacement + * @param replacement value replacement map, as explained above + */ + def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = replace0(cols, replacement) + + private def replace0[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = { + if (replacement.isEmpty || cols.isEmpty) { + return df + } + + // replacementMap is either Map[String, String] or Map[Double, Double] + val replacementMap: Map[_, _] = replacement.head._2 match { + case v: String => replacement + case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) } + } + + // targetColumnType is either DoubleType or StringType + val targetColumnType = replacement.head._1 match { + case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long => DoubleType + case _: String => StringType + } + + val columnEquals = df.sqlContext.analyzer.resolver + val projections = df.schema.fields.map { f => + val shouldReplace = cols.exists(colName => columnEquals(colName, f.name)) + if (f.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType && shouldReplace) { + replaceCol(f, replacementMap) + } else if (f.dataType == targetColumnType && shouldReplace) { + replaceCol(f, replacementMap) + } else { + df.col(f.name) + } + } + df.select(projections : _*) + } + + private def fill0(values: Seq[(String, Any)]): DataFrame = { + // Error handling + values.foreach { case (colName, replaceValue) => + // Check column name exists + df.resolve(colName) + + // Check data type + replaceValue match { + case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long | _: String => + // This is good + case _ => throw new IllegalArgumentException( + s"Unsupported value type ${replaceValue.getClass.getName} ($replaceValue).") + } + } + + val columnEquals = df.sqlContext.analyzer.resolver + val projections = df.schema.fields.map { f => + values.find { case (k, _) => columnEquals(k, f.name) }.map { case (_, v) => + v match { + case v: jl.Float => fillCol[Double](f, v.toDouble) + case v: jl.Double => fillCol[Double](f, v) + case v: jl.Long => fillCol[Double](f, v.toDouble) + case v: jl.Integer => fillCol[Double](f, v.toDouble) + case v: String => fillCol[String](f, v) + } + }.getOrElse(df.col(f.name)) + } + df.select(projections : _*) + } + + /** + * Returns a [[Column]] expression that replaces null value in `col` with `replacement`. + */ + private def fillCol[T](col: StructField, replacement: T): Column = { + coalesce(df.col(col.name), lit(replacement).cast(col.dataType)).as(col.name) + } + + /** + * Returns a [[Column]] expression that replaces value matching key in `replacementMap` with + * value in `replacementMap`, using [[CaseWhen]]. + * + * TODO: This can be optimized to use broadcast join when replacementMap is large. + */ + private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column = { + val branches: Seq[Expression] = replacementMap.flatMap { case (source, target) => + df.col(col.name).equalTo(lit(source).cast(col.dataType)).expr :: + lit(target).cast(col.dataType).expr :: Nil + }.toSeq + new Column(CaseWhen(branches ++ Seq(df.col(col.name).expr))).as(col.name) + } + + private def convertToDouble(v: Any): Double = v match { + case v: Float => v.toDouble + case v: Double => v + case v: Long => v.toDouble + case v: Int => v.toDouble + case v => throw new IllegalArgumentException( + s"Unsupported value type ${v.getClass.getName} ($v).") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala index f0e6a8f33218..d5d7e35a6b35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala @@ -20,8 +20,13 @@ package org.apache.spark.sql import org.apache.spark.annotation.Experimental /** + * :: Experimental :: * Holder for experimental methods for the bravest. We make NO guarantee about the stability * regarding binary compatibility and source compatibility of methods here. + * + * {{{ + * sqlContext.experimental.extraStrategies += ... + * }}} */ @Experimental class ExperimentalMethods protected[sql](sqlContext: SQLContext) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala new file mode 100644 index 000000000000..53ad67372e02 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.JavaConversions._ +import scala.language.implicitConversions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.types.NumericType + + +/** + * :: Experimental :: + * A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]]. + */ +@Experimental +class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) { + + private[sql] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { + val namedGroupingExprs = groupingExprs.map { + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.prettyString)() + } + DataFrame( + df.sqlContext, Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan)) + } + + private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => Expression) + : Seq[NamedExpression] = { + + val columnExprs = if (colNames.isEmpty) { + // No columns specified. Use all numeric columns. + df.numericColumns + } else { + // Make sure all specified columns are numeric. + colNames.map { colName => + val namedExpr = df.resolve(colName) + if (!namedExpr.dataType.isInstanceOf[NumericType]) { + throw new AnalysisException( + s""""$colName" is not a numeric column. """ + + "Aggregation function can only be applied on a numeric column.") + } + namedExpr + } + } + columnExprs.map { c => + val a = f(c) + Alias(a, a.prettyString)() + } + } + + private[this] def strToExpr(expr: String): (Expression => Expression) = { + expr.toLowerCase match { + case "avg" | "average" | "mean" => Average + case "max" => Max + case "min" => Min + case "sum" => Sum + case "count" | "size" => + // Turn count(*) into count(1) + (inputExpr: Expression) => inputExpr match { + case s: Star => Count(Literal(1)) + case _ => Count(inputExpr) + } + } + } + + /** + * (Scala-specific) Compute aggregates by specifying a map from column name to + * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns. + * + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * df.groupBy("department").agg( + * "age" -> "max", + * "expense" -> "sum" + * ) + * }}} + */ + def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { + agg((aggExpr +: aggExprs).toMap) + } + + /** + * (Scala-specific) Compute aggregates by specifying a map from column name to + * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns. + * + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * df.groupBy("department").agg(Map( + * "age" -> "max", + * "expense" -> "sum" + * )) + * }}} + */ + def agg(exprs: Map[String, String]): DataFrame = { + exprs.map { case (colName, expr) => + val a = strToExpr(expr)(df(colName).expr) + Alias(a, a.prettyString)() + }.toSeq + } + + /** + * (Java-specific) Compute aggregates by specifying a map from column name to + * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns. + * + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * import com.google.common.collect.ImmutableMap; + * df.groupBy("department").agg(ImmutableMap.of("age", "max", "expense", "sum")); + * }}} + */ + def agg(exprs: java.util.Map[String, String]): DataFrame = { + agg(exprs.toMap) + } + + /** + * Compute aggregates by specifying a series of aggregate columns. Unlike other methods in this + * class, the resulting [[DataFrame]] won't automatically include the grouping columns. + * + * The available aggregate methods are defined in [[org.apache.spark.sql.functions]]. + * + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * + * // Scala: + * import org.apache.spark.sql.functions._ + * df.groupBy("department").agg($"department", max($"age"), sum($"expense")) + * + * // Java: + * import static org.apache.spark.sql.functions.*; + * df.groupBy("department").agg(col("department"), max(col("age")), sum(col("expense"))); + * }}} + */ + @scala.annotation.varargs + def agg(expr: Column, exprs: Column*): DataFrame = { + val aggExprs = (expr +: exprs).map(_.expr).map { + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.prettyString)() + } + DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan)) + } + + /** + * Count the number of rows for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + */ + def count(): DataFrame = Seq(Alias(Count(Literal(1)), "count")()) + + /** + * Compute the average value for each numeric columns for each group. This is an alias for `avg`. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the average values for them. + */ + @scala.annotation.varargs + def mean(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames:_*)(Average) + } + + /** + * Compute the max value for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the max values for them. + */ + @scala.annotation.varargs + def max(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames:_*)(Max) + } + + /** + * Compute the mean value for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the mean values for them. + */ + @scala.annotation.varargs + def avg(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames:_*)(Average) + } + + /** + * Compute the min value for each numeric column for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the min values for them. + */ + @scala.annotation.varargs + def min(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames:_*)(Min) + } + + /** + * Compute the sum for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the sum for them. + */ + @scala.annotation.varargs + def sum(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames:_*)(Sum) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala b/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala new file mode 100644 index 000000000000..db484c5f5007 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.beans.Introspector +import java.lang.{Iterable => JIterable} +import java.util.{Iterator => JIterator, Map => JMap} + +import com.google.common.reflect.TypeToken + +import org.apache.spark.sql.types._ + +import scala.language.existentials + +/** + * Type-inference utilities for POJOs and Java collections. + */ +private [sql] object JavaTypeInference { + + private val iterableType = TypeToken.of(classOf[JIterable[_]]) + private val mapType = TypeToken.of(classOf[JMap[_, _]]) + private val iteratorReturnType = classOf[JIterable[_]].getMethod("iterator").getGenericReturnType + private val nextReturnType = classOf[JIterator[_]].getMethod("next").getGenericReturnType + private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType + private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType + + /** + * Infers the corresponding SQL data type of a Java type. + * @param typeToken Java type + * @return (SQL data type, nullable) + */ + private [sql] def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { + // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific. + typeToken.getRawType match { + case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => + (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) + + case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) + case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) + case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) + case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) + case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false) + case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false) + case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false) + case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false) + + case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true) + case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true) + case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true) + case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true) + case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true) + case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true) + case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) + + case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true) + case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) + case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) + + case _ if typeToken.isArray => + val (dataType, nullable) = inferDataType(typeToken.getComponentType) + (ArrayType(dataType, nullable), true) + + case _ if iterableType.isAssignableFrom(typeToken) => + val (dataType, nullable) = inferDataType(elementType(typeToken)) + (ArrayType(dataType, nullable), true) + + case _ if mapType.isAssignableFrom(typeToken) => + val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]] + val mapSupertype = typeToken2.getSupertype(classOf[JMap[_, _]]) + val keyType = elementType(mapSupertype.resolveType(keySetReturnType)) + val valueType = elementType(mapSupertype.resolveType(valuesReturnType)) + val (keyDataType, _) = inferDataType(keyType) + val (valueDataType, nullable) = inferDataType(valueType) + (MapType(keyDataType, valueDataType, nullable), true) + + case _ => + val beanInfo = Introspector.getBeanInfo(typeToken.getRawType) + val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") + val fields = properties.map { property => + val returnType = typeToken.method(property.getReadMethod).getReturnType + val (dataType, nullable) = inferDataType(returnType) + new StructField(property.getName, dataType, nullable) + } + (new StructType(fields), true) + } + } + + private def elementType(typeToken: TypeToken[_]): TypeToken[_] = { + val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]] + val iterableSupertype = typeToken2.getSupertype(classOf[JIterable[_]]) + val iteratorType = iterableSupertype.resolveType(iteratorReturnType) + val itemType = iteratorType.resolveType(nextReturnType) + itemType + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala b/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala new file mode 100644 index 000000000000..63dbab19947c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala @@ -0,0 +1,67 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import scala.reflect.ClassTag + +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + + +/** + * An internal interface defining the RDD-like methods for [[DataFrame]]. + * Please use [[DataFrame]] directly, and do NOT use this. + */ +private[sql] trait RDDApi[T] { + + def cache(): this.type + + def persist(): this.type + + def persist(newLevel: StorageLevel): this.type + + def unpersist(): this.type + + def unpersist(blocking: Boolean): this.type + + def map[R: ClassTag](f: T => R): RDD[R] + + def flatMap[R: ClassTag](f: T => TraversableOnce[R]): RDD[R] + + def mapPartitions[R: ClassTag](f: Iterator[T] => Iterator[R]): RDD[R] + + def foreach(f: T => Unit): Unit + + def foreachPartition(f: Iterator[T] => Unit): Unit + + def take(n: Int): Array[T] + + def collect(): Array[T] + + def collectAsList(): java.util.List[T] + + def count(): Long + + def first(): T + + def repartition(numPartitions: Int): DataFrame + + def coalesce(numPartitions: Int): DataFrame + + def distinct: DataFrame +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 243dc997078f..4fc5de7e824f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -33,9 +33,13 @@ private[spark] object SQLConf { val DIALECT = "spark.sql.dialect" val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" + val PARQUET_INT96_AS_TIMESTAMP = "spark.sql.parquet.int96AsTimestamp" val PARQUET_CACHE_METADATA = "spark.sql.parquet.cacheMetadata" val PARQUET_COMPRESSION = "spark.sql.parquet.compression.codec" val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.parquet.filterPushdown" + val PARQUET_USE_DATA_SOURCE_API = "spark.sql.parquet.useDataSourceApi" + + val HIVE_VERIFY_PARTITIONPATH = "spark.sql.hive.verifyPartitionPath" val COLUMN_NAME_OF_CORRUPT_RECORD = "spark.sql.columnNameOfCorruptRecord" val BROADCAST_TIMEOUT = "spark.sql.broadcastTimeout" @@ -43,10 +47,25 @@ private[spark] object SQLConf { // Options that control which operators can be chosen by the query planner. These should be // considered hints and may be ignored by future versions of Spark SQL. val EXTERNAL_SORT = "spark.sql.planner.externalSort" + val SORTMERGE_JOIN = "spark.sql.planner.sortMergeJoin" // This is only used for the thriftserver val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool" + // This is used to set the default data source + val DEFAULT_DATA_SOURCE_NAME = "spark.sql.sources.default" + // This is used to control the when we will split a schema's JSON string to multiple pieces + // in order to fit the JSON string in metastore's table property (by default, the value has + // a length restriction of 4000 characters). We will split the JSON string of a schema + // to its length exceeds the threshold. + val SCHEMA_STRING_LENGTH_THRESHOLD = "spark.sql.sources.schemaStringLengthThreshold" + + // Whether to perform eager analysis when constructing a dataframe. + // Set to false when debugging requires the ability to look at invalid query plans. + val DATAFRAME_EAGER_ANALYSIS = "spark.sql.eagerAnalysis" + + val USE_SQL_SERIALIZER2 = "spark.sql.useSerializer2" + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -101,9 +120,24 @@ private[sql] class SQLConf extends Serializable { private[spark] def parquetFilterPushDown = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED, "false").toBoolean + /** When true uses Parquet implementation based on data source API */ + private[spark] def parquetUseDataSourceApi = + getConf(PARQUET_USE_DATA_SOURCE_API, "true").toBoolean + + /** When true uses verifyPartitionPath to prune the path which is not exists. */ + private[spark] def verifyPartitionPath = + getConf(HIVE_VERIFY_PARTITIONPATH, "true").toBoolean + /** When true the planner will use the external sort, which may spill to disk. */ private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT, "false").toBoolean + /** + * Sort merge join would sort the two side of join first, and then iterate both sides together + * only once to get all matches. Using sort merge join can save a lot of memory usage compared + * to HashJoin. + */ + private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN, "false").toBoolean + /** * When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode * that evaluates expressions found in queries. In general this custom code runs much faster @@ -115,6 +149,8 @@ private[sql] class SQLConf extends Serializable { */ private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "false").toBoolean + private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "true").toBoolean + /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to * a broadcast value during the physical executions of join operations. Setting this to -1 @@ -140,6 +176,12 @@ private[sql] class SQLConf extends Serializable { private[spark] def isParquetBinaryAsString: Boolean = getConf(PARQUET_BINARY_AS_STRING, "false").toBoolean + /** + * When set to true, we always treat INT96Values in Parquet files as timestamp. + */ + private[spark] def isParquetINT96AsTimestamp: Boolean = + getConf(PARQUET_INT96_AS_TIMESTAMP, "true").toBoolean + /** * When set to true, partition pruning for in-memory columnar tables is enabled. */ @@ -155,6 +197,17 @@ private[sql] class SQLConf extends Serializable { private[spark] def broadcastTimeout: Int = getConf(BROADCAST_TIMEOUT, (5 * 60).toString).toInt + private[spark] def defaultDataSourceName: String = + getConf(DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.parquet") + + // Do not use a value larger than 4000 as the default value of this property. + // See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information. + private[spark] def schemaStringLengthThreshold: Int = + getConf(SCHEMA_STRING_LENGTH_THRESHOLD, "4000").toInt + + private[spark] def dataFrameEagerAnalysis: Boolean = + getConf(DATAFRAME_EAGER_ANALYSIS, "true").toBoolean + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 0a22968cc780..d8aa67801310 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -20,133 +20,356 @@ package org.apache.spark.sql import java.beans.Introspector import java.util.Properties +import scala.collection.JavaConversions._ import scala.collection.immutable import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.SparkContext -import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental} -import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} +import com.google.common.reflect.TypeToken + +import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.dsl.ExpressionConversions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.execution._ +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, expressions} +import org.apache.spark.sql.execution.{Filter, _} +import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} import org.apache.spark.sql.json._ -import org.apache.spark.sql.sources.{LogicalRelation, BaseRelation, DDLParser, DataSourceStrategy} +import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +import org.apache.spark.{Partition, SparkContext} /** - * :: AlphaComponent :: - * The entry point for running relational queries using Spark. Allows the creation of [[SchemaRDD]] - * objects and the execution of SQL queries. + * The entry point for working with structured data (rows and columns) in Spark. Allows the + * creation of [[DataFrame]] objects as well as the execution of SQL queries. * - * @groupname userf Spark SQL Functions + * @groupname basic Basic Operations + * @groupname ddl_ops Persistent Catalog DDL + * @groupname cachemgmt Cached Table Management + * @groupname genericdata Generic Data Sources + * @groupname specificdata Specific Data Sources + * @groupname config Configuration + * @groupname dataframes Custom DataFrame Creation * @groupname Ungrouped Support functions for language integrated queries. */ -@AlphaComponent class SQLContext(@transient val sparkContext: SparkContext) extends org.apache.spark.Logging - with CacheManager - with ExpressionConversions with Serializable { self => def this(sparkContext: JavaSparkContext) = this(sparkContext.sc) - // Note that this is a lazy val so we can override the default value in subclasses. - protected[sql] lazy val conf: SQLConf = new SQLConf + /** + * @return Spark SQL configuration + */ + protected[sql] def conf = tlSession.get().conf // 出于并发考虑 - /** Set Spark SQL configuration properties. */ + /** + * Set Spark SQL configuration properties. + * + * @group config + */ def setConf(props: Properties): Unit = conf.setConf(props) - /** Set the given Spark SQL configuration property. */ + /** + * Set the given Spark SQL configuration property. + * + * @group config + */ def setConf(key: String, value: String): Unit = conf.setConf(key, value) - /** Return the value of Spark SQL configuration property for the given key. */ + /** + * Return the value of Spark SQL configuration property for the given key. + * + * @group config + */ def getConf(key: String): String = conf.getConf(key) /** * Return the value of Spark SQL configuration property for the given key. If the key is not set * yet, return `defaultValue`. + * + * @group config */ def getConf(key: String, defaultValue: String): String = conf.getConf(key, defaultValue) /** * Return all the configuration properties that have been set (i.e. not the default). * This creates a new copy of the config properties in the form of a Map. + * + * @group config */ def getAllConfs: immutable.Map[String, String] = conf.getAllConfs + // TODO how to handle the temp table per user session? @transient protected[sql] lazy val catalog: Catalog = new SimpleCatalog(true) + // TODO how to handle the temp function per user session? @transient - protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry + protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry(true) @transient protected[sql] lazy val analyzer: Analyzer = - new Analyzer(catalog, functionRegistry, caseSensitive = true) + new Analyzer(catalog, functionRegistry, caseSensitive = true) { + override val extendedResolutionRules = + ExtractPythonUdfs :: + sources.PreInsertCastAndRename :: + Nil + + override val extendedCheckRules = Seq( + sources.PreWriteCheck(catalog) + ) + } @transient protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer @transient - protected[sql] val ddlParser = new DDLParser + protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_)) @transient protected[sql] val sqlParser = { val fallback = new catalyst.SqlParser - new SparkSQLParser(fallback(_)) + new SparkSQLParser(fallback.parse(_)) } protected[sql] def parseSql(sql: String): LogicalPlan = { - ddlParser(sql, false).getOrElse(sqlParser(sql)) + ddlParser.parse(sql, false).getOrElse(sqlParser.parse(sql)) } protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) - protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = - new this.QueryExecution { val logical = plan } + + protected[sql] def executePlan(plan: LogicalPlan) = new this.QueryExecution(plan) + + @transient + protected[sql] val tlSession = new ThreadLocal[SQLSession]() { + override def initialValue: SQLSession = defaultSession + } + + @transient + protected[sql] val defaultSession = createSession() sparkContext.getConf.getAll.foreach { case (key, value) if key.startsWith("spark.sql") => setConf(key, value) case _ => } + @transient + protected[sql] val cacheManager = new CacheManager(this) + + /** + * :: Experimental :: + * A collection of methods that are considered experimental, but can be used to hook into + * the query planner for advanced functionality. + * + * @group basic + */ + @Experimental + @transient + val experimental: ExperimentalMethods = new ExperimentalMethods(this) + + /** + * :: Experimental :: + * Returns a [[DataFrame]] with no rows or columns. + * + * @group basic + */ + @Experimental + @transient + lazy val emptyDataFrame: DataFrame = createDataFrame(sparkContext.emptyRDD[Row], StructType(Nil)) // 比如读parquet文件为空时,则返回该DF + + /** + * A collection of methods for registering user-defined functions (UDF). + * + * The following example registers a Scala closure as UDF: + * {{{ + * sqlContext.udf.register("myUdf", (arg1: Int, arg2: String) => arg2 + arg1) + * }}} + * + * The following example registers a UDF in Java: + * {{{ + * sqlContext.udf().register("myUDF", + * new UDF2() { + * @Override + * public String call(Integer arg1, String arg2) { + * return arg2 + arg1; + * } + * }, DataTypes.StringType); + * }}} + * + * Or, to use Java 8 lambda syntax: + * {{{ + * sqlContext.udf().register("myUDF", + * (Integer arg1, String arg2) -> arg2 + arg1), + * DataTypes.StringType); + * }}} + * + * @group basic + * TODO move to SQLSession? + */ + @transient + val udf: UDFRegistration = new UDFRegistration(this) + + /** + * Returns true if the table is currently cached in-memory. + * @group cachemgmt + */ + def isCached(tableName: String): Boolean = cacheManager.isCached(tableName) + + /** + * Caches the specified table in-memory. + * @group cachemgmt + */ + def cacheTable(tableName: String): Unit = cacheManager.cacheTable(tableName) + + /** + * Removes the specified table from the in-memory cache. + * @group cachemgmt + */ + def uncacheTable(tableName: String): Unit = cacheManager.uncacheTable(tableName) + + /** + * Removes all cached tables from the in-memory cache. + */ + def clearCache(): Unit = cacheManager.clearCache() + + // scalastyle:off + // Disable style checker so "implicits" object can start with lowercase i + /** + * :: Experimental :: + * (Scala-specific) Implicit methods available in Scala for converting + * common Scala objects into [[DataFrame]]s. + * + * {{{ + * val sqlContext = new SQLContext(sc) + * import sqlContext.implicits._ + * }}} + * + * @group basic + */ + @Experimental + object implicits extends Serializable { + // scalastyle:on + + /** Converts $"col name" into an [[Column]]. */ + implicit class StringToColumn(val sc: StringContext) { + def $(args: Any*): ColumnName = { + new ColumnName(sc.s(args :_*)) + } + } + + /** An implicit conversion that turns a Scala `Symbol` into a [[Column]]. */ + implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) + + /** Creates a DataFrame from an RDD of case classes or tuples. */ + implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = { + DataFrameHolder(self.createDataFrame(rdd)) // todo: 为什么搞一个DataFrameHolder? + } + + /** Creates a DataFrame from a local Seq of Product. */ + implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = + { + DataFrameHolder(self.createDataFrame(data)) + } + + // Do NOT add more implicit conversions. They are likely to break source compatibility by + // making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous + // because of [[DoubleRDDFunctions]]. + + /** Creates a single column DataFrame from an RDD[Int]. */ + implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = { + val dataType = IntegerType + val rows = data.mapPartitions { iter => + val row = new SpecificMutableRow(dataType :: Nil) + iter.map { v => + row.setInt(0, v) + row: Row + } + } + DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + } + + /** Creates a single column DataFrame from an RDD[Long]. */ + implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = { + val dataType = LongType + val rows = data.mapPartitions { iter => + val row = new SpecificMutableRow(dataType :: Nil) + iter.map { v => + row.setLong(0, v) + row: Row + } + } + DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + } + + /** Creates a single column DataFrame from an RDD[String]. */ + implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = { + val dataType = StringType + val rows = data.mapPartitions { iter => + val row = new SpecificMutableRow(dataType :: Nil) + iter.map { v => + row.setString(0, v) + row: Row + } + } + DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + } + } + /** - * Creates a SchemaRDD from an RDD of case classes. + * :: Experimental :: + * Creates a DataFrame from an RDD of case classes. * - * @group userf + * @group dataframes */ - implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]): SchemaRDD = { + @Experimental + def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { SparkPlan.currentContext.set(self) - val attributeSeq = ScalaReflection.attributesFor[A] - val schema = StructType.fromAttributes(attributeSeq) + val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] + val attributeSeq = schema.toAttributes val rowRDD = RDDConversions.productToRowRdd(rdd, schema) - new SchemaRDD(this, LogicalRDD(attributeSeq, rowRDD)(self)) + DataFrame(self, LogicalRDD(attributeSeq, rowRDD)(self)) } /** - * Convert a [[BaseRelation]] created for external data sources into a [[SchemaRDD]]. + * :: Experimental :: + * Creates a DataFrame from a local Seq of Product. + * + * @group dataframes */ - def baseRelationToSchemaRDD(baseRelation: BaseRelation): SchemaRDD = { - new SchemaRDD(this, LogicalRelation(baseRelation)) + @Experimental + def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { + SparkPlan.currentContext.set(self) + val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] + val attributeSeq = schema.toAttributes + DataFrame(self, LocalRelation.fromProduct(attributeSeq, data)) + } + + /** + * Convert a [[BaseRelation]] created for external data sources into a [[DataFrame]]. + * + * @group dataframes + */ + def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = { + DataFrame(this, LogicalRelation(baseRelation)) // todo:为什么是 LogicalRelation,按道理这会导致本地表的产生, 看错了 是LogicalRelation 不是 LocalRelation  } /** * :: DeveloperApi :: - * Creates a [[SchemaRDD]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD. + * Creates a [[DataFrame]] from an [[RDD]] containing [[Row]]s using the given schema. * It is important to make sure that the structure of every [[Row]] of the provided RDD matches * the provided schema. Otherwise, there will be runtime exception. * Example: * {{{ * import org.apache.spark.sql._ + * import org.apache.spark.sql.types._ * val sqlContext = new org.apache.spark.sql.SQLContext(sc) * * val schema = @@ -157,24 +380,66 @@ class SQLContext(@transient val sparkContext: SparkContext) * val people = * sc.textFile("examples/src/main/resources/people.txt").map( * _.split(",")).map(p => Row(p(0), p(1).trim.toInt)) - * val peopleSchemaRDD = sqlContext. applySchema(people, schema) - * peopleSchemaRDD.printSchema + * val dataFrame = sqlContext.createDataFrame(people, schema) + * dataFrame.printSchema * // root * // |-- name: string (nullable = false) * // |-- age: integer (nullable = true) * - * peopleSchemaRDD.registerTempTable("people") + * dataFrame.registerTempTable("people") * sqlContext.sql("select name from people").collect.foreach(println) * }}} * - * @group userf + * @group dataframes */ @DeveloperApi - def applySchema(rowRDD: RDD[Row], schema: StructType): SchemaRDD = { - // TODO: use MutableProjection when rowRDD is another SchemaRDD and the applied + def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = { + createDataFrame(rowRDD, schema, needsConversion = true) + } + + /** + * Creates a DataFrame from an RDD[Row]. User can specify whether the input rows should be + * converted to Catalyst rows. + */ + private[sql] + def createDataFrame(rowRDD: RDD[Row], schema: StructType, needsConversion: Boolean) = { + // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. - val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self) - new SchemaRDD(this, logicalPlan) + val catalystRows = if (needsConversion) { + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + rowRDD.map(converter(_).asInstanceOf[Row]) + } else { + rowRDD + } + val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self) + DataFrame(this, logicalPlan) + } + + /** + * :: DeveloperApi :: + * Creates a [[DataFrame]] from an [[JavaRDD]] containing [[Row]]s using the given schema. + * It is important to make sure that the structure of every [[Row]] of the provided RDD matches + * the provided schema. Otherwise, there will be runtime exception. + * + * @group dataframes + */ + @DeveloperApi + def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { + createDataFrame(rowRDD.rdd, schema) + } + + /** + * Creates a [[DataFrame]] from an [[JavaRDD]] containing [[Row]]s by applying + * a seq of names of columns to this RDD, the data type for each column will + * be inferred by the first row. + * + * @param rowRDD an JavaRDD of Row + * @param columns names for each column + * @return DataFrame + * @group dataframes + */ + def createDataFrame(rowRDD: JavaRDD[Row], columns: java.util.List[String]): DataFrame = { + createDataFrame(rowRDD.rdd, columns.toSeq) // todo:这个是调用谁? } /** @@ -182,8 +447,9 @@ class SQLContext(@transient val sparkContext: SparkContext) * * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, * SELECT * queries will return the columns in an undefined order. + * @group dataframes */ - def applySchema(rdd: RDD[_], beanClass: Class[_]): SchemaRDD = { + def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = { // 不保序,有点恐怖 val attributeSeq = getSchema(beanClass) val className = beanClass.getName val rowRdd = rdd.mapPartitions { iter => @@ -196,12 +462,12 @@ class SQLContext(@transient val sparkContext: SparkContext) iter.map { row => new GenericRow( extractors.zip(attributeSeq).map { case (e, attr) => - DataTypeConversions.convertJavaToCatalyst(e.invoke(row), attr.dataType) + CatalystTypeConverters.convertToCatalyst(e.invoke(row), attr.dataType) }.toArray[Any] ) : Row } } - new SchemaRDD(this, LogicalRDD(attributeSeq, rowRdd)(this)) + DataFrame(this, LogicalRDD(attributeSeq, rowRdd)(this)) } /** @@ -209,97 +475,454 @@ class SQLContext(@transient val sparkContext: SparkContext) * * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, * SELECT * queries will return the columns in an undefined order. + * @group dataframes */ - def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): SchemaRDD = { - applySchema(rdd.rdd, beanClass) + def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { + createDataFrame(rdd.rdd, beanClass) } /** - * Loads a Parquet file, returning the result as a [[SchemaRDD]]. + * :: DeveloperApi :: + * Creates a [[DataFrame]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD. + * It is important to make sure that the structure of every [[Row]] of the provided RDD matches + * the provided schema. Otherwise, there will be runtime exception. + * Example: + * {{{ + * import org.apache.spark.sql._ + * import org.apache.spark.sql.types._ + * val sqlContext = new org.apache.spark.sql.SQLContext(sc) + * + * val schema = + * StructType( + * StructField("name", StringType, false) :: + * StructField("age", IntegerType, true) :: Nil) + * + * val people = + * sc.textFile("examples/src/main/resources/people.txt").map( + * _.split(",")).map(p => Row(p(0), p(1).trim.toInt)) + * val dataFrame = sqlContext. applySchema(people, schema) + * dataFrame.printSchema + * // root + * // |-- name: string (nullable = false) + * // |-- age: integer (nullable = true) + * + * dataFrame.registerTempTable("people") + * sqlContext.sql("select name from people").collect.foreach(println) + * }}} + */ + @deprecated("use createDataFrame", "1.3.0") + def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { + createDataFrame(rowRDD, schema) + } + + @deprecated("use createDataFrame", "1.3.0") + def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { + createDataFrame(rowRDD, schema) + } + + /** + * Applies a schema to an RDD of Java Beans. + * + * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, + * SELECT * queries will return the columns in an undefined order. + */ + @deprecated("use createDataFrame", "1.3.0") + def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { + createDataFrame(rdd, beanClass) + } + + /** + * Applies a schema to an RDD of Java Beans. * - * @group userf + * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, + * SELECT * queries will return the columns in an undefined order. */ - def parquetFile(path: String): SchemaRDD = - new SchemaRDD(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this)) + @deprecated("use createDataFrame", "1.3.0") + def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { + createDataFrame(rdd, beanClass) + } + + /** + * Loads a Parquet file, returning the result as a [[DataFrame]]. This function returns an empty + * [[DataFrame]] if no paths are passed in. + * + * @group specificdata + */ + @scala.annotation.varargs + def parquetFile(paths: String*): DataFrame = { + if (paths.isEmpty) { + emptyDataFrame + } else if (conf.parquetUseDataSourceApi) { + baseRelationToDataFrame(parquet.ParquetRelation2(paths, Map.empty)(this)) + } else { + DataFrame(this, parquet.ParquetRelation( + paths.mkString(","), Some(sparkContext.hadoopConfiguration), this)) + } + } /** - * Loads a JSON file (one object per line), returning the result as a [[SchemaRDD]]. + * Loads a JSON file (one object per line), returning the result as a [[DataFrame]]. * It goes through the entire dataset once to determine the schema. * - * @group userf + * @group specificdata */ - def jsonFile(path: String): SchemaRDD = jsonFile(path, 1.0) + def jsonFile(path: String): DataFrame = jsonFile(path, 1.0) /** * :: Experimental :: * Loads a JSON file (one object per line) and applies the given schema, - * returning the result as a [[SchemaRDD]]. + * returning the result as a [[DataFrame]]. * - * @group userf + * @group specificdata */ @Experimental - def jsonFile(path: String, schema: StructType): SchemaRDD = { - val json = sparkContext.textFile(path) - jsonRDD(json, schema) - } + def jsonFile(path: String, schema: StructType): DataFrame = + load("json", schema, Map("path" -> path)) /** * :: Experimental :: + * @group specificdata */ @Experimental - def jsonFile(path: String, samplingRatio: Double): SchemaRDD = { - val json = sparkContext.textFile(path) - jsonRDD(json, samplingRatio) - } + def jsonFile(path: String, samplingRatio: Double): DataFrame = + load("json", Map("path" -> path, "samplingRatio" -> samplingRatio.toString)) + + /** + * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a + * [[DataFrame]]. + * It goes through the entire dataset once to determine the schema. + * + * @group specificdata + */ + def jsonRDD(json: RDD[String]): DataFrame = jsonRDD(json, 1.0) + /** * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a - * [[SchemaRDD]]. + * [[DataFrame]]. * It goes through the entire dataset once to determine the schema. * - * @group userf + * @group specificdata */ - def jsonRDD(json: RDD[String]): SchemaRDD = jsonRDD(json, 1.0) + def jsonRDD(json: JavaRDD[String]): DataFrame = jsonRDD(json.rdd, 1.0) /** * :: Experimental :: * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, - * returning the result as a [[SchemaRDD]]. + * returning the result as a [[DataFrame]]. * - * @group userf + * @group specificdata */ @Experimental - def jsonRDD(json: RDD[String], schema: StructType): SchemaRDD = { + def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord val appliedSchema = Option(schema).getOrElse( JsonRDD.nullTypeToStringType( JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord))) val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) - applySchema(rowRDD, appliedSchema) + createDataFrame(rowRDD, appliedSchema, needsConversion = false) } /** * :: Experimental :: + * Loads an JavaRDD storing JSON objects (one object per record) and applies the given + * schema, returning the result as a [[DataFrame]]. + * + * @group specificdata */ @Experimental - def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = { + def jsonRDD(json: JavaRDD[String], schema: StructType): DataFrame = { + jsonRDD(json.rdd, schema) + } + + /** + * :: Experimental :: + * Loads an RDD[String] storing JSON objects (one object per record) inferring the + * schema, returning the result as a [[DataFrame]]. + * + * @group specificdata + */ + @Experimental + def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord val appliedSchema = JsonRDD.nullTypeToStringType( JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord)) val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) - applySchema(rowRDD, appliedSchema) + createDataFrame(rowRDD, appliedSchema, needsConversion = false) } /** - * Registers the given RDD as a temporary table in the catalog. Temporary tables exist only - * during the lifetime of this instance of SQLContext. + * :: Experimental :: + * Loads a JavaRDD[String] storing JSON objects (one object per record) inferring the + * schema, returning the result as a [[DataFrame]]. * - * @group userf + * @group specificdata */ - def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = { - catalog.registerTable(Seq(tableName), rdd.queryExecution.logical) + @Experimental + def jsonRDD(json: JavaRDD[String], samplingRatio: Double): DataFrame = { + jsonRDD(json.rdd, samplingRatio); + } + + /** + * :: Experimental :: + * Returns the dataset stored at path as a DataFrame, + * using the default data source configured by spark.sql.sources.default. + * + * @group genericdata + */ + @Experimental + def load(path: String): DataFrame = { // 使用 默认的外部datasource 来导入文件,这里的问题是可能外部的datasource 不需要path参数,而是要其他参数 + val dataSourceName = conf.defaultDataSourceName + load(path, dataSourceName) + } + + /** + * :: Experimental :: + * Returns the dataset stored at path as a DataFrame, using the given data source. + * + * @group genericdata + */ + @Experimental + def load(path: String, source: String): DataFrame = { + load(source, Map("path" -> path)) + } + + /** + * :: Experimental :: + * (Java-specific) Returns the dataset specified by the given data source and + * a set of options as a DataFrame. + * + * @group genericdata + */ + @Experimental + def load(source: String, options: java.util.Map[String, String]): DataFrame = { + load(source, options.toMap) + } + + /** + * :: Experimental :: + * (Scala-specific) Returns the dataset specified by the given data source and + * a set of options as a DataFrame. + * + * @group genericdata + */ + @Experimental + def load(source: String, options: Map[String, String]): DataFrame = { + val resolved = ResolvedDataSource(this, None, source, options) + DataFrame(this, LogicalRelation(resolved.relation)) + } + + /** + * :: Experimental :: + * (Java-specific) Returns the dataset specified by the given data source and + * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. + * + * @group genericdata + */ + @Experimental + def load( + source: String, + schema: StructType, + options: java.util.Map[String, String]): DataFrame = { + load(source, schema, options.toMap) + } + + /** + * :: Experimental :: + * (Scala-specific) Returns the dataset specified by the given data source and + * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. + * @group genericdata + */ + @Experimental + def load( + source: String, + schema: StructType, + options: Map[String, String]): DataFrame = { + val resolved = ResolvedDataSource(this, Some(schema), source, options) + DataFrame(this, LogicalRelation(resolved.relation)) + } + + /** + * :: Experimental :: + * Creates an external table from the given path and returns the corresponding DataFrame. + * It will use the default data source configured by spark.sql.sources.default. + * + * @group ddl_ops + */ + @Experimental + def createExternalTable(tableName: String, path: String): DataFrame = { + val dataSourceName = conf.defaultDataSourceName + createExternalTable(tableName, path, dataSourceName) + } + + /** + * :: Experimental :: + * Creates an external table from the given path based on a data source + * and returns the corresponding DataFrame. + * + * @group ddl_ops + */ + @Experimental + def createExternalTable( + tableName: String, + path: String, + source: String): DataFrame = { + createExternalTable(tableName, source, Map("path" -> path)) + } + + /** + * :: Experimental :: + * Creates an external table from the given path based on a data source and a set of options. + * Then, returns the corresponding DataFrame. + * + * @group ddl_ops + */ + @Experimental + def createExternalTable( + tableName: String, + source: String, + options: java.util.Map[String, String]): DataFrame = { + createExternalTable(tableName, source, options.toMap) + } + + /** + * :: Experimental :: + * (Scala-specific) + * Creates an external table from the given path based on a data source and a set of options. + * Then, returns the corresponding DataFrame. + * + * @group ddl_ops + */ + @Experimental + def createExternalTable( + tableName: String, + source: String, + options: Map[String, String]): DataFrame = { + val cmd = + CreateTableUsing( + tableName, + userSpecifiedSchema = None, + source, + temporary = false, + options, + allowExisting = false, + managedIfNoPath = false) + executePlan(cmd).toRdd + table(tableName) + } + + /** + * :: Experimental :: + * Create an external table from the given path based on a data source, a schema and + * a set of options. Then, returns the corresponding DataFrame. + * + * @group ddl_ops + */ + @Experimental + def createExternalTable( + tableName: String, + source: String, + schema: StructType, + options: java.util.Map[String, String]): DataFrame = { + createExternalTable(tableName, source, schema, options.toMap) + } + + /** + * :: Experimental :: + * (Scala-specific) + * Create an external table from the given path based on a data source, a schema and + * a set of options. Then, returns the corresponding DataFrame. + * + * @group ddl_ops + */ + @Experimental + def createExternalTable( + tableName: String, + source: String, + schema: StructType, + options: Map[String, String]): DataFrame = { + val cmd = + CreateTableUsing( + tableName, + userSpecifiedSchema = Some(schema), + source, + temporary = false, + options, + allowExisting = false, + managedIfNoPath = false) + executePlan(cmd).toRdd + table(tableName) + } + + /** + * :: Experimental :: + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table. + * + * @group specificdata + */ + @Experimental + def jdbc(url: String, table: String): DataFrame = { + jdbc(url, table, JDBCRelation.columnPartition(null)) + } + + /** + * :: Experimental :: + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table. Partitions of the table will be retrieved in parallel based on the parameters + * passed to this function. + * + * @param columnName the name of a column of integral type that will be used for partitioning. + * @param lowerBound the minimum value of `columnName` used to decide partition stride + * @param upperBound the maximum value of `columnName` used to decide partition stride + * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split + * evenly into this many partitions + * + * @group specificdata + */ + @Experimental + def jdbc( + url: String, + table: String, + columnName: String, + lowerBound: Long, + upperBound: Long, + numPartitions: Int): DataFrame = { + val partitioning = JDBCPartitioningInfo(columnName, lowerBound, upperBound, numPartitions) + val parts = JDBCRelation.columnPartition(partitioning) + jdbc(url, table, parts) + } + + /** + * :: Experimental :: + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table. The theParts parameter gives a list expressions + * suitable for inclusion in WHERE clauses; each one defines one partition + * of the [[DataFrame]]. + * + * @group specificdata + */ + @Experimental + def jdbc(url: String, table: String, theParts: Array[String]): DataFrame = { + val parts: Array[Partition] = theParts.zipWithIndex.map { case (part, i) => + JDBCPartition(part, i) : Partition + } + jdbc(url, table, parts) + } + + private def jdbc(url: String, table: String, parts: Array[Partition]): DataFrame = { + val relation = JDBCRelation(url, table, parts)(this) + baseRelationToDataFrame(relation) + } + + /** + * Registers the given [[DataFrame]] as a temporary table in the catalog. Temporary tables exist + * only during the lifetime of this instance of SQLContext. + */ + private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = { + catalog.registerTable(Seq(tableName), df.logicalPlan) } /** @@ -308,73 +931,87 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @param tableName the name of the table to be unregistered. * - * @group userf + * @group basic */ def dropTempTable(tableName: String): Unit = { - tryUncacheQuery(table(tableName)) + cacheManager.tryUncacheQuery(table(tableName)) catalog.unregisterTable(Seq(tableName)) } /** - * Executes a SQL query using Spark, returning the result as a SchemaRDD. The dialect that is + * Executes a SQL query using Spark, returning the result as a [[DataFrame]]. The dialect that is * used for SQL parsing can be configured with 'spark.sql.dialect'. * - * @group userf + * @group basic */ - def sql(sqlText: String): SchemaRDD = { + def sql(sqlText: String): DataFrame = { if (conf.dialect == "sql") { - new SchemaRDD(this, parseSql(sqlText)) + DataFrame(this, parseSql(sqlText)) } else { sys.error(s"Unsupported SQL dialect: ${conf.dialect}") } } - /** Returns the specified table as a SchemaRDD */ - def table(tableName: String): SchemaRDD = - new SchemaRDD(this, catalog.lookupRelation(Seq(tableName))) + /** + * Returns the specified table as a [[DataFrame]]. + * + * @group ddl_ops + */ + def table(tableName: String): DataFrame = + DataFrame(this, catalog.lookupRelation(Seq(tableName))) /** - * A collection of methods that are considered experimental, but can be used to hook into - * the query planner for advanced functionalities. + * Returns a [[DataFrame]] containing names of existing tables in the current database. + * The returned DataFrame has two columns, tableName and isTemporary (a Boolean + * indicating if a table is a temporary one or not). + * + * @group ddl_ops */ - val experimental: ExperimentalMethods = new ExperimentalMethods(this) + def tables(): DataFrame = { + DataFrame(this, ShowTablesCommand(None)) + } /** - * A collection of methods for registering user-defined functions (UDF). + * Returns a [[DataFrame]] containing names of existing tables in the given database. + * The returned DataFrame has two columns, tableName and isTemporary (a Boolean + * indicating if a table is a temporary one or not). * - * The following example registers a Scala closure as UDF: - * {{{ - * sqlContext.udf.register("myUdf", (arg1: Int, arg2: String) => arg2 + arg1) - * }}} + * @group ddl_ops + */ + def tables(databaseName: String): DataFrame = { + DataFrame(this, ShowTablesCommand(Some(databaseName))) + } + + /** + * Returns the names of tables in the current database as an array. * - * The following example registers a UDF in Java: - * {{{ - * sqlContext.udf().register("myUDF", - * new UDF2() { - * @Override - * public String call(Integer arg1, String arg2) { - * return arg2 + arg1; - * } - * }, DataTypes.StringType); - * }}} + * @group ddl_ops + */ + def tableNames(): Array[String] = { + catalog.getTables(None).map { + case (tableName, _) => tableName + }.toArray + } + + /** + * Returns the names of tables in the given database as an array. * - * Or, to use Java 8 lambda syntax: - * {{{ - * sqlContext.udf().register("myUDF", - * (Integer arg1, String arg2) -> arg2 + arg1), - * DataTypes.StringType); - * }}} + * @group ddl_ops */ - val udf: UDFRegistration = new UDFRegistration(this) + def tableNames(databaseName: String): Array[String] = { + catalog.getTables(Some(databaseName)).map { + case (tableName, _) => tableName + }.toArray + } protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext: SparkContext = self.sparkContext val sqlContext: SQLContext = self - def codegenEnabled = self.conf.codegenEnabled + def codegenEnabled: Boolean = self.conf.codegenEnabled - def numPartitions = self.conf.numShufflePartitions + def numPartitions: Int = self.conf.numShufflePartitions def strategies: Seq[Strategy] = experimental.extraStrategies ++ ( @@ -411,7 +1048,8 @@ class SQLContext(@transient val sparkContext: SparkContext) val projectSet = AttributeSet(projectList.flatMap(_.references)) val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) - val filterCondition = prunePushedDownFilters(filterPredicates).reduceLeftOption(And) + val filterCondition = + prunePushedDownFilters(filterPredicates).reduceLeftOption(expressions.And) // Right now we still use a projection even if the only evaluation is applying an alias // to a column. Since this is a no-op, it could be avoided. However, using this @@ -437,7 +1075,7 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] val planner = new SparkPlanner @transient - protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[Row], 1) + protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[Row], 1) // todo: 用来干啥的? 和上面的emptyDF的关系? /** * Prepares a planned SparkPlan for execution by inserting shuffle operations as needed. @@ -445,7 +1083,33 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { val batches = - Batch("Add exchange", Once, AddExchange(self)) :: Nil + Batch("Add exchange", Once, EnsureRequirements(self)) :: Nil // 这里用了once,哈哈 + } + + protected[sql] def openSession(): SQLSession = { + detachSession() + val session = createSession() + tlSession.set(session) + + session + } + + protected[sql] def currentSession(): SQLSession = { + tlSession.get() + } + + protected[sql] def createSession(): SQLSession = { + new this.SQLSession() + } + + protected[sql] def detachSession(): Unit = { + tlSession.remove() + } + + protected[sql] class SQLSession { + // Note that this is a lazy val so we can override the default value in subclasses. + // scala里面lazy用法 很精妙 + protected[sql] lazy val conf: SQLConf = new SQLConf } /** @@ -454,27 +1118,30 @@ class SQLContext(@transient val sparkContext: SparkContext) * access to the intermediate phases of query execution for developers. */ @DeveloperApi - protected abstract class QueryExecution { - def logical: LogicalPlan + protected[sql] class QueryExecution(val logical: LogicalPlan) { // 提出去,不当作内部类,目前已有PR在重构 + def assertAnalyzed(): Unit = analyzer.checkAnalysis(analyzed) - lazy val analyzed = ExtractPythonUdfs(analyzer(logical)) - lazy val withCachedData = useCachedData(analyzed) - lazy val optimizedPlan = optimizer(withCachedData) + lazy val analyzed: LogicalPlan = analyzer.execute(logical) + lazy val withCachedData: LogicalPlan = { + assertAnalyzed() // 确认analyze 没有出错 + cacheManager.useCachedData(analyzed) // 如果该表已经被 cache了,使用cache的表 + } + lazy val optimizedPlan: LogicalPlan = optimizer.execute(withCachedData) // TODO: Don't just pick the first one... - lazy val sparkPlan = { + lazy val sparkPlan: SparkPlan = { SparkPlan.currentContext.set(self) planner(optimizedPlan).next() } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) + lazy val executedPlan: SparkPlan = prepareForExecution.execute(sparkPlan) /** Internal version of the RDD. Avoids copies and has no schema */ lazy val toRdd: RDD[Row] = executedPlan.execute() protected def stringOrError[A](f: => A): String = - try f.toString catch { case e: Throwable => e.toString } + try f.toString catch { case e: Throwable => e.toString } // 短小精悍的小函数 def simpleString: String = s"""== Physical Plan == @@ -485,6 +1152,7 @@ class SQLContext(@transient val sparkContext: SparkContext) // TODO previously will output RDD details by run (${stringOrError(toRdd.toDebugString)}) // however, the `toRdd` will cause the real execution, which is not what we want. // We need to think about how to avoid the side effect. + // todo: 我们可以考虑怎么把RDD的信息也打印出来 s"""== Parsed Logical Plan == |${stringOrError(logical)} |== Analyzed Logical Plan == @@ -512,7 +1180,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ protected[sql] def applySchemaToPythonRDD( rdd: RDD[Array[Any]], - schemaString: String): SchemaRDD = { + schemaString: String): DataFrame = { val schema = parseDataType(schemaString).asInstanceOf[StructType] applySchemaToPythonRDD(rdd, schema) } @@ -522,14 +1190,16 @@ class SQLContext(@transient val sparkContext: SparkContext) */ protected[sql] def applySchemaToPythonRDD( rdd: RDD[Array[Any]], - schema: StructType): SchemaRDD = { + schema: StructType): DataFrame = { def needsConversion(dataType: DataType): Boolean = dataType match { case ByteType => true case ShortType => true + case LongType => true case FloatType => true case DateType => true case TimestampType => true + case StringType => true case ArrayType(_, _) => true case MapType(_, _, _) => true case StructType(_) => true @@ -549,45 +1219,19 @@ class SQLContext(@transient val sparkContext: SparkContext) iter.map { m => new GenericRow(m): Row} } - new SchemaRDD(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) + DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) } /** * Returns a Catalyst Schema for the given java bean class. */ protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = { - // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific. - val beanInfo = Introspector.getBeanInfo(beanClass) - - // Note: The ordering of elements may differ from when the schema is inferred in Scala. - // This is because beanInfo.getPropertyDescriptors gives no guarantees about - // element ordering. - val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") - fields.map { property => - val (dataType, nullable) = property.getPropertyType match { - case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => - (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) - case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) - case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) - case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) - case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) - case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false) - case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false) - case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false) - case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false) - - case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true) - case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true) - case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true) - case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true) - case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true) - case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true) - case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) - case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true) - case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) - case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) - } - AttributeReference(property.getName, dataType, nullable)() + val (dataType, _) = JavaTypeInference.inferDataType(TypeToken.of(beanClass)) + dataType.asInstanceOf[StructType].fields.map { f => + AttributeReference(f.name, f.dataType, f.nullable)() } } + } + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala deleted file mode 100644 index d1e21dffeb8c..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ /dev/null @@ -1,511 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql - -import java.util.{List => JList} - -import scala.collection.JavaConversions._ - -import com.fasterxml.jackson.core.JsonFactory - -import net.razorvine.pickle.Pickler - -import org.apache.spark.{Dependency, OneToOneDependency, Partition, Partitioner, TaskContext} -import org.apache.spark.annotation.{AlphaComponent, Experimental} -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.api.python.SerDeUtil -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython} -import org.apache.spark.sql.json.JsonRDD -import org.apache.spark.sql.types.{BooleanType, StructType} -import org.apache.spark.storage.StorageLevel - -/** - * :: AlphaComponent :: - * An RDD of [[Row]] objects that has an associated schema. In addition to standard RDD functions, - * SchemaRDDs can be used in relational queries, as shown in the examples below. - * - * Importing a SQLContext brings an implicit into scope that automatically converts a standard RDD - * whose elements are scala case classes into a SchemaRDD. This conversion can also be done - * explicitly using the `createSchemaRDD` function on a [[SQLContext]]. - * - * A `SchemaRDD` can also be created by loading data in from external sources. - * Examples are loading data from Parquet files by using the `parquetFile` method on [[SQLContext]] - * and loading JSON datasets by using `jsonFile` and `jsonRDD` methods on [[SQLContext]]. - * - * == SQL Queries == - * A SchemaRDD can be registered as a table in the [[SQLContext]] that was used to create it. Once - * an RDD has been registered as a table, it can be used in the FROM clause of SQL statements. - * - * {{{ - * // One method for defining the schema of an RDD is to make a case class with the desired column - * // names and types. - * case class Record(key: Int, value: String) - * - * val sc: SparkContext // An existing spark context. - * val sqlContext = new SQLContext(sc) - * - * // Importing the SQL context gives access to all the SQL functions and implicit conversions. - * import sqlContext._ - * - * val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) - * // Any RDD containing case classes can be registered as a table. The schema of the table is - * // automatically inferred using scala reflection. - * rdd.registerTempTable("records") - * - * val results: SchemaRDD = sql("SELECT * FROM records") - * }}} - * - * == Language Integrated Queries == - * - * {{{ - * - * case class Record(key: Int, value: String) - * - * val sc: SparkContext // An existing spark context. - * val sqlContext = new SQLContext(sc) - * - * // Importing the SQL context gives access to all the SQL functions and implicit conversions. - * import sqlContext._ - * - * val rdd = sc.parallelize((1 to 100).map(i => Record(i, "val_" + i))) - * - * // Example of language integrated queries. - * rdd.where('key === 1).orderBy('value.asc).select('key).collect() - * }}} - * - * @groupname Query Language Integrated Queries - * @groupdesc Query Functions that create new queries from SchemaRDDs. The - * result of all query functions is also a SchemaRDD, allowing multiple operations to be - * chained using a builder pattern. - * @groupprio Query -2 - * @groupname schema SchemaRDD Functions - * @groupprio schema -1 - * @groupname Ungrouped Base RDD Functions - */ -@AlphaComponent -class SchemaRDD( - @transient val sqlContext: SQLContext, - @transient val baseLogicalPlan: LogicalPlan) - extends RDD[Row](sqlContext.sparkContext, Nil) with SchemaRDDLike { - - def baseSchemaRDD = this - - // ========================================================================================= - // RDD functions: Copy the internal row representation so we present immutable data to users. - // ========================================================================================= - - override def compute(split: Partition, context: TaskContext): Iterator[Row] = - firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala(_, this.schema)) - - override def getPartitions: Array[Partition] = firstParent[Row].partitions - - override protected def getDependencies: Seq[Dependency[_]] = { - schema // Force reification of the schema so it is available on executors. - - List(new OneToOneDependency(queryExecution.toRdd)) - } - - /** - * Returns the schema of this SchemaRDD (represented by a [[StructType]]). - * - * @group schema - */ - lazy val schema: StructType = queryExecution.analyzed.schema - - /** - * Returns a new RDD with each row transformed to a JSON string. - * - * @group schema - */ - def toJSON: RDD[String] = { - val rowSchema = this.schema - this.mapPartitions { iter => - val jsonFactory = new JsonFactory() - iter.map(JsonRDD.rowToJSON(rowSchema, jsonFactory)) - } - } - - - // ======================================================================= - // Query DSL - // ======================================================================= - - /** - * Changes the output of this relation to the given expressions, similar to the `SELECT` clause - * in SQL. - * - * {{{ - * schemaRDD.select('a, 'b + 'c, 'd as 'aliasedName) - * }}} - * - * @param exprs a set of logical expression that will be evaluated for each input row. - * - * @group Query - */ - def select(exprs: Expression*): SchemaRDD = { - val aliases = exprs.zipWithIndex.map { - case (ne: NamedExpression, _) => ne - case (e, i) => Alias(e, s"c$i")() - } - new SchemaRDD(sqlContext, Project(aliases, logicalPlan)) - } - - /** - * Filters the output, only returning those rows where `condition` evaluates to true. - * - * {{{ - * schemaRDD.where('a === 'b) - * schemaRDD.where('a === 1) - * schemaRDD.where('a + 'b > 10) - * }}} - * - * @group Query - */ - def where(condition: Expression): SchemaRDD = - new SchemaRDD(sqlContext, Filter(condition, logicalPlan)) - - /** - * Performs a relational join on two SchemaRDDs - * - * @param otherPlan the [[SchemaRDD]] that should be joined with this one. - * @param joinType One of `Inner`, `LeftOuter`, `RightOuter`, or `FullOuter`. Defaults to `Inner.` - * @param on An optional condition for the join operation. This is equivalent to the `ON` - * clause in standard SQL. In the case of `Inner` joins, specifying a - * `condition` is equivalent to adding `where` clauses after the `join`. - * - * @group Query - */ - def join( - otherPlan: SchemaRDD, - joinType: JoinType = Inner, - on: Option[Expression] = None): SchemaRDD = - new SchemaRDD(sqlContext, Join(logicalPlan, otherPlan.logicalPlan, joinType, on)) - - /** - * Sorts the results by the given expressions. - * {{{ - * schemaRDD.orderBy('a) - * schemaRDD.orderBy('a, 'b) - * schemaRDD.orderBy('a.asc, 'b.desc) - * }}} - * - * @group Query - */ - def orderBy(sortExprs: SortOrder*): SchemaRDD = - new SchemaRDD(sqlContext, Sort(sortExprs, true, logicalPlan)) - - /** - * Sorts the results by the given expressions within partition. - * {{{ - * schemaRDD.sortBy('a) - * schemaRDD.sortBy('a, 'b) - * schemaRDD.sortBy('a.asc, 'b.desc) - * }}} - * - * @group Query - */ - def sortBy(sortExprs: SortOrder*): SchemaRDD = - new SchemaRDD(sqlContext, Sort(sortExprs, false, logicalPlan)) - - @deprecated("use limit with integer argument", "1.1.0") - def limit(limitExpr: Expression): SchemaRDD = - new SchemaRDD(sqlContext, Limit(limitExpr, logicalPlan)) - - /** - * Limits the results by the given integer. - * {{{ - * schemaRDD.limit(10) - * }}} - * @group Query - */ - def limit(limitNum: Int): SchemaRDD = - new SchemaRDD(sqlContext, Limit(Literal(limitNum), logicalPlan)) - - /** - * Performs a grouping followed by an aggregation. - * - * {{{ - * schemaRDD.groupBy('year)(Sum('sales) as 'totalSales) - * }}} - * - * @group Query - */ - def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): SchemaRDD = { - val aliasedExprs = aggregateExprs.map { - case ne: NamedExpression => ne - case e => Alias(e, e.toString)() - } - new SchemaRDD(sqlContext, Aggregate(groupingExprs, aliasedExprs, logicalPlan)) - } - - /** - * Performs an aggregation over all Rows in this RDD. - * This is equivalent to a groupBy with no grouping expressions. - * - * {{{ - * schemaRDD.aggregate(Sum('sales) as 'totalSales) - * }}} - * - * @group Query - */ - def aggregate(aggregateExprs: Expression*): SchemaRDD = { - groupBy()(aggregateExprs: _*) - } - - /** - * Applies a qualifier to the attributes of this relation. Can be used to disambiguate attributes - * with the same name, for example, when performing self-joins. - * - * {{{ - * val x = schemaRDD.where('a === 1).as('x) - * val y = schemaRDD.where('a === 2).as('y) - * x.join(y).where("x.a".attr === "y.a".attr), - * }}} - * - * @group Query - */ - def as(alias: Symbol) = - new SchemaRDD(sqlContext, Subquery(alias.name, logicalPlan)) - - /** - * Combines the tuples of two RDDs with the same schema, keeping duplicates. - * - * @group Query - */ - def unionAll(otherPlan: SchemaRDD) = - new SchemaRDD(sqlContext, Union(logicalPlan, otherPlan.logicalPlan)) - - /** - * Performs a relational except on two SchemaRDDs - * - * @param otherPlan the [[SchemaRDD]] that should be excepted from this one. - * - * @group Query - */ - def except(otherPlan: SchemaRDD): SchemaRDD = - new SchemaRDD(sqlContext, Except(logicalPlan, otherPlan.logicalPlan)) - - /** - * Performs a relational intersect on two SchemaRDDs - * - * @param otherPlan the [[SchemaRDD]] that should be intersected with this one. - * - * @group Query - */ - def intersect(otherPlan: SchemaRDD): SchemaRDD = - new SchemaRDD(sqlContext, Intersect(logicalPlan, otherPlan.logicalPlan)) - - /** - * Filters tuples using a function over the value of the specified column. - * - * {{{ - * schemaRDD.where('a)((a: Int) => ...) - * }}} - * - * @group Query - */ - def where[T1](arg1: Symbol)(udf: (T1) => Boolean) = - new SchemaRDD( - sqlContext, - Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan)) - - /** - * :: Experimental :: - * Returns a sampled version of the underlying dataset. - * - * @group Query - */ - @Experimental - override - def sample( - withReplacement: Boolean = true, - fraction: Double, - seed: Long) = - new SchemaRDD(sqlContext, Sample(fraction, withReplacement, seed, logicalPlan)) - - /** - * :: Experimental :: - * Return the number of elements in the RDD. Unlike the base RDD implementation of count, this - * implementation leverages the query optimizer to compute the count on the SchemaRDD, which - * supports features such as filter pushdown. - * - * @group Query - */ - @Experimental - override def count(): Long = aggregate(Count(Literal(1))).collect().head.getLong(0) - - /** - * :: Experimental :: - * Applies the given Generator, or table generating function, to this relation. - * - * @param generator A table generating function. The API for such functions is likely to change - * in future releases - * @param join when set to true, each output row of the generator is joined with the input row - * that produced it. - * @param outer when set to true, at least one row will be produced for each input row, similar to - * an `OUTER JOIN` in SQL. When no output rows are produced by the generator for a - * given row, a single row will be output, with `NULL` values for each of the - * generated columns. - * @param alias an optional alias that can be used as qualifier for the attributes that are - * produced by this generate operation. - * - * @group Query - */ - @Experimental - def generate( - generator: Generator, - join: Boolean = false, - outer: Boolean = false, - alias: Option[String] = None) = - new SchemaRDD(sqlContext, Generate(generator, join, outer, alias, logicalPlan)) - - /** - * Returns this RDD as a SchemaRDD. Intended primarily to force the invocation of the implicit - * conversion from a standard RDD to a SchemaRDD. - * - * @group schema - */ - def toSchemaRDD = this - - /** - * Converts a JavaRDD to a PythonRDD. It is used by pyspark. - */ - private[sql] def javaToPython: JavaRDD[Array[Byte]] = { - val fieldTypes = schema.fields.map(_.dataType) - val jrdd = this.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() - SerDeUtil.javaToPython(jrdd) - } - - /** - * Serializes the Array[Row] returned by SchemaRDD's optimized collect(), using the same - * format as javaToPython. It is used by pyspark. - */ - private[sql] def collectToPython: JList[Array[Byte]] = { - val fieldTypes = schema.fields.map(_.dataType) - val pickle = new Pickler - new java.util.ArrayList(collect().map { row => - EvaluatePython.rowToArray(row, fieldTypes) - }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable) - } - - /** - * Serializes the Array[Row] returned by SchemaRDD's takeSample(), using the same - * format as javaToPython and collectToPython. It is used by pyspark. - */ - private[sql] def takeSampleToPython( - withReplacement: Boolean, - num: Int, - seed: Long): JList[Array[Byte]] = { - val fieldTypes = schema.fields.map(_.dataType) - val pickle = new Pickler - new java.util.ArrayList(this.takeSample(withReplacement, num, seed).map { row => - EvaluatePython.rowToArray(row, fieldTypes) - }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable) - } - - /** - * Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return value - * of base RDD functions that do not change schema. - * - * @param rdd RDD derived from this one and has same schema - * - * @group schema - */ - private def applySchema(rdd: RDD[Row]): SchemaRDD = { - new SchemaRDD(sqlContext, - LogicalRDD(queryExecution.analyzed.output.map(_.newInstance()), rdd)(sqlContext)) - } - - // ======================================================================= - // Overridden RDD actions - // ======================================================================= - - override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect() - - def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(collect() : _*) - - override def take(num: Int): Array[Row] = limit(num).collect() - - // ======================================================================= - // Base RDD functions that do NOT change schema - // ======================================================================= - - // Transformations (return a new RDD) - - override def coalesce(numPartitions: Int, shuffle: Boolean = false) - (implicit ord: Ordering[Row] = null): SchemaRDD = - applySchema(super.coalesce(numPartitions, shuffle)(ord)) - - override def distinct(): SchemaRDD = applySchema(super.distinct()) - - override def distinct(numPartitions: Int) - (implicit ord: Ordering[Row] = null): SchemaRDD = - applySchema(super.distinct(numPartitions)(ord)) - - def distinct(numPartitions: Int): SchemaRDD = - applySchema(super.distinct(numPartitions)(null)) - - override def filter(f: Row => Boolean): SchemaRDD = - applySchema(super.filter(f)) - - override def intersection(other: RDD[Row]): SchemaRDD = - applySchema(super.intersection(other)) - - override def intersection(other: RDD[Row], partitioner: Partitioner) - (implicit ord: Ordering[Row] = null): SchemaRDD = - applySchema(super.intersection(other, partitioner)(ord)) - - override def intersection(other: RDD[Row], numPartitions: Int): SchemaRDD = - applySchema(super.intersection(other, numPartitions)) - - override def repartition(numPartitions: Int) - (implicit ord: Ordering[Row] = null): SchemaRDD = - applySchema(super.repartition(numPartitions)(ord)) - - override def subtract(other: RDD[Row]): SchemaRDD = - applySchema(super.subtract(other)) - - override def subtract(other: RDD[Row], numPartitions: Int): SchemaRDD = - applySchema(super.subtract(other, numPartitions)) - - override def subtract(other: RDD[Row], p: Partitioner) - (implicit ord: Ordering[Row] = null): SchemaRDD = - applySchema(super.subtract(other, p)(ord)) - - /** Overridden cache function will always use the in-memory columnar caching. */ - override def cache(): this.type = { - sqlContext.cacheQuery(this) - this - } - - override def persist(newLevel: StorageLevel): this.type = { - sqlContext.cacheQuery(this, None, newLevel) - this - } - - override def unpersist(blocking: Boolean): this.type = { - sqlContext.tryUncacheQuery(this, blocking) - this - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala deleted file mode 100644 index 3cf9209465b7..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ /dev/null @@ -1,139 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql - -import org.apache.spark.annotation.{DeveloperApi, Experimental} -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.LogicalRDD - -/** - * Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java) - */ -private[sql] trait SchemaRDDLike { - @transient def sqlContext: SQLContext - @transient val baseLogicalPlan: LogicalPlan - - private[sql] def baseSchemaRDD: SchemaRDD - - /** - * :: DeveloperApi :: - * A lazily computed query execution workflow. All other RDD operations are passed - * through to the RDD that is produced by this workflow. This workflow is produced lazily because - * invoking the whole query optimization pipeline can be expensive. - * - * The query execution is considered a Developer API as phases may be added or removed in future - * releases. This execution is only exposed to provide an interface for inspecting the various - * phases for debugging purposes. Applications should not depend on particular phases existing - * or producing any specific output, even for exactly the same query. - * - * Additionally, the RDD exposed by this execution is not designed for consumption by end users. - * In particular, it does not contain any schema information, and it reuses Row objects - * internally. This object reuse improves performance, but can make programming against the RDD - * more difficult. Instead end users should perform RDD operations on a SchemaRDD directly. - */ - @transient - @DeveloperApi - lazy val queryExecution = sqlContext.executePlan(baseLogicalPlan) - - @transient protected[spark] val logicalPlan: LogicalPlan = baseLogicalPlan match { - // For various commands (like DDL) and queries with side effects, we force query optimization to - // happen right away to let these side effects take place eagerly. - case _: Command | _: InsertIntoTable | _: CreateTableAsSelect[_] |_: WriteToFile => - LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) - case _ => - baseLogicalPlan - } - - override def toString = - s"""${super.toString} - |== Query Plan == - |${queryExecution.simpleString}""".stripMargin.trim - - /** - * Saves the contents of this `SchemaRDD` as a parquet file, preserving the schema. Files that - * are written out using this method can be read back in as a SchemaRDD using the `parquetFile` - * function. - * - * @group schema - */ - def saveAsParquetFile(path: String): Unit = { - sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd - } - - /** - * Registers this RDD as a temporary table using the given name. The lifetime of this temporary - * table is tied to the [[SQLContext]] that was used to create this SchemaRDD. - * - * @group schema - */ - def registerTempTable(tableName: String): Unit = { - sqlContext.registerRDDAsTable(baseSchemaRDD, tableName) - } - - @deprecated("Use registerTempTable instead of registerAsTable.", "1.1") - def registerAsTable(tableName: String): Unit = registerTempTable(tableName) - - /** - * :: Experimental :: - * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. - * - * @group schema - */ - @Experimental - def insertInto(tableName: String, overwrite: Boolean): Unit = - sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)), - Map.empty, logicalPlan, overwrite)).toRdd - - /** - * :: Experimental :: - * Appends the rows from this RDD to the specified table. - * - * @group schema - */ - @Experimental - def insertInto(tableName: String): Unit = insertInto(tableName, overwrite = false) - - /** - * :: Experimental :: - * Creates a table from the the contents of this SchemaRDD. This will fail if the table already - * exists. - * - * Note that this currently only works with SchemaRDDs that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * - * @group schema - */ - @Experimental - def saveAsTable(tableName: String): Unit = - sqlContext.executePlan(CreateTableAsSelect(None, tableName, logicalPlan, false)).toRdd - - /** Returns the schema as a string in the tree format. - * - * @group schema - */ - def schemaString: String = baseSchemaRDD.schema.treeString - - /** Prints out the schema. - * - * @group schema - */ - def printSchema(): Unit = println(schemaString) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala index f1a4053b7911..2972f629438a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala @@ -23,7 +23,7 @@ import scala.util.parsing.combinator.RegexParsers import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.{UncacheTableCommand, CacheTableCommand, SetCommand} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.types.StringType @@ -33,11 +33,13 @@ import org.apache.spark.sql.types.StringType * * @param fallback A function that parses an input string to a logical plan */ +// 最顶层的 parser, 内部回调各个 dialect的 parser +// 目前已有的两个dialect parser是:1 SqlParser 2 ExtendedHiveQlParser。 这两个parser最好重命名为:SqlDialectParser 和 HiveQlDialectParser private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser { // A parser for the key-value part of the "SET [key = [value ]]" syntax private object SetCommandParser extends RegexParsers { - private val key: Parser[String] = "(?m)[^=]+".r + private val key: Parser[String] = "(?m)[^=]+".r // todo: 正则的含义? 需要学习下 private val value: Parser[String] = "(?m).*$".r @@ -57,12 +59,16 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr protected val AS = Keyword("AS") protected val CACHE = Keyword("CACHE") + protected val CLEAR = Keyword("CLEAR") + protected val IN = Keyword("IN") protected val LAZY = Keyword("LAZY") protected val SET = Keyword("SET") + protected val SHOW = Keyword("SHOW") protected val TABLE = Keyword("TABLE") + protected val TABLES = Keyword("TABLES") protected val UNCACHE = Keyword("UNCACHE") - override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | others + override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | show | others private lazy val cache: Parser[LogicalPlan] = CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ { @@ -71,15 +77,22 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr } private lazy val uncache: Parser[LogicalPlan] = - UNCACHE ~ TABLE ~> ident ^^ { - case tableName => UncacheTableCommand(tableName) - } + ( UNCACHE ~ TABLE ~> ident ^^ { + case tableName => UncacheTableCommand(tableName) + } + | CLEAR ~ CACHE ^^^ ClearCacheCommand + ) private lazy val set: Parser[LogicalPlan] = SET ~> restInput ^^ { case input => SetCommandParser(input) } + private lazy val show: Parser[LogicalPlan] = + SHOW ~> TABLES ~ (IN ~> ident).? ^^ { + case _ ~ dbName => ShowTablesCommand(dbName) + } + private lazy val others: Parser[LogicalPlan] = wholeInput ^^ { case input => fallback(input) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala similarity index 56% rename from sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala rename to sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 2e9d037f93c0..b97aaf73529a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -21,7 +21,7 @@ import java.util.{List => JList, Map => JMap} import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.Accumulator +import org.apache.spark.{Accumulator, Logging} import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.api.java._ @@ -32,9 +32,9 @@ import org.apache.spark.sql.types.DataType /** - * Functions for registering user-defined functions. + * Functions for registering user-defined functions. Use [[SQLContext.udf]] to access this. */ -class UDFRegistration (sqlContext: SQLContext) extends org.apache.spark.Logging { +class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { private val functionRegistry = sqlContext.functionRegistry @@ -61,7 +61,7 @@ class UDFRegistration (sqlContext: SQLContext) extends org.apache.spark.Logging val dataType = sqlContext.parseDataType(stringDataType) - def builder(e: Seq[Expression]) = + def builder(e: Seq[Expression]): PythonUDF = PythonUDF( name, command, @@ -78,20 +78,21 @@ class UDFRegistration (sqlContext: SQLContext) extends org.apache.spark.Logging // scalastyle:off - /* registerFunction 0-22 were generated by this script + /* register 0-22 were generated by this script (0 to 22).map { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A${i}: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) - val argDocs = (1 to x).map(i => s" * @tparam A$i type of the UDF argument at position $i.").foldLeft("")(_ + "\n" + _) println(s""" /** * Register a Scala closure of ${x} arguments as user-defined function (UDF). - * @tparam RT return type of UDF.$argDocs + * @tparam RT return type of UDF. */ - def register[$typeTags](name: String, func: Function$x[$types]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) }""") } @@ -116,462 +117,258 @@ class UDFRegistration (sqlContext: SQLContext) extends org.apache.spark.Logging * Register a Scala closure of 0 arguments as user-defined function (UDF). * @tparam RT return type of UDF. */ - def register[RT: TypeTag](name: String, func: Function0[RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 1 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. */ - def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 2 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 3 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 4 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 5 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 6 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 7 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 8 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 9 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 10 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - * @tparam A10 type of the UDF argument at position 10. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 11 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - * @tparam A10 type of the UDF argument at position 10. - * @tparam A11 type of the UDF argument at position 11. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 12 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - * @tparam A10 type of the UDF argument at position 10. - * @tparam A11 type of the UDF argument at position 11. - * @tparam A12 type of the UDF argument at position 12. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 13 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - * @tparam A10 type of the UDF argument at position 10. - * @tparam A11 type of the UDF argument at position 11. - * @tparam A12 type of the UDF argument at position 12. - * @tparam A13 type of the UDF argument at position 13. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 14 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - * @tparam A10 type of the UDF argument at position 10. - * @tparam A11 type of the UDF argument at position 11. - * @tparam A12 type of the UDF argument at position 12. - * @tparam A13 type of the UDF argument at position 13. - * @tparam A14 type of the UDF argument at position 14. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 15 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - * @tparam A10 type of the UDF argument at position 10. - * @tparam A11 type of the UDF argument at position 11. - * @tparam A12 type of the UDF argument at position 12. - * @tparam A13 type of the UDF argument at position 13. - * @tparam A14 type of the UDF argument at position 14. - * @tparam A15 type of the UDF argument at position 15. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 16 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - * @tparam A10 type of the UDF argument at position 10. - * @tparam A11 type of the UDF argument at position 11. - * @tparam A12 type of the UDF argument at position 12. - * @tparam A13 type of the UDF argument at position 13. - * @tparam A14 type of the UDF argument at position 14. - * @tparam A15 type of the UDF argument at position 15. - * @tparam A16 type of the UDF argument at position 16. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 17 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - * @tparam A10 type of the UDF argument at position 10. - * @tparam A11 type of the UDF argument at position 11. - * @tparam A12 type of the UDF argument at position 12. - * @tparam A13 type of the UDF argument at position 13. - * @tparam A14 type of the UDF argument at position 14. - * @tparam A15 type of the UDF argument at position 15. - * @tparam A16 type of the UDF argument at position 16. - * @tparam A17 type of the UDF argument at position 17. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 18 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - * @tparam A10 type of the UDF argument at position 10. - * @tparam A11 type of the UDF argument at position 11. - * @tparam A12 type of the UDF argument at position 12. - * @tparam A13 type of the UDF argument at position 13. - * @tparam A14 type of the UDF argument at position 14. - * @tparam A15 type of the UDF argument at position 15. - * @tparam A16 type of the UDF argument at position 16. - * @tparam A17 type of the UDF argument at position 17. - * @tparam A18 type of the UDF argument at position 18. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 19 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - * @tparam A10 type of the UDF argument at position 10. - * @tparam A11 type of the UDF argument at position 11. - * @tparam A12 type of the UDF argument at position 12. - * @tparam A13 type of the UDF argument at position 13. - * @tparam A14 type of the UDF argument at position 14. - * @tparam A15 type of the UDF argument at position 15. - * @tparam A16 type of the UDF argument at position 16. - * @tparam A17 type of the UDF argument at position 17. - * @tparam A18 type of the UDF argument at position 18. - * @tparam A19 type of the UDF argument at position 19. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 20 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - * @tparam A10 type of the UDF argument at position 10. - * @tparam A11 type of the UDF argument at position 11. - * @tparam A12 type of the UDF argument at position 12. - * @tparam A13 type of the UDF argument at position 13. - * @tparam A14 type of the UDF argument at position 14. - * @tparam A15 type of the UDF argument at position 15. - * @tparam A16 type of the UDF argument at position 16. - * @tparam A17 type of the UDF argument at position 17. - * @tparam A18 type of the UDF argument at position 18. - * @tparam A19 type of the UDF argument at position 19. - * @tparam A20 type of the UDF argument at position 20. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 21 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - * @tparam A10 type of the UDF argument at position 10. - * @tparam A11 type of the UDF argument at position 11. - * @tparam A12 type of the UDF argument at position 12. - * @tparam A13 type of the UDF argument at position 13. - * @tparam A14 type of the UDF argument at position 14. - * @tparam A15 type of the UDF argument at position 15. - * @tparam A16 type of the UDF argument at position 16. - * @tparam A17 type of the UDF argument at position 17. - * @tparam A18 type of the UDF argument at position 18. - * @tparam A19 type of the UDF argument at position 19. - * @tparam A20 type of the UDF argument at position 20. - * @tparam A21 type of the UDF argument at position 21. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** * Register a Scala closure of 22 arguments as user-defined function (UDF). * @tparam RT return type of UDF. - * @tparam A1 type of the UDF argument at position 1. - * @tparam A2 type of the UDF argument at position 2. - * @tparam A3 type of the UDF argument at position 3. - * @tparam A4 type of the UDF argument at position 4. - * @tparam A5 type of the UDF argument at position 5. - * @tparam A6 type of the UDF argument at position 6. - * @tparam A7 type of the UDF argument at position 7. - * @tparam A8 type of the UDF argument at position 8. - * @tparam A9 type of the UDF argument at position 9. - * @tparam A10 type of the UDF argument at position 10. - * @tparam A11 type of the UDF argument at position 11. - * @tparam A12 type of the UDF argument at position 12. - * @tparam A13 type of the UDF argument at position 13. - * @tparam A14 type of the UDF argument at position 14. - * @tparam A15 type of the UDF argument at position 15. - * @tparam A16 type of the UDF argument at position 16. - * @tparam A17 type of the UDF argument at position 17. - * @tparam A18 type of the UDF argument at position 18. - * @tparam A19 type of the UDF argument at position 19. - * @tparam A20 type of the UDF argument at position 20. - * @tparam A21 type of the UDF argument at position 21. - * @tparam A22 type of the UDF argument at position 22. - */ - def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { + val dataType = ScalaReflection.schemaFor[RT].dataType + def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } + ////////////////////////////////////////////////////////////////////////////////////////////// + ////////////////////////////////////////////////////////////////////////////////////////////// + /** * Register a user-defined function with 1 arguments. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala new file mode 100644 index 000000000000..295db539adfc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala @@ -0,0 +1,67 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import java.util.{List => JList, Map => JMap} + +import org.apache.spark.Accumulator +import org.apache.spark.api.python.PythonBroadcast +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.expressions.ScalaUdf +import org.apache.spark.sql.execution.PythonUDF +import org.apache.spark.sql.types.DataType + +/** + * A user-defined function. To create one, use the `udf` functions in [[functions]]. + * As an example: + * {{{ + * // Defined a UDF that returns true or false based on some numeric score. + * val predict = udf((score: Double) => if (score > 0.5) true else false) + * + * // Projects a column that adds a prediction column based on the score column. + * df.select( predict(df("score")) ) + * }}} + */ +case class UserDefinedFunction protected[sql] (f: AnyRef, dataType: DataType) { + + def apply(exprs: Column*): Column = { + Column(ScalaUdf(f, dataType, exprs.map(_.expr))) + } +} + +/** + * A user-defined Python function. To create one, use the `pythonUDF` functions in [[functions]]. + * This is used by Python API. + */ +private[sql] case class UserDefinedPythonFunction( + name: String, + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + broadcastVars: JList[Broadcast[PythonBroadcast]], + accumulator: Accumulator[JList[Array[Byte]]], + dataType: DataType) { + + /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ + def apply(exprs: Column*): Column = { + val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, broadcastVars, + accumulator, dataType, exprs.map(_.expr)) + Column(udf) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/package.scala new file mode 100644 index 000000000000..cbbd005228d4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/package.scala @@ -0,0 +1,23 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +/** + * Contains API classes that are specific to a single language (i.e. Java). + */ +package object api diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala new file mode 100644 index 000000000000..ae77f72998a2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.r + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.api.r.SerDe +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode} + +private[r] object SQLUtils { + def createSQLContext(jsc: JavaSparkContext): SQLContext = { + new SQLContext(jsc) + } + + def getJavaSparkContext(sqlCtx: SQLContext): JavaSparkContext = { + new JavaSparkContext(sqlCtx.sparkContext) + } + + def toSeq[T](arr: Array[T]): Seq[T] = { + arr.toSeq + } + + def createStructType(fields : Seq[StructField]): StructType = { + StructType(fields) + } + + def getSQLDataType(dataType: String): DataType = { + dataType match { + case "byte" => org.apache.spark.sql.types.ByteType + case "integer" => org.apache.spark.sql.types.IntegerType + case "double" => org.apache.spark.sql.types.DoubleType + case "numeric" => org.apache.spark.sql.types.DoubleType + case "character" => org.apache.spark.sql.types.StringType + case "string" => org.apache.spark.sql.types.StringType + case "binary" => org.apache.spark.sql.types.BinaryType + case "raw" => org.apache.spark.sql.types.BinaryType + case "logical" => org.apache.spark.sql.types.BooleanType + case "boolean" => org.apache.spark.sql.types.BooleanType + case "timestamp" => org.apache.spark.sql.types.TimestampType + case "date" => org.apache.spark.sql.types.DateType + case _ => throw new IllegalArgumentException(s"Invaid type $dataType") + } + } + + def createStructField(name: String, dataType: String, nullable: Boolean): StructField = { + val dtObj = getSQLDataType(dataType) + StructField(name, dtObj, nullable) + } + + def createDF(rdd: RDD[Array[Byte]], schema: StructType, sqlContext: SQLContext): DataFrame = { + val num = schema.fields.size + val rowRDD = rdd.map(bytesToRow) + sqlContext.createDataFrame(rowRDD, schema) + } + + // A helper to include grouping columns in Agg() + def aggWithGrouping(gd: GroupedData, exprs: Column*): DataFrame = { + val aggExprs = exprs.map { col => + col.expr match { + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.simpleString)() + } + } + gd.toDF(aggExprs) + } + + def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = { + df.map(r => rowToRBytes(r)) + } + + private[this] def bytesToRow(bytes: Array[Byte]): Row = { + val bis = new ByteArrayInputStream(bytes) + val dis = new DataInputStream(bis) + val num = SerDe.readInt(dis) + Row.fromSeq((0 until num).map { i => + SerDe.readObject(dis) + }.toSeq) + } + + private[this] def rowToRBytes(row: Row): Array[Byte] = { + val bos = new ByteArrayOutputStream() + val dos = new DataOutputStream(bos) + + SerDe.writeInt(dos, row.length) + (0 until row.length).map { idx => + val obj: Object = row(idx).asInstanceOf[Object] + SerDe.writeObject(dos, obj) + } + bos.toByteArray() + } + + def dfToCols(df: DataFrame): Array[Array[Byte]] = { + // localDF is Array[Row] + val localDF = df.collect() + val numCols = df.columns.length + // dfCols is Array[Array[Any]] + val dfCols = convertRowsToColumns(localDF, numCols) + + dfCols.map { col => + colToRBytes(col) + } + } + + def convertRowsToColumns(localDF: Array[Row], numCols: Int): Array[Array[Any]] = { + (0 until numCols).map { colIdx => + localDF.map { row => + row(colIdx) + } + }.toArray + } + + def colToRBytes(col: Array[Any]): Array[Byte] = { + val numRows = col.length + val bos = new ByteArrayOutputStream() + val dos = new DataOutputStream(bos) + + SerDe.writeInt(dos, numRows) + + col.map { item => + val obj: Object = item.asInstanceOf[Object] + SerDe.writeObject(dos, obj) + } + bos.toByteArray() + } + + def saveMode(mode: String): SaveMode = { + mode match { + case "append" => SaveMode.Append + case "overwrite" => SaveMode.Overwrite + case "error" => SaveMode.ErrorIfExists + case "ignore" => SaveMode.Ignore + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index 91c4c105b14e..64449b2659b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -21,7 +21,7 @@ import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor -import org.apache.spark.sql.types.{BinaryType, DataType, NativeType} +import org.apache.spark.sql.types._ /** * An `Iterator` like trait used to extract values from columnar byte buffer. When a value is @@ -48,9 +48,9 @@ private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType]( protected def initialize() {} - def hasNext = buffer.hasRemaining + override def hasNext: Boolean = buffer.hasRemaining - def extractTo(row: MutableRow, ordinal: Int): Unit = { + override def extractTo(row: MutableRow, ordinal: Int): Unit = { extractSingle(row, ordinal) } @@ -61,7 +61,7 @@ private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType]( protected def underlyingBuffer = buffer } -private[sql] abstract class NativeColumnAccessor[T <: NativeType]( +private[sql] abstract class NativeColumnAccessor[T <: AtomicType]( override protected val buffer: ByteBuffer, override protected val columnType: NativeColumnType[T]) extends BasicColumnAccessor(buffer, columnType) @@ -89,6 +89,9 @@ private[sql] class DoubleColumnAccessor(buffer: ByteBuffer) private[sql] class FloatColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, FLOAT) +private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int) + extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale)) + private[sql] class StringColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, STRING) @@ -107,24 +110,28 @@ private[sql] class GenericColumnAccessor(buffer: ByteBuffer) with NullableColumnAccessor private[sql] object ColumnAccessor { - def apply(buffer: ByteBuffer): ColumnAccessor = { + def apply(dataType: DataType, buffer: ByteBuffer): ColumnAccessor = { val dup = buffer.duplicate().order(ByteOrder.nativeOrder) - // The first 4 bytes in the buffer indicate the column type. - val columnTypeId = dup.getInt() - - columnTypeId match { - case INT.typeId => new IntColumnAccessor(dup) - case LONG.typeId => new LongColumnAccessor(dup) - case FLOAT.typeId => new FloatColumnAccessor(dup) - case DOUBLE.typeId => new DoubleColumnAccessor(dup) - case BOOLEAN.typeId => new BooleanColumnAccessor(dup) - case BYTE.typeId => new ByteColumnAccessor(dup) - case SHORT.typeId => new ShortColumnAccessor(dup) - case STRING.typeId => new StringColumnAccessor(dup) - case DATE.typeId => new DateColumnAccessor(dup) - case TIMESTAMP.typeId => new TimestampColumnAccessor(dup) - case BINARY.typeId => new BinaryColumnAccessor(dup) - case GENERIC.typeId => new GenericColumnAccessor(dup) + + // The first 4 bytes in the buffer indicate the column type. This field is not used now, + // because we always know the data type of the column ahead of time. + dup.getInt() + + dataType match { + case IntegerType => new IntColumnAccessor(dup) + case LongType => new LongColumnAccessor(dup) + case FloatType => new FloatColumnAccessor(dup) + case DoubleType => new DoubleColumnAccessor(dup) + case BooleanType => new BooleanColumnAccessor(dup) + case ByteType => new ByteColumnAccessor(dup) + case ShortType => new ShortColumnAccessor(dup) + case StringType => new StringColumnAccessor(dup) + case BinaryType => new BinaryColumnAccessor(dup) + case DateType => new DateColumnAccessor(dup) + case TimestampType => new TimestampColumnAccessor(dup) + case DecimalType.Fixed(precision, scale) if precision < 19 => + new FixedDecimalColumnAccessor(dup, precision, scale) + case _ => new GenericColumnAccessor(dup) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 3a4977b836af..aa10af400c81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -58,7 +58,7 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( override def initialize( initialSize: Int, columnName: String = "", - useCompression: Boolean = false) = { + useCompression: Boolean = false): Unit = { val size = if (initialSize == 0) DEFAULT_INITIAL_BUFFER_SIZE else initialSize this.columnName = columnName @@ -73,7 +73,7 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( columnType.append(row, ordinal, buffer) } - override def build() = { + override def build(): ByteBuffer = { buffer.flip().asInstanceOf[ByteBuffer] } } @@ -84,10 +84,10 @@ private[sql] abstract class ComplexColumnBuilder[T <: DataType, JvmType]( extends BasicColumnBuilder[T, JvmType](columnStats, columnType) with NullableColumnBuilder -private[sql] abstract class NativeColumnBuilder[T <: NativeType]( +private[sql] abstract class NativeColumnBuilder[T <: AtomicType]( override val columnStats: ColumnStats, override val columnType: NativeColumnType[T]) - extends BasicColumnBuilder[T, T#JvmType](columnStats, columnType) + extends BasicColumnBuilder[T, T#InternalType](columnStats, columnType) with NullableColumnBuilder with AllCompressionSchemes with CompressibleColumnBuilder[T] @@ -106,6 +106,13 @@ private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleCol private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT) +private[sql] class FixedDecimalColumnBuilder( + precision: Int, + scale: Int) + extends NativeColumnBuilder( + new FixedDecimalColumnStats, + FIXED_DECIMAL(precision, scale)) + private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) private[sql] class DateColumnBuilder extends NativeColumnBuilder(new DateColumnStats, DATE) @@ -139,25 +146,26 @@ private[sql] object ColumnBuilder { } def apply( - typeId: Int, + dataType: DataType, initialSize: Int = 0, columnName: String = "", useCompression: Boolean = false): ColumnBuilder = { - - val builder = (typeId match { - case INT.typeId => new IntColumnBuilder - case LONG.typeId => new LongColumnBuilder - case FLOAT.typeId => new FloatColumnBuilder - case DOUBLE.typeId => new DoubleColumnBuilder - case BOOLEAN.typeId => new BooleanColumnBuilder - case BYTE.typeId => new ByteColumnBuilder - case SHORT.typeId => new ShortColumnBuilder - case STRING.typeId => new StringColumnBuilder - case BINARY.typeId => new BinaryColumnBuilder - case GENERIC.typeId => new GenericColumnBuilder - case DATE.typeId => new DateColumnBuilder - case TIMESTAMP.typeId => new TimestampColumnBuilder - }).asInstanceOf[ColumnBuilder] + val builder: ColumnBuilder = dataType match { + case IntegerType => new IntColumnBuilder + case LongType => new LongColumnBuilder + case FloatType => new FloatColumnBuilder + case DoubleType => new DoubleColumnBuilder + case BooleanType => new BooleanColumnBuilder + case ByteType => new ByteColumnBuilder + case ShortType => new ShortColumnBuilder + case StringType => new StringColumnBuilder + case BinaryType => new BinaryColumnBuilder + case DateType => new DateColumnBuilder + case TimestampType => new TimestampColumnBuilder + case DecimalType.Fixed(precision, scale) if precision < 19 => + new FixedDecimalColumnBuilder(precision, scale) + case _ => new GenericColumnBuilder + } builder.initialize(initialSize, columnName, useCompression) builder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 391b3dae5c8c..b0f983c18067 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.columnar -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute, AttributeReference} @@ -76,7 +76,7 @@ private[sql] sealed trait ColumnStats extends Serializable { private[sql] class NoopColumnStats extends ColumnStats { override def gatherStats(row: Row, ordinal: Int): Unit = super.gatherStats(row, ordinal) - def collectedStatistics = Row(null, null, nullCount, count, 0L) + override def collectedStatistics: Row = Row(null, null, nullCount, count, 0L) } private[sql] class BooleanColumnStats extends ColumnStats { @@ -93,7 +93,7 @@ private[sql] class BooleanColumnStats extends ColumnStats { } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class ByteColumnStats extends ColumnStats { @@ -110,7 +110,7 @@ private[sql] class ByteColumnStats extends ColumnStats { } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class ShortColumnStats extends ColumnStats { @@ -127,7 +127,7 @@ private[sql] class ShortColumnStats extends ColumnStats { } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class LongColumnStats extends ColumnStats { @@ -144,7 +144,7 @@ private[sql] class LongColumnStats extends ColumnStats { } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class DoubleColumnStats extends ColumnStats { @@ -161,7 +161,7 @@ private[sql] class DoubleColumnStats extends ColumnStats { } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class FloatColumnStats extends ColumnStats { @@ -178,7 +178,24 @@ private[sql] class FloatColumnStats extends ColumnStats { } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) +} + +private[sql] class FixedDecimalColumnStats extends ColumnStats { + protected var upper: Decimal = null + protected var lower: Decimal = null + + override def gatherStats(row: Row, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) + if (!row.isNullAt(ordinal)) { + val value = row(ordinal).asInstanceOf[Decimal] + if (upper == null || value.compareTo(upper) > 0) upper = value + if (lower == null || value.compareTo(lower) < 0) lower = value + sizeInBytes += FIXED_DECIMAL.defaultSize + } + } + + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class IntColumnStats extends ColumnStats { @@ -195,42 +212,27 @@ private[sql] class IntColumnStats extends ColumnStats { } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class StringColumnStats extends ColumnStats { - protected var upper: String = null - protected var lower: String = null + protected var upper: UTF8String = null + protected var lower: UTF8String = null override def gatherStats(row: Row, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row.getString(ordinal) + val value = row(ordinal).asInstanceOf[UTF8String] if (upper == null || value.compareTo(upper) > 0) upper = value if (lower == null || value.compareTo(lower) < 0) lower = value sizeInBytes += STRING.actualSize(row, ordinal) } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } -private[sql] class DateColumnStats extends ColumnStats { - protected var upper: Date = null - protected var lower: Date = null - - override def gatherStats(row: Row, ordinal: Int) { - super.gatherStats(row, ordinal) - if (!row.isNullAt(ordinal)) { - val value = row(ordinal).asInstanceOf[Date] - if (upper == null || value.compareTo(upper) > 0) upper = value - if (lower == null || value.compareTo(lower) < 0) lower = value - sizeInBytes += DATE.defaultSize - } - } - - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) -} +private[sql] class DateColumnStats extends IntColumnStats private[sql] class TimestampColumnStats extends ColumnStats { protected var upper: Timestamp = null @@ -246,7 +248,7 @@ private[sql] class TimestampColumnStats extends ColumnStats { } } - def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class BinaryColumnStats extends ColumnStats { @@ -257,7 +259,7 @@ private[sql] class BinaryColumnStats extends ColumnStats { } } - def collectedStatistics = Row(null, null, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(null, null, nullCount, count, sizeInBytes) } private[sql] class GenericColumnStats extends ColumnStats { @@ -268,5 +270,5 @@ private[sql] class GenericColumnStats extends ColumnStats { } } - def collectedStatistics = Row(null, null, nullCount, count, sizeInBytes) + override def collectedStatistics: Row = Row(null, null, nullCount, count, sizeInBytes) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index fcf2faa0914c..20be5ca9d004 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import scala.reflect.runtime.universe.TypeTag @@ -98,23 +98,23 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( */ def clone(v: JvmType): JvmType = v - override def toString = getClass.getSimpleName.stripSuffix("$") + override def toString: String = getClass.getSimpleName.stripSuffix("$") } -private[sql] abstract class NativeColumnType[T <: NativeType]( +private[sql] abstract class NativeColumnType[T <: AtomicType]( val dataType: T, typeId: Int, defaultSize: Int) - extends ColumnType[T, T#JvmType](typeId, defaultSize) { + extends ColumnType[T, T#InternalType](typeId, defaultSize) { /** * Scala TypeTag. Can be used to create primitive arrays and hash tables. */ - def scalaTag: TypeTag[dataType.JvmType] = dataType.tag + def scalaTag: TypeTag[dataType.InternalType] = dataType.tag } private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { - def append(v: Int, buffer: ByteBuffer): Unit = { + override def append(v: Int, buffer: ByteBuffer): Unit = { buffer.putInt(v) } @@ -122,7 +122,7 @@ private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { buffer.putInt(row.getInt(ordinal)) } - def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): Int = { buffer.getInt() } @@ -134,7 +134,7 @@ private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { row.setInt(ordinal, value) } - override def getField(row: Row, ordinal: Int) = row.getInt(ordinal) + override def getField(row: Row, ordinal: Int): Int = row.getInt(ordinal) override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to.setInt(toOrdinal, from.getInt(fromOrdinal)) @@ -150,7 +150,7 @@ private[sql] object LONG extends NativeColumnType(LongType, 1, 8) { buffer.putLong(row.getLong(ordinal)) } - override def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): Long = { buffer.getLong() } @@ -162,7 +162,7 @@ private[sql] object LONG extends NativeColumnType(LongType, 1, 8) { row.setLong(ordinal, value) } - override def getField(row: Row, ordinal: Int) = row.getLong(ordinal) + override def getField(row: Row, ordinal: Int): Long = row.getLong(ordinal) override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to.setLong(toOrdinal, from.getLong(fromOrdinal)) @@ -178,7 +178,7 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) { buffer.putFloat(row.getFloat(ordinal)) } - override def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): Float = { buffer.getFloat() } @@ -190,7 +190,7 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) { row.setFloat(ordinal, value) } - override def getField(row: Row, ordinal: Int) = row.getFloat(ordinal) + override def getField(row: Row, ordinal: Int): Float = row.getFloat(ordinal) override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to.setFloat(toOrdinal, from.getFloat(fromOrdinal)) @@ -206,7 +206,7 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) { buffer.putDouble(row.getDouble(ordinal)) } - override def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): Double = { buffer.getDouble() } @@ -218,7 +218,7 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) { row.setDouble(ordinal, value) } - override def getField(row: Row, ordinal: Int) = row.getDouble(ordinal) + override def getField(row: Row, ordinal: Int): Double = row.getDouble(ordinal) override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to.setDouble(toOrdinal, from.getDouble(fromOrdinal)) @@ -234,7 +234,7 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) { buffer.put(if (row.getBoolean(ordinal)) 1: Byte else 0: Byte) } - override def extract(buffer: ByteBuffer) = buffer.get() == 1 + override def extract(buffer: ByteBuffer): Boolean = buffer.get() == 1 override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { row.setBoolean(ordinal, buffer.get() == 1) @@ -244,7 +244,7 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) { row.setBoolean(ordinal, value) } - override def getField(row: Row, ordinal: Int) = row.getBoolean(ordinal) + override def getField(row: Row, ordinal: Int): Boolean = row.getBoolean(ordinal) override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to.setBoolean(toOrdinal, from.getBoolean(fromOrdinal)) @@ -260,7 +260,7 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) { buffer.put(row.getByte(ordinal)) } - override def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): Byte = { buffer.get() } @@ -272,7 +272,7 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) { row.setByte(ordinal, value) } - override def getField(row: Row, ordinal: Int) = row.getByte(ordinal) + override def getField(row: Row, ordinal: Int): Byte = row.getByte(ordinal) override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to.setByte(toOrdinal, from.getByte(fromOrdinal)) @@ -288,7 +288,7 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { buffer.putShort(row.getShort(ordinal)) } - override def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): Short = { buffer.getShort() } @@ -300,7 +300,7 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { row.setShort(ordinal, value) } - override def getField(row: Row, ordinal: Int) = row.getShort(ordinal) + override def getField(row: Row, ordinal: Int): Short = row.getShort(ordinal) override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to.setShort(toOrdinal, from.getShort(fromOrdinal)) @@ -312,50 +312,51 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { row.getString(ordinal).getBytes("utf-8").length + 4 } - override def append(v: String, buffer: ByteBuffer): Unit = { - val stringBytes = v.getBytes("utf-8") + override def append(v: UTF8String, buffer: ByteBuffer): Unit = { + val stringBytes = v.getBytes buffer.putInt(stringBytes.length).put(stringBytes, 0, stringBytes.length) } - override def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): UTF8String = { val length = buffer.getInt() val stringBytes = new Array[Byte](length) buffer.get(stringBytes, 0, length) - new String(stringBytes, "utf-8") + UTF8String(stringBytes) } - override def setField(row: MutableRow, ordinal: Int, value: String): Unit = { - row.setString(ordinal, value) + override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = { + row.update(ordinal, value) } - override def getField(row: Row, ordinal: Int) = row.getString(ordinal) + override def getField(row: Row, ordinal: Int): UTF8String = { + row(ordinal).asInstanceOf[UTF8String] + } override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { - to.setString(toOrdinal, from.getString(fromOrdinal)) + to.update(toOrdinal, from(fromOrdinal)) } } -private[sql] object DATE extends NativeColumnType(DateType, 8, 8) { - override def extract(buffer: ByteBuffer) = { - val date = new Date(buffer.getLong()) - date +private[sql] object DATE extends NativeColumnType(DateType, 8, 4) { + override def extract(buffer: ByteBuffer): Int = { + buffer.getInt } - override def append(v: Date, buffer: ByteBuffer): Unit = { - buffer.putLong(v.getTime) + override def append(v: Int, buffer: ByteBuffer): Unit = { + buffer.putInt(v) } - override def getField(row: Row, ordinal: Int) = { - row(ordinal).asInstanceOf[Date] + override def getField(row: Row, ordinal: Int): Int = { + row(ordinal).asInstanceOf[Int] } - override def setField(row: MutableRow, ordinal: Int, value: Date): Unit = { + def setField(row: MutableRow, ordinal: Int, value: Int): Unit = { row(ordinal) = value } } private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 12) { - override def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): Timestamp = { val timestamp = new Timestamp(buffer.getLong()) timestamp.setNanos(buffer.getInt()) timestamp @@ -365,7 +366,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 12) { buffer.putLong(v.getTime).putInt(v.getNanos) } - override def getField(row: Row, ordinal: Int) = { + override def getField(row: Row, ordinal: Int): Timestamp = { row(ordinal).asInstanceOf[Timestamp] } @@ -374,12 +375,39 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 12) { } } +private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int) + extends NativeColumnType( + DecimalType(Some(PrecisionInfo(precision, scale))), + 10, + FIXED_DECIMAL.defaultSize) { + + override def extract(buffer: ByteBuffer): Decimal = { + Decimal(buffer.getLong(), precision, scale) + } + + override def append(v: Decimal, buffer: ByteBuffer): Unit = { + buffer.putLong(v.toUnscaledLong) + } + + override def getField(row: Row, ordinal: Int): Decimal = { + row(ordinal).asInstanceOf[Decimal] + } + + override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = { + row(ordinal) = value + } +} + +private[sql] object FIXED_DECIMAL { + val defaultSize = 8 +} + private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( typeId: Int, defaultSize: Int) extends ColumnType[T, Array[Byte]](typeId, defaultSize) { - override def actualSize(row: Row, ordinal: Int) = { + override def actualSize(row: Row, ordinal: Int): Int = { getField(row, ordinal).length + 4 } @@ -387,7 +415,7 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( buffer.putInt(v.length).put(v, 0, v.length) } - override def extract(buffer: ByteBuffer) = { + override def extract(buffer: ByteBuffer): Array[Byte] = { val length = buffer.getInt() val bytes = new Array[Byte](length) buffer.get(bytes, 0, length) @@ -395,40 +423,46 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( } } -private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](10, 16) { +private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](11, 16) { override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { row(ordinal) = value } - override def getField(row: Row, ordinal: Int) = row(ordinal).asInstanceOf[Array[Byte]] + override def getField(row: Row, ordinal: Int): Array[Byte] = { + row(ordinal).asInstanceOf[Array[Byte]] + } } // Used to process generic objects (all types other than those listed above). Objects should be // serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized // byte array. -private[sql] object GENERIC extends ByteArrayColumnType[DataType](11, 16) { +private[sql] object GENERIC extends ByteArrayColumnType[DataType](12, 16) { override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { row(ordinal) = SparkSqlSerializer.deserialize[Any](value) } - override def getField(row: Row, ordinal: Int) = SparkSqlSerializer.serialize(row(ordinal)) + override def getField(row: Row, ordinal: Int): Array[Byte] = { + SparkSqlSerializer.serialize(row(ordinal)) + } } private[sql] object ColumnType { def apply(dataType: DataType): ColumnType[_, _] = { dataType match { - case IntegerType => INT - case LongType => LONG - case FloatType => FLOAT - case DoubleType => DOUBLE - case BooleanType => BOOLEAN - case ByteType => BYTE - case ShortType => SHORT - case StringType => STRING - case BinaryType => BINARY - case DateType => DATE + case IntegerType => INT + case LongType => LONG + case FloatType => FLOAT + case DoubleType => DOUBLE + case BooleanType => BOOLEAN + case ByteType => BYTE + case ShortType => SHORT + case StringType => STRING + case BinaryType => BINARY + case DateType => DATE case TimestampType => TIMESTAMP - case _ => GENERIC + case DecimalType.Fixed(precision, scale) if precision < 19 => + FIXED_DECIMAL(precision, scale) + case _ => GENERIC } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 11d5943fb427..d9b6fb43ab83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -19,10 +19,15 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer +import org.apache.spark.{Accumulable, Accumulator, Accumulators} +import org.apache.spark.sql.catalyst.expressions + import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row +import org.apache.spark.SparkContext import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ @@ -50,11 +55,16 @@ private[sql] case class InMemoryRelation( child: SparkPlan, tableName: Option[String])( private var _cachedColumnBuffers: RDD[CachedBatch] = null, - private var _statistics: Statistics = null) + private var _statistics: Statistics = null, + private var _batchStats: Accumulable[ArrayBuffer[Row], Row] = null) extends LogicalPlan with MultiInstanceRelation { - private val batchStats = - child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[Row]) + private val batchStats: Accumulable[ArrayBuffer[Row], Row] = + if (_batchStats == null) { + child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[Row]) + } else { + _batchStats + } val partitionStatistics = new PartitionStatistics(output) @@ -77,20 +87,23 @@ private[sql] case class InMemoryRelation( _statistics } - override def statistics = if (_statistics == null) { - if (batchStats.value.isEmpty) { - // Underlying columnar RDD hasn't been materialized, no useful statistics information - // available, return the default statistics. - Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes) + override def statistics: Statistics = { + if (_statistics == null) { + if (batchStats.value.isEmpty) { + // Underlying columnar RDD hasn't been materialized, no useful statistics information + // available, return the default statistics. + Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes) + } else { + // Underlying columnar RDD has been materialized, required information has also been + // collected via the `batchStats` accumulator, compute the final statistics, + // and update `_statistics`. + _statistics = Statistics(sizeInBytes = computeSizeInBytes) + _statistics + } } else { - // Underlying columnar RDD has been materialized, required information has also been collected - // via the `batchStats` accumulator, compute the final statistics, and update `_statistics`. - _statistics = Statistics(sizeInBytes = computeSizeInBytes) + // Pre-computed statistics _statistics } - } else { - // Pre-computed statistics - _statistics } // If the cached column buffers were not passed in, we calculate them in the constructor. @@ -99,7 +112,7 @@ private[sql] case class InMemoryRelation( buildBuffers() } - def recache() = { + def recache(): Unit = { _cachedColumnBuffers.unpersist() _cachedColumnBuffers = null buildBuffers() @@ -109,16 +122,27 @@ private[sql] case class InMemoryRelation( val output = child.output val cached = child.execute().mapPartitions { rowIterator => new Iterator[CachedBatch] { - def next() = { + def next(): CachedBatch = { val columnBuilders = output.map { attribute => val columnType = ColumnType(attribute.dataType) val initialBufferSize = columnType.defaultSize * batchSize - ColumnBuilder(columnType.typeId, initialBufferSize, attribute.name, useCompression) + ColumnBuilder(attribute.dataType, initialBufferSize, attribute.name, useCompression) }.toArray var rowCount = 0 while (rowIterator.hasNext && rowCount < batchSize) { val row = rowIterator.next() + + // Added for SPARK-6082. This assertion can be useful for scenarios when something + // like Hive TRANSFORM is used. The external data generation script used in TRANSFORM + // may result malformed rows, causing ArrayIndexOutOfBoundsException, which is somewhat + // hard to decipher. + assert( + row.size == columnBuilders.size, + s"""Row column number mismatch, expected ${output.size} columns, but got ${row.size}. + |Row content: $row + """.stripMargin) + var i = 0 while (i < row.length) { columnBuilders(i).appendFrom(row, i) @@ -133,7 +157,7 @@ private[sql] case class InMemoryRelation( CachedBatch(columnBuilders.map(_.build().array()), stats) } - def hasNext = rowIterator.hasNext + def hasNext: Boolean = rowIterator.hasNext } }.persist(storageLevel) @@ -144,12 +168,12 @@ private[sql] case class InMemoryRelation( def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { InMemoryRelation( newOutput, useCompression, batchSize, storageLevel, child, tableName)( - _cachedColumnBuffers, statisticsToBePropagated) + _cachedColumnBuffers, statisticsToBePropagated, batchStats) } - override def children = Seq.empty + override def children: Seq[LogicalPlan] = Seq.empty - override def newInstance() = { + override def newInstance(): this.type = { new InMemoryRelation( output.map(_.newInstance()), useCompression, @@ -158,13 +182,20 @@ private[sql] case class InMemoryRelation( child, tableName)( _cachedColumnBuffers, - statisticsToBePropagated).asInstanceOf[this.type] + statisticsToBePropagated, + batchStats).asInstanceOf[this.type] } - def cachedColumnBuffers = _cachedColumnBuffers + def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers override protected def otherCopyArgs: Seq[AnyRef] = - Seq(_cachedColumnBuffers, statisticsToBePropagated) + Seq(_cachedColumnBuffers, statisticsToBePropagated, batchStats) + + private[sql] def uncache(blocking: Boolean): Unit = { + Accumulators.remove(batchStats.id) + cachedColumnBuffers.unpersist(blocking) + _cachedColumnBuffers = null + } } private[sql] case class InMemoryColumnarTableScan( @@ -209,7 +240,7 @@ private[sql] case class InMemoryColumnarTableScan( case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 } - val partitionFilters = { + val partitionFilters: Seq[Expression] = { predicates.flatMap { p => val filter = buildFilter.lift(p) val boundFilter = @@ -227,15 +258,20 @@ private[sql] case class InMemoryColumnarTableScan( } } + lazy val enableAccumulators: Boolean = + sqlContext.getConf("spark.sql.inMemoryTableScanStatistics.enable", "false").toBoolean + // Accumulators used for testing purposes - val readPartitions = sparkContext.accumulator(0) - val readBatches = sparkContext.accumulator(0) + lazy val readPartitions: Accumulator[Int] = sparkContext.accumulator(0) + lazy val readBatches: Accumulator[Int] = sparkContext.accumulator(0) private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning - override def execute() = { - readPartitions.setValue(0) - readBatches.setValue(0) + override def execute(): RDD[Row] = { + if (enableAccumulators) { + readPartitions.setValue(0) + readBatches.setValue(0) + } relation.cachedColumnBuffers.mapPartitions { cachedBatchIterator => val partitionFilter = newPredicate( @@ -260,17 +296,19 @@ private[sql] case class InMemoryColumnarTableScan( val nextRow = new SpecificMutableRow(requestedColumnDataTypes) - def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]) = { + def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]): Iterator[Row] = { val rows = cacheBatches.flatMap { cachedBatch => // Build column accessors - val columnAccessors = requestedColumnIndices.map { batch => - ColumnAccessor(ByteBuffer.wrap(cachedBatch.buffers(batch))) + val columnAccessors = requestedColumnIndices.map { batchColumnIndex => + ColumnAccessor( + relation.output(batchColumnIndex).dataType, + ByteBuffer.wrap(cachedBatch.buffers(batchColumnIndex))) } // Extract rows via column accessors new Iterator[Row] { private[this] val rowLen = nextRow.length - override def next() = { + override def next(): Row = { var i = 0 while (i < rowLen) { columnAccessors(i).extractTo(nextRow, i) @@ -279,11 +317,11 @@ private[sql] case class InMemoryColumnarTableScan( nextRow } - override def hasNext = columnAccessors(0).hasNext + override def hasNext: Boolean = columnAccessors(0).hasNext } } - if (rows.hasNext) { + if (rows.hasNext && enableAccumulators) { readPartitions += 1 } @@ -295,14 +333,16 @@ private[sql] case class InMemoryColumnarTableScan( if (inMemoryPartitionPruningEnabled) { cachedBatchIterator.filter { cachedBatch => if (!partitionFilter(cachedBatch.stats)) { - def statsString = relation.partitionStatistics.schema + def statsString: String = relation.partitionStatistics.schema .zip(cachedBatch.stats.toSeq) .map { case (a, s) => s"${a.name}: $s" } .mkString(", ") logInfo(s"Skipping partition based on stats $statsString") false } else { - readBatches += 1 + if (enableAccumulators) { + readBatches += 1 + } true } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala index 965782a40031..4d35650d4b1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala @@ -55,5 +55,5 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor { pos += 1 } - abstract override def hasNext = seenNulls < nullCount || super.hasNext + abstract override def hasNext: Boolean = seenNulls < nullCount || super.hasNext } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala index 7dff9deac8dc..cb205defbb1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala @@ -19,19 +19,19 @@ package org.apache.spark.sql.columnar.compression import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.columnar.{ColumnAccessor, NativeColumnAccessor} -import org.apache.spark.sql.types.NativeType +import org.apache.spark.sql.types.AtomicType -private[sql] trait CompressibleColumnAccessor[T <: NativeType] extends ColumnAccessor { +private[sql] trait CompressibleColumnAccessor[T <: AtomicType] extends ColumnAccessor { this: NativeColumnAccessor[T] => private var decoder: Decoder[T] = _ - abstract override protected def initialize() = { + abstract override protected def initialize(): Unit = { super.initialize() decoder = CompressionScheme(underlyingBuffer.getInt()).decoder(buffer, columnType) } - abstract override def hasNext = super.hasNext || decoder.hasNext + abstract override def hasNext: Boolean = super.hasNext || decoder.hasNext override def extractSingle(row: MutableRow, ordinal: Int): Unit = { decoder.next(row, ordinal) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala index aead768ecdf0..8e2a1af6dae7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala @@ -22,7 +22,7 @@ import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.columnar.{ColumnBuilder, NativeColumnBuilder} -import org.apache.spark.sql.types.NativeType +import org.apache.spark.sql.types.AtomicType /** * A stackable trait that builds optionally compressed byte buffer for a column. Memory layout of @@ -41,7 +41,7 @@ import org.apache.spark.sql.types.NativeType * header body * }}} */ -private[sql] trait CompressibleColumnBuilder[T <: NativeType] +private[sql] trait CompressibleColumnBuilder[T <: AtomicType] extends ColumnBuilder with Logging { this: NativeColumnBuilder[T] with WithCompressionSchemes => @@ -81,7 +81,7 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] } } - override def build() = { + override def build(): ByteBuffer = { val nonNullBuffer = buildNonNulls() val typeId = nonNullBuffer.getInt() val encoder: Encoder[T] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala index 879d29bcfa6f..17c2d9b11118 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala @@ -22,9 +22,9 @@ import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType} -import org.apache.spark.sql.types.NativeType +import org.apache.spark.sql.types.AtomicType -private[sql] trait Encoder[T <: NativeType] { +private[sql] trait Encoder[T <: AtomicType] { def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {} def compressedSize: Int @@ -38,7 +38,7 @@ private[sql] trait Encoder[T <: NativeType] { def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer } -private[sql] trait Decoder[T <: NativeType] { +private[sql] trait Decoder[T <: AtomicType] { def next(row: MutableRow, ordinal: Int): Unit def hasNext: Boolean @@ -49,9 +49,9 @@ private[sql] trait CompressionScheme { def supports(columnType: ColumnType[_, _]): Boolean - def encoder[T <: NativeType](columnType: NativeColumnType[T]): Encoder[T] + def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] - def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] + def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] } private[sql] trait WithCompressionSchemes { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala index 68a5b1de7691..534ae90ddbc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala @@ -33,56 +33,58 @@ import org.apache.spark.util.Utils private[sql] case object PassThrough extends CompressionScheme { override val typeId = 0 - override def supports(columnType: ColumnType[_, _]) = true + override def supports(columnType: ColumnType[_, _]): Boolean = true - override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] = { new this.Encoder[T](columnType) } - override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + override def decoder[T <: AtomicType]( + buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] = { new this.Decoder(buffer, columnType) } - class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { - override def uncompressedSize = 0 + class Encoder[T <: AtomicType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { + override def uncompressedSize: Int = 0 - override def compressedSize = 0 + override def compressedSize: Int = 0 - override def compress(from: ByteBuffer, to: ByteBuffer) = { + override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { // Writes compression type ID and copies raw contents to.putInt(PassThrough.typeId).put(from).rewind() to } } - class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + class Decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) extends compression.Decoder[T] { override def next(row: MutableRow, ordinal: Int): Unit = { columnType.extract(buffer, row, ordinal) } - override def hasNext = buffer.hasRemaining + override def hasNext: Boolean = buffer.hasRemaining } } private[sql] case object RunLengthEncoding extends CompressionScheme { override val typeId = 1 - override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] = { new this.Encoder[T](columnType) } - override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + override def decoder[T <: AtomicType]( + buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] = { new this.Decoder(buffer, columnType) } - override def supports(columnType: ColumnType[_, _]) = columnType match { + override def supports(columnType: ColumnType[_, _]): Boolean = columnType match { case INT | LONG | SHORT | BYTE | STRING | BOOLEAN => true case _ => false } - class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { + class Encoder[T <: AtomicType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { private var _uncompressedSize = 0 private var _compressedSize = 0 @@ -90,9 +92,9 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { private val lastValue = new SpecificMutableRow(Seq(columnType.dataType)) private var lastRun = 0 - override def uncompressedSize = _uncompressedSize + override def uncompressedSize: Int = _uncompressedSize - override def compressedSize = _compressedSize + override def compressedSize: Int = _compressedSize override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { val value = columnType.getField(row, ordinal) @@ -114,7 +116,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { } } - override def compress(from: ByteBuffer, to: ByteBuffer) = { + override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { to.putInt(RunLengthEncoding.typeId) if (from.hasRemaining) { @@ -150,12 +152,12 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { } } - class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + class Decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) extends compression.Decoder[T] { private var run = 0 private var valueCount = 0 - private var currentValue: T#JvmType = _ + private var currentValue: T#InternalType = _ override def next(row: MutableRow, ordinal: Int): Unit = { if (valueCount == run) { @@ -169,7 +171,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { columnType.setField(row, ordinal, currentValue) } - override def hasNext = valueCount < run || buffer.hasRemaining + override def hasNext: Boolean = valueCount < run || buffer.hasRemaining } } @@ -179,20 +181,21 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { // 32K unique values allowed val MAX_DICT_SIZE = Short.MaxValue - override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + : Decoder[T] = { new this.Decoder(buffer, columnType) } - override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] = { new this.Encoder[T](columnType) } - override def supports(columnType: ColumnType[_, _]) = columnType match { + override def supports(columnType: ColumnType[_, _]): Boolean = columnType match { case INT | LONG | STRING => true case _ => false } - class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { + class Encoder[T <: AtomicType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { // Size of the input, uncompressed, in bytes. Note that we only count until the dictionary // overflows. private var _uncompressedSize = 0 @@ -205,7 +208,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { private var count = 0 // The reverse mapping of _dictionary, i.e. mapping encoded integer to the value itself. - private var values = new mutable.ArrayBuffer[T#JvmType](1024) + private var values = new mutable.ArrayBuffer[T#InternalType](1024) // The dictionary that maps a value to the encoded short integer. private val dictionary = mutable.HashMap.empty[Any, Short] @@ -237,7 +240,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { } } - override def compress(from: ByteBuffer, to: ByteBuffer) = { + override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { if (overflow) { throw new IllegalStateException( "Dictionary encoding should not be used because of dictionary overflow.") @@ -260,19 +263,19 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { to } - override def uncompressedSize = _uncompressedSize + override def uncompressedSize: Int = _uncompressedSize - override def compressedSize = if (overflow) Int.MaxValue else dictionarySize + count * 2 + override def compressedSize: Int = if (overflow) Int.MaxValue else dictionarySize + count * 2 } - class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + class Decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) extends compression.Decoder[T] { private val dictionary = { // TODO Can we clean up this mess? Maybe move this to `DataType`? implicit val classTag = { val mirror = runtimeMirror(Utils.getSparkClassLoader) - ClassTag[T#JvmType](mirror.runtimeClass(columnType.scalaTag.tpe)) + ClassTag[T#InternalType](mirror.runtimeClass(columnType.scalaTag.tpe)) } Array.fill(buffer.getInt()) { @@ -284,7 +287,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { columnType.setField(row, ordinal, dictionary(buffer.getShort())) } - override def hasNext = buffer.hasRemaining + override def hasNext: Boolean = buffer.hasRemaining } } @@ -293,15 +296,16 @@ private[sql] case object BooleanBitSet extends CompressionScheme { val BITS_PER_LONG = 64 - override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + : compression.Decoder[T] = { new this.Decoder(buffer).asInstanceOf[compression.Decoder[T]] } - override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): compression.Encoder[T] = { (new this.Encoder).asInstanceOf[compression.Encoder[T]] } - override def supports(columnType: ColumnType[_, _]) = columnType == BOOLEAN + override def supports(columnType: ColumnType[_, _]): Boolean = columnType == BOOLEAN class Encoder extends compression.Encoder[BooleanType.type] { private var _uncompressedSize = 0 @@ -310,7 +314,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme { _uncompressedSize += BOOLEAN.defaultSize } - override def compress(from: ByteBuffer, to: ByteBuffer) = { + override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { to.putInt(BooleanBitSet.typeId) // Total element count (1 byte per Boolean value) .putInt(from.remaining) @@ -347,9 +351,9 @@ private[sql] case object BooleanBitSet extends CompressionScheme { to } - override def uncompressedSize = _uncompressedSize + override def uncompressedSize: Int = _uncompressedSize - override def compressedSize = { + override def compressedSize: Int = { val extra = if (_uncompressedSize % BITS_PER_LONG == 0) 0 else 1 (_uncompressedSize / BITS_PER_LONG + extra) * 8 + 4 } @@ -380,22 +384,23 @@ private[sql] case object BooleanBitSet extends CompressionScheme { private[sql] case object IntDelta extends CompressionScheme { override def typeId: Int = 4 - override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + : compression.Decoder[T] = { new Decoder(buffer, INT).asInstanceOf[compression.Decoder[T]] } - override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): compression.Encoder[T] = { (new Encoder).asInstanceOf[compression.Encoder[T]] } - override def supports(columnType: ColumnType[_, _]) = columnType == INT + override def supports(columnType: ColumnType[_, _]): Boolean = columnType == INT class Encoder extends compression.Encoder[IntegerType.type] { protected var _compressedSize: Int = 0 protected var _uncompressedSize: Int = 0 - override def compressedSize = _compressedSize - override def uncompressedSize = _uncompressedSize + override def compressedSize: Int = _compressedSize + override def uncompressedSize: Int = _uncompressedSize private var prevValue: Int = _ @@ -459,22 +464,23 @@ private[sql] case object IntDelta extends CompressionScheme { private[sql] case object LongDelta extends CompressionScheme { override def typeId: Int = 5 - override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + : compression.Decoder[T] = { new Decoder(buffer, LONG).asInstanceOf[compression.Decoder[T]] } - override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): compression.Encoder[T] = { (new Encoder).asInstanceOf[compression.Encoder[T]] } - override def supports(columnType: ColumnType[_, _]) = columnType == LONG + override def supports(columnType: ColumnType[_, _]): Boolean = columnType == LONG class Encoder extends compression.Encoder[LongType.type] { protected var _compressedSize: Int = 0 protected var _uncompressedSize: Int = 0 - override def compressedSize = _compressedSize - override def uncompressedSize = _uncompressedSize + override def compressedSize: Int = _compressedSize + override def uncompressedSize: Int = _uncompressedSize private var prevValue: Long = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index be9f155253d7..18b1ba4c5c4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -21,6 +21,7 @@ import java.util.HashMap import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -45,7 +46,7 @@ case class Aggregate( child: SparkPlan) extends UnaryNode { - override def requiredChildDistribution = + override def requiredChildDistribution: List[Distribution] = { if (partial) { UnspecifiedDistribution :: Nil } else { @@ -55,12 +56,9 @@ case class Aggregate( ClusteredDistribution(groupingExpressions) :: Nil } } + } - // HACK: Generators don't correctly preserve their output through serializations so we grab - // out child's output attributes statically here. - private[this] val childOutput = child.output - - override def output = aggregateExpressions.map(_.toAttribute) + override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) /** * An aggregate that needs to be computed for each row in a group. @@ -81,7 +79,7 @@ case class Aggregate( case a: AggregateExpression => ComputedAggregate( a, - BindReferences.bindReference(a, childOutput), + BindReferences.bindReference(a, child.output), AttributeReference(s"aggResult:$a", a.dataType, a.nullable)()) } }.toArray @@ -123,7 +121,7 @@ case class Aggregate( } } - override def execute() = attachTree(this, "execute") { + override def execute(): RDD[Row] = attachTree(this, "execute") { if (groupingExpressions.isEmpty) { child.execute().mapPartitions { iter => val buffer = newAggregateBuffer() @@ -150,7 +148,7 @@ case class Aggregate( } else { child.execute().mapPartitions { iter => val hashTable = new HashMap[Row, Array[AggregateFunction]] - val groupingProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) + val groupingProjection = new InterpretedMutableProjection(groupingExpressions, child.output) var currentRow: Row = null while (iter.hasNext) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 7c0b72aab448..5b2e46962cd3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -19,24 +19,45 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf} -import org.apache.spark.rdd.ShuffledRDD +import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner} +import org.apache.spark.rdd.{RDD, ShuffledRDD} +import org.apache.spark.serializer.Serializer import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.RowOrdering +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.DataType import org.apache.spark.util.MutablePair +object Exchange { + /** + * Returns true when the ordering expressions are a subset of the key. + * if true, ShuffledRDD can use `setKeyOrdering(orderingKey)` to sort within [[Exchange]]. + */ + def canSortWithShuffle(partitioning: Partitioning, desiredOrdering: Seq[SortOrder]): Boolean = { + desiredOrdering.map(_.child).toSet.subsetOf(partitioning.keyExpressions.toSet) + } +} + /** * :: DeveloperApi :: + * Performs a shuffle that will result in the desired `newPartitioning`. Optionally sorts each + * resulting partition based on expressions from the partition key. It is invalid to construct an + * exchange operator with a `newOrdering` that cannot be calculated using the partitioning key. */ @DeveloperApi -case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { +case class Exchange( + newPartitioning: Partitioning, + newOrdering: Seq[SortOrder], + child: SparkPlan) + extends UnaryNode { - override def outputPartitioning = newPartitioning + override def outputPartitioning: Partitioning = newPartitioning - override def output = child.output + override def outputOrdering: Seq[SortOrder] = newOrdering + + override def output: Seq[Attribute] = child.output /** We must copy rows when sort based shuffle is on */ protected def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] @@ -44,7 +65,62 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una private val bypassMergeThreshold = child.sqlContext.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - override def execute() = attachTree(this , "execute") { + private val keyOrdering = { + if (newOrdering.nonEmpty) { + val key = newPartitioning.keyExpressions + val boundOrdering = newOrdering.map { o => + val ordinal = key.indexOf(o.child) + if (ordinal == -1) sys.error(s"Invalid ordering on $o requested for $newPartitioning") + o.copy(child = BoundReference(ordinal, o.child.dataType, o.child.nullable)) + } + new RowOrdering(boundOrdering) + } else { + null // Ordering will not be used + } + } + + @transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf + + def serializer( + keySchema: Array[DataType], + valueSchema: Array[DataType], + numPartitions: Int): Serializer = { + // In ExternalSorter's spillToMergeableFile function, key-value pairs are written out + // through write(key) and then write(value) instead of write((key, value)). Because + // SparkSqlSerializer2 assumes that objects passed in are Product2, we cannot safely use + // it when spillToMergeableFile in ExternalSorter will be used. + // So, we will not use SparkSqlSerializer2 when + // - Sort-based shuffle is enabled and the number of reducers (numPartitions) is greater + // then the bypassMergeThreshold; or + // - newOrdering is defined. + val cannotUseSqlSerializer2 = + (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || newOrdering.nonEmpty + + // It is true when there is no field that needs to be write out. + // For now, we will not use SparkSqlSerializer2 when noField is true. + val noField = + (keySchema == null || keySchema.length == 0) && + (valueSchema == null || valueSchema.length == 0) + + val useSqlSerializer2 = + child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled. + !cannotUseSqlSerializer2 && // Safe to use Serializer2. + SparkSqlSerializer2.support(keySchema) && // The schema of key is supported. + SparkSqlSerializer2.support(valueSchema) && // The schema of value is supported. + !noField + + val serializer = if (useSqlSerializer2) { + logInfo("Using SparkSqlSerializer2.") + new SparkSqlSerializer2(keySchema, valueSchema) + } else { + logInfo("Using SparkSqlSerializer.") + new SparkSqlSerializer(sparkConf) + } + + serializer + } + + override def execute(): RDD[Row] = attachTree(this , "execute") { newPartitioning match { case HashPartitioning(expressions, numPartitions) => // TODO: Eliminate redundant expressions in grouping key and value. @@ -55,7 +131,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una // we can avoid the defensive copies to improve performance. In the long run, we probably // want to include information in shuffle dependencies to indicate whether elements in the // source RDD should be copied. - val rdd = if (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) { + val willMergeSort = sortBasedShuffleOn && numPartitions > bypassMergeThreshold + + val rdd = if (willMergeSort || newOrdering.nonEmpty) { child.execute().mapPartitions { iter => val hashExpressions = newMutableProjection(expressions, child.output)() iter.map(r => (hashExpressions(r).copy(), r.copy())) @@ -68,12 +146,20 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una } } val part = new HashPartitioner(numPartitions) - val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part) - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + val shuffled = + if (newOrdering.nonEmpty) { + new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(keyOrdering) + } else { + new ShuffledRDD[Row, Row, Row](rdd, part) + } + val keySchema = expressions.map(_.dataType).toArray + val valueSchema = child.output.map(_.dataType).toArray + shuffled.setSerializer(serializer(keySchema, valueSchema, numPartitions)) + shuffled.map(_._2) case RangePartitioning(sortingExpressions, numPartitions) => - val rdd = if (sortBasedShuffleOn) { + val rdd = if (sortBasedShuffleOn || newOrdering.nonEmpty) { child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))} } else { child.execute().mapPartitions { iter => @@ -86,8 +172,14 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una implicit val ordering = new RowOrdering(sortingExpressions, child.output) val part = new RangePartitioner(numPartitions, rdd, ascending = true) - val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part) - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + val shuffled = + if (newOrdering.nonEmpty) { + new ShuffledRDD[Row, Null, Null](rdd, part).setKeyOrdering(keyOrdering) + } else { + new ShuffledRDD[Row, Null, Null](rdd, part) + } + val keySchema = child.output.map(_.dataType).toArray + shuffled.setSerializer(serializer(keySchema, null, numPartitions)) shuffled.map(_._1) @@ -106,7 +198,8 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una } val partitioner = new HashPartitioner(1) val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner) - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + val valueSchema = child.output.map(_.dataType).toArray + shuffled.setSerializer(serializer(null, valueSchema, 1)) shuffled.map(_._2) case _ => sys.error(s"Exchange not implemented for $newPartitioning") @@ -119,27 +212,34 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]] * of input data meets the * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for - * each operator by inserting [[Exchange]] Operators where required. + * each operator by inserting [[Exchange]] Operators where required. Also ensure that the + * required input partition ordering requirements are met. */ -private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPlan] { +private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] { // TODO: Determine the number of partitions. - def numPartitions = sqlContext.conf.numShufflePartitions + def numPartitions: Int = sqlContext.conf.numShufflePartitions def apply(plan: SparkPlan): SparkPlan = plan.transformUp { case operator: SparkPlan => - // Check if every child's outputPartitioning satisfies the corresponding + // True iff every child's outputPartitioning satisfies the corresponding // required data distribution. - def meetsRequirements = - !operator.requiredChildDistribution.zip(operator.children).map { + def meetsRequirements: Boolean = + operator.requiredChildDistribution.zip(operator.children).forall { case (required, child) => val valid = child.outputPartitioning.satisfies(required) logDebug( s"${if (valid) "Valid" else "Invalid"} distribution," + s"required: $required current: ${child.outputPartitioning}") valid - }.exists(!_) + } - // Check if outputPartitionings of children are compatible with each other. + // True iff any of the children are incorrectly sorted. + def needsAnySort: Boolean = + operator.requiredChildOrdering.zip(operator.children).exists { + case (required, child) => required.nonEmpty && required != child.outputOrdering + } + + // True iff outputPartitionings of children are compatible with each other. // It is possible that every child satisfies its required data distribution // but two children have incompatible outputPartitionings. For example, // A dataset is range partitioned by "a.asc" (RangePartitioning) and another @@ -147,7 +247,7 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl // datasets are both clustered by "a", but these two outputPartitionings are not // compatible. // TODO: ASSUMES TRANSITIVITY? - def compatible = + def compatible: Boolean = !operator.children .map(_.outputPartitioning) .sliding(2) @@ -156,28 +256,69 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl case Seq(a,b) => a compatibleWith b }.exists(!_) - // Check if the partitioning we want to ensure is the same as the child's output - // partitioning. If so, we do not need to add the Exchange operator. - def addExchangeIfNecessary(partitioning: Partitioning, child: SparkPlan) = - if (child.outputPartitioning != partitioning) Exchange(partitioning, child) else child + // Adds Exchange or Sort operators as required + def addOperatorsIfNecessary( + partitioning: Partitioning, + rowOrdering: Seq[SortOrder], + child: SparkPlan): SparkPlan = { + val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering + val needsShuffle = child.outputPartitioning != partitioning + val canSortWithShuffle = Exchange.canSortWithShuffle(partitioning, rowOrdering) + + if (needSort && needsShuffle && canSortWithShuffle) { + Exchange(partitioning, rowOrdering, child) + } else { + val withShuffle = if (needsShuffle) { + Exchange(partitioning, Nil, child) + } else { + child + } + + val withSort = if (needSort) { + if (sqlContext.conf.externalSortEnabled) { + ExternalSort(rowOrdering, global = false, withShuffle) + } else { + Sort(rowOrdering, global = false, withShuffle) + } + } else { + withShuffle + } + + withSort + } + } - if (meetsRequirements && compatible) { + if (meetsRequirements && compatible && !needsAnySort) { operator } else { // At least one child does not satisfies its required data distribution or // at least one child's outputPartitioning is not compatible with another child's // outputPartitioning. In this case, we need to add Exchange operators. - val repartitionedChildren = operator.requiredChildDistribution.zip(operator.children).map { - case (AllTuples, child) => - addExchangeIfNecessary(SinglePartition, child) - case (ClusteredDistribution(clustering), child) => - addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child) - case (OrderedDistribution(ordering), child) => - addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child) - case (UnspecifiedDistribution, child) => child - case (dist, _) => sys.error(s"Don't know how to ensure $dist") + val requirements = + (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children) + + val fixedChildren = requirements.zipped.map { + case (AllTuples, rowOrdering, child) => + addOperatorsIfNecessary(SinglePartition, rowOrdering, child) + case (ClusteredDistribution(clustering), rowOrdering, child) => + addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child) + case (OrderedDistribution(ordering), rowOrdering, child) => + addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child) + + case (UnspecifiedDistribution, Seq(), child) => + child + case (UnspecifiedDistribution, rowOrdering, child) => + if (sqlContext.conf.externalSortEnabled) { + ExternalSort(rowOrdering, global = false, child) + } else { + Sort(rowOrdering, global = false, child) + } + + case (dist, ordering, _) => + sys.error(s"Don't know how to ensure $dist with ordering $ordering") } - operator.withNewChildren(repartitionedChildren) + + operator.withNewChildren(fixedChildren) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 20b14834bb0d..6fd93c4dc6fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, SpecificMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{Row, SQLContext} /** * :: DeveloperApi :: @@ -37,13 +37,42 @@ object RDDConversions { Iterator.empty } else { val bufferedIterator = iterator.buffered - val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity) + val mutableRow = new SpecificMutableRow(schema.fields.map(_.dataType)) + val schemaFields = schema.fields.toArray + val converters = schemaFields.map { + f => CatalystTypeConverters.createToCatalystConverter(f.dataType) + } + bufferedIterator.map { r => + var i = 0 + while (i < mutableRow.length) { + mutableRow(i) = converters(i)(r.productElement(i)) + i += 1 + } + + mutableRow + } + } + } + } + + /** + * Convert the objects inside Row into the types Catalyst expected. + */ + def rowToRowRdd(data: RDD[Row], schema: StructType): RDD[Row] = { + data.mapPartitions { iterator => + if (iterator.isEmpty) { + Iterator.empty + } else { + val bufferedIterator = iterator.buffered + val mutableRow = new GenericMutableRow(bufferedIterator.head.toSeq.toArray) val schemaFields = schema.fields.toArray + val converters = schemaFields.map { + f => CatalystTypeConverters.createToCatalystConverter(f.dataType) + } bufferedIterator.map { r => var i = 0 while (i < mutableRow.length) { - mutableRow(i) = - ScalaReflection.convertToCatalyst(r.productElement(i), schemaFields(i).dataType) + mutableRow(i) = converters(i)(r(i)) i += 1 } @@ -54,59 +83,49 @@ object RDDConversions { } } +/** Logical plan node for scanning data from an RDD. */ case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLContext) - extends LogicalPlan with MultiInstanceRelation { + extends LogicalPlan with MultiInstanceRelation { // wf: here should be leaf node - def children = Nil + override def children: Seq[LogicalPlan] = Nil - def newInstance() = + override def newInstance(): LogicalRDD.this.type = LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type] - override def sameResult(plan: LogicalPlan) = plan match { + override def sameResult(plan: LogicalPlan): Boolean = plan match { case LogicalRDD(_, otherRDD) => rdd.id == otherRDD.id case _ => false } - @transient override lazy val statistics = Statistics( + @transient override lazy val statistics: Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. sizeInBytes = BigInt(sqlContext.conf.defaultSizeInBytes) ) } +/** Physical plan node for scanning data from an RDD. */ case class PhysicalRDD(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { - override def execute() = rdd -} - -@deprecated("Use LogicalRDD", "1.2.0") -case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { - override def execute() = rdd + override def execute(): RDD[Row] = rdd } -@deprecated("Use LogicalRDD", "1.2.0") -case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQLContext) - extends LogicalPlan with MultiInstanceRelation { +/** Logical plan node for scanning data from a local collection. */ +case class LogicalLocalTable(output: Seq[Attribute], rows: Seq[Row])(sqlContext: SQLContext) + extends LogicalPlan with MultiInstanceRelation { - def output = alreadyPlanned.output - override def children = Nil + override def children: Seq[LogicalPlan] = Nil - override final def newInstance(): this.type = { - SparkLogicalPlan( - alreadyPlanned match { - case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance()), rdd) - case _ => sys.error("Multiple instance of the same relation detected.") - })(sqlContext).asInstanceOf[this.type] - } + override def newInstance(): this.type = + LogicalLocalTable(output.map(_.newInstance()), rows)(sqlContext).asInstanceOf[this.type] - override def sameResult(plan: LogicalPlan) = plan match { - case SparkLogicalPlan(ExistingRdd(_, rdd)) => - rdd.id == alreadyPlanned.asInstanceOf[ExistingRdd].rdd.id + override def sameResult(plan: LogicalPlan): Boolean = plan match { + case LogicalRDD(_, otherRDD) => rows == rows case _ => false } - @transient override lazy val statistics = Statistics( - // TODO: Instead of returning a default value here, find a way to return a meaningful size - // estimate for RDDs. See PR 1238 for more discussions. - sizeInBytes = BigInt(sqlContext.conf.defaultSizeInBytes) + @transient override lazy val statistics: Statistics = Statistics( + // TODO: Improve the statistics estimation. + // This is made small enough so it can be broadcasted. + sizeInBytes = sqlContext.conf.autoBroadcastJoinThreshold - 1 ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index 95172420608f..575849481faa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ @@ -42,7 +43,7 @@ case class Expand( // as UNKNOWN partitioning override def outputPartitioning: Partitioning = UnknownPartitioning(0) - override def execute() = attachTree(this, "execute") { + override def execute(): RDD[Row] = attachTree(this, "execute") { child.execute().mapPartitions { iter => // TODO Move out projection objects creation and transfer to // workers via closure. However we can't assume the Projection @@ -55,7 +56,7 @@ case class Expand( private[this] var idx = -1 // -1 means the initial state private[this] var input: Row = _ - override final def hasNext = (-1 < idx && idx < groups.length) || iter.hasNext + override final def hasNext: Boolean = (-1 < idx && idx < groups.length) || iter.hasNext override final def next(): Row = { if (idx <= 0) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 38877c28de3a..5201e20a1056 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ /** @@ -26,44 +27,34 @@ import org.apache.spark.sql.catalyst.expressions._ * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional * programming with one important additional feature, which allows the input rows to be joined with * their output. + * @param generator the generator expression * @param join when true, each output row is implicitly joined with the input tuple that produced * it. * @param outer when true, each input row will be output at least once, even if the output of the * given `generator` is empty. `outer` has no effect when `join` is false. + * @param output the output attributes of this node, which constructed in analysis phase, + * and we can not change it, as the parent node bound with it already. */ @DeveloperApi case class Generate( generator: Generator, join: Boolean, outer: Boolean, + output: Seq[Attribute], child: SparkPlan) extends UnaryNode { - // This must be a val since the generator output expr ids are not preserved by serialization. - protected val generatorOutput: Seq[Attribute] = { - if (join && outer) { - generator.output.map(_.withNullability(true)) - } else { - generator.output - } - } - - // This must be a val since the generator output expr ids are not preserved by serialization. - override val output = - if (join) child.output ++ generatorOutput else generatorOutput - val boundGenerator = BindReferences.bindReference(generator, child.output) - override def execute() = { + override def execute(): RDD[Row] = { if (join) { child.execute().mapPartitions { iter => - val nullValues = Seq.fill(generator.output.size)(Literal(null)) + val nullValues = Seq.fill(generator.elementTypes.size)(Literal(null)) // Used to produce rows with no matches when outer = true. val outerProjection = newProjection(child.output ++ nullValues, child.output) - val joinProjection = - newProjection(child.output ++ generatorOutput, child.output ++ generatorOutput) + val joinProjection = newProjection(output, output) val joinedRow = new JoinedRow iter.flatMap {row => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 4abe26fe4afc..b1ef6556de1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.trees._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -49,7 +50,7 @@ case class GeneratedAggregate( child: SparkPlan) extends UnaryNode { - override def requiredChildDistribution = + override def requiredChildDistribution: Seq[Distribution] = if (partial) { UnspecifiedDistribution :: Nil } else { @@ -60,13 +61,15 @@ case class GeneratedAggregate( } } - override def output = aggregateExpressions.map(_.toAttribute) + override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) - override def execute() = { + override def execute(): RDD[Row] = { val aggregatesToCompute = aggregateExpressions.flatMap { a => a.collect { case agg: AggregateExpression => agg} } + // If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite + // (in test "aggregation with codegen"). val computeFunctions = aggregatesToCompute.map { case c @ Count(expr) => // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its @@ -92,13 +95,16 @@ case class GeneratedAggregate( } val currentSum = AttributeReference("currentSum", calcType, nullable = true)() - val initialValue = Literal(null, calcType) + val initialValue = Literal.create(null, calcType) - // Coalasce avoids double calculation... + // Coalesce avoids double calculation... // but really, common sub expression elimination would be better.... val zero = Cast(Literal(0), calcType) val updateFunction = Coalesce( - Add(Coalesce(currentSum :: zero :: Nil), Cast(expr, calcType)) :: currentSum :: Nil) + Add( + Coalesce(currentSum :: zero :: Nil), + Cast(expr, calcType) + ) :: currentSum :: zero :: Nil) val result = expr.dataType match { case DecimalType.Fixed(_, _) => @@ -108,8 +114,8 @@ case class GeneratedAggregate( AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - case a @ Average(expr) => - val calcType = + case cs @ CombineSum(expr) => + val calcType = expr.dataType expr.dataType match { case DecimalType.Fixed(_, _) => DecimalType.Unlimited @@ -117,45 +123,39 @@ case class GeneratedAggregate( expr.dataType } - val currentCount = AttributeReference("currentCount", LongType, nullable = false)() - val currentSum = AttributeReference("currentSum", calcType, nullable = false)() - val initialCount = Literal(0L) - val initialSum = Cast(Literal(0L), calcType) + val currentSum = AttributeReference("currentSum", calcType, nullable = true)() + val initialValue = Literal.create(null, calcType) + // Coalasce avoids double calculation... + // but really, common sub expression elimination would be better.... + val zero = Cast(Literal(0), calcType) // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its // UnscaledValue will be null if and only if x is null; helps with Average on decimals - val toCount = expr match { + val actualExpr = expr match { case UnscaledValue(e) => e case _ => expr } - - val updateCount = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount) - val updateSum = Coalesce(Add(Cast(expr, calcType), currentSum) :: currentSum :: Nil) - + // partial sum result can be null only when no input rows present + val updateFunction = If( + IsNotNull(actualExpr), + Coalesce( + Add( + Coalesce(currentSum :: zero :: Nil), + Cast(expr, calcType)) :: currentSum :: zero :: Nil), + currentSum) + val result = expr.dataType match { case DecimalType.Fixed(_, _) => - If(EqualTo(currentCount, Literal(0L)), - Literal(null, a.dataType), - Cast(Divide( - Cast(currentSum, DecimalType.Unlimited), - Cast(currentCount, DecimalType.Unlimited)), a.dataType)) - case _ => - If(EqualTo(currentCount, Literal(0L)), - Literal(null, a.dataType), - Divide(Cast(currentSum, a.dataType), Cast(currentCount, a.dataType))) + Cast(currentSum, cs.dataType) + case _ => currentSum } - AggregateEvaluation( - currentCount :: currentSum :: Nil, - initialCount :: initialSum :: Nil, - updateCount :: updateSum :: Nil, - result - ) - + AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) + case m @ Max(expr) => val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() - val initialValue = Literal(null, expr.dataType) + val initialValue = Literal.create(null, expr.dataType) val updateMax = MaxOf(currentMax, expr) AggregateEvaluation( @@ -164,8 +164,20 @@ case class GeneratedAggregate( updateMax :: Nil, currentMax) + case m @ Min(expr) => + val currentMin = AttributeReference("currentMin", expr.dataType, nullable = true)() + val initialValue = Literal.create(null, expr.dataType) + val updateMin = MinOf(currentMin, expr) + + AggregateEvaluation( + currentMin :: Nil, + initialValue :: Nil, + updateMin :: Nil, + currentMin) + case CollectHashSet(Seq(expr)) => - val set = AttributeReference("hashSet", ArrayType(expr.dataType), nullable = false)() + val set = + AttributeReference("hashSet", new OpenHashSetUDT(expr.dataType), nullable = false)() val initialValue = NewSet(expr.dataType) val addToSet = AddItemToSet(expr, set) @@ -176,9 +188,10 @@ case class GeneratedAggregate( set) case CombineSetsAndCount(inputSet) => - val ArrayType(inputType, _) = inputSet.dataType - val set = AttributeReference("hashSet", inputSet.dataType, nullable = false)() - val initialValue = NewSet(inputType) + val elementType = inputSet.dataType.asInstanceOf[OpenHashSetUDT].elementType + val set = + AttributeReference("hashSet", new OpenHashSetUDT(elementType), nullable = false)() + val initialValue = NewSet(elementType) val collectSets = CombineSets(set, inputSet) AggregateEvaluation( @@ -186,6 +199,8 @@ case class GeneratedAggregate( initialValue :: Nil, collectSets :: Nil, CountSet(set)) + + case o => sys.error(s"$o can't be codegened.") } val computationSchema = computeFunctions.flatMap(_.schema) @@ -271,9 +286,9 @@ case class GeneratedAggregate( private[this] val resultIterator = buffers.entrySet.iterator() private[this] val resultProjection = resultProjectionBuilder() - def hasNext = resultIterator.hasNext + def hasNext: Boolean = resultIterator.hasNext - def next() = { + def next(): Row = { val currentGroup = resultIterator.next() resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala new file mode 100644 index 000000000000..8a8c3a404323 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions.Attribute + + +/** + * Physical plan node for scanning data from a local collection. + */ +case class LocalTableScan(output: Seq[Attribute], rows: Seq[Row]) extends LeafNode { + + private lazy val rdd = sqlContext.sparkContext.parallelize(rows) + + override def execute(): RDD[Row] = rdd + + + override def executeCollect(): Array[Row] = { + val converter = CatalystTypeConverters.createToScalaConverter(schema) + rows.map(converter(_).asInstanceOf[Row]).toArray + } + + + override def executeTake(limit: Int): Array[Row] = { + val converter = CatalystTypeConverters.createToScalaConverter(schema) + rows.map(converter(_).asInstanceOf[Row]).take(limit).toArray + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 6fecd1ff066c..59c89800da00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -21,12 +21,14 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.{ScalaReflection, trees} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ +import scala.collection.mutable.ArrayBuffer + object SparkPlan { protected[sql] val currentContext = new ThreadLocal[SQLContext]() } @@ -65,10 +67,17 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ // TODO: Move to `DistributedPlan` /** Specifies how data is partitioned across different nodes in the cluster. */ def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH! + /** Specifies any partition requirements on the input data for this operator. */ def requiredChildDistribution: Seq[Distribution] = Seq.fill(children.size)(UnspecifiedDistribution) + /** Specifies how data is ordered in each partition. */ + def outputOrdering: Seq[SortOrder] = Nil + + /** Specifies sort order for each partition requirements on the input data for this operator. */ + def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) + /** * Runs this query returning the result as an RDD. */ @@ -77,15 +86,65 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** * Runs this query returning the result as an array. */ - def executeCollect(): Array[Row] = - execute().map(ScalaReflection.convertRowToScala(_, schema)).collect() + + def executeCollect(): Array[Row] = { + execute().mapPartitions { iter => + val converter = CatalystTypeConverters.createToScalaConverter(schema) + iter.map(converter(_).asInstanceOf[Row]) + }.collect() + } + + /** + * Runs this query returning the first `n` rows as an array. + * + * This is modeled after RDD.take but never runs any job locally on the driver. + */ + def executeTake(n: Int): Array[Row] = { + if (n == 0) { + return new Array[Row](0) + } + + val childRDD = execute().map(_.copy()) + + val buf = new ArrayBuffer[Row] + val totalParts = childRDD.partitions.length + var partsScanned = 0 + while (buf.size < n && partsScanned < totalParts) { + // The number of partitions to try in this iteration. It is ok for this number to be + // greater than totalParts because we actually cap it at totalParts in runJob. + var numPartsToTry = 1 + if (partsScanned > 0) { + // If we didn't find any rows after the first iteration, just try all partitions next. + // Otherwise, interpolate the number of partitions we need to try, but overestimate it + // by 50%. + if (buf.size == 0) { + numPartsToTry = totalParts - 1 + } else { + numPartsToTry = (1.5 * n * partsScanned / buf.size).toInt + } + } + numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions + + val left = n - buf.size + val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + val sc = sqlContext.sparkContext + val res = + sc.runJob(childRDD, (it: Iterator[Row]) => it.take(left).toArray, p, allowLocal = false) + + res.foreach(buf ++= _.take(n - buf.size)) + partsScanned += numPartsToTry + } + + val converter = CatalystTypeConverters.createToScalaConverter(schema) + buf.toArray.map(converter(_).asInstanceOf[Row]) + } protected def newProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { log.debug( s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") if (codegenEnabled) { - GenerateProjection(expressions, inputSchema) + GenerateProjection.generate(expressions, inputSchema) } else { new InterpretedProjection(expressions, inputSchema) } @@ -97,7 +156,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ log.debug( s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") if(codegenEnabled) { - GenerateMutableProjection(expressions, inputSchema) + GenerateMutableProjection.generate(expressions, inputSchema) } else { () => new InterpretedMutableProjection(expressions, inputSchema) } @@ -107,15 +166,15 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ protected def newPredicate( expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = { if (codegenEnabled) { - GeneratePredicate(expression, inputSchema) + GeneratePredicate.generate(expression, inputSchema) } else { - InterpretedPredicate(expression, inputSchema) + InterpretedPredicate.create(expression, inputSchema) } } protected def newOrdering(order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[Row] = { if (codegenEnabled) { - GenerateOrdering(order, inputSchema) + GenerateOrdering.generate(order, inputSchema) } else { new RowOrdering(order, inputSchema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 30564e14fa89..eea15aff5dbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import java.nio.ByteBuffer +import java.util.{HashMap => JavaHashMap} import org.apache.spark.sql.types.Decimal @@ -26,20 +27,19 @@ import scala.reflect.ClassTag import com.clearspring.analytics.stream.cardinality.HyperLogLog import com.esotericsoftware.kryo.io.{Input, Output} import com.esotericsoftware.kryo.{Serializer, Kryo} -import com.twitter.chill.{AllScalaRegistrar, ResourcePool} +import com.twitter.chill.ResourcePool import org.apache.spark.{SparkEnv, SparkConf} import org.apache.spark.serializer.{SerializerInstance, KryoSerializer} import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.util.collection.OpenHashSet import org.apache.spark.util.MutablePair -import org.apache.spark.util.Utils import org.apache.spark.sql.catalyst.expressions.codegen.{IntegerHashSet, LongHashSet} private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) { override def newKryo(): Kryo = { - val kryo = new Kryo() + val kryo = super.newKryo() kryo.setRegistrationRequired(false) kryo.register(classOf[MutablePair[_, _]]) kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow]) @@ -55,10 +55,9 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]], new OpenHashSetSerializer) kryo.register(classOf[Decimal]) + kryo.register(classOf[JavaHashMap[_, _]]) kryo.setReferences(false) - kryo.setClassLoader(Utils.getSparkClassLoader) - new AllScalaRegistrar().apply(kryo) kryo } } @@ -66,15 +65,12 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co private[execution] class KryoResourcePool(size: Int) extends ResourcePool[SerializerInstance](size) { - val ser: KryoSerializer = { + val ser: SparkSqlSerializer = { val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) - // TODO (lian) Using KryoSerializer here is workaround, needs further investigation - // Using SparkSqlSerializer here makes BasicQuerySuite to fail because of Kryo serialization - // related error. - new KryoSerializer(sparkConf) + new SparkSqlSerializer(sparkConf) } - def newInstance() = ser.newInstance() + def newInstance(): SerializerInstance = ser.newInstance() } private[sql] object SparkSqlSerializer { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala new file mode 100644 index 000000000000..cec97de2cd8e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -0,0 +1,421 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.io._ +import java.math.{BigDecimal, BigInteger} +import java.nio.ByteBuffer +import java.sql.Timestamp + +import scala.reflect.ClassTag + +import org.apache.spark.serializer._ +import org.apache.spark.Logging +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.types._ + +/** + * The serialization stream for [[SparkSqlSerializer2]]. It assumes that the object passed in + * its `writeObject` are [[Product2]]. The serialization functions for the key and value of the + * [[Product2]] are constructed based on their schemata. + * The benefit of this serialization stream is that compared with general-purpose serializers like + * Kryo and Java serializer, it can significantly reduce the size of serialized and has a lower + * allocation cost, which can benefit the shuffle operation. Right now, its main limitations are: + * 1. It does not support complex types, i.e. Map, Array, and Struct. + * 2. It assumes that the objects passed in are [[Product2]]. So, it cannot be used when + * [[org.apache.spark.util.collection.ExternalSorter]]'s merge sort operation is used because + * the objects passed in the serializer are not in the type of [[Product2]]. Also also see + * the comment of the `serializer` method in [[Exchange]] for more information on it. + */ +private[sql] class Serializer2SerializationStream( + keySchema: Array[DataType], + valueSchema: Array[DataType], + out: OutputStream) + extends SerializationStream with Logging { + + val rowOut = new DataOutputStream(out) + val writeKey = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut) + val writeValue = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut) + + def writeObject[T: ClassTag](t: T): SerializationStream = { + val kv = t.asInstanceOf[Product2[Row, Row]] + writeKey(kv._1) + writeValue(kv._2) + + this + } + + def flush(): Unit = { + rowOut.flush() + } + + def close(): Unit = { + rowOut.close() + } +} + +/** + * The corresponding deserialization stream for [[Serializer2SerializationStream]]. + */ +private[sql] class Serializer2DeserializationStream( + keySchema: Array[DataType], + valueSchema: Array[DataType], + in: InputStream) + extends DeserializationStream with Logging { + + val rowIn = new DataInputStream(new BufferedInputStream(in)) + + val key = if (keySchema != null) new SpecificMutableRow(keySchema) else null + val value = if (valueSchema != null) new SpecificMutableRow(valueSchema) else null + val readKey = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key) + val readValue = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value) + + def readObject[T: ClassTag](): T = { + readKey() + readValue() + + (key, value).asInstanceOf[T] + } + + def close(): Unit = { + rowIn.close() + } +} + +private[sql] class ShuffleSerializerInstance( + keySchema: Array[DataType], + valueSchema: Array[DataType]) + extends SerializerInstance { + + def serialize[T: ClassTag](t: T): ByteBuffer = + throw new UnsupportedOperationException("Not supported.") + + def deserialize[T: ClassTag](bytes: ByteBuffer): T = + throw new UnsupportedOperationException("Not supported.") + + def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = + throw new UnsupportedOperationException("Not supported.") + + def serializeStream(s: OutputStream): SerializationStream = { + new Serializer2SerializationStream(keySchema, valueSchema, s) + } + + def deserializeStream(s: InputStream): DeserializationStream = { + new Serializer2DeserializationStream(keySchema, valueSchema, s) + } +} + +/** + * SparkSqlSerializer2 is a special serializer that creates serialization function and + * deserialization function based on the schema of data. It assumes that values passed in + * are key/value pairs and values returned from it are also key/value pairs. + * The schema of keys is represented by `keySchema` and that of values is represented by + * `valueSchema`. + */ +private[sql] class SparkSqlSerializer2(keySchema: Array[DataType], valueSchema: Array[DataType]) + extends Serializer + with Logging + with Serializable{ + + def newInstance(): SerializerInstance = new ShuffleSerializerInstance(keySchema, valueSchema) +} + +private[sql] object SparkSqlSerializer2 { + + final val NULL = 0 + final val NOT_NULL = 1 + + /** + * Check if rows with the given schema can be serialized with ShuffleSerializer. + */ + def support(schema: Array[DataType]): Boolean = { + if (schema == null) return true + + var i = 0 + while (i < schema.length) { + schema(i) match { + case udt: UserDefinedType[_] => return false + case array: ArrayType => return false + case map: MapType => return false + case struct: StructType => return false + case _ => + } + i += 1 + } + + return true + } + + /** + * The util function to create the serialization function based on the given schema. + */ + def createSerializationFunction(schema: Array[DataType], out: DataOutputStream): Row => Unit = { + (row: Row) => + // If the schema is null, the returned function does nothing when it get called. + if (schema != null) { + var i = 0 + while (i < schema.length) { + schema(i) match { + // When we write values to the underlying stream, we also first write the null byte + // first. Then, if the value is not null, we write the contents out. + + case NullType => // Write nothing. + + case BooleanType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeBoolean(row.getBoolean(i)) + } + + case ByteType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeByte(row.getByte(i)) + } + + case ShortType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeShort(row.getShort(i)) + } + + case IntegerType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeInt(row.getInt(i)) + } + + case LongType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeLong(row.getLong(i)) + } + + case FloatType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeFloat(row.getFloat(i)) + } + + case DoubleType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeDouble(row.getDouble(i)) + } + + case decimal: DecimalType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + val value = row.apply(i).asInstanceOf[Decimal] + val javaBigDecimal = value.toJavaBigDecimal + // First, write out the unscaled value. + val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray + out.writeInt(bytes.length) + out.write(bytes) + // Then, write out the scale. + out.writeInt(javaBigDecimal.scale()) + } + + case DateType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeInt(row.getAs[Int](i)) + } + + case TimestampType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + val timestamp = row.getAs[java.sql.Timestamp](i) + val time = timestamp.getTime + val nanos = timestamp.getNanos + out.writeLong(time - (nanos / 1000000)) // Write the milliseconds value. + out.writeInt(nanos) // Write the nanoseconds part. + } + + case StringType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + val bytes = row.getAs[UTF8String](i).getBytes + out.writeInt(bytes.length) + out.write(bytes) + } + + case BinaryType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + val bytes = row.getAs[Array[Byte]](i) + out.writeInt(bytes.length) + out.write(bytes) + } + } + i += 1 + } + } + } + + /** + * The util function to create the deserialization function based on the given schema. + */ + def createDeserializationFunction( + schema: Array[DataType], + in: DataInputStream, + mutableRow: SpecificMutableRow): () => Unit = { + () => { + // If the schema is null, the returned function does nothing when it get called. + if (schema != null) { + var i = 0 + while (i < schema.length) { + schema(i) match { + // When we read values from the underlying stream, we also first read the null byte + // first. Then, if the value is not null, we update the field of the mutable row. + + case NullType => mutableRow.setNullAt(i) // Read nothing. + + case BooleanType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setBoolean(i, in.readBoolean()) + } + + case ByteType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setByte(i, in.readByte()) + } + + case ShortType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setShort(i, in.readShort()) + } + + case IntegerType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setInt(i, in.readInt()) + } + + case LongType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setLong(i, in.readLong()) + } + + case FloatType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setFloat(i, in.readFloat()) + } + + case DoubleType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setDouble(i, in.readDouble()) + } + + case decimal: DecimalType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + // First, read in the unscaled value. + val length = in.readInt() + val bytes = new Array[Byte](length) + in.readFully(bytes) + val unscaledVal = new BigInteger(bytes) + // Then, read the scale. + val scale = in.readInt() + // Finally, create the Decimal object and set it in the row. + mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, scale))) + } + + case DateType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.update(i, in.readInt()) + } + + case TimestampType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + val time = in.readLong() // Read the milliseconds value. + val nanos = in.readInt() // Read the nanoseconds part. + val timestamp = new Timestamp(time) + timestamp.setNanos(nanos) + mutableRow.update(i, timestamp) + } + + case StringType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + val length = in.readInt() + val bytes = new Array[Byte](length) + in.readFully(bytes) + mutableRow.update(i, UTF8String(bytes)) + } + + case BinaryType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + val length = in.readInt() + val bytes = new Array[Byte](length) + in.readFully(bytes) + mutableRow.update(i, bytes) + } + } + i += 1 + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 0cc9d049c964..030ef118f75d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,17 +17,17 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{SQLContext, Strategy, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} +import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} import org.apache.spark.sql.parquet._ +import org.apache.spark.sql.sources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} import org.apache.spark.sql.types._ -import org.apache.spark.sql.sources.{CreateTempTableUsing, CreateTableUsing} - +import org.apache.spark.sql.{SQLContext, Strategy, execution} private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SQLContext#SparkPlanner => @@ -90,6 +90,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { left.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) + // If the sort merge join option is set, we want to use sort merge join prior to hashjoin + // for now let's support inner join first, then add outer join + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) + if sqlContext.conf.sortMergeJoinEnabled => + val mergeJoin = + joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) + condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => val buildSide = if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { @@ -154,15 +162,15 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => Nil } - def canBeCodeGened(aggs: Seq[AggregateExpression]) = !aggs.exists { - case _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false + def canBeCodeGened(aggs: Seq[AggregateExpression]): Boolean = !aggs.exists { + case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => false // The generated set implementation is pretty limited ATM. case CollectHashSet(exprs) if exprs.size == 1 && Seq(IntegerType, LongType).contains(exprs.head.dataType) => false case _ => true } - def allAggregates(exprs: Seq[Expression]) = + def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression] = exprs.flatMap(_.collect { case a: AggregateExpression => a }) } @@ -211,9 +219,15 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { ParquetRelation.create(path, child, sparkContext.hadoopConfiguration, sqlContext) // Note: overwrite=false because otherwise the metadata we just created will be deleted InsertIntoParquetTable(relation, planLater(child), overwrite = false) :: Nil - case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) => + case logical.InsertIntoTable( + table: ParquetRelation, partition, child, overwrite, ifNotExists) => InsertIntoParquetTable(table, planLater(child), overwrite) :: Nil case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => + val partitionColNames = relation.partitioningAttributes.map(_.name).toSet + val filtersToPush = filters.filter { pred => + val referencedColNames = pred.references.map(_.name).toSet + referencedColNames.intersect(partitionColNames).isEmpty + } val prunePushedDownFilters = if (sqlContext.conf.parquetFilterPushDown) { (predicates: Seq[Expression]) => { @@ -225,6 +239,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // "A AND B" in the higher-level filter, not just "B". predicates.map(p => p -> ParquetFilters.createFilter(p)).collect { case (predicate, None) => predicate + // Filter needs to be applied above when it contains partitioning + // columns + case (predicate, _) if(!predicate.references.map(_.name).toSet + .intersect (partitionColNames).isEmpty) => predicate } } } else { @@ -237,7 +255,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { ParquetTableScan( _, relation, - if (sqlContext.conf.parquetFilterPushDown) filters else Nil)) :: Nil + if (sqlContext.conf.parquetFilterPushDown) filtersToPush else Nil)) :: Nil case _ => Nil } @@ -257,7 +275,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { - def numPartitions = self.numPartitions + def numPartitions: Int = self.numPartitions def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case r: RunnableCommand => ExecutedCommand(r) :: Nil @@ -284,13 +302,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil case logical.Sample(fraction, withReplacement, seed, child) => execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil - case SparkLogicalPlan(alreadyPlanned) => alreadyPlanned :: Nil case logical.LocalRelation(output, data) => - val nPartitions = if (data.isEmpty) 1 else numPartitions - PhysicalRDD( - output, - RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions), - StructType.fromAttributes(output))) :: Nil + LocalTableScan(output, data) :: Nil case logical.Limit(IntegerLiteral(limit), child) => execution.Limit(limit, planLater(child)) :: Nil case Unions(unionChildren) => @@ -299,12 +312,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Except(planLater(left), planLater(right)) :: Nil case logical.Intersect(left, right) => execution.Intersect(planLater(left), planLater(right)) :: Nil - case logical.Generate(generator, join, outer, _, child) => - execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil - case logical.NoRelation => + case g @ logical.Generate(generator, join, outer, _, _, child) => + execution.Generate( + generator, join = join, outer = outer, g.output, planLater(child)) :: Nil + case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => - execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil + execution.Exchange( + HashPartitioning(expressions, numPartitions), Nil, planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil @@ -314,13 +329,27 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object DDLStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case CreateTableUsing(tableName, userSpecifiedSchema, provider, true, options) => + case CreateTableUsing(tableName, userSpecifiedSchema, provider, true, opts, false, _) => ExecutedCommand( - CreateTempTableUsing(tableName, userSpecifiedSchema, provider, options)) :: Nil - - case CreateTableUsing(tableName, userSpecifiedSchema, provider, false, options) => + CreateTempTableUsing( + tableName, userSpecifiedSchema, provider, opts)) :: Nil + case c: CreateTableUsing if !c.temporary => + sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") + case c: CreateTableUsing if c.temporary && c.allowExisting => + sys.error("allowExisting should be set to false when creating a temporary table.") + + case CreateTableUsingAsSelect(tableName, provider, true, mode, opts, query) => + val cmd = + CreateTempTableUsingAsSelect(tableName, provider, mode, opts, query) + ExecutedCommand(cmd) :: Nil + case c: CreateTableUsingAsSelect if !c.temporary => sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") + case LogicalDescribeCommand(table, isExtended) => + val resultPlan = self.sqlContext.executePlan(table).executedPlan + ExecutedCommand( + RunnableDescribeCommand(resultPlan, resultPlan.output, isExtended)) :: Nil + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 16ca4be5587c..d286fe81bee5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -17,18 +17,15 @@ package org.apache.spark.sql.execution -import scala.collection.mutable.ArrayBuffer -import scala.reflect.runtime.universe.TypeTag - import org.apache.spark.{SparkEnv, HashPartitioner, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, OrderedDistribution, SinglePartition, UnspecifiedDistribution} -import org.apache.spark.util.MutablePair +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.util.{CompletionIterator, MutablePair} import org.apache.spark.util.collection.ExternalSorter /** @@ -36,14 +33,16 @@ import org.apache.spark.util.collection.ExternalSorter */ @DeveloperApi case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { - override def output = projectList.map(_.toAttribute) + override def output: Seq[Attribute] = projectList.map(_.toAttribute) @transient lazy val buildProjection = newMutableProjection(projectList, child.output) - def execute() = child.execute().mapPartitions { iter => + override def execute(): RDD[Row] = child.execute().mapPartitions { iter => val resuableProjection = buildProjection() iter.map(resuableProjection) } + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering } /** @@ -51,13 +50,15 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends */ @DeveloperApi case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output - @transient lazy val conditionEvaluator = newPredicate(condition, child.output) + @transient lazy val conditionEvaluator: (Row) => Boolean = newPredicate(condition, child.output) - def execute() = child.execute().mapPartitions { iter => + override def execute(): RDD[Row] = child.execute().mapPartitions { iter => iter.filter(conditionEvaluator) } + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering } /** @@ -67,10 +68,12 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: SparkPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output // TODO: How to pick seed? - override def execute() = child.execute().map(_.copy()).sample(withReplacement, fraction, seed) + override def execute(): RDD[Row] = { + child.execute().map(_.copy()).sample(withReplacement, fraction, seed) + } } /** @@ -79,8 +82,8 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: @DeveloperApi case class Union(children: Seq[SparkPlan]) extends SparkPlan { // TODO: attributes output by union should be distinct for nullability purposes - override def output = children.head.output - override def execute() = sparkContext.union(children.map(_.execute())) + override def output: Seq[Attribute] = children.head.output + override def execute(): RDD[Row] = sparkContext.union(children.map(_.execute())) } /** @@ -100,54 +103,12 @@ case class Limit(limit: Int, child: SparkPlan) /** We must copy rows when sort based shuffle is on */ private def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] - override def output = child.output - override def outputPartitioning = SinglePartition - - /** - * A custom implementation modeled after the take function on RDDs but which never runs any job - * locally. This is to avoid shipping an entire partition of data in order to retrieve only a few - * rows. - */ - override def executeCollect(): Array[Row] = { - if (limit == 0) { - return new Array[Row](0) - } - - val childRDD = child.execute().map(_.copy()) - - val buf = new ArrayBuffer[Row] - val totalParts = childRDD.partitions.length - var partsScanned = 0 - while (buf.size < limit && partsScanned < totalParts) { - // The number of partitions to try in this iteration. It is ok for this number to be - // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1 - if (partsScanned > 0) { - // If we didn't find any rows after the first iteration, just try all partitions next. - // Otherwise, interpolate the number of partitions we need to try, but overestimate it - // by 50%. - if (buf.size == 0) { - numPartsToTry = totalParts - 1 - } else { - numPartsToTry = (1.5 * limit * partsScanned / buf.size).toInt - } - } - numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions - - val left = limit - buf.size - val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) - val sc = sqlContext.sparkContext - val res = - sc.runJob(childRDD, (it: Iterator[Row]) => it.take(left).toArray, p, allowLocal = false) + override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = SinglePartition - res.foreach(buf ++= _.take(limit - buf.size)) - partsScanned += numPartsToTry - } + override def executeCollect(): Array[Row] = child.executeTake(limit) - buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema)) - } - - override def execute() = { + override def execute(): RDD[Row] = { val rdd: RDD[_ <: Product2[Boolean, Row]] = if (sortBasedShuffleOn) { child.execute().mapPartitions { iter => iter.take(limit).map(row => (false, row.copy())) @@ -160,7 +121,7 @@ case class Limit(limit: Int, child: SparkPlan) } val part = new HashPartitioner(1) val shuffled = new ShuffledRDD[Boolean, Row, Row](rdd, part) - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + shuffled.setSerializer(new SparkSqlSerializer(child.sqlContext.sparkContext.getConf)) shuffled.mapPartitions(_.take(limit).map(_._2)) } } @@ -174,18 +135,24 @@ case class Limit(limit: Int, child: SparkPlan) @DeveloperApi case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode { - override def output = child.output - override def outputPartitioning = SinglePartition + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = SinglePartition - val ord = new RowOrdering(sortOrder, child.output) + private val ord: RowOrdering = new RowOrdering(sortOrder, child.output) - // TODO: Is this copying for no reason? - override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ord) - .map(ScalaReflection.convertRowToScala(_, this.schema)) + private def collectData(): Array[Row] = child.execute().map(_.copy()).takeOrdered(limit)(ord) + + override def executeCollect(): Array[Row] = { + val converter = CatalystTypeConverters.createToScalaConverter(schema) + collectData().map(converter(_).asInstanceOf[Row]) + } // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. - override def execute() = sparkContext.makeRDD(executeCollect(), 1) + override def execute(): RDD[Row] = sparkContext.makeRDD(collectData(), 1) + + override def outputOrdering: Seq[SortOrder] = sortOrder } /** @@ -200,17 +167,19 @@ case class Sort( global: Boolean, child: SparkPlan) extends UnaryNode { - override def requiredChildDistribution = + override def requiredChildDistribution: Seq[Distribution] = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - override def execute() = attachTree(this, "sort") { + override def execute(): RDD[Row] = attachTree(this, "sort") { child.execute().mapPartitions( { iterator => val ordering = newOrdering(sortOrder, child.output) iterator.map(_.copy()).toArray.sorted(ordering).iterator }, preservesPartitioning = true) } - override def output = child.output + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder } /** @@ -225,19 +194,24 @@ case class ExternalSort( global: Boolean, child: SparkPlan) extends UnaryNode { - override def requiredChildDistribution = + + override def requiredChildDistribution: Seq[Distribution] = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - override def execute() = attachTree(this, "sort") { + override def execute(): RDD[Row] = attachTree(this, "sort") { child.execute().mapPartitions( { iterator => val ordering = newOrdering(sortOrder, child.output) val sorter = new ExternalSorter[Row, Null, Row](ordering = Some(ordering)) - sorter.insertAll(iterator.map(r => (r, null))) - sorter.iterator.map(_._1) + sorter.insertAll(iterator.map(r => (r.copy, null))) + val baseIterator = sorter.iterator.map(_._1) + // TODO(marmbrus): The complex type signature below thwarts inference for no reason. + CompletionIterator[Row, Iterator[Row]](baseIterator, sorter.stop()) }, preservesPartitioning = true) } - override def output = child.output + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder } /** @@ -249,12 +223,12 @@ case class ExternalSort( */ @DeveloperApi case class Distinct(partial: Boolean, child: SparkPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output - override def requiredChildDistribution = + override def requiredChildDistribution: Seq[Distribution] = if (partial) UnspecifiedDistribution :: Nil else ClusteredDistribution(child.output) :: Nil - override def execute() = { + override def execute(): RDD[Row] = { child.execute().mapPartitions { iter => val hashSet = new scala.collection.mutable.HashSet[Row]() @@ -279,9 +253,9 @@ case class Distinct(partial: Boolean, child: SparkPlan) extends UnaryNode { */ @DeveloperApi case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { - override def output = left.output + override def output: Seq[Attribute] = left.output - override def execute() = { + override def execute(): RDD[Row] = { left.execute().map(_.copy()).subtract(right.execute().map(_.copy())) } } @@ -293,9 +267,9 @@ case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { */ @DeveloperApi case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { - override def output = children.head.output + override def output: Seq[Attribute] = children.head.output - override def execute() = { + override def execute(): RDD[Row] = { left.execute().map(_.copy()).intersection(right.execute().map(_.copy())) } } @@ -308,6 +282,7 @@ case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { */ @DeveloperApi case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPlan { - def children = child :: Nil - def execute() = child.execute() + def children: Seq[SparkPlan] = child :: Nil + + def execute(): RDD[Row] = child.execute() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 52a31f01a435..99f24910fd61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -20,11 +20,13 @@ package org.apache.spark.sql.execution import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SchemaRDD, SQLConf, SQLContext} +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.{Row, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} /** * A logical command that is executed for its side-effects. `RunnableCommand`s are @@ -52,13 +54,19 @@ case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan { */ protected[sql] lazy val sideEffectResult: Seq[Row] = cmd.run(sqlContext) - override def output = cmd.output + override def output: Seq[Attribute] = cmd.output - override def children = Nil + override def children: Seq[SparkPlan] = Nil override def executeCollect(): Array[Row] = sideEffectResult.toArray - override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1) + override def executeTake(limit: Int): Array[Row] = sideEffectResult.take(limit).toArray + + override def execute(): RDD[Row] = { + val converted = sideEffectResult.map(r => + CatalystTypeConverters.convertToCatalyst(r, schema).asInstanceOf[Row]) + sqlContext.sparkContext.parallelize(converted, 1) + } } /** @@ -67,9 +75,10 @@ case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan { @DeveloperApi case class SetCommand( kv: Option[(String, Option[String])], - override val output: Seq[Attribute]) extends RunnableCommand with Logging { + override val output: Seq[Attribute]) + extends RunnableCommand with Logging { - override def run(sqlContext: SQLContext) = kv match { + override def run(sqlContext: SQLContext): Seq[Row] = kv match { // Configures the deprecated "mapred.reduce.tasks" property. case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, Some(value))) => logWarning( @@ -113,10 +122,13 @@ case class SetCommand( @DeveloperApi case class ExplainCommand( logicalPlan: LogicalPlan, - override val output: Seq[Attribute], extended: Boolean = false) extends RunnableCommand { + override val output: Seq[Attribute] = + Seq(AttributeReference("plan", StringType, nullable = false)()), + extended: Boolean = false) + extends RunnableCommand { // Run through the optimizer to generate the physical plan. - override def run(sqlContext: SQLContext) = try { + override def run(sqlContext: SQLContext): Seq[Row] = try { // TODO in Hive, the "extended" ExplainCommand prints the AST as well, and detailed properties. val queryExecution = sqlContext.executePlan(logicalPlan) val outputString = if (extended) queryExecution.toString else queryExecution.simpleString @@ -134,10 +146,13 @@ case class ExplainCommand( case class CacheTableCommand( tableName: String, plan: Option[LogicalPlan], - isLazy: Boolean) extends RunnableCommand { + isLazy: Boolean) + extends RunnableCommand { - override def run(sqlContext: SQLContext) = { - plan.foreach(p => new SchemaRDD(sqlContext, p).registerTempTable(tableName)) + override def run(sqlContext: SQLContext): Seq[Row] = { + plan.foreach { logicalPlan => + sqlContext.registerDataFrameAsTable(DataFrame(sqlContext, logicalPlan), tableName) + } sqlContext.cacheTable(tableName) if (!isLazy) { @@ -158,8 +173,23 @@ case class CacheTableCommand( @DeveloperApi case class UncacheTableCommand(tableName: String) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { - sqlContext.table(tableName).unpersist() + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.table(tableName).unpersist(blocking = false) + Seq.empty[Row] + } + + override def output: Seq[Attribute] = Seq.empty +} + +/** + * :: DeveloperApi :: + * Clear all cached data from the in-memory cache. + */ +@DeveloperApi +case object ClearCacheCommand extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.clearCache() Seq.empty[Row] } @@ -172,9 +202,47 @@ case class UncacheTableCommand(tableName: String) extends RunnableCommand { @DeveloperApi case class DescribeCommand( child: SparkPlan, - override val output: Seq[Attribute]) extends RunnableCommand { + override val output: Seq[Attribute], + isExtended: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + child.schema.fields.map { field => + val cmtKey = "comment" + val comment = if (field.metadata.contains(cmtKey)) field.metadata.getString(cmtKey) else "" + Row(field.name, field.dataType.simpleString, comment) + } + } +} + +/** + * A command for users to get tables in the given database. + * If a databaseName is not given, the current database will be used. + * The syntax of using this command in SQL is: + * {{{ + * SHOW TABLES [IN databaseName] + * }}} + * :: DeveloperApi :: + */ +@DeveloperApi +case class ShowTablesCommand(databaseName: Option[String]) extends RunnableCommand { + + // The result of SHOW TABLES has two columns, tableName and isTemporary. + override val output: Seq[Attribute] = { + val schema = StructType( + StructField("tableName", StringType, false) :: + StructField("isTemporary", BooleanType, false) :: Nil) + + schema.toAttributes + } + + override def run(sqlContext: SQLContext): Seq[Row] = { + // Since we need to return a Seq of rows, we will call getTables directly + // instead of calling tables in sqlContext. + val rows = sqlContext.catalog.getTables(databaseName).map { + case (tableName, isTemporary) => Row(tableName, isTemporary) + } - override def run(sqlContext: SQLContext) = { - child.output.map(field => Row(field.name, field.dataType.toString, null)) + rows } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 4d7e338e8ed1..710787096e6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.execution +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.Attribute + import scala.collection.mutable.HashSet -import org.apache.spark.{AccumulatorParam, Accumulator, SparkContext} +import org.apache.spark.{AccumulatorParam, Accumulator} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.SparkContext._ -import org.apache.spark.sql.{SchemaRDD, Row} +import org.apache.spark.sql.{SQLConf, SQLContext, DataFrame, Row} import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.types._ @@ -32,17 +34,28 @@ import org.apache.spark.sql.types._ * * Usage: * {{{ - * sql("SELECT key FROM src").debug + * import org.apache.spark.sql.execution.debug._ + * sql("SELECT key FROM src").debug() + * dataFrame.typeCheck() * }}} */ package object debug { + /** + * Augments [[SQLContext]] with debug methods. + */ + implicit class DebugSQLContext(sqlContext: SQLContext) { + def debug(): Unit = { + sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false") + } + } + /** * :: DeveloperApi :: - * Augments SchemaRDDs with debug methods. + * Augments [[DataFrame]]s with debug methods. */ @DeveloperApi - implicit class DebugQuery(query: SchemaRDD) { + implicit class DebugQuery(query: DataFrame) { def debug(): Unit = { val plan = query.queryExecution.executedPlan val visited = new collection.mutable.HashSet[TreeNodeRef]() @@ -77,7 +90,7 @@ package object debug { } private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode { - def output = child.output + def output: Seq[Attribute] = child.output implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] { def zero(initialValue: HashSet[String]): HashSet[String] = { @@ -98,10 +111,10 @@ package object debug { */ case class ColumnMetrics( elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty)) - val tupleCount = sparkContext.accumulator[Int](0) + val tupleCount: Accumulator[Int] = sparkContext.accumulator[Int](0) - val numColumns = child.output.size - val columnStats = Array.fill(child.output.size)(new ColumnMetrics()) + val numColumns: Int = child.output.size + val columnStats: Array[ColumnMetrics] = Array.fill(child.output.size)(new ColumnMetrics()) def dumpStats(): Unit = { println(s"== ${child.simpleString} ==") @@ -112,11 +125,11 @@ package object debug { } } - def execute() = { + def execute(): RDD[Row] = { child.execute().mapPartitions { iter => new Iterator[Row] { - def hasNext = iter.hasNext - def next() = { + def hasNext: Boolean = iter.hasNext + def next(): Row = { val currentRow = iter.next() tupleCount += 1 var i = 0 @@ -135,11 +148,9 @@ package object debug { } /** - * :: DeveloperApi :: * Helper functions for checking that runtime types match a given schema. */ - @DeveloperApi - object TypeCheck { + private[sql] object TypeCheck { def typeCheck(data: Any, schema: DataType): Unit = (data, schema) match { case (null, _) => @@ -153,37 +164,36 @@ package object debug { case (_: Long, LongType) => case (_: Int, IntegerType) => - case (_: String, StringType) => + case (_: UTF8String, StringType) => case (_: Float, FloatType) => case (_: Byte, ByteType) => case (_: Short, ShortType) => case (_: Boolean, BooleanType) => case (_: Double, DoubleType) => + case (v, udt: UserDefinedType[_]) => typeCheck(v, udt.sqlType) case (d, t) => sys.error(s"Invalid data found: got $d (${d.getClass}) expected $t") } } /** - * :: DeveloperApi :: - * Augments SchemaRDDs with debug methods. + * Augments [[DataFrame]]s with debug methods. */ - @DeveloperApi private[sql] case class TypeCheck(child: SparkPlan) extends SparkPlan { import TypeCheck._ - override def nodeName = "" + override def nodeName: String = "" /* Only required when defining this class in a REPL. override def makeCopy(args: Array[Object]): this.type = TypeCheck(args(0).asInstanceOf[SparkPlan]).asInstanceOf[this.type] */ - def output = child.output + def output: Seq[Attribute] = child.output - def children = child :: Nil + def children: List[SparkPlan] = child :: Nil - def execute() = { + def execute(): RDD[Row] = { child.execute().map { row => try typeCheck(row, child.schema) catch { case e: Exception => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 2dd22c020ef1..926f5e6c137e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.rdd.RDD + import scala.concurrent._ import scala.concurrent.duration._ import scala.concurrent.ExecutionContext.Implicits.global import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.{Row, Expression} -import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} /** @@ -42,7 +44,7 @@ case class BroadcastHashJoin( right: SparkPlan) extends BinaryNode with HashJoin { - val timeout = { + val timeout: Duration = { val timeoutValue = sqlContext.conf.broadcastTimeout if (timeoutValue < 0) { Duration.Inf @@ -53,7 +55,7 @@ case class BroadcastHashJoin( override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning - override def requiredChildDistribution = + override def requiredChildDistribution: Seq[Distribution] = UnspecifiedDistribution :: UnspecifiedDistribution :: Nil @transient @@ -64,7 +66,7 @@ case class BroadcastHashJoin( sparkContext.broadcast(hashed) } - override def execute() = { + override def execute(): RDD[Row] = { val broadcastRelation = Await.result(broadcastFuture, timeout) streamedPlan.execute().mapPartitions { streamedIter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 2ab064fd0151..3ef1e0d7fbdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.{Expression, Row} -import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} /** @@ -34,11 +34,11 @@ case class BroadcastLeftSemiJoinHash( left: SparkPlan, right: SparkPlan) extends BinaryNode with HashJoin { - override val buildSide = BuildRight + override val buildSide: BuildSide = BuildRight - override def output = left.output + override def output: Seq[Attribute] = left.output - override def execute() = { + override def execute(): RDD[Row] = { val buildIter= buildPlan.execute().map(_.copy()).collect().toIterator val hashSet = new java.util.HashSet[Row]() var currentRow: Row = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 36aad13778bd..56200f6b8c8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} @@ -44,7 +45,7 @@ case class BroadcastNestedLoopJoin( override def outputPartitioning: Partitioning = streamed.outputPartitioning - override def output = { + override def output: Seq[Attribute] = { joinType match { case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) @@ -58,12 +59,12 @@ case class BroadcastNestedLoopJoin( } @transient private lazy val boundCondition = - InterpretedPredicate( + InterpretedPredicate.create( condition .map(c => BindReferences.bindReference(c, left.output ++ right.output)) .getOrElse(Literal(true))) - override def execute() = { + override def execute(): RDD[Row] = { val broadcastedRelation = sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index 76c14c02aab3..1cbc98354d67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.JoinedRow +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} /** @@ -26,9 +28,9 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} */ @DeveloperApi case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { - override def output = left.output ++ right.output + override def output: Seq[Attribute] = left.output ++ right.output - override def execute() = { + override def execute(): RDD[Row] = { val leftResults = left.execute().map(_.copy()) val rightResults = right.execute().map(_.copy()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 4012d757d5f9..851de1685509 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -41,7 +41,7 @@ trait HashJoin { case BuildRight => (rightKeys, leftKeys) } - override def output = left.output ++ right.output + override def output: Seq[Attribute] = left.output ++ right.output @transient protected lazy val buildSideKeyGenerator: Projection = newProjection(buildKeys, buildPlan.output) @@ -65,7 +65,7 @@ trait HashJoin { (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || (streamIter.hasNext && fetchNext()) - override final def next() = { + override final def next(): Row = { val ret = buildSide match { case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 59ef90427254..a396c0f5d56e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.joins import java.util.{HashMap => JavaHashMap} +import org.apache.spark.rdd.RDD + import scala.collection.JavaConversions._ import org.apache.spark.annotation.DeveloperApi @@ -49,10 +51,10 @@ case class HashOuterJoin( case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") } - override def requiredChildDistribution = + override def requiredChildDistribution: Seq[ClusteredDistribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def output = { + override def output: Seq[Attribute] = { joinType match { case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) @@ -78,12 +80,12 @@ case class HashOuterJoin( private[this] def leftOuterIterator( key: Row, joinedRow: JoinedRow, rightIter: Iterable[Row]): Iterator[Row] = { - val ret: Iterable[Row] = ( + val ret: Iterable[Row] = { if (!key.anyNull) { val temp = rightIter.collect { - case r if (boundCondition(joinedRow.withRight(r))) => joinedRow.copy + case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy() } - if (temp.size == 0) { + if (temp.size == 0) { joinedRow.withRight(rightNullRow).copy :: Nil } else { temp @@ -91,19 +93,19 @@ case class HashOuterJoin( } else { joinedRow.withRight(rightNullRow).copy :: Nil } - ) + } ret.iterator } private[this] def rightOuterIterator( key: Row, leftIter: Iterable[Row], joinedRow: JoinedRow): Iterator[Row] = { - val ret: Iterable[Row] = ( + val ret: Iterable[Row] = { if (!key.anyNull) { val temp = leftIter.collect { - case l if (boundCondition(joinedRow.withLeft(l))) => joinedRow.copy + case l if boundCondition(joinedRow.withLeft(l)) => joinedRow.copy } - if (temp.size == 0) { + if (temp.size == 0) { joinedRow.withLeft(leftNullRow).copy :: Nil } else { temp @@ -111,7 +113,7 @@ case class HashOuterJoin( } else { joinedRow.withLeft(leftNullRow).copy :: Nil } - ) + } ret.iterator } @@ -130,12 +132,12 @@ case class HashOuterJoin( // 1. For those matched (satisfy the join condition) records with both sides filled, // append them directly - case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { + case (r, idx) if boundCondition(joinedRow.withRight(r)) => matched = true // if the row satisfy the join condition, add its index into the matched set rightMatchedSet.add(idx) - joinedRow.copy - } + joinedRow.copy() + } ++ DUMMY_LIST.filter(_ => !matched).map( _ => { // 2. For those unmatched records in left, append additional records with empty right. @@ -143,22 +145,21 @@ case class HashOuterJoin( // as we don't know whether we need to append it until finish iterating all // of the records in right side. // If we didn't get any proper row, then append a single row with empty right. - joinedRow.withRight(rightNullRow).copy + joinedRow.withRight(rightNullRow).copy() }) } ++ rightIter.zipWithIndex.collect { // 3. For those unmatched records in right, append additional records with empty left. // Re-visiting the records in right, and append additional row with empty left, if its not // in the matched set. - case (r, idx) if (!rightMatchedSet.contains(idx)) => { - joinedRow(leftNullRow, r).copy - } + case (r, idx) if !rightMatchedSet.contains(idx) => + joinedRow(leftNullRow, r).copy() } } else { leftIter.iterator.map[Row] { l => - joinedRow(l, rightNullRow).copy + joinedRow(l, rightNullRow).copy() } ++ rightIter.iterator.map[Row] { r => - joinedRow(leftNullRow, r).copy + joinedRow(leftNullRow, r).copy() } } } @@ -182,13 +183,13 @@ case class HashOuterJoin( hashTable } - override def execute() = { + override def execute(): RDD[Row] = { val joinedRow = new JoinedRow() left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => // TODO this probably can be replaced by external sort (sort merged join?) joinType match { - case LeftOuter => { + case LeftOuter => val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) val keyGenerator = newProjection(leftKeys, left.output) leftIter.flatMap( currentRow => { @@ -196,8 +197,8 @@ case class HashOuterJoin( joinedRow.withLeft(currentRow) leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST)) }) - } - case RightOuter => { + + case RightOuter => val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) val keyGenerator = newProjection(rightKeys, right.output) rightIter.flatMap ( currentRow => { @@ -205,8 +206,8 @@ case class HashOuterJoin( joinedRow.withRight(currentRow) rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) }) - } - case FullOuter => { + + case FullOuter => val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => @@ -214,7 +215,7 @@ case class HashOuterJoin( leftHashTable.getOrElse(key, EMPTY_LIST), rightHashTable.getOrElse(key, EMPTY_LIST), joinedRow) } - } + case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 38b8993b03f8..ab84c123e0c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.execution.joins +import java.io.{ObjectInput, ObjectOutput, Externalizable} import java.util.{HashMap => JavaHashMap} import org.apache.spark.sql.catalyst.expressions.{Projection, Row} +import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.util.collection.CompactBuffer @@ -29,16 +31,43 @@ import org.apache.spark.util.collection.CompactBuffer */ private[joins] sealed trait HashedRelation { def get(key: Row): CompactBuffer[Row] + + // This is a helper method to implement Externalizable, and is used by + // GeneralHashedRelation and UniqueKeyHashedRelation + protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = { + out.writeInt(serialized.length) // Write the length of serialized bytes first + out.write(serialized) + } + + // This is a helper method to implement Externalizable, and is used by + // GeneralHashedRelation and UniqueKeyHashedRelation + protected def readBytes(in: ObjectInput): Array[Byte] = { + val serializedSize = in.readInt() // Read the length of serialized bytes first + val bytes = new Array[Byte](serializedSize) + in.readFully(bytes) + bytes + } } /** * A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values. */ -private[joins] final class GeneralHashedRelation(hashTable: JavaHashMap[Row, CompactBuffer[Row]]) - extends HashedRelation with Serializable { +private[joins] final class GeneralHashedRelation( + private var hashTable: JavaHashMap[Row, CompactBuffer[Row]]) + extends HashedRelation with Externalizable { + + def this() = this(null) // Needed for serialization + + override def get(key: Row): CompactBuffer[Row] = hashTable.get(key) + + override def writeExternal(out: ObjectOutput): Unit = { + writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + } - override def get(key: Row) = hashTable.get(key) + override def readExternal(in: ObjectInput): Unit = { + hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + } } @@ -46,15 +75,25 @@ private[joins] final class GeneralHashedRelation(hashTable: JavaHashMap[Row, Com * A specialized [[HashedRelation]] that maps key into a single value. This implementation * assumes the key is unique. */ -private[joins] final class UniqueKeyHashedRelation(hashTable: JavaHashMap[Row, Row]) - extends HashedRelation with Serializable { +private[joins] final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[Row, Row]) + extends HashedRelation with Externalizable { - override def get(key: Row) = { + def this() = this(null) // Needed for serialization + + override def get(key: Row): CompactBuffer[Row] = { val v = hashTable.get(key) if (v eq null) null else CompactBuffer(v) } def getValue(key: Row): Row = hashTable.get(key) + + override def writeExternal(out: ObjectOutput): Unit = { + writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + } + + override def readExternal(in: ObjectInput): Unit = { + hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index 60003d1900d8..e06f63f94b78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -35,20 +36,21 @@ case class LeftSemiJoinBNL( override def outputPartitioning: Partitioning = streamed.outputPartitioning - override def output = left.output + override def output: Seq[Attribute] = left.output /** The Streamed Relation */ - override def left = streamed + override def left: SparkPlan = streamed + /** The Broadcast relation */ - override def right = broadcast + override def right: SparkPlan = broadcast @transient private lazy val boundCondition = - InterpretedPredicate( + InterpretedPredicate.create( condition .map(c => BindReferences.bindReference(c, left.output ++ right.output)) .getOrElse(Literal(true))) - override def execute() = { + override def execute(): RDD[Row] = { val broadcastedRelation = sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index ea7babf3be94..a04f2a63b5a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.{Expression, Row} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -34,14 +35,14 @@ case class LeftSemiJoinHash( left: SparkPlan, right: SparkPlan) extends BinaryNode with HashJoin { - override val buildSide = BuildRight + override val buildSide: BuildSide = BuildRight - override def requiredChildDistribution = + override def requiredChildDistribution: Seq[ClusteredDistribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def output = left.output + override def output: Seq[Attribute] = left.output - override def execute() = { + override def execute(): RDD[Row] = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => val hashSet = new java.util.HashSet[Row]() var currentRow: Row = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 418c1c23e554..a6cd8337c1c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -38,10 +40,10 @@ case class ShuffledHashJoin( override def outputPartitioning: Partitioning = left.outputPartitioning - override def requiredChildDistribution = + override def requiredChildDistribution: Seq[ClusteredDistribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def execute() = { + override def execute(): RDD[Row] = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => val hashed = HashedRelation(buildIter, buildSideKeyGenerator) hashJoin(streamIter, hashed) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala new file mode 100644 index 000000000000..b5123668ba11 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import java.util.NoSuchElementException + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.collection.CompactBuffer + +/** + * :: DeveloperApi :: + * Performs an sort merge join of two child relations. + */ +@DeveloperApi +case class SortMergeJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def output: Seq[Attribute] = left.output ++ right.output + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + // this is to manually construct an ordering that can be used to compare keys from both sides + private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType)) + + override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys) + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil + + @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) + @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) + + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = + keys.map(SortOrder(_, Ascending)) + + override def execute(): RDD[Row] = { + val leftResults = left.execute().map(_.copy()) + val rightResults = right.execute().map(_.copy()) + + leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => + new Iterator[Row] { + // Mutable per row objects. + private[this] val joinRow = new JoinedRow5 + private[this] var leftElement: Row = _ + private[this] var rightElement: Row = _ + private[this] var leftKey: Row = _ + private[this] var rightKey: Row = _ + private[this] var rightMatches: CompactBuffer[Row] = _ + private[this] var rightPosition: Int = -1 + private[this] var stop: Boolean = false + private[this] var matchKey: Row = _ + + // initialize iterator + initialize() + + override final def hasNext: Boolean = nextMatchingPair() + + override final def next(): Row = { + if (hasNext) { + // we are using the buffered right rows and run down left iterator + val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) + rightPosition += 1 + if (rightPosition >= rightMatches.size) { + rightPosition = 0 + fetchLeft() + if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) { + stop = false + rightMatches = null + } + } + joinedRow + } else { + // no more result + throw new NoSuchElementException + } + } + + private def fetchLeft() = { + if (leftIter.hasNext) { + leftElement = leftIter.next() + leftKey = leftKeyGenerator(leftElement) + } else { + leftElement = null + } + } + + private def fetchRight() = { + if (rightIter.hasNext) { + rightElement = rightIter.next() + rightKey = rightKeyGenerator(rightElement) + } else { + rightElement = null + } + } + + private def initialize() = { + fetchLeft() + fetchRight() + } + + /** + * Searches the right iterator for the next rows that have matches in left side, and store + * them in a buffer. + * + * @return true if the search is successful, and false if the right iterator runs out of + * tuples. + */ + private def nextMatchingPair(): Boolean = { + if (!stop && rightElement != null) { + // run both side to get the first match pair + while (!stop && leftElement != null && rightElement != null) { + val comparing = keyOrdering.compare(leftKey, rightKey) + // for inner join, we need to filter those null keys + stop = comparing == 0 && !leftKey.anyNull + if (comparing > 0 || rightKey.anyNull) { + fetchRight() + } else if (comparing < 0 || leftKey.anyNull) { + fetchLeft() + } + } + rightMatches = new CompactBuffer[Row]() + if (stop) { + stop = false + // iterate the right side to buffer all rows that matches + // as the records should be ordered, exit when we meet the first that not match + while (!stop && rightElement != null) { + rightMatches += rightElement + fetchRight() + stop = keyOrdering.compare(leftKey, rightKey) != 0 + } + if (rightMatches.size > 0) { + rightPosition = 0 + matchKey = leftKey + } + } + } + rightMatches != null && rightMatches.size > 0 + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index b85021acc9d4..7a43bfd8bc8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution import java.util.{List => JList, Map => JMap} +import org.apache.spark.rdd.RDD + import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ @@ -48,11 +50,13 @@ private[spark] case class PythonUDF( dataType: DataType, children: Seq[Expression]) extends Expression with SparkLogging { - override def toString = s"PythonUDF#$name(${children.mkString(",")})" + override def toString: String = s"PythonUDF#$name(${children.mkString(",")})" def nullable: Boolean = true - override def eval(input: Row) = sys.error("PythonUDFs can not be directly evaluated.") + override def eval(input: Row): PythonUDF.this.EvaluatedType = { + sys.error("PythonUDFs can not be directly evaluated.") + } } /** @@ -63,7 +67,7 @@ private[spark] case class PythonUDF( * multiple child operators. */ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan) = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Skip EvaluatePython nodes. case p: EvaluatePython => p @@ -107,7 +111,7 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { } object EvaluatePython { - def apply(udf: PythonUDF, child: LogicalPlan) = + def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython = new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)()) /** @@ -135,6 +139,9 @@ object EvaluatePython { case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) + case (date: Int, DateType) => DateUtils.toJavaDate(date) + case (s: UTF8String, StringType) => s.toString + // Pyrolite can handle Timestamp and Decimal case (other, _) => other } @@ -171,7 +178,7 @@ object EvaluatePython { }): Row case (c: java.util.Calendar, DateType) => - new java.sql.Date(c.getTime().getTime()) + DateUtils.fromJavaDate(new java.sql.Date(c.getTime().getTime())) case (c: java.util.Calendar, TimestampType) => new java.sql.Timestamp(c.getTime().getTime()) @@ -184,8 +191,10 @@ object EvaluatePython { case (c: Int, ShortType) => c.toShort case (c: Long, ShortType) => c.toShort case (c: Long, IntegerType) => c.toInt + case (c: Int, LongType) => c.toLong case (c: Double, FloatType) => c.toFloat - case (c, StringType) if !c.isInstanceOf[String] => c.toString + case (c: String, StringType) => UTF8String(c) + case (c, StringType) if !c.isInstanceOf[String] => UTF8String(c.toString) case (c, _) => c } @@ -202,7 +211,10 @@ case class EvaluatePython( resultAttribute: AttributeReference) extends logical.UnaryNode { - def output = child.output :+ resultAttribute + def output: Seq[Attribute] = child.output :+ resultAttribute + + // References should not include the produced attribute. + override def references: AttributeSet = udf.references } /** @@ -213,9 +225,10 @@ case class EvaluatePython( @DeveloperApi case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) extends SparkPlan { - def children = child :: Nil - def execute() = { + def children: Seq[SparkPlan] = child :: Nil + + def execute(): RDD[Row] = { // TODO: Clean up after ourselves? val childResults = child.execute().map(_.copy()).cache() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala new file mode 100644 index 000000000000..ff91e1d74bc2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -0,0 +1,627 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.{TypeTag, typeTag} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + + +/** + * :: Experimental :: + * Functions available for [[DataFrame]]. + * + * @groupname udf_funcs UDF functions + * @groupname agg_funcs Aggregate functions + * @groupname sort_funcs Sorting functions + * @groupname normal_funcs Non-aggregate functions + * @groupname Ungrouped Support functions for DataFrames. + */ +@Experimental +// scalastyle:off +object functions { +// scalastyle:on + + private[this] implicit def toColumn(expr: Expression): Column = Column(expr) + + /** + * Returns a [[Column]] based on the given column name. + * + * @group normal_funcs + */ + def col(colName: String): Column = Column(colName) + + /** + * Returns a [[Column]] based on the given column name. Alias of [[col]]. + * + * @group normal_funcs + */ + def column(colName: String): Column = Column(colName) + + /** + * Creates a [[Column]] of literal value. + * + * The passed in object is returned directly if it is already a [[Column]]. + * If the object is a Scala Symbol, it is converted into a [[Column]] also. + * Otherwise, a new [[Column]] is created to represent the literal value. + * + * @group normal_funcs + */ + def lit(literal: Any): Column = { + literal match { + case c: Column => return c + case s: Symbol => return new ColumnName(literal.asInstanceOf[Symbol].name) + case _ => // continue + } + + val literalExpr = Literal(literal) + Column(literalExpr) + } + + ////////////////////////////////////////////////////////////////////////////////////////////// + // Sort functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Returns a sort expression based on ascending order of the column. + * {{ + * // Sort by dept in ascending order, and then age in descending order. + * df.sort(asc("dept"), desc("age")) + * }} + * + * @group sort_funcs + */ + def asc(columnName: String): Column = Column(columnName).asc + + /** + * Returns a sort expression based on the descending order of the column. + * {{ + * // Sort by dept in ascending order, and then age in descending order. + * df.sort(asc("dept"), desc("age")) + * }} + * + * @group sort_funcs + */ + def desc(columnName: String): Column = Column(columnName).desc + + ////////////////////////////////////////////////////////////////////////////////////////////// + // Aggregate functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Aggregate function: returns the sum of all values in the expression. + * + * @group agg_funcs + */ + def sum(e: Column): Column = Sum(e.expr) + + /** + * Aggregate function: returns the sum of all values in the given column. + * + * @group agg_funcs + */ + def sum(columnName: String): Column = sum(Column(columnName)) + + /** + * Aggregate function: returns the sum of distinct values in the expression. + * + * @group agg_funcs + */ + def sumDistinct(e: Column): Column = SumDistinct(e.expr) + + /** + * Aggregate function: returns the sum of distinct values in the expression. + * + * @group agg_funcs + */ + def sumDistinct(columnName: String): Column = sumDistinct(Column(columnName)) + + /** + * Aggregate function: returns the number of items in a group. + * + * @group agg_funcs + */ + def count(e: Column): Column = e.expr match { + // Turn count(*) into count(1) + case s: Star => Count(Literal(1)) + case _ => Count(e.expr) + } + + /** + * Aggregate function: returns the number of items in a group. + * + * @group agg_funcs + */ + def count(columnName: String): Column = count(Column(columnName)) + + /** + * Aggregate function: returns the number of distinct items in a group. + * + * @group agg_funcs + */ + @scala.annotation.varargs + def countDistinct(expr: Column, exprs: Column*): Column = + CountDistinct((expr +: exprs).map(_.expr)) + + /** + * Aggregate function: returns the number of distinct items in a group. + * + * @group agg_funcs + */ + @scala.annotation.varargs + def countDistinct(columnName: String, columnNames: String*): Column = + countDistinct(Column(columnName), columnNames.map(Column.apply) :_*) + + /** + * Aggregate function: returns the approximate number of distinct items in a group. + * + * @group agg_funcs + */ + def approxCountDistinct(e: Column): Column = ApproxCountDistinct(e.expr) + + /** + * Aggregate function: returns the approximate number of distinct items in a group. + * + * @group agg_funcs + */ + def approxCountDistinct(columnName: String): Column = approxCountDistinct(column(columnName)) + + /** + * Aggregate function: returns the approximate number of distinct items in a group. + * + * @group agg_funcs + */ + def approxCountDistinct(e: Column, rsd: Double): Column = ApproxCountDistinct(e.expr, rsd) + + /** + * Aggregate function: returns the approximate number of distinct items in a group. + * + * @group agg_funcs + */ + def approxCountDistinct(columnName: String, rsd: Double): Column = { + approxCountDistinct(Column(columnName), rsd) + } + + /** + * Aggregate function: returns the average of the values in a group. + * + * @group agg_funcs + */ + def avg(e: Column): Column = Average(e.expr) + + /** + * Aggregate function: returns the average of the values in a group. + * + * @group agg_funcs + */ + def avg(columnName: String): Column = avg(Column(columnName)) + + /** + * Aggregate function: returns the first value in a group. + * + * @group agg_funcs + */ + def first(e: Column): Column = First(e.expr) + + /** + * Aggregate function: returns the first value of a column in a group. + * + * @group agg_funcs + */ + def first(columnName: String): Column = first(Column(columnName)) + + /** + * Aggregate function: returns the last value in a group. + * + * @group agg_funcs + */ + def last(e: Column): Column = Last(e.expr) + + /** + * Aggregate function: returns the last value of the column in a group. + * + * @group agg_funcs + */ + def last(columnName: String): Column = last(Column(columnName)) + + /** + * Aggregate function: returns the minimum value of the expression in a group. + * + * @group agg_funcs + */ + def min(e: Column): Column = Min(e.expr) + + /** + * Aggregate function: returns the minimum value of the column in a group. + * + * @group agg_funcs + */ + def min(columnName: String): Column = min(Column(columnName)) + + /** + * Aggregate function: returns the maximum value of the expression in a group. + * + * @group agg_funcs + */ + def max(e: Column): Column = Max(e.expr) + + /** + * Aggregate function: returns the maximum value of the column in a group. + * + * @group agg_funcs + */ + def max(columnName: String): Column = max(Column(columnName)) + + ////////////////////////////////////////////////////////////////////////////////////////////// + // Non-aggregate functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Returns the first column that is not null. + * {{{ + * df.select(coalesce(df("a"), df("b"))) + * }}} + * + * @group normal_funcs + */ + @scala.annotation.varargs + def coalesce(e: Column*): Column = Coalesce(e.map(_.expr)) + + /** + * Unary minus, i.e. negate the expression. + * {{{ + * // Select the amount column and negates all values. + * // Scala: + * df.select( -df("amount") ) + * + * // Java: + * df.select( negate(df.col("amount")) ); + * }}} + * + * @group normal_funcs + */ + def negate(e: Column): Column = -e + + /** + * Inversion of boolean expression, i.e. NOT. + * {{ + * // Scala: select rows that are not active (isActive === false) + * df.filter( !df("isActive") ) + * + * // Java: + * df.filter( not(df.col("isActive")) ); + * }} + * + * @group normal_funcs + */ + def not(e: Column): Column = !e + + /** + * Converts a string expression to upper case. + * + * @group normal_funcs + */ + def upper(e: Column): Column = Upper(e.expr) + + /** + * Converts a string exprsesion to lower case. + * + * @group normal_funcs + */ + def lower(e: Column): Column = Lower(e.expr) + + /** + * Computes the square root of the specified float value. + * + * @group normal_funcs + */ + def sqrt(e: Column): Column = Sqrt(e.expr) + + /** + * Computes the absolutle value. + * + * @group normal_funcs + */ + def abs(e: Column): Column = Abs(e.expr) + + ////////////////////////////////////////////////////////////////////////////////////////////// + ////////////////////////////////////////////////////////////////////////////////////////////// + + // scalastyle:off + + /* Use the following code to generate: + (0 to 10).map { x => + val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) + val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) + println(s""" + /** + * Defines a user-defined function of ${x} arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + */ + def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + }""") + } + + (0 to 10).map { x => + val args = (1 to x).map(i => s"arg$i: Column").mkString(", ") + val fTypes = Seq.fill(x + 1)("_").mkString(", ") + val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ") + println(s""" + /** + * Call a Scala function of ${x} arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + */ + def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = { + ScalaUdf(f, returnType, Seq($argsInUdf)) + }""") + } + } + */ + /** + * Defines a user-defined function of 0 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + */ + def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = { + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + } + + /** + * Defines a user-defined function of 1 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + */ + def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + } + + /** + * Defines a user-defined function of 2 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + } + + /** + * Defines a user-defined function of 3 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + } + + /** + * Defines a user-defined function of 4 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + } + + /** + * Defines a user-defined function of 5 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + } + + /** + * Defines a user-defined function of 6 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + } + + /** + * Defines a user-defined function of 7 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + } + + /** + * Defines a user-defined function of 8 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + } + + /** + * Defines a user-defined function of 9 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + } + + /** + * Defines a user-defined function of 10 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * @group udf_funcs + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + } + + ////////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Call a Scala function of 0 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + */ + def callUDF(f: Function0[_], returnType: DataType): Column = { + ScalaUdf(f, returnType, Seq()) + } + + /** + * Call a Scala function of 1 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + */ + def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr)) + } + + /** + * Call a Scala function of 2 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + */ + def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr)) + } + + /** + * Call a Scala function of 3 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + */ + def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) + } + + /** + * Call a Scala function of 4 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + */ + def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) + } + + /** + * Call a Scala function of 5 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + */ + def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) + } + + /** + * Call a Scala function of 6 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + */ + def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) + } + + /** + * Call a Scala function of 7 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + */ + def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) + } + + /** + * Call a Scala function of 8 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + */ + def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) + } + + /** + * Call a Scala function of 9 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + */ + def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) + } + + /** + * Call a Scala function of 10 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + */ + def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) + } + + // scalastyle:on + + /** + * Call an user-defined function. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") + * val sqlContext = df.sqlContext + * sqlContext.udf.register("simpleUdf", (v: Int) => v * v) + * df.select($"id", callUdf("simpleUdf", $"value")) + * }}} + * + * @group udf_funcs + */ + def callUdf(udfName: String, cols: Column*): Column = { + UnresolvedFunction(udfName, cols.map(_.expr)) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala new file mode 100644 index 000000000000..0feabc4282f4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import org.apache.spark.sql.types._ + +import java.sql.Types + + +/** + * Encapsulates workarounds for the extensions, quirks, and bugs in various + * databases. Lots of databases define types that aren't explicitly supported + * by the JDBC spec. Some JDBC drivers also report inaccurate + * information---for instance, BIT(n>1) being reported as a BIT type is quite + * common, even though BIT in JDBC is meant for single-bit values. Also, there + * does not appear to be a standard name for an unbounded string or binary + * type; we use BLOB and CLOB by default but override with database-specific + * alternatives when these are absent or do not behave correctly. + * + * Currently, the only thing DriverQuirks does is handle type mapping. + * `getCatalystType` is used when reading from a JDBC table and `getJDBCType` + * is used when writing to a JDBC table. If `getCatalystType` returns `null`, + * the default type handling is used for the given JDBC type. Similarly, + * if `getJDBCType` returns `(null, None)`, the default type handling is used + * for the given Catalyst type. + */ +private[sql] abstract class DriverQuirks { + def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType + def getJDBCType(dt: DataType): (String, Option[Int]) +} + +private[sql] object DriverQuirks { + /** + * Fetch the DriverQuirks class corresponding to a given database url. + */ + def get(url: String): DriverQuirks = { + if (url.startsWith("jdbc:mysql")) { + new MySQLQuirks() + } else if (url.startsWith("jdbc:postgresql")) { + new PostgresQuirks() + } else { + new NoQuirks() + } + } +} + +private[sql] class NoQuirks extends DriverQuirks { + def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = + null + def getJDBCType(dt: DataType): (String, Option[Int]) = (null, None) +} + +private[sql] class PostgresQuirks extends DriverQuirks { + def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = { + if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { + BinaryType + } else if (sqlType == Types.OTHER && typeName.equals("cidr")) { + StringType + } else if (sqlType == Types.OTHER && typeName.equals("inet")) { + StringType + } else null + } + + def getJDBCType(dt: DataType): (String, Option[Int]) = dt match { + case StringType => ("TEXT", Some(java.sql.Types.CHAR)) + case BinaryType => ("BYTEA", Some(java.sql.Types.BINARY)) + case BooleanType => ("BOOLEAN", Some(java.sql.Types.BOOLEAN)) + case _ => (null, None) + } +} + +private[sql] class MySQLQuirks extends DriverQuirks { + def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = { + if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { + // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as + // byte arrays instead of longs. + md.putLong("binarylong", 1) + LongType + } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) { + BooleanType + } else null + } + def getJDBCType(dt: DataType): (String, Option[Int]) = (null, None) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala new file mode 100644 index 000000000000..f32651004212 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -0,0 +1,433 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException} +import java.util.Properties + +import org.apache.commons.lang.StringEscapeUtils.escapeSql +import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{Row, SpecificMutableRow} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.sources._ +import org.apache.spark.util.Utils + +private[sql] object JDBCRDD extends Logging { + /** + * Maps a JDBC type to a Catalyst type. This function is called only when + * the DriverQuirks class corresponding to your database driver returns null. + * + * @param sqlType - A field of java.sql.Types + * @return The Catalyst type corresponding to sqlType. + */ + private def getCatalystType(sqlType: Int): DataType = { + val answer = sqlType match { + case java.sql.Types.ARRAY => null + case java.sql.Types.BIGINT => LongType + case java.sql.Types.BINARY => BinaryType + case java.sql.Types.BIT => BooleanType // Per JDBC; Quirks handles quirky drivers. + case java.sql.Types.BLOB => BinaryType + case java.sql.Types.BOOLEAN => BooleanType + case java.sql.Types.CHAR => StringType + case java.sql.Types.CLOB => StringType + case java.sql.Types.DATALINK => null + case java.sql.Types.DATE => DateType + case java.sql.Types.DECIMAL => DecimalType.Unlimited + case java.sql.Types.DISTINCT => null + case java.sql.Types.DOUBLE => DoubleType + case java.sql.Types.FLOAT => FloatType + case java.sql.Types.INTEGER => IntegerType + case java.sql.Types.JAVA_OBJECT => null + case java.sql.Types.LONGNVARCHAR => StringType + case java.sql.Types.LONGVARBINARY => BinaryType + case java.sql.Types.LONGVARCHAR => StringType + case java.sql.Types.NCHAR => StringType + case java.sql.Types.NCLOB => StringType + case java.sql.Types.NULL => null + case java.sql.Types.NUMERIC => DecimalType.Unlimited + case java.sql.Types.NVARCHAR => StringType + case java.sql.Types.OTHER => null + case java.sql.Types.REAL => DoubleType + case java.sql.Types.REF => StringType + case java.sql.Types.ROWID => LongType + case java.sql.Types.SMALLINT => IntegerType + case java.sql.Types.SQLXML => StringType + case java.sql.Types.STRUCT => StringType + case java.sql.Types.TIME => TimestampType + case java.sql.Types.TIMESTAMP => TimestampType + case java.sql.Types.TINYINT => IntegerType + case java.sql.Types.VARBINARY => BinaryType + case java.sql.Types.VARCHAR => StringType + case _ => null + } + + if (answer == null) throw new SQLException("Unsupported type " + sqlType) + answer + } + + /** + * Takes a (schema, table) specification and returns the table's Catalyst + * schema. + * + * @param url - The JDBC url to fetch information from. + * @param table - The table name of the desired table. This may also be a + * SQL query wrapped in parentheses. + * + * @return A StructType giving the table's Catalyst schema. + * @throws SQLException if the table specification is garbage. + * @throws SQLException if the table contains an unsupported type. + */ + def resolveTable(url: String, table: String, properties: Properties): StructType = { + val quirks = DriverQuirks.get(url) + val conn: Connection = DriverManager.getConnection(url, properties) + try { + val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery() + try { + val rsmd = rs.getMetaData + val ncols = rsmd.getColumnCount + val fields = new Array[StructField](ncols) + var i = 0 + while (i < ncols) { + val columnName = rsmd.getColumnName(i + 1) + val dataType = rsmd.getColumnType(i + 1) + val typeName = rsmd.getColumnTypeName(i + 1) + val fieldSize = rsmd.getPrecision(i + 1) + val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls + val metadata = new MetadataBuilder().putString("name", columnName) + var columnType = quirks.getCatalystType(dataType, typeName, fieldSize, metadata) + if (columnType == null) columnType = getCatalystType(dataType) + fields(i) = StructField(columnName, columnType, nullable, metadata.build()) + i = i + 1 + } + return new StructType(fields) + } finally { + rs.close() + } + } finally { + conn.close() + } + + throw new RuntimeException("This line is unreachable.") + } + + /** + * Prune all but the specified columns from the specified Catalyst schema. + * + * @param schema - The Catalyst schema of the master table + * @param columns - The list of desired columns + * + * @return A Catalyst schema corresponding to columns in the given order. + */ + private def pruneSchema(schema: StructType, columns: Array[String]): StructType = { + val fieldMap = Map(schema.fields map { x => x.metadata.getString("name") -> x }: _*) + new StructType(columns map { name => fieldMap(name) }) + } + + /** + * Given a driver string and an url, return a function that loads the + * specified driver string then returns a connection to the JDBC url. + * getConnector is run on the driver code, while the function it returns + * is run on the executor. + * + * @param driver - The class name of the JDBC driver for the given url. + * @param url - The JDBC url to connect to. + * + * @return A function that loads the driver and connects to the url. + */ + def getConnector(driver: String, url: String, properties: Properties): () => Connection = { + () => { + try { + if (driver != null) Utils.getContextOrSparkClassLoader.loadClass(driver) + } catch { + case e: ClassNotFoundException => { + logWarning(s"Couldn't find class $driver", e); + } + } + DriverManager.getConnection(url, properties) + } + } + /** + * Build and return JDBCRDD from the given information. + * + * @param sc - Your SparkContext. + * @param schema - The Catalyst schema of the underlying database table. + * @param driver - The class name of the JDBC driver for the given url. + * @param url - The JDBC url to connect to. + * @param fqTable - The fully-qualified table name (or paren'd SQL query) to use. + * @param requiredColumns - The names of the columns to SELECT. + * @param filters - The filters to include in all WHERE clauses. + * @param parts - An array of JDBCPartitions specifying partition ids and + * per-partition WHERE clauses. + * + * @return An RDD representing "SELECT requiredColumns FROM fqTable". + */ + def scanTable( + sc: SparkContext, + schema: StructType, + driver: String, + url: String, + properties: Properties, + fqTable: String, + requiredColumns: Array[String], + filters: Array[Filter], + parts: Array[Partition]): RDD[Row] = { + + val prunedSchema = pruneSchema(schema, requiredColumns) + + return new + JDBCRDD( + sc, + getConnector(driver, url, properties), + prunedSchema, + fqTable, + requiredColumns, + filters, + parts) + } +} + +/** + * An RDD representing a table in a database accessed via JDBC. Both the + * driver code and the workers must be able to access the database; the driver + * needs to fetch the schema while the workers need to fetch the data. + */ +private[sql] class JDBCRDD( + sc: SparkContext, + getConnection: () => Connection, + schema: StructType, + fqTable: String, + columns: Array[String], + filters: Array[Filter], + partitions: Array[Partition]) + extends RDD[Row](sc, Nil) { + + /** + * Retrieve the list of partitions corresponding to this RDD. + */ + override def getPartitions: Array[Partition] = partitions + + /** + * `columns`, but as a String suitable for injection into a SQL query. + */ + private val columnList: String = { + val sb = new StringBuilder() + columns.foreach(x => sb.append(",").append(x)) + if (sb.length == 0) "1" else sb.substring(1) + } + + /** + * Converts value to SQL expression. + */ + private def compileValue(value: Any): Any = value match { + case stringValue: UTF8String => s"'${escapeSql(stringValue.toString)}'" + case _ => value + } + + /** + * Turns a single Filter into a String representing a SQL expression. + * Returns null for an unhandled filter. + */ + private def compileFilter(f: Filter): String = f match { + case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" + case LessThan(attr, value) => s"$attr < ${compileValue(value)}" + case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" + case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" + case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" + case _ => null + } + + /** + * `filters`, but as a WHERE clause suitable for injection into a SQL query. + */ + private val filterWhereClause: String = { + val filterStrings = filters map compileFilter filter (_ != null) + if (filterStrings.size > 0) { + val sb = new StringBuilder("WHERE ") + filterStrings.foreach(x => sb.append(x).append(" AND ")) + sb.substring(0, sb.length - 5) + } else "" + } + + /** + * A WHERE clause representing both `filters`, if any, and the current partition. + */ + private def getWhereClause(part: JDBCPartition): String = { + if (part.whereClause != null && filterWhereClause.length > 0) { + filterWhereClause + " AND " + part.whereClause + } else if (part.whereClause != null) { + "WHERE " + part.whereClause + } else { + filterWhereClause + } + } + + // Each JDBC-to-Catalyst conversion corresponds to a tag defined here so that + // we don't have to potentially poke around in the Metadata once for every + // row. + // Is there a better way to do this? I'd rather be using a type that + // contains only the tags I define. + abstract class JDBCConversion + case object BooleanConversion extends JDBCConversion + case object DateConversion extends JDBCConversion + case object DecimalConversion extends JDBCConversion + case object DoubleConversion extends JDBCConversion + case object FloatConversion extends JDBCConversion + case object IntegerConversion extends JDBCConversion + case object LongConversion extends JDBCConversion + case object BinaryLongConversion extends JDBCConversion + case object StringConversion extends JDBCConversion + case object TimestampConversion extends JDBCConversion + case object BinaryConversion extends JDBCConversion + + /** + * Maps a StructType to a type tag list. + */ + def getConversions(schema: StructType): Array[JDBCConversion] = { + schema.fields.map(sf => sf.dataType match { + case BooleanType => BooleanConversion + case DateType => DateConversion + case DecimalType.Unlimited => DecimalConversion + case DoubleType => DoubleConversion + case FloatType => FloatConversion + case IntegerType => IntegerConversion + case LongType => + if (sf.metadata.contains("binarylong")) BinaryLongConversion else LongConversion + case StringType => StringConversion + case TimestampType => TimestampConversion + case BinaryType => BinaryConversion + case _ => throw new IllegalArgumentException(s"Unsupported field $sf") + }).toArray + } + + + /** + * Runs the SQL query against the JDBC driver. + */ + override def compute(thePart: Partition, context: TaskContext): Iterator[Row] = new Iterator[Row] + { + var closed = false + var finished = false + var gotNext = false + var nextValue: Row = null + + context.addTaskCompletionListener{ context => close() } + val part = thePart.asInstanceOf[JDBCPartition] + val conn = getConnection() + + // H2's JDBC driver does not support the setSchema() method. We pass a + // fully-qualified table name in the SELECT statement. I don't know how to + // talk about a table in a completely portable way. + + val myWhereClause = getWhereClause(part) + + val sqlText = s"SELECT $columnList FROM $fqTable $myWhereClause" + val stmt = conn.prepareStatement(sqlText, + ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) + val rs = stmt.executeQuery() + + val conversions = getConversions(schema) + val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType)) + + def getNext(): Row = { + if (rs.next()) { + var i = 0 + while (i < conversions.length) { + val pos = i + 1 + conversions(i) match { + case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos)) + case DateConversion => + mutableRow.update(i, DateUtils.fromJavaDate(rs.getDate(pos))) + case DecimalConversion => mutableRow.update(i, rs.getBigDecimal(pos)) + case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos)) + case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos)) + case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos)) + case LongConversion => mutableRow.setLong(i, rs.getLong(pos)) + // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 + case StringConversion => mutableRow.setString(i, rs.getString(pos)) + case TimestampConversion => mutableRow.update(i, rs.getTimestamp(pos)) + case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) + case BinaryLongConversion => { + val bytes = rs.getBytes(pos) + var ans = 0L + var j = 0 + while (j < bytes.size) { + ans = 256 * ans + (255 & bytes(j)) + j = j + 1; + } + mutableRow.setLong(i, ans) + } + } + if (rs.wasNull) mutableRow.setNullAt(i) + i = i + 1 + } + mutableRow + } else { + finished = true + null.asInstanceOf[Row] + } + } + + def close() { + if (closed) return + try { + if (null != rs) { + rs.close() + } + } catch { + case e: Exception => logWarning("Exception closing resultset", e) + } + try { + if (null != stmt) { + stmt.close() + } + } catch { + case e: Exception => logWarning("Exception closing statement", e) + } + try { + if (null != conn) { + conn.close() + } + logInfo("closed connection") + } catch { + case e: Exception => logWarning("Exception closing connection", e) + } + } + + override def hasNext: Boolean = { + if (!finished) { + if (!gotNext) { + nextValue = getNext() + if (finished) { + close() + } + gotNext = true + } + } + !finished + } + + override def next(): Row = { + if (!hasNext) { + throw new NoSuchElementException("End of stream") + } + gotNext = false + nextValue + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala new file mode 100644 index 000000000000..5f480083d5a4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.DriverManager +import java.util.Properties + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.Partition +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils + +/** + * Data corresponding to one partition of a JDBCRDD. + */ +private[sql] case class JDBCPartition(whereClause: String, idx: Int) extends Partition { + override def index: Int = idx +} + +/** + * Instructions on how to partition the table among workers. + */ +private[sql] case class JDBCPartitioningInfo( + column: String, + lowerBound: Long, + upperBound: Long, + numPartitions: Int) + +private[sql] object JDBCRelation { + /** + * Given a partitioning schematic (a column of integral type, a number of + * partitions, and upper and lower bounds on the column's value), generate + * WHERE clauses for each partition so that each row in the table appears + * exactly once. The parameters minValue and maxValue are advisory in that + * incorrect values may cause the partitioning to be poor, but no data + * will fail to be represented. + */ + def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = { + if (partitioning == null) return Array[Partition](JDBCPartition(null, 0)) + + val numPartitions = partitioning.numPartitions + val column = partitioning.column + if (numPartitions == 1) return Array[Partition](JDBCPartition(null, 0)) + // Overflow and silliness can happen if you subtract then divide. + // Here we get a little roundoff, but that's (hopefully) OK. + val stride: Long = (partitioning.upperBound / numPartitions + - partitioning.lowerBound / numPartitions) + var i: Int = 0 + var currentValue: Long = partitioning.lowerBound + var ans = new ArrayBuffer[Partition]() + while (i < numPartitions) { + val lowerBound = if (i != 0) s"$column >= $currentValue" else null + currentValue += stride + val upperBound = if (i != numPartitions - 1) s"$column < $currentValue" else null + val whereClause = + if (upperBound == null) { + lowerBound + } else if (lowerBound == null) { + upperBound + } else { + s"$lowerBound AND $upperBound" + } + ans += JDBCPartition(whereClause, i) + i = i + 1 + } + ans.toArray + } +} + +private[sql] class DefaultSource extends RelationProvider { + /** Returns a new base relation with the given parameters. */ + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) + val driver = parameters.getOrElse("driver", null) + val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) + val partitionColumn = parameters.getOrElse("partitionColumn", null) + val lowerBound = parameters.getOrElse("lowerBound", null) + val upperBound = parameters.getOrElse("upperBound", null) + val numPartitions = parameters.getOrElse("numPartitions", null) + + if (driver != null) Utils.getContextOrSparkClassLoader.loadClass(driver) + + if (partitionColumn != null + && (lowerBound == null || upperBound == null || numPartitions == null)) { + sys.error("Partitioning incompletely specified") + } + + val partitionInfo = if (partitionColumn == null) { + null + } else { + JDBCPartitioningInfo( + partitionColumn, + lowerBound.toLong, + upperBound.toLong, + numPartitions.toInt) + } + val parts = JDBCRelation.columnPartition(partitionInfo) + val properties = new Properties() // Additional properties that we will pass to getConnection + parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) + JDBCRelation(url, table, parts, properties)(sqlContext) + } +} + +private[sql] case class JDBCRelation( + url: String, + table: String, + parts: Array[Partition], + properties: Properties = new Properties())(@transient val sqlContext: SQLContext) + extends BaseRelation + with PrunedFilteredScan { + + override val needConversion: Boolean = false + + override val schema: StructType = JDBCRDD.resolveTable(url, table, properties) + + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { + val driver: String = DriverManager.getDriver(url).getClass.getCanonicalName + JDBCRDD.scanTable( + sqlContext.sparkContext, + schema, + driver, + url, + properties, + table, + requiredColumns, + filters, + parts) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala new file mode 100644 index 000000000000..d4e0abc040bc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.sql.{Connection, DriverManager, PreparedStatement} + +import org.apache.spark.Logging +import org.apache.spark.sql.types._ + +package object jdbc { + private[sql] object JDBCWriteDetails extends Logging { + /** + * Returns a PreparedStatement that inserts a row into table via conn. + */ + def insertStatement(conn: Connection, table: String, rddSchema: StructType): + PreparedStatement = { + val sql = new StringBuilder(s"INSERT INTO $table VALUES (") + var fieldsLeft = rddSchema.fields.length + while (fieldsLeft > 0) { + sql.append("?") + if (fieldsLeft > 1) sql.append(", ") else sql.append(")") + fieldsLeft = fieldsLeft - 1 + } + conn.prepareStatement(sql.toString) + } + + /** + * Saves a partition of a DataFrame to the JDBC database. This is done in + * a single database transaction in order to avoid repeatedly inserting + * data as much as possible. + * + * It is still theoretically possible for rows in a DataFrame to be + * inserted into the database more than once if a stage somehow fails after + * the commit occurs but before the stage can return successfully. + * + * This is not a closure inside saveTable() because apparently cosmetic + * implementation changes elsewhere might easily render such a closure + * non-Serializable. Instead, we explicitly close over all variables that + * are used. + */ + def savePartition(url: String, table: String, iterator: Iterator[Row], + rddSchema: StructType, nullTypes: Array[Int]): Iterator[Byte] = { + val conn = DriverManager.getConnection(url) + var committed = false + try { + conn.setAutoCommit(false) // Everything in the same db transaction. + val stmt = insertStatement(conn, table, rddSchema) + try { + while (iterator.hasNext) { + val row = iterator.next() + val numFields = rddSchema.fields.length + var i = 0 + while (i < numFields) { + if (row.isNullAt(i)) { + stmt.setNull(i + 1, nullTypes(i)) + } else { + rddSchema.fields(i).dataType match { + case IntegerType => stmt.setInt(i + 1, row.getInt(i)) + case LongType => stmt.setLong(i + 1, row.getLong(i)) + case DoubleType => stmt.setDouble(i + 1, row.getDouble(i)) + case FloatType => stmt.setFloat(i + 1, row.getFloat(i)) + case ShortType => stmt.setInt(i + 1, row.getShort(i)) + case ByteType => stmt.setInt(i + 1, row.getByte(i)) + case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i)) + case StringType => stmt.setString(i + 1, row.getString(i)) + case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i)) + case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) + case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) + case DecimalType.Unlimited => stmt.setBigDecimal(i + 1, + row.getAs[java.math.BigDecimal](i)) + case _ => throw new IllegalArgumentException( + s"Can't translate non-null value for field $i") + } + } + i = i + 1 + } + stmt.executeUpdate() + } + } finally { + stmt.close() + } + conn.commit() + committed = true + } finally { + if (!committed) { + // The stage must fail. We got here through an exception path, so + // let the exception through unless rollback() or close() want to + // tell the user about another problem. + conn.rollback() + conn.close() + } else { + // The stage must succeed. We cannot propagate any exception close() might throw. + try { + conn.close() + } catch { + case e: Exception => logWarning("Transaction succeeded, but closing failed", e) + } + } + } + Array[Byte]().iterator + } + + /** + * Compute the schema string for this RDD. + */ + def schemaString(df: DataFrame, url: String): String = { + val sb = new StringBuilder() + val quirks = DriverQuirks.get(url) + df.schema.fields foreach { field => { + val name = field.name + var typ: String = quirks.getJDBCType(field.dataType)._1 + if (typ == null) typ = field.dataType match { + case IntegerType => "INTEGER" + case LongType => "BIGINT" + case DoubleType => "DOUBLE PRECISION" + case FloatType => "REAL" + case ShortType => "INTEGER" + case ByteType => "BYTE" + case BooleanType => "BIT(1)" + case StringType => "TEXT" + case BinaryType => "BLOB" + case TimestampType => "TIMESTAMP" + case DateType => "DATE" + case DecimalType.Unlimited => "DECIMAL(40,20)" + case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") + } + val nullable = if (field.nullable) "" else "NOT NULL" + sb.append(s", $name $typ $nullable") + }} + if (sb.length < 2) "" else sb.substring(2) + } + + /** + * Saves the RDD to the database in a single transaction. + */ + def saveTable(df: DataFrame, url: String, table: String) { + val quirks = DriverQuirks.get(url) + var nullTypes: Array[Int] = df.schema.fields.map(field => { + var nullType: Option[Int] = quirks.getJDBCType(field.dataType)._2 + if (nullType.isEmpty) { + field.dataType match { + case IntegerType => java.sql.Types.INTEGER + case LongType => java.sql.Types.BIGINT + case DoubleType => java.sql.Types.DOUBLE + case FloatType => java.sql.Types.REAL + case ShortType => java.sql.Types.INTEGER + case ByteType => java.sql.Types.INTEGER + case BooleanType => java.sql.Types.BIT + case StringType => java.sql.Types.CLOB + case BinaryType => java.sql.Types.BLOB + case TimestampType => java.sql.Types.TIMESTAMP + case DateType => java.sql.Types.DATE + case DecimalType.Unlimited => java.sql.Types.DECIMAL + case _ => throw new IllegalArgumentException( + s"Can't translate null value for field $field") + } + } else nullType.get + }).toArray + + val rddSchema = df.schema + df.foreachPartition { iterator => + JDBCWriteDetails.savePartition(url, table, iterator, rddSchema, nullTypes) + } + } + + } +} // package object jdbc diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 1af96c28d5fd..e3352d02787f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -17,21 +17,32 @@ package org.apache.spark.sql.json -import org.apache.spark.sql.SQLContext +import java.io.IOException + +import org.apache.hadoop.fs.Path + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} -private[sql] class DefaultSource extends RelationProvider with SchemaRelationProvider { +private[sql] class DefaultSource + extends RelationProvider with SchemaRelationProvider with CreatableRelationProvider { + + private def checkPath(parameters: Map[String, String]): String = { + parameters.getOrElse("path", sys.error("'path' must be specified for json data.")) + } /** Returns a new base relation with the parameters. */ override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { - val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified")) + val path = checkPath(parameters) val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) - JSONRelation(fileName, samplingRatio, None)(sqlContext) + JSONRelation(path, samplingRatio, None)(sqlContext) } /** Returns a new base relation with the given schema and parameters. */ @@ -39,21 +50,70 @@ private[sql] class DefaultSource extends RelationProvider with SchemaRelationPro sqlContext: SQLContext, parameters: Map[String, String], schema: StructType): BaseRelation = { - val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified")) + val path = checkPath(parameters) val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) - JSONRelation(fileName, samplingRatio, Some(schema))(sqlContext) + JSONRelation(path, samplingRatio, Some(schema))(sqlContext) + } + + override def createRelation( + sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + val path = checkPath(parameters) + val filesystemPath = new Path(path) + val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val doSave = if (fs.exists(filesystemPath)) { + mode match { + case SaveMode.Append => + sys.error(s"Append mode is not supported by ${this.getClass.getCanonicalName}") + case SaveMode.Overwrite => { + var success: Boolean = false + try { + success = fs.delete(filesystemPath, true) + } catch { + case e: IOException => + throw new IOException( + s"Unable to clear output directory ${filesystemPath.toString} prior" + + s" to writing to JSON table:\n${e.toString}") + } + if (!success) { + throw new IOException( + s"Unable to clear output directory ${filesystemPath.toString} prior" + + s" to writing to JSON table.") + } + true + } + case SaveMode.ErrorIfExists => + sys.error(s"path $path already exists.") + case SaveMode.Ignore => false + } + } else { + true + } + if (doSave) { + // Only save data when the save mode is not ignore. + data.toJSON.saveAsTextFile(path) + } + + createRelation(sqlContext, parameters, data.schema) } } private[sql] case class JSONRelation( - fileName: String, + path: String, samplingRatio: Double, userSpecifiedSchema: Option[StructType])( @transient val sqlContext: SQLContext) - extends TableScan { + extends BaseRelation + with TableScan + with InsertableRelation { + + // TODO: Support partitioned JSON relation. + private def baseRDD = sqlContext.sparkContext.textFile(path) - private def baseRDD = sqlContext.sparkContext.textFile(fileName) + override val needConversion: Boolean = false override val schema = userSpecifiedSchema.getOrElse( JsonRDD.nullTypeToStringType( @@ -62,6 +122,45 @@ private[sql] case class JSONRelation( samplingRatio, sqlContext.conf.columnNameOfCorruptRecord))) - override def buildScan() = + override def buildScan(): RDD[Row] = JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.conf.columnNameOfCorruptRecord) + + override def insert(data: DataFrame, overwrite: Boolean): Unit = { + val filesystemPath = new Path(path) + val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + + if (overwrite) { + if (fs.exists(filesystemPath)) { + var success: Boolean = false + try { + success = fs.delete(filesystemPath, true) + } catch { + case e: IOException => + throw new IOException( + s"Unable to clear output directory ${filesystemPath.toString} prior" + + s" to writing to JSON table:\n${e.toString}") + } + if (!success) { + throw new IOException( + s"Unable to clear output directory ${filesystemPath.toString} prior" + + s" to writing to JSON table.") + } + } + // Write the data. + data.toJSON.saveAsTextFile(path) + // Right now, we assume that the schema is not changed. We will not update the schema. + // schema = data.schema + } else { + // TODO: Support INSERT INTO + sys.error("JSON table only support INSERT OVERWRITE for now.") + } + } + + override def hashCode(): Int = 41 * (41 + path.hashCode) + schema.hashCode() + + override def equals(other: Any): Boolean = other match { + case that: JSONRelation => + (this.path == that.path) && this.schema.sameType(that.schema) + case _ => false + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 9171939f7e8f..6e94e7056eb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql.json -import java.io.StringWriter -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import scala.collection.Map import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} -import com.fasterxml.jackson.core.JsonProcessingException -import com.fasterxml.jackson.core.JsonFactory +import com.fasterxml.jackson.core.{JsonGenerator, JsonProcessingException} import com.fasterxml.jackson.databind.ObjectMapper import org.apache.spark.rdd.RDD @@ -50,7 +48,11 @@ private[sql] object JsonRDD extends Logging { require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0") val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1) val allKeys = - parseJson(schemaData, columnNameOfCorruptRecords).map(allKeysWithValueTypes).reduce(_ ++ _) + if (schemaData.isEmpty()) { + Set.empty[(String, DataType)] + } else { + parseJson(schemaData, columnNameOfCorruptRecords).map(allKeysWithValueTypes).reduce(_ ++ _) + } createSchema(allKeys) } @@ -179,7 +181,12 @@ private[sql] object JsonRDD extends Logging { } private def typeOfPrimitiveValue: PartialFunction[Any, DataType] = { - ScalaReflection.typeOfObject orElse { + // For Integer values, use LongType by default. + val useLongType: PartialFunction[Any, DataType] = { + case value: IntegerType.InternalType => LongType + } + + useLongType orElse ScalaReflection.typeOfObject orElse { // Since we do not have a data type backed by BigInteger, // when we see a Java BigInteger, we use DecimalType. case value: java.math.BigInteger => DecimalType.Unlimited @@ -196,13 +203,12 @@ private[sql] object JsonRDD extends Logging { * type conflicts. */ private def typeOfArray(l: Seq[Any]): ArrayType = { - val containsNull = l.exists(v => v == null) val elements = l.flatMap(v => Option(v)) if (elements.isEmpty) { // If this JSON array is empty, we use NullType as a placeholder. // If this array is not empty in other JSON objects, we can resolve // the type after we have passed through all JSON objects. - ArrayType(NullType, containsNull) + ArrayType(NullType, containsNull = true) } else { val elementType = elements.map { e => e match { @@ -214,7 +220,7 @@ private[sql] object JsonRDD extends Logging { } }.reduce((type1: DataType, type2: DataType) => compatibleType(type1, type2)) - ArrayType(elementType, containsNull) + ArrayType(elementType, containsNull = true) } } @@ -242,7 +248,7 @@ private[sql] object JsonRDD extends Logging { // The value associated with the key is an array. // Handle inner structs of an array. def buildKeyPathForInnerStructs(v: Any, t: DataType): Seq[(String, DataType)] = t match { - case ArrayType(e: StructType, containsNull) => { + case ArrayType(e: StructType, _) => { // The elements of this arrays are structs. v.asInstanceOf[Seq[Map[String, Any]]].flatMap(Option(_)).flatMap { element => allKeysWithValueTypes(element) @@ -250,7 +256,7 @@ private[sql] object JsonRDD extends Logging { case (k, t) => (s"$key.$k", t) } } - case ArrayType(t1, containsNull) => + case ArrayType(t1, _) => v.asInstanceOf[Seq[Any]].flatMap(Option(_)).flatMap { element => buildKeyPathForInnerStructs(element, t1) } @@ -303,6 +309,10 @@ private[sql] object JsonRDD extends Logging { val parsed = mapper.readValue(record, classOf[Object]) match { case map: java.util.Map[_, _] => scalafy(map).asInstanceOf[Map[String, Any]] :: Nil case list: java.util.List[_] => scalafy(list).asInstanceOf[Seq[Map[String, Any]]] + case _ => + sys.error( + s"Failed to parse record $record. Please make sure that each line of the file " + + "(or each string in the RDD) is a valid JSON object or an array of JSON objects.") } parsed @@ -377,10 +387,12 @@ private[sql] object JsonRDD extends Logging { } } - private def toDate(value: Any): Date = { + private def toDate(value: Any): Int = { value match { // only support string as date - case value: java.lang.String => new Date(DataTypeConversions.stringToTime(value).getTime) + case value: java.lang.String => + DateUtils.millisToDays(DateUtils.stringToTime(value).getTime) + case value: java.sql.Date => DateUtils.fromJavaDate(value) } } @@ -388,7 +400,7 @@ private[sql] object JsonRDD extends Logging { value match { case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong) case value: java.lang.Long => new Timestamp(value) - case value: java.lang.String => toTimestamp(DataTypeConversions.stringToTime(value).getTime) + case value: java.lang.String => toTimestamp(DateUtils.stringToTime(value).getTime) } } @@ -397,16 +409,19 @@ private[sql] object JsonRDD extends Logging { null } else { desiredType match { - case StringType => toString(value) + case StringType => UTF8String(toString(value)) case _ if value == null || value == "" => null // guard the non string type - case IntegerType => value.asInstanceOf[IntegerType.JvmType] + case IntegerType => value.asInstanceOf[IntegerType.InternalType] case LongType => toLong(value) case DoubleType => toDouble(value) case DecimalType() => toDecimal(value) - case BooleanType => value.asInstanceOf[BooleanType.JvmType] + case BooleanType => value.asInstanceOf[BooleanType.InternalType] case NullType => null case ArrayType(elementType, _) => value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) + case MapType(StringType, valueType, _) => + val map = value.asInstanceOf[Map[String, Any]] + map.mapValues(enforceCorrectType(_, valueType)).map(identity) case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct) case DateType => toDate(value) case TimestampType => toTimestamp(value) @@ -428,14 +443,11 @@ private[sql] object JsonRDD extends Logging { /** Transforms a single Row to JSON using Jackson * - * @param jsonFactory a JsonFactory object to construct a JsonGenerator * @param rowSchema the schema object used for conversion + * @param gen a JsonGenerator object * @param row The row to convert */ - private[sql] def rowToJSON(rowSchema: StructType, jsonFactory: JsonFactory)(row: Row): String = { - val writer = new StringWriter() - val gen = jsonFactory.createGenerator(writer) - + private[sql] def rowToJSON(rowSchema: StructType, gen: JsonGenerator)(row: Row) = { def valWriter: (DataType, Any) => Unit = { case (_, null) | (NullType, _) => gen.writeNull() case (StringType, v: String) => gen.writeString(v) @@ -477,8 +489,5 @@ private[sql] object JsonRDD extends Logging { } valWriter(rowSchema, row) - gen.close() - writer.toString } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 6dd39be80703..3f97a11ceb97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -34,8 +34,17 @@ import org.apache.spark.sql.execution.SparkPlan package object sql { /** - * Converts a logical plan into zero or more SparkPlans. + * Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting + * with the query planner and is not designed to be stable across spark releases. Developers + * writing libraries should instead consider using the stable APIs provided in + * [[org.apache.spark.sql.sources]] */ @DeveloperApi type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan] + + /** + * Type alias for [[DataFrame]]. Kept here for backward source compatibility for Scala. + */ + @deprecated("1.3.0", "use DataFrame") + type SchemaRDD = DataFrame } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 9d9150246c8d..36cb5e03bbca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -17,14 +17,20 @@ package org.apache.spark.sql.parquet +import java.sql.Timestamp +import java.util.{TimeZone, Calendar} + import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap} +import jodd.datetime.JDateTime +import parquet.column.Dictionary import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} import parquet.schema.MessageType import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.parquet.CatalystConverter.FieldType import org.apache.spark.sql.types._ +import org.apache.spark.sql.parquet.timestamp.NanoTime /** * Collection of converters of Parquet types (group and primitive types) that @@ -60,6 +66,11 @@ private[sql] object CatalystConverter { // Using a different value will result in Parquet silently dropping columns. val ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME = "bag" val ARRAY_ELEMENTS_SCHEMA_NAME = "array" + // SPARK-4520: Thrift generated parquet files have different array element + // schema names than avro. Thrift parquet uses array_schema_name + "_tuple" + // as opposed to "array" used by default. For more information, check + // TestThriftSchemaConverter.java in parquet.thrift. + val THRIFT_ARRAY_ELEMENTS_SCHEMA_NAME_SUFFIX = "_tuple" val MAP_KEY_SCHEMA_NAME = "key" val MAP_VALUE_SCHEMA_NAME = "value" val MAP_SCHEMA_NAME = "map" @@ -79,7 +90,7 @@ private[sql] object CatalystConverter { createConverter(field.copy(dataType = udt.sqlType), fieldIndex, parent) } // For native JVM types we use a converter with native arrays - case ArrayType(elementType: NativeType, false) => { + case ArrayType(elementType: AtomicType, false) => { new CatalystNativeArrayConverter(elementType, fieldIndex, parent) } // This is for other types of arrays, including those with nested fields @@ -102,22 +113,24 @@ private[sql] object CatalystConverter { } // Strings, Shorts and Bytes do not have a corresponding type in Parquet // so we need to treat them separately - case StringType => { + case StringType => + new CatalystPrimitiveStringConverter(parent, fieldIndex) + case ShortType => { new CatalystPrimitiveConverter(parent, fieldIndex) { - override def addBinary(value: Binary): Unit = - parent.updateString(fieldIndex, value) + override def addInt(value: Int): Unit = + parent.updateShort(fieldIndex, value.asInstanceOf[ShortType.InternalType]) } } - case ShortType => { + case ByteType => { new CatalystPrimitiveConverter(parent, fieldIndex) { override def addInt(value: Int): Unit = - parent.updateShort(fieldIndex, value.asInstanceOf[ShortType.JvmType]) + parent.updateByte(fieldIndex, value.asInstanceOf[ByteType.InternalType]) } } - case ByteType => { + case DateType => { new CatalystPrimitiveConverter(parent, fieldIndex) { override def addInt(value: Int): Unit = - parent.updateByte(fieldIndex, value.asInstanceOf[ByteType.JvmType]) + parent.updateDate(fieldIndex, value.asInstanceOf[DateType.InternalType]) } } case d: DecimalType => { @@ -126,8 +139,15 @@ private[sql] object CatalystConverter { parent.updateDecimal(fieldIndex, value, d) } } + case TimestampType => { + new CatalystPrimitiveConverter(parent, fieldIndex) { + override def addBinary(value: Binary): Unit = + parent.updateTimestamp(fieldIndex, value) + } + } // All other primitive types use the default converter - case ctype: PrimitiveType => { // note: need the type tag here! + case ctype: DataType if ParquetTypesConverter.isPrimitiveType(ctype) => { + // note: need the type tag here! new CatalystPrimitiveConverter(parent, fieldIndex) } case _ => throw new RuntimeException( @@ -179,6 +199,9 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = updateField(fieldIndex, value) + protected[parquet] def updateDate(fieldIndex: Int, value: Int): Unit = + updateField(fieldIndex, value) + protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = updateField(fieldIndex, value) @@ -197,12 +220,14 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = updateField(fieldIndex, value.getBytes) - protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = - updateField(fieldIndex, value.toStringUsingUTF8) + protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = + updateField(fieldIndex, UTF8String(value)) - protected[parquet] def updateDecimal(fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { + protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = + updateField(fieldIndex, readTimestamp(value)) + + protected[parquet] def updateDecimal(fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = updateField(fieldIndex, readDecimal(new Decimal(), value, ctype)) - } protected[parquet] def isRootConverter: Boolean = parent == null @@ -235,6 +260,13 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { unscaled = (unscaled << (64 - numBits)) >> (64 - numBits) dest.set(unscaled, precision, scale) } + + /** + * Read a Timestamp value from a Parquet Int96Value + */ + protected[parquet] def readTimestamp(value: Binary): Timestamp = { + CatalystTimestampConverter.convertToTimestamp(value) + } } /** @@ -293,9 +325,9 @@ private[parquet] class CatalystGroupConverter( override def start(): Unit = { current = ArrayBuffer.fill(size)(null) - converters.foreach { - converter => if (!converter.isPrimitive) { - converter.asInstanceOf[CatalystConverter].clearBuffer + converters.foreach { converter => + if (!converter.isPrimitive) { + converter.asInstanceOf[CatalystConverter].clearBuffer() } } } @@ -366,6 +398,9 @@ private[parquet] class CatalystPrimitiveRowConverter( override protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = current.setInt(fieldIndex, value) + override protected[parquet] def updateDate(fieldIndex: Int, value: Int): Unit = + current.update(fieldIndex, value) + override protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = current.setLong(fieldIndex, value) @@ -384,8 +419,11 @@ private[parquet] class CatalystPrimitiveRowConverter( override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = current.update(fieldIndex, value.getBytes) - override protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = - current.setString(fieldIndex, value.toStringUsingUTF8) + override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = + current.update(fieldIndex, UTF8String(value)) + + override protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = + current.update(fieldIndex, readTimestamp(value)) override protected[parquet] def updateDecimal( fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { @@ -426,10 +464,103 @@ private[parquet] class CatalystPrimitiveConverter( parent.updateLong(fieldIndex, value) } +/** + * A `parquet.io.api.PrimitiveConverter` that converts Parquet Binary to Catalyst String. + * Supports dictionaries to reduce Binary to String conversion overhead. + * + * Follows pattern in Parquet of using dictionaries, where supported, for String conversion. + * + * @param parent The parent group converter. + * @param fieldIndex The index inside the record. + */ +private[parquet] class CatalystPrimitiveStringConverter(parent: CatalystConverter, fieldIndex: Int) + extends CatalystPrimitiveConverter(parent, fieldIndex) { + + private[this] var dict: Array[Array[Byte]] = null + + override def hasDictionarySupport: Boolean = true + + override def setDictionary(dictionary: Dictionary):Unit = + dict = Array.tabulate(dictionary.getMaxId + 1) { dictionary.decodeToBinary(_).getBytes } + + override def addValueFromDictionary(dictionaryId: Int): Unit = + parent.updateString(fieldIndex, dict(dictionaryId)) + + override def addBinary(value: Binary): Unit = + parent.updateString(fieldIndex, value.getBytes) +} + private[parquet] object CatalystArrayConverter { val INITIAL_ARRAY_SIZE = 20 } +private[parquet] object CatalystTimestampConverter { + // TODO most part of this comes from Hive-0.14 + // Hive code might have some issues, so we need to keep an eye on it. + // Also we use NanoTime and Int96Values from parquet-examples. + // We utilize jodd to convert between NanoTime and Timestamp + val parquetTsCalendar = new ThreadLocal[Calendar] + def getCalendar: Calendar = { + // this is a cache for the calendar instance. + if (parquetTsCalendar.get == null) { + parquetTsCalendar.set(Calendar.getInstance(TimeZone.getTimeZone("GMT"))) + } + parquetTsCalendar.get + } + val NANOS_PER_SECOND: Long = 1000000000 + val SECONDS_PER_MINUTE: Long = 60 + val MINUTES_PER_HOUR: Long = 60 + val NANOS_PER_MILLI: Long = 1000000 + + def convertToTimestamp(value: Binary): Timestamp = { + val nt = NanoTime.fromBinary(value) + val timeOfDayNanos = nt.getTimeOfDayNanos + val julianDay = nt.getJulianDay + val jDateTime = new JDateTime(julianDay.toDouble) + val calendar = getCalendar + calendar.set(Calendar.YEAR, jDateTime.getYear) + calendar.set(Calendar.MONTH, jDateTime.getMonth - 1) + calendar.set(Calendar.DAY_OF_MONTH, jDateTime.getDay) + + // written in command style + var remainder = timeOfDayNanos + calendar.set( + Calendar.HOUR_OF_DAY, + (remainder / (NANOS_PER_SECOND * SECONDS_PER_MINUTE * MINUTES_PER_HOUR)).toInt) + remainder = remainder % (NANOS_PER_SECOND * SECONDS_PER_MINUTE * MINUTES_PER_HOUR) + calendar.set( + Calendar.MINUTE, (remainder / (NANOS_PER_SECOND * SECONDS_PER_MINUTE)).toInt) + remainder = remainder % (NANOS_PER_SECOND * SECONDS_PER_MINUTE) + calendar.set(Calendar.SECOND, (remainder / NANOS_PER_SECOND).toInt) + val nanos = remainder % NANOS_PER_SECOND + val ts = new Timestamp(calendar.getTimeInMillis) + ts.setNanos(nanos.toInt) + ts + } + + def convertFromTimestamp(ts: Timestamp): Binary = { + val calendar = getCalendar + calendar.setTime(ts) + val jDateTime = new JDateTime(calendar.get(Calendar.YEAR), + calendar.get(Calendar.MONTH) + 1, calendar.get(Calendar.DAY_OF_MONTH)) + // Hive-0.14 didn't set hour before get day number, while the day number should + // has something to do with hour, since julian day number grows at 12h GMT + // here we just follow what hive does. + val julianDay = jDateTime.getJulianDayNumber + + val hour = calendar.get(Calendar.HOUR_OF_DAY) + val minute = calendar.get(Calendar.MINUTE) + val second = calendar.get(Calendar.SECOND) + val nanos = ts.getNanos + // Hive-0.14 would use hours directly, that might be wrong, since the day starts + // from 12h in Julian. here we just follow what hive does. + val nanosOfDay = nanos + second * NANOS_PER_SECOND + + minute * NANOS_PER_SECOND * SECONDS_PER_MINUTE + + hour * NANOS_PER_SECOND * SECONDS_PER_MINUTE * MINUTES_PER_HOUR + NanoTime(julianDay, nanosOfDay).toBinary + } +} + /** * A `parquet.io.api.GroupConverter` that converts a single-element groups that * match the characteristics of an array (see @@ -482,7 +613,7 @@ private[parquet] class CatalystArrayConverter( override def start(): Unit = { if (!converter.isPrimitive) { - converter.asInstanceOf[CatalystConverter].clearBuffer + converter.asInstanceOf[CatalystConverter].clearBuffer() } } @@ -506,13 +637,13 @@ private[parquet] class CatalystArrayConverter( * @param capacity The (initial) capacity of the buffer */ private[parquet] class CatalystNativeArrayConverter( - val elementType: NativeType, + val elementType: AtomicType, val index: Int, protected[parquet] val parent: CatalystConverter, protected[parquet] var capacity: Int = CatalystArrayConverter.INITIAL_ARRAY_SIZE) extends CatalystConverter { - type NativeType = elementType.JvmType + type NativeType = elementType.InternalType private var buffer: Array[NativeType] = elementType.classTag.newArray(capacity) @@ -583,9 +714,9 @@ private[parquet] class CatalystNativeArrayConverter( elements += 1 } - override protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = { + override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = { checkGrowBuffer() - buffer(elements) = value.toStringUsingUTF8.asInstanceOf[NativeType] + buffer(elements) = UTF8String(value).asInstanceOf[NativeType] elements += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index f08350878f23..5eb1c6abc243 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -55,7 +55,7 @@ private[sql] object ParquetFilters { case StringType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), - Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) case BinaryType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), @@ -76,7 +76,7 @@ private[sql] object ParquetFilters { case StringType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), - Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) case BinaryType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), @@ -94,7 +94,7 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) case BinaryType => (n: String, v: Any) => FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -111,7 +111,7 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) case BinaryType => (n: String, v: Any) => FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -128,7 +128,7 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) case BinaryType => (n: String, v: Any) => FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -145,7 +145,7 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) case BinaryType => (n: String, v: Any) => FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -164,33 +164,57 @@ private[sql] object ParquetFilters { case EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeEq.lift(dataType).map(_(name, value)) + case EqualTo(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => + makeEq.lift(dataType).map(_(name, value)) case EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeEq.lift(dataType).map(_(name, value)) - + case EqualTo(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => + makeEq.lift(dataType).map(_(name, value)) + case Not(EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType))) => makeNotEq.lift(dataType).map(_(name, value)) + case Not(EqualTo(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _))) => + makeNotEq.lift(dataType).map(_(name, value)) case Not(EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _))) => makeNotEq.lift(dataType).map(_(name, value)) + case Not(EqualTo(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType))) => + makeNotEq.lift(dataType).map(_(name, value)) case LessThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeLt.lift(dataType).map(_(name, value)) + case LessThan(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => + makeLt.lift(dataType).map(_(name, value)) case LessThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeGt.lift(dataType).map(_(name, value)) + case LessThan(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => + makeGt.lift(dataType).map(_(name, value)) case LessThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeLtEq.lift(dataType).map(_(name, value)) + case LessThanOrEqual(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => + makeLtEq.lift(dataType).map(_(name, value)) case LessThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeGtEq.lift(dataType).map(_(name, value)) + case LessThanOrEqual(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => + makeGtEq.lift(dataType).map(_(name, value)) case GreaterThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeGt.lift(dataType).map(_(name, value)) + case GreaterThan(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => + makeGt.lift(dataType).map(_(name, value)) case GreaterThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeLt.lift(dataType).map(_(name, value)) + case GreaterThan(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => + makeLt.lift(dataType).map(_(name, value)) case GreaterThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeGtEq.lift(dataType).map(_(name, value)) + case GreaterThanOrEqual(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => + makeGtEq.lift(dataType).map(_(name, value)) case GreaterThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeLtEq.lift(dataType).map(_(name, value)) + case GreaterThanOrEqual(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => + makeLtEq.lift(dataType).map(_(name, value)) case And(lhs, rhs) => (createFilter(lhs) ++ createFilter(rhs)).reduceOption(FilterApi.and) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index cde5160149e9..fcb9513ab66f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -18,15 +18,17 @@ package org.apache.spark.sql.parquet import java.io.IOException +import java.util.logging.Level import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.permission.FsAction -import parquet.hadoop.ParquetOutputFormat +import org.apache.spark.sql.types.{StructType, DataType} +import parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat} import parquet.hadoop.metadata.CompressionCodecName import parquet.schema.MessageType -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException} import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} @@ -34,8 +36,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Stati /** * Relation that consists of data stored in a Parquet columnar format. * - * Users should interact with parquet files though a SchemaRDD, created by a [[SQLContext]] instead - * of using this class directly. + * Users should interact with parquet files though a [[DataFrame]], created by a [[SQLContext]] + * instead of using this class directly. * * {{{ * val parquetRDD = sqlContext.parquetFile("path/to/parquet.file") @@ -65,20 +67,26 @@ private[sql] case class ParquetRelation( ParquetTypesConverter.readSchemaFromFile( new Path(path.split(",").head), conf, - sqlContext.conf.isParquetBinaryAsString) - + sqlContext.conf.isParquetBinaryAsString, + sqlContext.conf.isParquetINT96AsTimestamp) lazy val attributeMap = AttributeMap(output.map(o => o -> o)) - override def newInstance() = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] + override def newInstance(): this.type = { + ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] + } // Equals must also take into account the output attributes so that we can distinguish between // different instances of the same relation, - override def equals(other: Any) = other match { + override def equals(other: Any): Boolean = other match { case p: ParquetRelation => p.path == path && p.output == output case _ => false } + override def hashCode: Int = { + com.google.common.base.Objects.hashCode(path, output) + } + // TODO: Use data from the footers. override lazy val statistics = Statistics(sizeInBytes = sqlContext.conf.defaultSizeInBytes) } @@ -91,7 +99,7 @@ private[sql] object ParquetRelation { // checks first to see if there's any handlers already set // and if not it creates them. If this method executes prior // to that class being loaded then: - // 1) there's no handlers installed so there's none to + // 1) there's no handlers installed so there's none to // remove. But when it IS finally loaded the desired affect // of removing them is circumvented. // 2) The parquet.Log static initializer calls setUseParentHanders(false) @@ -99,7 +107,7 @@ private[sql] object ParquetRelation { // // Therefore we need to force the class to be loaded. // This should really be resolved by Parquet. - Class.forName(classOf[parquet.Log].getName()) + Class.forName(classOf[parquet.Log].getName) // Note: Logger.getLogger("parquet") has a default logger // that appends to Console which needs to be cleared. @@ -108,6 +116,11 @@ private[sql] object ParquetRelation { // TODO(witgo): Need to set the log level ? // if(parquetLogger.getLevel != null) parquetLogger.setLevel(null) if (!parquetLogger.getUseParentHandlers) parquetLogger.setUseParentHandlers(true) + + // Disables WARN log message in ParquetOutputCommitter. + // See https://issues.apache.org/jira/browse/SPARK-5968 for details + Class.forName(classOf[ParquetOutputCommitter].getName) + java.util.logging.Logger.getLogger(classOf[ParquetOutputCommitter].getName).setLevel(Level.OFF) } // The element type for the RDDs that this relation maps to. @@ -166,9 +179,13 @@ private[sql] object ParquetRelation { sqlContext.conf.parquetCompressionCodec.toUpperCase, CompressionCodecName.UNCOMPRESSED) .name()) ParquetRelation.enableLogForwarding() - ParquetTypesConverter.writeMetaData(attributes, path, conf) + // This is a hack. We always set nullable/containsNull/valueContainsNull to true + // for the schema of a parquet data. + val schema = StructType.fromAttributes(attributes).asNullable + val newAttributes = schema.toAttributes + ParquetTypesConverter.writeMetaData(newAttributes, path, conf) new ParquetRelation(path.toString, Some(conf), sqlContext) { - override val output = attributes + override val output = newAttributes } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 28cd17fde46a..a938b7757868 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -19,10 +19,9 @@ package org.apache.spark.sql.parquet import java.io.IOException import java.lang.{Long => JLong} -import java.text.SimpleDateFormat -import java.text.NumberFormat +import java.text.{NumberFormat, SimpleDateFormat} import java.util.concurrent.{Callable, TimeUnit} -import java.util.{ArrayList, Collections, Date, List => JList} +import java.util.{Date, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable @@ -43,11 +42,13 @@ import parquet.io.ParquetDecodingException import parquet.schema.MessageType import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row, _} import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} +import org.apache.spark.sql.types.StructType import org.apache.spark.{Logging, SerializableWritable, TaskContext} /** @@ -55,7 +56,7 @@ import org.apache.spark.{Logging, SerializableWritable, TaskContext} * Parquet table scan operator. Imports the file that backs the given * [[org.apache.spark.sql.parquet.ParquetRelation]] as a ``RDD[Row]``. */ -case class ParquetTableScan( +private[sql] case class ParquetTableScan( attributes: Seq[Attribute], relation: ParquetRelation, columnPruningPred: Seq[Expression]) @@ -125,6 +126,13 @@ case class ParquetTableScan( conf) if (requestedPartitionOrdinals.nonEmpty) { + // This check is based on CatalystConverter.createRootConverter. + val primitiveRow = output.forall(a => ParquetTypesConverter.isPrimitiveType(a.dataType)) + + // Uses temporary variable to avoid the whole `ParquetTableScan` object being captured into + // the `mapPartitionsWithInputSplit` closure below. + val outputSize = output.size + baseRDD.mapPartitionsWithInputSplit { case (split, iter) => val partValue = "([^=]+)=([^=]+)".r val partValues = @@ -142,19 +150,47 @@ case class ParquetTableScan( relation.partitioningAttributes .map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow)) - new Iterator[Row] { - def hasNext = iter.hasNext - def next() = { - val row = iter.next()._2.asInstanceOf[SpecificMutableRow] - - // Parquet will leave partitioning columns empty, so we fill them in here. - var i = 0 - while (i < requestedPartitionOrdinals.size) { - row(requestedPartitionOrdinals(i)._2) = - partitionRowValues(requestedPartitionOrdinals(i)._1) - i += 1 + if (primitiveRow) { + new Iterator[Row] { + def hasNext: Boolean = iter.hasNext + def next(): Row = { + // We are using CatalystPrimitiveRowConverter and it returns a SpecificMutableRow. + val row = iter.next()._2.asInstanceOf[SpecificMutableRow] + + // Parquet will leave partitioning columns empty, so we fill them in here. + var i = 0 + while (i < requestedPartitionOrdinals.size) { + row(requestedPartitionOrdinals(i)._2) = + partitionRowValues(requestedPartitionOrdinals(i)._1) + i += 1 + } + row + } + } + } else { + // Create a mutable row since we need to fill in values from partition columns. + val mutableRow = new GenericMutableRow(outputSize) + new Iterator[Row] { + def hasNext: Boolean = iter.hasNext + def next(): Row = { + // We are using CatalystGroupConverter and it returns a GenericRow. + // Since GenericRow is not mutable, we just cast it to a Row. + val row = iter.next()._2.asInstanceOf[Row] + + var i = 0 + while (i < row.size) { + mutableRow(i) = row(i) + i += 1 + } + // Parquet will leave partitioning columns empty, so we fill them in here. + i = 0 + while (i < requestedPartitionOrdinals.size) { + mutableRow(requestedPartitionOrdinals(i)._2) = + partitionRowValues(requestedPartitionOrdinals(i)._1) + i += 1 + } + mutableRow } - row } } } @@ -210,7 +246,7 @@ case class ParquetTableScan( * (only detected via filename pattern so will not catch all cases). */ @DeveloperApi -case class InsertIntoParquetTable( +private[sql] case class InsertIntoParquetTable( relation: ParquetRelation, child: SparkPlan, overwrite: Boolean = false) @@ -219,7 +255,7 @@ case class InsertIntoParquetTable( /** * Inserts all rows into the Parquet file. */ - override def execute() = { + override def execute(): RDD[Row] = { // TODO: currently we do not check whether the "schema"s are compatible // That means if one first creates a table and then INSERTs data with // and incompatible schema the execution will fail. It would be nice @@ -232,7 +268,7 @@ case class InsertIntoParquetTable( val job = new Job(sqlContext.sparkContext.hadoopConfiguration) val writeSupport = - if (child.output.map(_.dataType).forall(_.isPrimitive)) { + if (child.output.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) { log.debug("Initializing MutableRowWriteSupport") classOf[org.apache.spark.sql.parquet.MutableRowWriteSupport] } else { @@ -242,7 +278,10 @@ case class InsertIntoParquetTable( ParquetOutputFormat.setWriteSupportClass(job, writeSupport) val conf = ContextUtil.getConfiguration(job) - RowWriteSupport.setSchema(relation.output, conf) + // This is a hack. We always set nullable/containsNull/valueContainsNull to true + // for the schema of a parquet data. + val schema = StructType.fromAttributes(relation.output).asNullable + RowWriteSupport.setSchema(schema.toAttributes, conf) val fspath = new Path(relation.path) val fs = fspath.getFileSystem(conf) @@ -263,7 +302,7 @@ case class InsertIntoParquetTable( childRdd } - override def output = child.output + override def output: Seq[Attribute] = child.output /** * Stores the given Row RDD as a Hadoop file. @@ -317,7 +356,7 @@ case class InsertIntoParquetTable( } finally { writer.close(hadoopContext) } - committer.commitTask(hadoopContext) + SparkHadoopMapRedUtil.commitTask(committer, hadoopContext, context) 1 } val jobFormat = new AppendingParquetOutputFormat(taskIdOffset) @@ -373,8 +412,6 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) private[parquet] class FilteringParquetRowInputFormat extends parquet.hadoop.ParquetInputFormat[Row] with Logging { - private var footers: JList[Footer] = _ - private var fileStatuses = Map.empty[Path, FileStatus] override def createRecordReader( @@ -395,46 +432,15 @@ private[parquet] class FilteringParquetRowInputFormat } } - override def getFooters(jobContext: JobContext): JList[Footer] = { - import org.apache.spark.sql.parquet.FilteringParquetRowInputFormat.footerCache - - if (footers eq null) { - val conf = ContextUtil.getConfiguration(jobContext) - val cacheMetadata = conf.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true) - val statuses = listStatus(jobContext) - fileStatuses = statuses.map(file => file.getPath -> file).toMap - if (statuses.isEmpty) { - footers = Collections.emptyList[Footer] - } else if (!cacheMetadata) { - // Read the footers from HDFS - footers = getFooters(conf, statuses) - } else { - // Read only the footers that are not in the footerCache - val foundFooters = footerCache.getAllPresent(statuses) - val toFetch = new ArrayList[FileStatus] - for (s <- statuses) { - if (!foundFooters.containsKey(s)) { - toFetch.add(s) - } - } - val newFooters = new mutable.HashMap[FileStatus, Footer] - if (toFetch.size > 0) { - val startFetch = System.currentTimeMillis - val fetched = getFooters(conf, toFetch) - logInfo(s"Fetched $toFetch footers in ${System.currentTimeMillis - startFetch} ms") - for ((status, i) <- toFetch.zipWithIndex) { - newFooters(status) = fetched.get(i) - } - footerCache.putAll(newFooters) - } - footers = new ArrayList[Footer](statuses.size) - for (status <- statuses) { - footers.add(newFooters.getOrElse(status, foundFooters.get(status))) - } - } - } + // This is only a temporary solution sicne we need to use fileStatuses in + // both getClientSideSplits and getTaskSideSplits. It can be removed once we get rid of these + // two methods. + override def getSplits(jobContext: JobContext): JList[InputSplit] = { + // First set fileStatuses. + val statuses = listStatus(jobContext) + fileStatuses = statuses.map(file => file.getPath -> file).toMap - footers + super.getSplits(jobContext) } // TODO Remove this method and related code once PARQUET-16 is fixed @@ -459,13 +465,21 @@ private[parquet] class FilteringParquetRowInputFormat val getGlobalMetaData = classOf[ParquetFileWriter].getDeclaredMethod("getGlobalMetaData", classOf[JList[Footer]]) getGlobalMetaData.setAccessible(true) - val globalMetaData = getGlobalMetaData.invoke(null, footers).asInstanceOf[GlobalMetaData] + var globalMetaData = getGlobalMetaData.invoke(null, footers).asInstanceOf[GlobalMetaData] if (globalMetaData == null) { val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] return splits } + val metadata = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) + val mergedMetadata = globalMetaData + .getKeyValueMetaData + .updated(RowReadSupport.SPARK_METADATA_KEY, setAsJavaSet(Set(metadata))) + + globalMetaData = new GlobalMetaData(globalMetaData.getSchema, + mergedMetadata, globalMetaData.getCreatedBy) + val readContext = getReadSupport(configuration).init( new InitContext(configuration, globalMetaData.getKeyValueMetaData, @@ -498,6 +512,7 @@ private[parquet] class FilteringParquetRowInputFormat import parquet.filter2.compat.FilterCompat.Filter import parquet.filter2.compat.RowGroupFilter + import org.apache.spark.sql.parquet.FilteringParquetRowInputFormat.blockLocationCache val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true) @@ -647,6 +662,6 @@ private[parquet] object FileSystemHelper { sys.error("ERROR: attempting to append to set of Parquet files and found file" + s"that does not match name pattern: $other") case _ => 0 - }.reduceLeft((a, b) => if (a < b) b else a) + }.reduceOption(_ max _).getOrElse(0) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index fd63ad814406..c45c431438ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -83,7 +83,8 @@ private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { // TODO: Why it can be null? if (schema == null) { log.debug("falling back to Parquet read schema") - schema = ParquetTypesConverter.convertToAttributes(parquetSchema, false) + schema = ParquetTypesConverter.convertToAttributes( + parquetSchema, false, true) } log.debug(s"list of attributes that will be read: $schema") new RowRecordMaterializer(parquetSchema, schema) @@ -98,7 +99,11 @@ private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { val requestedAttributes = RowReadSupport.getRequestedSchema(configuration) if (requestedAttributes != null) { - parquetSchema = ParquetTypesConverter.convertFromAttributes(requestedAttributes) + // If the parquet file is thrift derived, there is a good chance that + // it will have the thrift class in metadata. + val isThriftDerived = keyValueMetaData.keySet().contains("thrift.class") + parquetSchema = ParquetTypesConverter + .convertFromAttributes(requestedAttributes, isThriftDerived) metadata.put( RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, ParquetTypesConverter.convertToString(requestedAttributes)) @@ -154,7 +159,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { val attributesSize = attributes.size if (attributesSize > record.size) { throw new IndexOutOfBoundsException( - s"Trying to write more fields than contained in row (${attributesSize}>${record.size})") + s"Trying to write more fields than contained in row ($attributesSize > ${record.size})") } var index = 0 @@ -184,28 +189,27 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { case t @ StructType(_) => writeStruct( t, value.asInstanceOf[CatalystConverter.StructScalaType[_]]) - case _ => writePrimitive(schema.asInstanceOf[PrimitiveType], value) + case _ => writePrimitive(schema.asInstanceOf[AtomicType], value) } } } - private[parquet] def writePrimitive(schema: PrimitiveType, value: Any): Unit = { + private[parquet] def writePrimitive(schema: DataType, value: Any): Unit = { if (value != null) { schema match { case StringType => writer.addBinary( - Binary.fromByteArray( - value.asInstanceOf[String].getBytes("utf-8") - ) - ) + Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes)) case BinaryType => writer.addBinary( Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) case IntegerType => writer.addInteger(value.asInstanceOf[Int]) case ShortType => writer.addInteger(value.asInstanceOf[Short]) case LongType => writer.addLong(value.asInstanceOf[Long]) + case TimestampType => writeTimestamp(value.asInstanceOf[java.sql.Timestamp]) case ByteType => writer.addInteger(value.asInstanceOf[Byte]) case DoubleType => writer.addDouble(value.asInstanceOf[Double]) case FloatType => writer.addFloat(value.asInstanceOf[Float]) case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean]) + case DateType => writer.addInteger(value.asInstanceOf[Int]) case d: DecimalType => if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { sys.error(s"Unsupported datatype $d, cannot write to consumer") @@ -307,6 +311,10 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes)) } + private[parquet] def writeTimestamp(ts: java.sql.Timestamp): Unit = { + val binaryNanoTime = CatalystTimestampConverter.convertFromTimestamp(ts) + writer.addBinary(binaryNanoTime) + } } // Optimized for non-nested rows @@ -315,7 +323,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { val attributesSize = attributes.size if (attributesSize > record.size) { throw new IndexOutOfBoundsException( - s"Trying to write more fields than contained in row (${attributesSize}>${record.size})") + s"Trying to write more fields than contained in row ($attributesSize > ${record.size})") } var index = 0 @@ -338,10 +346,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { index: Int): Unit = { ctype match { case StringType => writer.addBinary( - Binary.fromByteArray( - record(index).asInstanceOf[String].getBytes("utf-8") - ) - ) + Binary.fromByteArray(record(index).asInstanceOf[UTF8String].getBytes)) case BinaryType => writer.addBinary( Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) case IntegerType => writer.addInteger(record.getInt(index)) @@ -351,6 +356,8 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { case DoubleType => writer.addDouble(record.getDouble(index)) case FloatType => writer.addFloat(record.getFloat(index)) case BooleanType => writer.addBoolean(record.getBoolean(index)) + case DateType => writer.addInteger(record.getInt(index)) + case TimestampType => writeTimestamp(record(index).asInstanceOf[java.sql.Timestamp]) case d: DecimalType => if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { sys.error(s"Unsupported datatype $d, cannot write to consumer") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala index 02ce1b3e6d81..9d17516e0ef7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala @@ -23,8 +23,7 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import scala.util.Try -import org.apache.spark.sql.{SQLContext, SchemaRDD} -import org.apache.spark.sql.catalyst.util +import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} import org.apache.spark.util.Utils /** @@ -34,10 +33,11 @@ import org.apache.spark.util.Utils * convenient to use tuples rather than special case classes when writing test cases/suites. * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. */ -trait ParquetTest { +private[sql] trait ParquetTest { val sqlContext: SQLContext - import sqlContext._ + import sqlContext.implicits.{localSeqToDataFrameHolder, rddToDataFrameHolder} + import sqlContext.{conf, sparkContext} protected def configuration = sparkContext.hadoopConfiguration @@ -49,11 +49,11 @@ trait ParquetTest { */ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(getConf(key)).toOption) - (keys, values).zipped.foreach(setConf) + val currentValues = keys.map(key => Try(conf.getConf(key)).toOption) + (keys, values).zipped.foreach(conf.setConf) try f finally { keys.zip(currentValues).foreach { - case (key, Some(value)) => setConf(key, value) + case (key, Some(value)) => conf.setConf(key, value) case (key, None) => conf.unsetConf(key) } } @@ -66,8 +66,9 @@ trait ParquetTest { * @todo Probably this method should be moved to a more general place */ protected def withTempPath(f: File => Unit): Unit = { - val file = util.getTempFilePath("parquetTest").getCanonicalFile - try f(file) finally if (file.exists()) Utils.deleteRecursively(file) + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) } /** @@ -89,39 +90,66 @@ trait ParquetTest { (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - sparkContext.parallelize(data).saveAsParquetFile(file.getCanonicalPath) + sparkContext.parallelize(data).toDF().saveAsParquetFile(file.getCanonicalPath) f(file.getCanonicalPath) } } /** - * Writes `data` to a Parquet file and reads it back as a SchemaRDD, which is then passed to `f`. - * The Parquet file will be deleted after `f` returns. + * Writes `data` to a Parquet file and reads it back as a [[DataFrame]], + * which is then passed to `f`. The Parquet file will be deleted after `f` returns. */ - protected def withParquetRDD[T <: Product: ClassTag: TypeTag] + protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag] (data: Seq[T]) - (f: SchemaRDD => Unit): Unit = { - withParquetFile(data)(path => f(parquetFile(path))) + (f: DataFrame => Unit): Unit = { + withParquetFile(data)(path => f(sqlContext.parquetFile(path))) } /** * Drops temporary table `tableName` after calling `f`. */ protected def withTempTable(tableName: String)(f: => Unit): Unit = { - try f finally dropTempTable(tableName) + try f finally sqlContext.dropTempTable(tableName) } /** - * Writes `data` to a Parquet file, reads it back as a SchemaRDD and registers it as a temporary - * table named `tableName`, then call `f`. The temporary table together with the Parquet file will - * be dropped/deleted after `f` returns. + * Writes `data` to a Parquet file, reads it back as a [[DataFrame]] and registers it as a + * temporary table named `tableName`, then call `f`. The temporary table together with the + * Parquet file will be dropped/deleted after `f` returns. */ protected def withParquetTable[T <: Product: ClassTag: TypeTag] (data: Seq[T], tableName: String) (f: => Unit): Unit = { - withParquetRDD(data) { rdd => - rdd.registerTempTable(tableName) + withParquetDataFrame(data) { df => + sqlContext.registerDataFrameAsTable(df, tableName) withTempTable(tableName)(f) } } + + protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( + data: Seq[T], path: File): Unit = { + data.toDF().save(path.getCanonicalPath, "org.apache.spark.sql.parquet", SaveMode.Overwrite) + } + + protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( + df: DataFrame, path: File): Unit = { + df.save(path.getCanonicalPath, "org.apache.spark.sql.parquet", SaveMode.Overwrite) + } + + protected def makePartitionDir( + basePath: File, + defaultPartitionName: String, + partitionCols: (String, Any)*): File = { + val partNames = partitionCols.map { case (k, v) => + val valueString = if (v == null || v == "") defaultPartitionName else v.toString + s"$k=$valueString" + } + + val partDir = partNames.foldLeft(basePath) { (parent, child) => + new File(parent, child) + } + + assert(partDir.mkdirs(), s"Couldn't create directory $partDir") + partDir + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala deleted file mode 100644 index d5993656e022..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala +++ /dev/null @@ -1,462 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import java.io.File - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} -import org.apache.hadoop.mapreduce.Job -import org.apache.spark.sql.test.TestSQLContext - -import parquet.example.data.{GroupWriter, Group} -import parquet.example.data.simple.SimpleGroup -import parquet.hadoop.{ParquetReader, ParquetFileReader, ParquetWriter} -import parquet.hadoop.api.WriteSupport -import parquet.hadoop.api.WriteSupport.WriteContext -import parquet.hadoop.example.GroupReadSupport -import parquet.hadoop.util.ContextUtil -import parquet.io.api.RecordConsumer -import parquet.schema.{MessageType, MessageTypeParser} - -import org.apache.spark.util.Utils - -// Write support class for nested groups: ParquetWriter initializes GroupWriteSupport -// with an empty configuration (it is after all not intended to be used in this way?) -// and members are private so we need to make our own in order to pass the schema -// to the writer. -private class TestGroupWriteSupport(schema: MessageType) extends WriteSupport[Group] { - var groupWriter: GroupWriter = null - override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { - groupWriter = new GroupWriter(recordConsumer, schema) - } - override def init(configuration: Configuration): WriteContext = { - new WriteContext(schema, new java.util.HashMap[String, String]()) - } - override def write(record: Group) { - groupWriter.write(record) - } -} - -private[sql] object ParquetTestData { - - val testSchema = - """message myrecord { - optional boolean myboolean; - optional int32 myint; - optional binary mystring (UTF8); - optional int64 mylong; - optional float myfloat; - optional double mydouble; - }""" - - // field names for test assertion error messages - val testSchemaFieldNames = Seq( - "myboolean:Boolean", - "myint:Int", - "mystring:String", - "mylong:Long", - "myfloat:Float", - "mydouble:Double" - ) - - val subTestSchema = - """ - message myrecord { - optional boolean myboolean; - optional int64 mylong; - } - """ - - val testFilterSchema = - """ - message myrecord { - required boolean myboolean; - required int32 myint; - required binary mystring (UTF8); - required int64 mylong; - required float myfloat; - required double mydouble; - optional boolean myoptboolean; - optional int32 myoptint; - optional binary myoptstring (UTF8); - optional int64 myoptlong; - optional float myoptfloat; - optional double myoptdouble; - } - """ - - // field names for test assertion error messages - val subTestSchemaFieldNames = Seq( - "myboolean:Boolean", - "mylong:Long" - ) - - val testDir = Utils.createTempDir() - val testFilterDir = Utils.createTempDir() - - lazy val testData = new ParquetRelation(testDir.toURI.toString, None, TestSQLContext) - - val testNestedSchema1 = - // based on blogpost example, source: - // https://blog.twitter.com/2013/dremel-made-simple-with-parquet - // note: instead of string we have to use binary (?) otherwise - // Parquet gives us: - // IllegalArgumentException: expected one of [INT64, INT32, BOOLEAN, - // BINARY, FLOAT, DOUBLE, INT96, FIXED_LEN_BYTE_ARRAY] - // Also repeated primitives seem tricky to convert (AvroParquet - // only uses them in arrays?) so only use at most one in each group - // and nothing else in that group (-> is mapped to array)! - // The "values" inside ownerPhoneNumbers is a keyword currently - // so that array types can be translated correctly. - """ - message AddressBook { - required binary owner (UTF8); - optional group ownerPhoneNumbers { - repeated binary array (UTF8); - } - optional group contacts { - repeated group array { - required binary name (UTF8); - optional binary phoneNumber (UTF8); - } - } - } - """ - - - val testNestedSchema2 = - """ - message TestNested2 { - required int32 firstInt; - optional int32 secondInt; - optional group longs { - repeated int64 array; - } - required group entries { - repeated group array { - required double value; - optional boolean truth; - } - } - optional group outerouter { - repeated group array { - repeated group array { - repeated int32 array; - } - } - } - } - """ - - val testNestedSchema3 = - """ - message TestNested3 { - required int32 x; - optional group booleanNumberPairs { - repeated group array { - required int32 key; - optional group value { - repeated group array { - required double nestedValue; - optional boolean truth; - } - } - } - } - } - """ - - val testNestedSchema4 = - """ - message TestNested4 { - required int32 x; - optional group data1 { - repeated group map { - required binary key (UTF8); - required int32 value; - } - } - required group data2 { - repeated group map { - required binary key (UTF8); - required group value { - required int64 payload1; - optional binary payload2 (UTF8); - } - } - } - } - """ - - val testNestedDir1 = Utils.createTempDir() - val testNestedDir2 = Utils.createTempDir() - val testNestedDir3 = Utils.createTempDir() - val testNestedDir4 = Utils.createTempDir() - - lazy val testNestedData1 = - new ParquetRelation(testNestedDir1.toURI.toString, None, TestSQLContext) - lazy val testNestedData2 = - new ParquetRelation(testNestedDir2.toURI.toString, None, TestSQLContext) - - def writeFile() = { - testDir.delete() - val path: Path = new Path(new Path(testDir.toURI), new Path("part-r-0.parquet")) - val job = new Job() - val schema: MessageType = MessageTypeParser.parseMessageType(testSchema) - val writeSupport = new TestGroupWriteSupport(schema) - val writer = new ParquetWriter[Group](path, writeSupport) - - for(i <- 0 until 15) { - val record = new SimpleGroup(schema) - if (i % 3 == 0) { - record.add(0, true) - } else { - record.add(0, false) - } - if (i % 5 == 0) { - record.add(1, 5) - } - record.add(2, "abc") - record.add(3, i.toLong << 33) - record.add(4, 2.5F) - record.add(5, 4.5D) - writer.write(record) - } - writer.close() - } - - def writeFilterFile(records: Int = 200) = { - // for microbenchmark use: records = 300000000 - testFilterDir.delete - val path: Path = new Path(new Path(testFilterDir.toURI), new Path("part-r-0.parquet")) - val schema: MessageType = MessageTypeParser.parseMessageType(testFilterSchema) - val writeSupport = new TestGroupWriteSupport(schema) - val writer = new ParquetWriter[Group](path, writeSupport) - - for(i <- 0 to records) { - val record = new SimpleGroup(schema) - if (i % 4 == 0) { - record.add(0, true) - } else { - record.add(0, false) - } - record.add(1, i) - record.add(2, i.toString) - record.add(3, i.toLong) - record.add(4, i.toFloat + 0.5f) - record.add(5, i.toDouble + 0.5d) - if (i % 2 == 0) { - if (i % 3 == 0) { - record.add(6, true) - } else { - record.add(6, false) - } - record.add(7, i) - record.add(8, i.toString) - record.add(9, i.toLong) - record.add(10, i.toFloat + 0.5f) - record.add(11, i.toDouble + 0.5d) - } - - writer.write(record) - } - writer.close() - } - - def writeNestedFile1() { - // example data from https://blog.twitter.com/2013/dremel-made-simple-with-parquet - testNestedDir1.delete() - val path: Path = new Path(new Path(testNestedDir1.toURI), new Path("part-r-0.parquet")) - val schema: MessageType = MessageTypeParser.parseMessageType(testNestedSchema1) - - val r1 = new SimpleGroup(schema) - r1.add(0, "Julien Le Dem") - r1.addGroup(1) - .append(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, "555 123 4567") - .append(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, "555 666 1337") - .append(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, "XXX XXX XXXX") - val contacts = r1.addGroup(2) - contacts.addGroup(0) - .append("name", "Dmitriy Ryaboy") - .append("phoneNumber", "555 987 6543") - contacts.addGroup(0) - .append("name", "Chris Aniszczyk") - - val r2 = new SimpleGroup(schema) - r2.add(0, "A. Nonymous") - - val writeSupport = new TestGroupWriteSupport(schema) - val writer = new ParquetWriter[Group](path, writeSupport) - writer.write(r1) - writer.write(r2) - writer.close() - } - - def writeNestedFile2() { - testNestedDir2.delete() - val path: Path = new Path(new Path(testNestedDir2.toURI), new Path("part-r-0.parquet")) - val schema: MessageType = MessageTypeParser.parseMessageType(testNestedSchema2) - - val r1 = new SimpleGroup(schema) - r1.add(0, 1) - r1.add(1, 7) - val longs = r1.addGroup(2) - longs.add(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME , 1.toLong << 32) - longs.add(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 1.toLong << 33) - longs.add(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 1.toLong << 34) - val booleanNumberPair = r1.addGroup(3).addGroup(0) - booleanNumberPair.add("value", 2.5) - booleanNumberPair.add("truth", false) - val top_level = r1.addGroup(4) - val second_level_a = top_level.addGroup(0) - val second_level_b = top_level.addGroup(0) - val third_level_aa = second_level_a.addGroup(0) - val third_level_ab = second_level_a.addGroup(0) - val third_level_c = second_level_b.addGroup(0) - third_level_aa.add( - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, - 7) - third_level_ab.add( - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, - 8) - third_level_c.add( - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, - 9) - - val writeSupport = new TestGroupWriteSupport(schema) - val writer = new ParquetWriter[Group](path, writeSupport) - writer.write(r1) - writer.close() - } - - def writeNestedFile3() { - testNestedDir3.delete() - val path: Path = new Path(new Path(testNestedDir3.toURI), new Path("part-r-0.parquet")) - val schema: MessageType = MessageTypeParser.parseMessageType(testNestedSchema3) - - val r1 = new SimpleGroup(schema) - r1.add(0, 1) - val booleanNumberPairs = r1.addGroup(1) - val g1 = booleanNumberPairs.addGroup(0) - g1.add(0, 1) - val nested1 = g1.addGroup(1) - val ng1 = nested1.addGroup(0) - ng1.add(0, 1.5) - ng1.add(1, false) - val ng2 = nested1.addGroup(0) - ng2.add(0, 2.5) - ng2.add(1, true) - val g2 = booleanNumberPairs.addGroup(0) - g2.add(0, 2) - val ng3 = g2.addGroup(1) - .addGroup(0) - ng3.add(0, 3.5) - ng3.add(1, false) - - val writeSupport = new TestGroupWriteSupport(schema) - val writer = new ParquetWriter[Group](path, writeSupport) - writer.write(r1) - writer.close() - } - - def writeNestedFile4() { - testNestedDir4.delete() - val path: Path = new Path(new Path(testNestedDir4.toURI), new Path("part-r-0.parquet")) - val schema: MessageType = MessageTypeParser.parseMessageType(testNestedSchema4) - - val r1 = new SimpleGroup(schema) - r1.add(0, 7) - val map1 = r1.addGroup(1) - val keyValue1 = map1.addGroup(0) - keyValue1.add(0, "key1") - keyValue1.add(1, 1) - val keyValue2 = map1.addGroup(0) - keyValue2.add(0, "key2") - keyValue2.add(1, 2) - val map2 = r1.addGroup(2) - val keyValue3 = map2.addGroup(0) - // TODO: currently only string key type supported - keyValue3.add(0, "seven") - val valueGroup1 = keyValue3.addGroup(1) - valueGroup1.add(0, 42.toLong) - valueGroup1.add(1, "the answer") - val keyValue4 = map2.addGroup(0) - // TODO: currently only string key type supported - keyValue4.add(0, "eight") - val valueGroup2 = keyValue4.addGroup(1) - valueGroup2.add(0, 49.toLong) - - val writeSupport = new TestGroupWriteSupport(schema) - val writer = new ParquetWriter[Group](path, writeSupport) - writer.write(r1) - writer.close() - } - - // TODO: this is not actually used anywhere but useful for debugging - /* def readNestedFile(file: File, schemaString: String): Unit = { - val configuration = new Configuration() - val path = new Path(new Path(file.toURI), new Path("part-r-0.parquet")) - val fs: FileSystem = path.getFileSystem(configuration) - val schema: MessageType = MessageTypeParser.parseMessageType(schemaString) - assert(schema != null) - val outputStatus: FileStatus = fs.getFileStatus(new Path(path.toString)) - val footers = ParquetFileReader.readFooter(configuration, outputStatus) - assert(footers != null) - val reader = new ParquetReader(new Path(path.toString), new GroupReadSupport()) - val first = reader.read() - assert(first != null) - } */ - - // to test golb pattern (wild card pattern matching for parquetFile input - val testGlobDir = Utils.createTempDir() - val testGlobSubDir1 = Utils.createTempDir(testGlobDir.getPath) - val testGlobSubDir2 = Utils.createTempDir(testGlobDir.getPath) - val testGlobSubDir3 = Utils.createTempDir(testGlobDir.getPath) - - def writeGlobFiles() = { - val subDirs = Array(testGlobSubDir1, testGlobSubDir2, testGlobSubDir3) - - subDirs.foreach { dir => - val path: Path = new Path(new Path(dir.toURI), new Path("part-r-0.parquet")) - val job = new Job() - val schema: MessageType = MessageTypeParser.parseMessageType(testSchema) - val writeSupport = new TestGroupWriteSupport(schema) - val writer = new ParquetWriter[Group](path, writeSupport) - - for(i <- 0 until 15) { - val record = new SimpleGroup(schema) - if(i % 3 == 0) { - record.add(0, true) - } else { - record.add(0, false) - } - if(i % 5 == 0) { - record.add(1, 5) - } - record.add(2, "abc") - record.add(3, i.toLong << 33) - record.add(4, 2.5F) - record.add(5, 4.5D) - writer.write(record) - } - writer.close() - } - } -} - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 6d8c682ccced..1dc819b5d7b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -19,24 +19,23 @@ package org.apache.spark.sql.parquet import java.io.IOException +import scala.collection.mutable.ArrayBuffer import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job - import parquet.format.converter.ParquetMetadataConverter -import parquet.hadoop.{ParquetFileReader, Footer, ParquetFileWriter} -import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData} +import parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} import parquet.hadoop.util.ContextUtil -import parquet.schema.{Type => ParquetType, Types => ParquetTypes, PrimitiveType => ParquetPrimitiveType, MessageType} -import parquet.schema.{GroupType => ParquetGroupType, OriginalType => ParquetOriginalType, ConversionPatterns, DecimalMetadata} +import parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} import parquet.schema.Type.Repetition +import parquet.schema.{ConversionPatterns, DecimalMetadata, GroupType => ParquetGroupType, MessageType, OriginalType => ParquetOriginalType, PrimitiveType => ParquetPrimitiveType, Type => ParquetType, Types => ParquetTypes} -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.types._ +import org.apache.spark.{Logging, SparkException} // Implicits import scala.collection.JavaConversions._ @@ -49,12 +48,15 @@ private[parquet] case class ParquetTypeInfo( length: Option[Int] = None) private[parquet] object ParquetTypesConverter extends Logging { - def isPrimitiveType(ctype: DataType): Boolean = - classOf[PrimitiveType] isAssignableFrom ctype.getClass + def isPrimitiveType(ctype: DataType): Boolean = ctype match { + case _: NumericType | BooleanType | StringType | BinaryType => true + case _: DataType => false + } def toPrimitiveDataType( parquetType: ParquetPrimitiveType, - binaryAsString: Boolean): DataType = { + binaryAsString: Boolean, + int96AsTimestamp: Boolean): DataType = { val originalType = parquetType.getOriginalType val decimalInfo = parquetType.getDecimalMetadata parquetType.getPrimitiveTypeName match { @@ -64,8 +66,11 @@ private[parquet] object ParquetTypesConverter extends Logging { case ParquetPrimitiveTypeName.BOOLEAN => BooleanType case ParquetPrimitiveTypeName.DOUBLE => DoubleType case ParquetPrimitiveTypeName.FLOAT => FloatType + case ParquetPrimitiveTypeName.INT32 + if originalType == ParquetOriginalType.DATE => DateType case ParquetPrimitiveTypeName.INT32 => IntegerType case ParquetPrimitiveTypeName.INT64 => LongType + case ParquetPrimitiveTypeName.INT96 if int96AsTimestamp => TimestampType case ParquetPrimitiveTypeName.INT96 => // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? sys.error("Potential loss of precision: cannot convert INT96") @@ -103,7 +108,9 @@ private[parquet] object ParquetTypesConverter extends Logging { * @param parquetType The type to convert. * @return The corresponding Catalyst type. */ - def toDataType(parquetType: ParquetType, isBinaryAsString: Boolean): DataType = { + def toDataType(parquetType: ParquetType, + isBinaryAsString: Boolean, + isInt96AsTimestamp: Boolean): DataType = { def correspondsToMap(groupType: ParquetGroupType): Boolean = { if (groupType.getFieldCount != 1 || groupType.getFields.apply(0).isPrimitive) { false @@ -125,7 +132,7 @@ private[parquet] object ParquetTypesConverter extends Logging { } if (parquetType.isPrimitive) { - toPrimitiveDataType(parquetType.asPrimitiveType, isBinaryAsString) + toPrimitiveDataType(parquetType.asPrimitiveType, isBinaryAsString, isInt96AsTimestamp) } else { val groupType = parquetType.asGroupType() parquetType.getOriginalType match { @@ -137,9 +144,12 @@ private[parquet] object ParquetTypesConverter extends Logging { if (field.getName == CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME) { val bag = field.asGroupType() assert(bag.getFieldCount == 1) - ArrayType(toDataType(bag.getFields.apply(0), isBinaryAsString), containsNull = true) + ArrayType( + toDataType(bag.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp), + containsNull = true) } else { - ArrayType(toDataType(field, isBinaryAsString), containsNull = false) + ArrayType( + toDataType(field, isBinaryAsString, isInt96AsTimestamp), containsNull = false) } } case ParquetOriginalType.MAP => { @@ -152,8 +162,10 @@ private[parquet] object ParquetTypesConverter extends Logging { "Parquet Map type malformatted: nested group should have 2 (key, value) fields!") assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) - val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString) - val valueType = toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString) + val keyType = + toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp) + val valueType = + toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString, isInt96AsTimestamp) MapType(keyType, valueType, keyValueGroup.getFields.apply(1).getRepetition != Repetition.REQUIRED) } @@ -163,8 +175,10 @@ private[parquet] object ParquetTypesConverter extends Logging { val keyValueGroup = groupType.getFields.apply(0).asGroupType() assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) - val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString) - val valueType = toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString) + val keyType = + toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp) + val valueType = + toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString, isInt96AsTimestamp) MapType(keyType, valueType, keyValueGroup.getFields.apply(1).getRepetition != Repetition.REQUIRED) } else if (correspondsToArray(groupType)) { // ArrayType @@ -172,16 +186,19 @@ private[parquet] object ParquetTypesConverter extends Logging { if (field.getName == CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME) { val bag = field.asGroupType() assert(bag.getFieldCount == 1) - ArrayType(toDataType(bag.getFields.apply(0), isBinaryAsString), containsNull = true) + ArrayType( + toDataType(bag.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp), + containsNull = true) } else { - ArrayType(toDataType(field, isBinaryAsString), containsNull = false) + ArrayType( + toDataType(field, isBinaryAsString, isInt96AsTimestamp), containsNull = false) } } else { // everything else: StructType val fields = groupType .getFields .map(ptype => new StructField( ptype.getName, - toDataType(ptype, isBinaryAsString), + toDataType(ptype, isBinaryAsString, isInt96AsTimestamp), ptype.getRepetition != Repetition.REQUIRED)) StructType(fields) } @@ -209,7 +226,10 @@ private[parquet] object ParquetTypesConverter extends Logging { // There is no type for Byte or Short so we promote them to INT32. case ShortType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) case ByteType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) + case DateType => Some(ParquetTypeInfo( + ParquetPrimitiveTypeName.INT32, Some(ParquetOriginalType.DATE))) case LongType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT64)) + case TimestampType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT96)) case DecimalType.Fixed(precision, scale) if precision <= 18 => // TODO: for now, our writer only supports decimals that fit in a Long Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, @@ -270,13 +290,19 @@ private[parquet] object ParquetTypesConverter extends Logging { ctype: DataType, name: String, nullable: Boolean = true, - inArray: Boolean = false): ParquetType = { + inArray: Boolean = false, + toThriftSchemaNames: Boolean = false): ParquetType = { val repetition = if (inArray) { Repetition.REPEATED } else { if (nullable) Repetition.OPTIONAL else Repetition.REQUIRED } + val arraySchemaName = if (toThriftSchemaNames) { + name + CatalystConverter.THRIFT_ARRAY_ELEMENTS_SCHEMA_NAME_SUFFIX + } else { + CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME + } val typeInfo = fromPrimitiveDataType(ctype) typeInfo.map { case ParquetTypeInfo(primitiveType, originalType, decimalMetadata, length) => @@ -291,22 +317,24 @@ private[parquet] object ParquetTypesConverter extends Logging { }.getOrElse { ctype match { case udt: UserDefinedType[_] => { - fromDataType(udt.sqlType, name, nullable, inArray) + fromDataType(udt.sqlType, name, nullable, inArray, toThriftSchemaNames) } case ArrayType(elementType, false) => { val parquetElementType = fromDataType( elementType, - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, + arraySchemaName, nullable = false, - inArray = true) + inArray = true, + toThriftSchemaNames) ConversionPatterns.listType(repetition, name, parquetElementType) } case ArrayType(elementType, true) => { val parquetElementType = fromDataType( elementType, - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, + arraySchemaName, nullable = true, - inArray = false) + inArray = false, + toThriftSchemaNames) ConversionPatterns.listType( repetition, name, @@ -317,7 +345,8 @@ private[parquet] object ParquetTypesConverter extends Logging { } case StructType(structFields) => { val fields = structFields.map { - field => fromDataType(field.dataType, field.name, field.nullable, inArray = false) + field => fromDataType(field.dataType, field.name, field.nullable, + inArray = false, toThriftSchemaNames) } new ParquetGroupType(repetition, name, fields.toSeq) } @@ -327,13 +356,15 @@ private[parquet] object ParquetTypesConverter extends Logging { keyType, CatalystConverter.MAP_KEY_SCHEMA_NAME, nullable = false, - inArray = false) + inArray = false, + toThriftSchemaNames) val parquetValueType = fromDataType( valueType, CatalystConverter.MAP_VALUE_SCHEMA_NAME, nullable = valueContainsNull, - inArray = false) + inArray = false, + toThriftSchemaNames) ConversionPatterns.mapType( repetition, name, @@ -345,7 +376,9 @@ private[parquet] object ParquetTypesConverter extends Logging { } } - def convertToAttributes(parquetSchema: ParquetType, isBinaryAsString: Boolean): Seq[Attribute] = { + def convertToAttributes(parquetSchema: ParquetType, + isBinaryAsString: Boolean, + isInt96AsTimestamp: Boolean): Seq[Attribute] = { parquetSchema .asGroupType() .getFields @@ -353,14 +386,17 @@ private[parquet] object ParquetTypesConverter extends Logging { field => new AttributeReference( field.getName, - toDataType(field, isBinaryAsString), + toDataType(field, isBinaryAsString, isInt96AsTimestamp), field.getRepetition != Repetition.REQUIRED)()) } - def convertFromAttributes(attributes: Seq[Attribute]): MessageType = { + def convertFromAttributes(attributes: Seq[Attribute], + toThriftSchemaNames: Boolean = false): MessageType = { + checkSpecialCharacters(attributes) val fields = attributes.map( attribute => - fromDataType(attribute.dataType, attribute.name, attribute.nullable)) + fromDataType(attribute.dataType, attribute.name, attribute.nullable, + toThriftSchemaNames = toThriftSchemaNames)) new MessageType("root", fields) } @@ -371,7 +407,20 @@ private[parquet] object ParquetTypesConverter extends Logging { } } + private def checkSpecialCharacters(schema: Seq[Attribute]) = { + // ,;{}()\n\t= and space character are special characters in Parquet schema + schema.map(_.name).foreach { name => + if (name.matches(".*[ ,;{}()\n\t=].*")) { + sys.error( + s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\n\t=". + |Please use alias to rename it. + """.stripMargin.split("\n").mkString(" ")) + } + } + } + def convertToString(schema: Seq[Attribute]): String = { + checkSpecialCharacters(schema) StructType.fromAttributes(schema).json } @@ -476,7 +525,8 @@ private[parquet] object ParquetTypesConverter extends Logging { def readSchemaFromFile( origPath: Path, conf: Option[Configuration], - isBinaryAsString: Boolean): Seq[Attribute] = { + isBinaryAsString: Boolean, + isInt96AsTimestamp: Boolean): Seq[Attribute] = { val keyValueMetadata: java.util.Map[String, String] = readMetaData(origPath, conf) .getFileMetaData @@ -485,7 +535,9 @@ private[parquet] object ParquetTypesConverter extends Logging { convertFromString(keyValueMetadata.get(RowReadSupport.SPARK_METADATA_KEY)) } else { val attributes = convertToAttributes( - readMetaData(origPath, conf).getFileMetaData.getSchema, isBinaryAsString) + readMetaData(origPath, conf).getFileMetaData.getSchema, + isBinaryAsString, + isInt96AsTimestamp) log.info(s"Falling back to schema conversion from Parquet types; result: $attributes") attributes } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 1b50afbbabcb..85e60733bc57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -16,207 +16,471 @@ */ package org.apache.spark.sql.parquet -import java.util.{List => JList} +import java.io.IOException +import java.lang.{Double => JDouble, Float => JFloat, Long => JLong} +import java.math.{BigDecimal => JBigDecimal} +import java.net.URI +import java.text.SimpleDateFormat +import java.util.{Date, List => JList} import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer +import scala.util.Try +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} -import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable -import org.apache.hadoop.mapreduce.{JobContext, InputSplit, Job} - -import parquet.hadoop.ParquetInputFormat +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat +import org.apache.hadoop.mapreduce.{InputSplit, Job, JobContext} +import parquet.filter2.predicate.FilterApi +import parquet.format.converter.ParquetMetadataConverter +import parquet.hadoop.metadata.CompressionCodecName import parquet.hadoop.util.ContextUtil +import parquet.hadoop.{ParquetInputFormat, _} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.{Partition => SparkPartition, Logging} -import org.apache.spark.rdd.{NewHadoopPartition, RDD} -import org.apache.spark.sql.{SQLConf, Row, SQLContext} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.mapred.SparkHadoopMapRedUtil +import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.rdd.{NewHadoopPartition, NewHadoopRDD, RDD} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, expressions} +import org.apache.spark.sql.parquet.ParquetTypesConverter._ import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} - +import org.apache.spark.sql.types.{IntegerType, StructField, StructType, _} +import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext, SaveMode} +import org.apache.spark.{Logging, SerializableWritable, SparkException, TaskContext, Partition => SparkPartition} /** - * Allows creation of parquet based tables using the syntax - * `CREATE TEMPORARY TABLE ... USING org.apache.spark.sql.parquet`. Currently the only option - * required is `path`, which should be the location of a collection of, optionally partitioned, - * parquet files. + * Allows creation of Parquet based tables using the syntax: + * {{{ + * CREATE TEMPORARY TABLE ... USING org.apache.spark.sql.parquet OPTIONS (...) + * }}} + * + * Supported options include: + * + * - `path`: Required. When reading Parquet files, `path` should point to the location of the + * Parquet file(s). It can be either a single raw Parquet file, or a directory of Parquet files. + * In the latter case, this data source tries to discover partitioning information if the the + * directory is structured in the same style of Hive partitioned tables. When writing Parquet + * file, `path` should point to the destination folder. + * + * - `mergeSchema`: Optional. Indicates whether we should merge potentially different (but + * compatible) schemas stored in all Parquet part-files. + * + * - `partition.defaultName`: Optional. Partition name used when a value of a partition column is + * null or empty string. This is similar to the `hive.exec.default.partition.name` configuration + * in Hive. */ -class DefaultSource extends RelationProvider { +private[sql] class DefaultSource + extends RelationProvider + with SchemaRelationProvider + with CreatableRelationProvider { + + private def checkPath(parameters: Map[String, String]): String = { + parameters.getOrElse("path", sys.error("'path' must be specified for parquet tables.")) + } + /** Returns a new base relation with the given parameters. */ override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { - val path = - parameters.getOrElse("path", sys.error("'path' must be specified for parquet tables.")) + ParquetRelation2(Seq(checkPath(parameters)), parameters, None)(sqlContext) + } - ParquetRelation2(path)(sqlContext) + /** Returns a new base relation with the given parameters and schema. */ + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: StructType): BaseRelation = { + ParquetRelation2(Seq(checkPath(parameters)), parameters, Some(schema))(sqlContext) + } + + /** Returns a new base relation with the given parameters and save given data into it. */ + override def createRelation( + sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + val path = checkPath(parameters) + val filesystemPath = new Path(path) + val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val doInsertion = (mode, fs.exists(filesystemPath)) match { + case (SaveMode.ErrorIfExists, true) => + sys.error(s"path $path already exists.") + case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) => + true + case (SaveMode.Ignore, exists) => + !exists + } + + val relation = if (doInsertion) { + // This is a hack. We always set nullable/containsNull/valueContainsNull to true + // for the schema of a parquet data. + val df = + sqlContext.createDataFrame( + data.queryExecution.toRdd, + data.schema.asNullable, + needsConversion = false) + val createdRelation = + createRelation(sqlContext, parameters, df.schema).asInstanceOf[ParquetRelation2] + createdRelation.insert(df, overwrite = mode == SaveMode.Overwrite) + createdRelation + } else { + // If the save mode is Ignore, we will just create the relation based on existing data. + createRelation(sqlContext, parameters) + } + + relation } } -private[parquet] case class Partition(partitionValues: Map[String, Any], files: Seq[FileStatus]) +private[sql] case class Partition(values: Row, path: String) + +private[sql] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition]) /** * An alternative to [[ParquetRelation]] that plugs in using the data sources API. This class is - * currently not intended as a full replacement of the parquet support in Spark SQL though it is - * likely that it will eventually subsume the existing physical plan implementation. - * - * Compared with the current implementation, this class has the following notable differences: + * intended as a full replacement of the Parquet support in Spark SQL. The old implementation will + * be deprecated and eventually removed once this version is proved to be stable enough. * - * Partitioning: Partitions are auto discovered and must be in the form of directories `key=value/` - * located at `path`. Currently only a single partitioning column is supported and it must - * be an integer. This class supports both fully self-describing data, which contains the partition - * key, and data where the partition key is only present in the folder structure. The presence - * of the partitioning key in the data is also auto-detected. The `null` partition is not yet - * supported. + * Compared with the old implementation, this class has the following notable differences: * - * Metadata: The metadata is automatically discovered by reading the first parquet file present. - * There is currently no support for working with files that have different schema. Additionally, - * when parquet metadata caching is turned on, the FileStatus objects for all data will be cached - * to improve the speed of interactive querying. When data is added to a table it must be dropped - * and recreated to pick up any changes. - * - * Statistics: Statistics for the size of the table are automatically populated during metadata - * discovery. + * - Partitioning discovery: Hive style multi-level partitions are auto discovered. + * - Metadata discovery: Parquet is a format comes with schema evolving support. This data source + * can detect and merge schemas from all Parquet part-files as long as they are compatible. + * Also, metadata and [[FileStatus]]es are cached for better performance. + * - Statistics: Statistics for the size of the table are automatically populated during schema + * discovery. */ @DeveloperApi -case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext) - extends CatalystScan with Logging { +private[sql] case class ParquetRelation2( + paths: Seq[String], + parameters: Map[String, String], + maybeSchema: Option[StructType] = None, + maybePartitionSpec: Option[PartitionSpec] = None)( + @transient val sqlContext: SQLContext) + extends BaseRelation + with CatalystScan + with InsertableRelation + with SparkHadoopMapReduceUtil + with Logging { + + // Should we merge schemas from all Parquet part-files? + private val shouldMergeSchemas = + parameters.getOrElse(ParquetRelation2.MERGE_SCHEMA, "true").toBoolean + + // Optional Metastore schema, used when converting Hive Metastore Parquet table + private val maybeMetastoreSchema = + parameters + .get(ParquetRelation2.METASTORE_SCHEMA) + .map(s => DataType.fromJson(s).asInstanceOf[StructType]) + + // Hive uses this as part of the default partition name when the partition column value is null + // or empty string + private val defaultPartitionName = parameters.getOrElse( + ParquetRelation2.DEFAULT_PARTITION_NAME, "__HIVE_DEFAULT_PARTITION__") + + override def equals(other: Any): Boolean = other match { + case relation: ParquetRelation2 => + // If schema merging is required, we don't compare the actual schemas since they may evolve. + val schemaEquality = if (shouldMergeSchemas) { + shouldMergeSchemas == relation.shouldMergeSchemas + } else { + schema == relation.schema + } + + paths.toSet == relation.paths.toSet && + schemaEquality && + maybeMetastoreSchema == relation.maybeMetastoreSchema && + maybePartitionSpec == relation.maybePartitionSpec + + case _ => false + } + + override def hashCode(): Int = { + if (shouldMergeSchemas) { + com.google.common.base.Objects.hashCode( + shouldMergeSchemas: java.lang.Boolean, + paths.toSet, + maybeMetastoreSchema, + maybePartitionSpec) + } else { + com.google.common.base.Objects.hashCode( + shouldMergeSchemas: java.lang.Boolean, + schema, + paths.toSet, + maybeMetastoreSchema, + maybePartitionSpec) + } + } + + private[sql] def sparkContext = sqlContext.sparkContext + + private class MetadataCache { + // `FileStatus` objects of all "_metadata" files. + private var metadataStatuses: Array[FileStatus] = _ - def sparkContext = sqlContext.sparkContext + // `FileStatus` objects of all "_common_metadata" files. + private var commonMetadataStatuses: Array[FileStatus] = _ - // Minor Hack: scala doesnt seem to respect @transient for vals declared via extraction - @transient - private var partitionKeys: Seq[String] = _ - @transient - private var partitions: Seq[Partition] = _ - discoverPartitions() + // Parquet footer cache. + var footers: Map[FileStatus, Footer] = _ - // TODO: Only finds the first partition, assumes the key is of type Integer... - private def discoverPartitions() = { - val fs = FileSystem.get(new java.net.URI(path), sparkContext.hadoopConfiguration) - val partValue = "([^=]+)=([^=]+)".r + // `FileStatus` objects of all data files (Parquet part-files). + var dataStatuses: Array[FileStatus] = _ - val childrenOfPath = fs.listStatus(new Path(path)).filterNot(_.getPath.getName.startsWith("_")) - val childDirs = childrenOfPath.filter(s => s.isDir) + // Partition spec of this table, including names, data types, and values of each partition + // column, and paths of each partition. + var partitionSpec: PartitionSpec = _ - if (childDirs.size > 0) { - val partitionPairs = childDirs.map(_.getPath.getName).map { - case partValue(key, value) => (key, value) + // Schema of the actual Parquet files, without partition columns discovered from partition + // directory paths. + var parquetSchema: StructType = _ + + // Schema of the whole table, including partition columns. + var schema: StructType = _ + + // Indicates whether partition columns are also included in Parquet data file schema. If not, + // we need to fill in partition column values into read rows when scanning the table. + var partitionKeysIncludedInParquetSchema: Boolean = _ + + def prepareMetadata(path: Path, schema: StructType, conf: Configuration): Unit = { + conf.set( + ParquetOutputFormat.COMPRESSION, + ParquetRelation + .shortParquetCompressionCodecNames + .getOrElse( + sqlContext.conf.parquetCompressionCodec.toUpperCase, + CompressionCodecName.UNCOMPRESSED).name()) + + ParquetRelation.enableLogForwarding() + ParquetTypesConverter.writeMetaData(schema.toAttributes, path, conf) + } + + /** + * Refreshes `FileStatus`es, footers, partition spec, and table schema. + */ + def refresh(): Unit = { + // Support either reading a collection of raw Parquet part-files, or a collection of folders + // containing Parquet files (e.g. partitioned Parquet table). + val baseStatuses = paths.distinct.map { p => + val fs = FileSystem.get(URI.create(p), sparkContext.hadoopConfiguration) + val path = new Path(p) + val qualified = path.makeQualified(fs.getUri, fs.getWorkingDirectory) + + if (!fs.exists(qualified) && maybeSchema.isDefined) { + fs.mkdirs(qualified) + prepareMetadata(qualified, maybeSchema.get, sparkContext.hadoopConfiguration) + } + + fs.getFileStatus(qualified) + }.toArray + assert(baseStatuses.forall(!_.isDir) || baseStatuses.forall(_.isDir)) + + // Lists `FileStatus`es of all leaf nodes (files) under all base directories. + val leaves = baseStatuses.flatMap { f => + val fs = FileSystem.get(f.getPath.toUri, sparkContext.hadoopConfiguration) + SparkHadoopUtil.get.listLeafStatuses(fs, f.getPath).filter { f => + isSummaryFile(f.getPath) || + !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) + } } - val foundKeys = partitionPairs.map(_._1).distinct - if (foundKeys.size > 1) { - sys.error(s"Too many distinct partition keys: $foundKeys") + dataStatuses = leaves.filterNot(f => isSummaryFile(f.getPath)) + metadataStatuses = leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE) + commonMetadataStatuses = + leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE) + + footers = (dataStatuses ++ metadataStatuses ++ commonMetadataStatuses).par.map { f => + val parquetMetadata = ParquetFileReader.readFooter( + sparkContext.hadoopConfiguration, f, ParquetMetadataConverter.NO_FILTER) + f -> new Footer(f.getPath, parquetMetadata) + }.seq.toMap + + partitionSpec = maybePartitionSpec.getOrElse { + val partitionDirs = leaves + .filterNot(baseStatuses.contains) + .map(_.getPath.getParent) + .distinct + + if (partitionDirs.nonEmpty) { + // Parses names and values of partition columns, and infer their data types. + ParquetRelation2.parsePartitions(partitionDirs, defaultPartitionName) + } else { + // No partition directories found, makes an empty specification + PartitionSpec(StructType(Seq.empty[StructField]), Seq.empty[Partition]) + } } - // Do a parallel lookup of partition metadata. - val partitionFiles = - childDirs.par.map { d => - fs.listStatus(d.getPath) - // TODO: Is there a standard hadoop function for this? - .filterNot(_.getPath.getName.startsWith("_")) - .filterNot(_.getPath.getName.startsWith(".")) - }.seq - - partitionKeys = foundKeys.toSeq - partitions = partitionFiles.zip(partitionPairs).map { case (files, (key, value)) => - Partition(Map(key -> value.toInt), files) - }.toSeq - } else { - partitionKeys = Nil - partitions = Partition(Map.empty, childrenOfPath) :: Nil + // To get the schema. We first try to get the schema defined in maybeSchema. + // If maybeSchema is not defined, we will try to get the schema from existing parquet data + // (through readSchema). If data does not exist, we will try to get the schema defined in + // maybeMetastoreSchema (defined in the options of the data source). + // Finally, if we still could not get the schema. We throw an error. + parquetSchema = + maybeSchema + .orElse(readSchema()) + .orElse(maybeMetastoreSchema) + .getOrElse(sys.error("Failed to get the schema.")) + + partitionKeysIncludedInParquetSchema = + isPartitioned && + partitionColumns.forall(f => parquetSchema.fieldNames.contains(f.name)) + + schema = { + val fullRelationSchema = if (partitionKeysIncludedInParquetSchema) { + parquetSchema + } else { + StructType(parquetSchema.fields ++ partitionColumns.fields) + } + + // If this Parquet relation is converted from a Hive Metastore table, must reconcile case + // insensitivity issue and possible schema mismatch. + maybeMetastoreSchema + .map(ParquetRelation2.mergeMetastoreParquetSchema(_, fullRelationSchema)) + .getOrElse(fullRelationSchema) + } + } + + private def readSchema(): Option[StructType] = { + // Sees which file(s) we need to touch in order to figure out the schema. + val filesToTouch = + // Always tries the summary files first if users don't require a merged schema. In this case, + // "_common_metadata" is more preferable than "_metadata" because it doesn't contain row + // groups information, and could be much smaller for large Parquet files with lots of row + // groups. + // + // NOTE: Metadata stored in the summary files are merged from all part-files. However, for + // user defined key-value metadata (in which we store Spark SQL schema), Parquet doesn't know + // how to merge them correctly if some key is associated with different values in different + // part-files. When this happens, Parquet simply gives up generating the summary file. This + // implies that if a summary file presents, then: + // + // 1. Either all part-files have exactly the same Spark SQL schema, or + // 2. Some part-files don't contain Spark SQL schema in the key-value metadata at all (thus + // their schemas may differ from each other). + // + // Here we tend to be pessimistic and take the second case into account. Basically this means + // we can't trust the summary files if users require a merged schema, and must touch all part- + // files to do the merge. + if (shouldMergeSchemas) { + // Also includes summary files, 'cause there might be empty partition directories. + (metadataStatuses ++ commonMetadataStatuses ++ dataStatuses).toSeq + } else { + // Tries any "_common_metadata" first. Parquet files written by old versions or Parquet + // don't have this. + commonMetadataStatuses.headOption + // Falls back to "_metadata" + .orElse(metadataStatuses.headOption) + // Summary file(s) not found, the Parquet file is either corrupted, or different part- + // files contain conflicting user defined metadata (two or more values are associated + // with a same key in different files). In either case, we fall back to any of the + // first part-file, and just assume all schemas are consistent. + .orElse(dataStatuses.headOption) + .toSeq + } + + ParquetRelation2.readSchema(filesToTouch.map(footers.apply), sqlContext) } } - override val sizeInBytes = partitions.flatMap(_.files).map(_.getLen).sum + @transient private val metadataCache = new MetadataCache + metadataCache.refresh() - val dataSchema = StructType.fromAttributes( // TODO: Parquet code should not deal with attributes. - ParquetTypesConverter.readSchemaFromFile( - partitions.head.files.head.getPath, - Some(sparkContext.hadoopConfiguration), - sqlContext.conf.isParquetBinaryAsString)) + def partitionSpec: PartitionSpec = metadataCache.partitionSpec - val dataIncludesKey = - partitionKeys.headOption.map(dataSchema.fieldNames.contains(_)).getOrElse(true) + def partitionColumns: StructType = metadataCache.partitionSpec.partitionColumns - override val schema = - if (dataIncludesKey) { - dataSchema - } else { - StructType(dataSchema.fields :+ StructField(partitionKeys.head, IntegerType)) - } + def partitions: Seq[Partition] = metadataCache.partitionSpec.partitions - override def buildScan(output: Seq[Attribute], predicates: Seq[Expression]): RDD[Row] = { - // This is mostly a hack so that we can use the existing parquet filter code. - val requiredColumns = output.map(_.name) + def isPartitioned: Boolean = partitionColumns.nonEmpty - val job = new Job(sparkContext.hadoopConfiguration) - ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) - val jobConf: Configuration = ContextUtil.getConfiguration(job) + private def partitionKeysIncludedInDataSchema = metadataCache.partitionKeysIncludedInParquetSchema - val requestedSchema = StructType(requiredColumns.map(schema(_))) + private def parquetSchema = metadataCache.parquetSchema - val partitionKeySet = partitionKeys.toSet - val rawPredicate = - predicates - .filter(_.references.map(_.name).toSet.subsetOf(partitionKeySet)) - .reduceOption(And) - .getOrElse(Literal(true)) + override def schema: StructType = metadataCache.schema - // Translate the predicate so that it reads from the information derived from the - // folder structure - val castedPredicate = rawPredicate transform { - case a: AttributeReference => - val idx = partitionKeys.indexWhere(a.name == _) - BoundReference(idx, IntegerType, nullable = true) - } + private def isSummaryFile(file: Path): Boolean = { + file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || + file.getName == ParquetFileWriter.PARQUET_METADATA_FILE + } - val inputData = new GenericMutableRow(partitionKeys.size) - val pruningCondition = InterpretedPredicate(castedPredicate) + // Skip type conversion + override val needConversion: Boolean = false - val selectedPartitions = - if (partitionKeys.nonEmpty && predicates.nonEmpty) { - partitions.filter { part => - inputData(0) = part.partitionValues.values.head - pruningCondition(inputData) - } - } else { - partitions + // TODO Should calculate per scan size + // It's common that a query only scans a fraction of a large Parquet file. Returning size of the + // whole Parquet file disables some optimizations in this case (e.g. broadcast join). + override val sizeInBytes = metadataCache.dataStatuses.map(_.getLen).sum + + // This is mostly a hack so that we can use the existing parquet filter code. + override def buildScan(output: Seq[Attribute], predicates: Seq[Expression]): RDD[Row] = { + val job = new Job(sparkContext.hadoopConfiguration) + ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) + val jobConf: Configuration = ContextUtil.getConfiguration(job) + + val selectedPartitions = prunePartitions(predicates, partitions) + val selectedFiles = if (isPartitioned) { + selectedPartitions.flatMap { p => + metadataCache.dataStatuses.filter(_.getPath.getParent.toString == p.path) } + } else { + metadataCache.dataStatuses.toSeq + } + val selectedFooters = selectedFiles.map(metadataCache.footers) - val fs = FileSystem.get(new java.net.URI(path), sparkContext.hadoopConfiguration) - val selectedFiles = selectedPartitions.flatMap(_.files).map(f => fs.makeQualified(f.getPath)) // FileInputFormat cannot handle empty lists. if (selectedFiles.nonEmpty) { - org.apache.hadoop.mapreduce.lib.input.FileInputFormat.setInputPaths(job, selectedFiles: _*) + // In order to encode the authority of a Path containning special characters such as /, + // we need to use the string retruned by the URI of the path to create a new Path. + val selectedPaths = selectedFiles.map(status => new Path(status.getPath.toUri.toString)) + FileInputFormat.setInputPaths(job, selectedPaths: _*) + } + + // Try to push down filters when filter push-down is enabled. + if (sqlContext.conf.parquetFilterPushDown) { + val partitionColNames = partitionColumns.map(_.name).toSet + predicates + // Don't push down predicates which reference partition columns + .filter { pred => + val referencedColNames = pred.references.map(_.name).toSet + referencedColNames.intersect(partitionColNames).isEmpty + } + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(ParquetFilters.createFilter) + .reduceOption(FilterApi.and) + .foreach(ParquetInputFormat.setFilterPredicate(jobConf, _)) } - // Push down filters when possible - predicates - .reduceOption(And) - .flatMap(ParquetFilters.createFilter) - .filter(_ => sqlContext.conf.parquetFilterPushDown) - .foreach(ParquetInputFormat.setFilterPredicate(jobConf, _)) + if (isPartitioned) { + logInfo { + val percentRead = selectedPartitions.size.toDouble / partitions.size.toDouble * 100 + s"Reading $percentRead% of partitions" + } + } - def percentRead = selectedPartitions.size.toDouble / partitions.size.toDouble * 100 - logInfo(s"Reading $percentRead% of $path partitions") + val requiredColumns = output.map(_.name) + val requestedSchema = StructType(requiredColumns.map(schema(_))) // Store both requested and original schema in `Configuration` jobConf.set( RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, - ParquetTypesConverter.convertToString(requestedSchema.toAttributes)) + convertToString(requestedSchema.toAttributes)) jobConf.set( RowWriteSupport.SPARK_ROW_SCHEMA, - ParquetTypesConverter.convertToString(schema.toAttributes)) + convertToString(schema.toAttributes)) // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata val useCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "true").toBoolean jobConf.set(SQLConf.PARQUET_CACHE_METADATA, useCache.toString) val baseRDD = - new org.apache.spark.rdd.NewHadoopRDD( + new NewHadoopRDD( sparkContext, classOf[FilteringParquetRowInputFormat], classOf[Void], @@ -225,66 +489,530 @@ case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext) val cacheMetadata = useCache @transient - val cachedStatus = selectedPartitions.flatMap(_.files) + val cachedStatus = selectedFiles.map { st => + // In order to encode the authority of a Path containning special characters such as /, + // we need to use the string retruned by the URI of the path to create a new Path. + val newPath = new Path(st.getPath.toUri.toString) + + new FileStatus( + st.getLen, + st.isDir, + st.getReplication, + st.getBlockSize, + st.getModificationTime, + st.getAccessTime, + st.getPermission, + st.getOwner, + st.getGroup, + newPath) + } + + @transient + val cachedFooters = selectedFooters.map { f => + // In order to encode the authority of a Path containning special characters such as /, + // we need to use the string retruned by the URI of the path to create a new Path. + new Footer(new Path(f.getFile.toUri.toString), f.getParquetMetadata) + } + // Overridden so we can inject our own cached files statuses. override def getPartitions: Array[SparkPartition] = { - val inputFormat = - if (cacheMetadata) { - new FilteringParquetRowInputFormat { - override def listStatus(jobContext: JobContext): JList[FileStatus] = cachedStatus - } - } else { - new FilteringParquetRowInputFormat - } + val inputFormat = if (cacheMetadata) { + new FilteringParquetRowInputFormat { + override def listStatus(jobContext: JobContext): JList[FileStatus] = cachedStatus - inputFormat match { - case configurable: Configurable => - configurable.setConf(getConf) - case _ => + override def getFooters(jobContext: JobContext): JList[Footer] = cachedFooters + } + } else { + new FilteringParquetRowInputFormat } + val jobContext = newJobContext(getConf, jobId) - val rawSplits = inputFormat.getSplits(jobContext).toArray - val result = new Array[SparkPartition](rawSplits.size) - for (i <- 0 until rawSplits.size) { - result(i) = - new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + val rawSplits = inputFormat.getSplits(jobContext) + + Array.tabulate[SparkPartition](rawSplits.size) { i => + new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) } - result } } - // The ordinal for the partition key in the result row, if requested. - val partitionKeyLocation = - partitionKeys - .headOption - .map(requiredColumns.indexOf(_)) - .getOrElse(-1) + // The ordinals for partition keys in the result row, if requested. + val partitionKeyLocations = partitionColumns.fieldNames.zipWithIndex.map { + case (name, index) => index -> requiredColumns.indexOf(name) + }.toMap.filter { + case (_, index) => index >= 0 + } // When the data does not include the key and the key is requested then we must fill it in // based on information from the input split. - if (!dataIncludesKey && partitionKeyLocation != -1) { - baseRDD.mapPartitionsWithInputSplit { case (split, iter) => - val partValue = "([^=]+)=([^=]+)".r - val partValues = - split.asInstanceOf[parquet.hadoop.ParquetInputSplit] - .getPath - .toString - .split("/") - .flatMap { - case partValue(key, value) => Some(key -> value) - case _ => None - }.toMap - - val currentValue = partValues.values.head.toInt - iter.map { pair => - val res = pair._2.asInstanceOf[SpecificMutableRow] - res.setInt(partitionKeyLocation, currentValue) - res + if (!partitionKeysIncludedInDataSchema && partitionKeyLocations.nonEmpty) { + // This check is based on CatalystConverter.createRootConverter. + val primitiveRow = + requestedSchema.forall(a => ParquetTypesConverter.isPrimitiveType(a.dataType)) + + baseRDD.mapPartitionsWithInputSplit { case (split: ParquetInputSplit, iterator) => + val partValues = selectedPartitions.collectFirst { + case p if split.getPath.getParent.toString == p.path => + CatalystTypeConverters.convertToCatalyst(p.values).asInstanceOf[Row] + }.get + + val requiredPartOrdinal = partitionKeyLocations.keys.toSeq + + if (primitiveRow) { + iterator.map { pair => + // We are using CatalystPrimitiveRowConverter and it returns a SpecificMutableRow. + val row = pair._2.asInstanceOf[SpecificMutableRow] + var i = 0 + while (i < requiredPartOrdinal.size) { + // TODO Avoids boxing cost here! + val partOrdinal = requiredPartOrdinal(i) + row.update(partitionKeyLocations(partOrdinal), partValues(partOrdinal)) + i += 1 + } + row + } + } else { + // Create a mutable row since we need to fill in values from partition columns. + val mutableRow = new GenericMutableRow(requestedSchema.size) + iterator.map { pair => + // We are using CatalystGroupConverter and it returns a GenericRow. + // Since GenericRow is not mutable, we just cast it to a Row. + val row = pair._2.asInstanceOf[Row] + var i = 0 + while (i < row.size) { + // TODO Avoids boxing cost here! + mutableRow(i) = row(i) + i += 1 + } + + i = 0 + while (i < requiredPartOrdinal.size) { + // TODO Avoids boxing cost here! + val partOrdinal = requiredPartOrdinal(i) + mutableRow.update(partitionKeyLocations(partOrdinal), partValues(partOrdinal)) + i += 1 + } + mutableRow + } } } } else { baseRDD.map(_._2) } } + + private def prunePartitions( + predicates: Seq[Expression], + partitions: Seq[Partition]): Seq[Partition] = { + val partitionColumnNames = partitionColumns.map(_.name).toSet + val partitionPruningPredicates = predicates.filter { + _.references.map(_.name).toSet.subsetOf(partitionColumnNames) + } + + val rawPredicate = + partitionPruningPredicates.reduceOption(expressions.And).getOrElse(Literal(true)) + val boundPredicate = InterpretedPredicate.create(rawPredicate transform { + case a: AttributeReference => + val index = partitionColumns.indexWhere(a.name == _.name) + BoundReference(index, partitionColumns(index).dataType, nullable = true) + }) + + if (isPartitioned && partitionPruningPredicates.nonEmpty) { + partitions.filter(p => boundPredicate(p.values)) + } else { + partitions + } + } + + override def insert(data: DataFrame, overwrite: Boolean): Unit = { + assert(paths.size == 1, s"Can't write to multiple destinations: ${paths.mkString(",")}") + + // TODO: currently we do not check whether the "schema"s are compatible + // That means if one first creates a table and then INSERTs data with + // and incompatible schema the execution will fail. It would be nice + // to catch this early one, maybe having the planner validate the schema + // before calling execute(). + + val job = new Job(sqlContext.sparkContext.hadoopConfiguration) + val writeSupport = + if (parquetSchema.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) { + log.debug("Initializing MutableRowWriteSupport") + classOf[MutableRowWriteSupport] + } else { + classOf[RowWriteSupport] + } + + ParquetOutputFormat.setWriteSupportClass(job, writeSupport) + + val conf = ContextUtil.getConfiguration(job) + RowWriteSupport.setSchema(data.schema.toAttributes, conf) + + val destinationPath = new Path(paths.head) + + if (overwrite) { + val fs = destinationPath.getFileSystem(conf) + if (fs.exists(destinationPath)) { + var success: Boolean = false + try { + success = fs.delete(destinationPath, true) + } catch { + case e: IOException => + throw new IOException( + s"Unable to clear output directory ${destinationPath.toString} prior" + + s" to writing to Parquet table:\n${e.toString}") + } + if (!success) { + throw new IOException( + s"Unable to clear output directory ${destinationPath.toString} prior" + + s" to writing to Parquet table.") + } + } + } + + job.setOutputKeyClass(classOf[Void]) + job.setOutputValueClass(classOf[Row]) + FileOutputFormat.setOutputPath(job, destinationPath) + + val wrappedConf = new SerializableWritable(job.getConfiguration) + val jobTrackerId = new SimpleDateFormat("yyyyMMddHHmm").format(new Date()) + val stageId = sqlContext.sparkContext.newRddId() + + val taskIdOffset = if (overwrite) { + 1 + } else { + FileSystemHelper.findMaxTaskId( + FileOutputFormat.getOutputPath(job).toString, job.getConfiguration) + 1 + } + + def writeShard(context: TaskContext, iterator: Iterator[Row]): Unit = { + /* "reduce task" */ + val attemptId = newTaskAttemptID( + jobTrackerId, stageId, isMap = false, context.partitionId(), context.attemptNumber()) + val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) + val format = new AppendingParquetOutputFormat(taskIdOffset) + val committer = format.getOutputCommitter(hadoopContext) + committer.setupTask(hadoopContext) + val writer = format.getRecordWriter(hadoopContext) + try { + while (iterator.hasNext) { + val row = iterator.next() + writer.write(null, row) + } + } finally { + writer.close(hadoopContext) + } + + SparkHadoopMapRedUtil.commitTask(committer, hadoopContext, context) + } + val jobFormat = new AppendingParquetOutputFormat(taskIdOffset) + /* apparently we need a TaskAttemptID to construct an OutputCommitter; + * however we're only going to use this local OutputCommitter for + * setupJob/commitJob, so we just use a dummy "map" task. + */ + val jobAttemptId = newTaskAttemptID(jobTrackerId, stageId, isMap = true, 0, 0) + val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) + val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) + + jobCommitter.setupJob(jobTaskContext) + sqlContext.sparkContext.runJob(data.queryExecution.executedPlan.execute(), writeShard _) + jobCommitter.commitJob(jobTaskContext) + + metadataCache.refresh() + } +} + +private[sql] object ParquetRelation2 extends Logging { + // Whether we should merge schemas collected from all Parquet part-files. + val MERGE_SCHEMA = "mergeSchema" + + // Default partition name to use when the partition column value is null or empty string. + val DEFAULT_PARTITION_NAME = "partition.defaultName" + + // Hive Metastore schema, used when converting Metastore Parquet tables. This option is only used + // internally. + private[sql] val METASTORE_SCHEMA = "metastoreSchema" + + private[parquet] def readSchema( + footers: Seq[Footer], sqlContext: SQLContext): Option[StructType] = { + footers.map { footer => + val metadata = footer.getParquetMetadata.getFileMetaData + val parquetSchema = metadata.getSchema + val maybeSparkSchema = metadata + .getKeyValueMetaData + .toMap + .get(RowReadSupport.SPARK_METADATA_KEY) + .flatMap { serializedSchema => + // Don't throw even if we failed to parse the serialized Spark schema. Just fallback to + // whatever is available. + Try(DataType.fromJson(serializedSchema)) + .recover { case _: Throwable => + logInfo( + s"Serialized Spark schema in Parquet key-value metadata is not in JSON format, " + + "falling back to the deprecated DataType.fromCaseClassString parser.") + DataType.fromCaseClassString(serializedSchema) + } + .recover { case cause: Throwable => + logWarning( + s"""Failed to parse serialized Spark schema in Parquet key-value metadata: + |\t$serializedSchema + """.stripMargin, + cause) + } + .map(_.asInstanceOf[StructType]) + .toOption + } + + maybeSparkSchema.getOrElse { + // Falls back to Parquet schema if Spark SQL schema is absent. + StructType.fromAttributes( + // TODO Really no need to use `Attribute` here, we only need to know the data type. + convertToAttributes( + parquetSchema, + sqlContext.conf.isParquetBinaryAsString, + sqlContext.conf.isParquetINT96AsTimestamp)) + } + }.reduceOption { (left, right) => + try left.merge(right) catch { case e: Throwable => + throw new SparkException(s"Failed to merge incompatible schemas $left and $right", e) + } + } + } + + /** + * Reconciles Hive Metastore case insensitivity issue and data type conflicts between Metastore + * schema and Parquet schema. + * + * Hive doesn't retain case information, while Parquet is case sensitive. On the other hand, the + * schema read from Parquet files may be incomplete (e.g. older versions of Parquet doesn't + * distinguish binary and string). This method generates a correct schema by merging Metastore + * schema data types and Parquet schema field names. + */ + private[parquet] def mergeMetastoreParquetSchema( + metastoreSchema: StructType, + parquetSchema: StructType): StructType = { + def schemaConflictMessage: String = + s"""Converting Hive Metastore Parquet, but detected conflicting schemas. Metastore schema: + |${metastoreSchema.prettyJson} + | + |Parquet schema: + |${parquetSchema.prettyJson} + """.stripMargin + + val mergedParquetSchema = mergeMissingNullableFields(metastoreSchema, parquetSchema) + + assert(metastoreSchema.size <= mergedParquetSchema.size, schemaConflictMessage) + + val ordinalMap = metastoreSchema.zipWithIndex.map { + case (field, index) => field.name.toLowerCase -> index + }.toMap + val reorderedParquetSchema = mergedParquetSchema.sortBy(f => + ordinalMap.getOrElse(f.name.toLowerCase, metastoreSchema.size + 1)) + + StructType(metastoreSchema.zip(reorderedParquetSchema).map { + // Uses Parquet field names but retains Metastore data types. + case (mSchema, pSchema) if mSchema.name.toLowerCase == pSchema.name.toLowerCase => + mSchema.copy(name = pSchema.name) + case _ => + throw new SparkException(schemaConflictMessage) + }) + } + + /** + * Returns the original schema from the Parquet file with any missing nullable fields from the + * Hive Metastore schema merged in. + * + * When constructing a DataFrame from a collection of structured data, the resulting object has + * a schema corresponding to the union of the fields present in each element of the collection. + * Spark SQL simply assigns a null value to any field that isn't present for a particular row. + * In some cases, it is possible that a given table partition stored as a Parquet file doesn't + * contain a particular nullable field in its schema despite that field being present in the + * table schema obtained from the Hive Metastore. This method returns a schema representing the + * Parquet file schema along with any additional nullable fields from the Metastore schema + * merged in. + */ + private[parquet] def mergeMissingNullableFields( + metastoreSchema: StructType, + parquetSchema: StructType): StructType = { + val fieldMap = metastoreSchema.map(f => f.name.toLowerCase -> f).toMap + val missingFields = metastoreSchema + .map(_.name.toLowerCase) + .diff(parquetSchema.map(_.name.toLowerCase)) + .map(fieldMap(_)) + .filter(_.nullable) + StructType(parquetSchema ++ missingFields) + } + + + // TODO Data source implementations shouldn't touch Catalyst types (`Literal`). + // However, we are already using Catalyst expressions for partition pruning and predicate + // push-down here... + private[parquet] case class PartitionValues(columnNames: Seq[String], literals: Seq[Literal]) { + require(columnNames.size == literals.size) + } + + /** + * Given a group of qualified paths, tries to parse them and returns a partition specification. + * For example, given: + * {{{ + * hdfs://:/path/to/partition/a=1/b=hello/c=3.14 + * hdfs://:/path/to/partition/a=2/b=world/c=6.28 + * }}} + * it returns: + * {{{ + * PartitionSpec( + * partitionColumns = StructType( + * StructField(name = "a", dataType = IntegerType, nullable = true), + * StructField(name = "b", dataType = StringType, nullable = true), + * StructField(name = "c", dataType = DoubleType, nullable = true)), + * partitions = Seq( + * Partition( + * values = Row(1, "hello", 3.14), + * path = "hdfs://:/path/to/partition/a=1/b=hello/c=3.14"), + * Partition( + * values = Row(2, "world", 6.28), + * path = "hdfs://:/path/to/partition/a=2/b=world/c=6.28"))) + * }}} + */ + private[parquet] def parsePartitions( + paths: Seq[Path], + defaultPartitionName: String): PartitionSpec = { + val partitionValues = resolvePartitions(paths.map(parsePartition(_, defaultPartitionName))) + val fields = { + val (PartitionValues(columnNames, literals)) = partitionValues.head + columnNames.zip(literals).map { case (name, Literal(_, dataType)) => + StructField(name, dataType, nullable = true) + } + } + + val partitions = partitionValues.zip(paths).map { + case (PartitionValues(_, literals), path) => + Partition(Row(literals.map(_.value): _*), path.toString) + } + + PartitionSpec(StructType(fields), partitions) + } + + /** + * Parses a single partition, returns column names and values of each partition column. For + * example, given: + * {{{ + * path = hdfs://:/path/to/partition/a=42/b=hello/c=3.14 + * }}} + * it returns: + * {{{ + * PartitionValues( + * Seq("a", "b", "c"), + * Seq( + * Literal.create(42, IntegerType), + * Literal.create("hello", StringType), + * Literal.create(3.14, FloatType))) + * }}} + */ + private[parquet] def parsePartition( + path: Path, + defaultPartitionName: String): PartitionValues = { + val columns = ArrayBuffer.empty[(String, Literal)] + // Old Hadoop versions don't have `Path.isRoot` + var finished = path.getParent == null + var chopped = path + + while (!finished) { + val maybeColumn = parsePartitionColumn(chopped.getName, defaultPartitionName) + maybeColumn.foreach(columns += _) + chopped = chopped.getParent + finished = maybeColumn.isEmpty || chopped.getParent == null + } + + val (columnNames, values) = columns.reverse.unzip + PartitionValues(columnNames, values) + } + + private def parsePartitionColumn( + columnSpec: String, + defaultPartitionName: String): Option[(String, Literal)] = { + val equalSignIndex = columnSpec.indexOf('=') + if (equalSignIndex == -1) { + None + } else { + val columnName = columnSpec.take(equalSignIndex) + assert(columnName.nonEmpty, s"Empty partition column name in '$columnSpec'") + + val rawColumnValue = columnSpec.drop(equalSignIndex + 1) + assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'") + + val literal = inferPartitionColumnValue(rawColumnValue, defaultPartitionName) + Some(columnName -> literal) + } + } + + /** + * Resolves possible type conflicts between partitions by up-casting "lower" types. The up- + * casting order is: + * {{{ + * NullType -> + * IntegerType -> LongType -> + * FloatType -> DoubleType -> DecimalType.Unlimited -> + * StringType + * }}} + */ + private[parquet] def resolvePartitions(values: Seq[PartitionValues]): Seq[PartitionValues] = { + // Column names of all partitions must match + val distinctPartitionsColNames = values.map(_.columnNames).distinct + assert(distinctPartitionsColNames.size == 1, { + val list = distinctPartitionsColNames.mkString("\t", "\n", "") + s"Conflicting partition column names detected:\n$list" + }) + + // Resolves possible type conflicts for each column + val columnCount = values.head.columnNames.size + val resolvedValues = (0 until columnCount).map { i => + resolveTypeConflicts(values.map(_.literals(i))) + } + + // Fills resolved literals back to each partition + values.zipWithIndex.map { case (d, index) => + d.copy(literals = resolvedValues.map(_(index))) + } + } + + /** + * Converts a string to a `Literal` with automatic type inference. Currently only supports + * [[IntegerType]], [[LongType]], [[FloatType]], [[DoubleType]], [[DecimalType.Unlimited]], and + * [[StringType]]. + */ + private[parquet] def inferPartitionColumnValue( + raw: String, + defaultPartitionName: String): Literal = { + // First tries integral types + Try(Literal.create(Integer.parseInt(raw), IntegerType)) + .orElse(Try(Literal.create(JLong.parseLong(raw), LongType))) + // Then falls back to fractional types + .orElse(Try(Literal.create(JFloat.parseFloat(raw), FloatType))) + .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) + .orElse(Try(Literal.create(new JBigDecimal(raw), DecimalType.Unlimited))) + // Then falls back to string + .getOrElse { + if (raw == defaultPartitionName) Literal.create(null, NullType) + else Literal.create(raw, StringType) + } + } + + private val upCastingOrder: Seq[DataType] = + Seq(NullType, IntegerType, LongType, FloatType, DoubleType, DecimalType.Unlimited, StringType) + + /** + * Given a collection of [[Literal]]s, resolves possible type conflicts by up-casting "lower" + * types. + */ + private def resolveTypeConflicts(literals: Seq[Literal]): Seq[Literal] = { + val desiredType = { + val topType = literals.map(_.dataType).maxBy(upCastingOrder.indexOf(_)) + // Falls back to string if all values of this column are null or empty string + if (topType == NullType) StringType else topType + } + + literals.map { case l @ Literal(_, dataType) => + Literal.create(Cast(l, desiredType).eval(), desiredType) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala new file mode 100644 index 000000000000..70bcca7526aa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet.timestamp + +import java.nio.{ByteBuffer, ByteOrder} + +import parquet.Preconditions +import parquet.io.api.{Binary, RecordConsumer} + +private[parquet] class NanoTime extends Serializable { + private var julianDay = 0 + private var timeOfDayNanos = 0L + + def set(julianDay: Int, timeOfDayNanos: Long): this.type = { + this.julianDay = julianDay + this.timeOfDayNanos = timeOfDayNanos + this + } + + def getJulianDay: Int = julianDay + + def getTimeOfDayNanos: Long = timeOfDayNanos + + def toBinary: Binary = { + val buf = ByteBuffer.allocate(12) + buf.order(ByteOrder.LITTLE_ENDIAN) + buf.putLong(timeOfDayNanos) + buf.putInt(julianDay) + buf.flip() + Binary.fromByteBuffer(buf) + } + + def writeValue(recordConsumer: RecordConsumer): Unit = { + recordConsumer.addBinary(toBinary) + } + + override def toString: String = + "NanoTime{julianDay=" + julianDay + ", timeOfDayNanos=" + timeOfDayNanos + "}" +} + +private[sql] object NanoTime { + def fromBinary(bytes: Binary): NanoTime = { + Preconditions.checkArgument(bytes.length() == 12, "Must be 12 bytes") + val buf = bytes.toByteBuffer + buf.order(ByteOrder.LITTLE_ENDIAN) + val timeOfDayNanos = buf.getLong + val julianDay = buf.getInt + new NanoTime().set(julianDay, timeOfDayNanos) + } + + def apply(julianDay: Int, timeOfDayNanos: Long): NanoTime = { + new NanoTime().set(julianDay, timeOfDayNanos) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index 37853d4d0301..b3d71f687a60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -18,19 +18,20 @@ package org.apache.spark.sql.sources import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row -import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.{UTF8String, StringType} +import org.apache.spark.sql.{Row, Strategy, execution, sources} /** * A Strategy for planning scans over data sources defined using the sources API. */ private[sql] object DataSourceStrategy extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: CatalystScan)) => pruneFilterProjectRaw( l, @@ -53,7 +54,11 @@ private[sql] object DataSourceStrategy extends Strategy { (a, _) => t.buildScan(a)) :: Nil case l @ LogicalRelation(t: TableScan) => - execution.PhysicalRDD(l.output, t.buildScan()) :: Nil + createPhysicalRDD(l.relation, l.output, t.buildScan()) :: Nil + + case i @ logical.InsertIntoTable( + l @ LogicalRelation(t: InsertableRelation), part, query, overwrite, false) if part.isEmpty => + execution.ExecutedCommand(InsertIntoDataSource(l, query, overwrite)) :: Nil case _ => Nil } @@ -82,7 +87,7 @@ private[sql] object DataSourceStrategy extends Strategy { val projectSet = AttributeSet(projectList.flatMap(_.references)) val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) - val filterCondition = filterPredicates.reduceLeftOption(And) + val filterCondition = filterPredicates.reduceLeftOption(expressions.And) val pushedFilters = filterPredicates.map { _ transform { case a: AttributeReference => relation.attributeMap(a) // Match original case of attributes. @@ -98,38 +103,93 @@ private[sql] object DataSourceStrategy extends Strategy { projectList.asInstanceOf[Seq[Attribute]] // Safe due to if above. .map(relation.attributeMap) // Match original case of attributes. - val scan = - execution.PhysicalRDD( - projectList.map(_.toAttribute), + val scan = createPhysicalRDD(relation.relation, projectList.map(_.toAttribute), scanBuilder(requestedColumns, pushedFilters)) filterCondition.map(execution.Filter(_, scan)).getOrElse(scan) } else { val requestedColumns = (projectSet ++ filterSet).map(relation.attributeMap).toSeq - val scan = - execution.PhysicalRDD(requestedColumns, scanBuilder(requestedColumns, pushedFilters)) + val scan = createPhysicalRDD(relation.relation, requestedColumns, + scanBuilder(requestedColumns, pushedFilters)) execution.Project(projectList, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) } } - protected[sql] def selectFilters(filters: Seq[Expression]): Seq[Filter] = filters.collect { - case expressions.EqualTo(a: Attribute, Literal(v, _)) => EqualTo(a.name, v) - case expressions.EqualTo(Literal(v, _), a: Attribute) => EqualTo(a.name, v) - - case expressions.GreaterThan(a: Attribute, Literal(v, _)) => GreaterThan(a.name, v) - case expressions.GreaterThan(Literal(v, _), a: Attribute) => LessThan(a.name, v) - - case expressions.LessThan(a: Attribute, Literal(v, _)) => LessThan(a.name, v) - case expressions.LessThan(Literal(v, _), a: Attribute) => GreaterThan(a.name, v) - - case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) => - GreaterThanOrEqual(a.name, v) - case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) => - LessThanOrEqual(a.name, v) + private[this] def createPhysicalRDD( + relation: BaseRelation, + output: Seq[Attribute], + rdd: RDD[Row]): SparkPlan = { + val converted = if (relation.needConversion) { + execution.RDDConversions.rowToRowRdd(rdd, relation.schema) + } else { + rdd + } + execution.PhysicalRDD(output, converted) + } - case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => LessThanOrEqual(a.name, v) - case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => GreaterThanOrEqual(a.name, v) + /** + * Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s, + * and convert them. + */ + protected[sql] def selectFilters(filters: Seq[Expression]) = { + def translate(predicate: Expression): Option[Filter] = predicate match { + case expressions.EqualTo(a: Attribute, Literal(v, _)) => + Some(sources.EqualTo(a.name, v)) + case expressions.EqualTo(Literal(v, _), a: Attribute) => + Some(sources.EqualTo(a.name, v)) + + case expressions.GreaterThan(a: Attribute, Literal(v, _)) => + Some(sources.GreaterThan(a.name, v)) + case expressions.GreaterThan(Literal(v, _), a: Attribute) => + Some(sources.LessThan(a.name, v)) + + case expressions.LessThan(a: Attribute, Literal(v, _)) => + Some(sources.LessThan(a.name, v)) + case expressions.LessThan(Literal(v, _), a: Attribute) => + Some(sources.GreaterThan(a.name, v)) + + case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) => + Some(sources.GreaterThanOrEqual(a.name, v)) + case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) => + Some(sources.LessThanOrEqual(a.name, v)) + + case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => + Some(sources.LessThanOrEqual(a.name, v)) + case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => + Some(sources.GreaterThanOrEqual(a.name, v)) + + case expressions.InSet(a: Attribute, set) => + Some(sources.In(a.name, set.toArray)) + + case expressions.IsNull(a: Attribute) => + Some(sources.IsNull(a.name)) + case expressions.IsNotNull(a: Attribute) => + Some(sources.IsNotNull(a.name)) + + case expressions.And(left, right) => + (translate(left) ++ translate(right)).reduceOption(sources.And) + + case expressions.Or(left, right) => + for { + leftFilter <- translate(left) + rightFilter <- translate(right) + } yield sources.Or(leftFilter, rightFilter) + + case expressions.Not(child) => + translate(child).map(sources.Not) + + case expressions.StartsWith(a: Attribute, Literal(v: UTF8String, StringType)) => + Some(sources.StringStartsWith(a.name, v.toString)) + + case expressions.EndsWith(a: Attribute, Literal(v: UTF8String, StringType)) => + Some(sources.StringEndsWith(a.name, v.toString)) + + case expressions.Contains(a: Attribute, Literal(v: UTF8String, StringType)) => + Some(sources.StringContains(a.name, v.toString)) + + case _ => None + } - case expressions.InSet(a: Attribute, set) => In(a.name, set.toArray) + filters.flatMap(translate) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala index 12b59ba20bb1..598ef18a379e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala @@ -23,31 +23,35 @@ import org.apache.spark.sql.catalyst.plans.logical.{Statistics, LeafNode, Logica /** * Used to link a [[BaseRelation]] in to a logical query plan. */ -private[sql] case class LogicalRelation(relation: BaseRelation) +private[sql] case class LogicalRelation(relation: BaseRelation) // 叫 externalLogicalRelation比较合适? extends LeafNode with MultiInstanceRelation { override val output: Seq[AttributeReference] = relation.schema.toAttributes // Logical Relations are distinct if they have different output for the sake of transformations. - override def equals(other: Any) = other match { + override def equals(other: Any): Boolean = other match { case l @ LogicalRelation(otherRelation) => relation == otherRelation && output == l.output case _ => false } - override def sameResult(otherPlan: LogicalPlan) = otherPlan match { + override def hashCode: Int = { + com.google.common.base.Objects.hashCode(relation, output) + } + + override def sameResult(otherPlan: LogicalPlan): Boolean = otherPlan match { case LogicalRelation(otherRelation) => relation == otherRelation case _ => false } - @transient override lazy val statistics = Statistics( + @transient override lazy val statistics: Statistics = Statistics( sizeInBytes = BigInt(relation.sizeInBytes) ) /** Used to lookup original attribute capitalization */ - val attributeMap = AttributeMap(output.map(o => (o, o))) + val attributeMap: AttributeMap[AttributeReference] = AttributeMap(output.map(o => (o, o))) - def newInstance() = LogicalRelation(relation).asInstanceOf[this.type] + def newInstance(): this.type = LogicalRelation(relation).asInstanceOf[this.type] - override def simpleString = s"Relation[${output.mkString(",")}] $relation" + override def simpleString: String = s"Relation[${output.mkString(",")}] $relation" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala new file mode 100644 index 000000000000..dbdb0d39c26a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.sources + +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.RunnableCommand + +private[sql] case class InsertIntoDataSource( + logicalRelation: LogicalRelation, + query: LogicalPlan, + overwrite: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] + val data = DataFrame(sqlContext, query) + // Apply the schema of the existing table to the new data. + val df = sqlContext.createDataFrame( + data.queryExecution.toRdd, logicalRelation.schema, needsConversion = false) + relation.insert(df, overwrite) + + // Invalidate the cache. + sqlContext.cacheManager.invalidateCache(logicalRelation) + + Seq.empty[Row] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 171b816a2633..e7a0685e013d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -17,223 +17,362 @@ package org.apache.spark.sql.sources +import scala.language.existentials +import scala.util.matching.Regex import scala.language.implicitConversions import org.apache.spark.Logging -import org.apache.spark.sql.{SchemaRDD, SQLContext} +import org.apache.spark.sql.{AnalysisException, SaveMode, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.AbstractSparkSQLParser +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row} import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.types._ import org.apache.spark.util.Utils - /** * A parser for foreign DDL commands. */ -private[sql] class DDLParser extends AbstractSparkSQLParser with Logging { +private[sql] class DDLParser( + parseQuery: String => LogicalPlan) + extends AbstractSparkSQLParser with DataTypeParser with Logging { - def apply(input: String, exceptionOnError: Boolean): Option[LogicalPlan] = { + def parse(input: String, exceptionOnError: Boolean): Option[LogicalPlan] = { try { - Some(apply(input)) + Some(parse(input)) } catch { + case ddlException: DDLException => throw ddlException case _ if !exceptionOnError => None case x: Throwable => throw x } } - def parseType(input: String): DataType = { - lexical.initialize(reservedWords) - phrase(dataType)(new lexical.Scanner(input)) match { - case Success(r, x) => r - case x => - sys.error(s"Unsupported dataType: $x") - } - } - - // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` // properties via reflection the class in runtime for constructing the SqlLexical object protected val CREATE = Keyword("CREATE") protected val TEMPORARY = Keyword("TEMPORARY") protected val TABLE = Keyword("TABLE") + protected val IF = Keyword("IF") + protected val NOT = Keyword("NOT") + protected val EXISTS = Keyword("EXISTS") protected val USING = Keyword("USING") protected val OPTIONS = Keyword("OPTIONS") + protected val DESCRIBE = Keyword("DESCRIBE") + protected val EXTENDED = Keyword("EXTENDED") + protected val AS = Keyword("AS") + protected val COMMENT = Keyword("COMMENT") + protected val REFRESH = Keyword("REFRESH") - // Data types. - protected val STRING = Keyword("STRING") - protected val BINARY = Keyword("BINARY") - protected val BOOLEAN = Keyword("BOOLEAN") - protected val TINYINT = Keyword("TINYINT") - protected val SMALLINT = Keyword("SMALLINT") - protected val INT = Keyword("INT") - protected val BIGINT = Keyword("BIGINT") - protected val FLOAT = Keyword("FLOAT") - protected val DOUBLE = Keyword("DOUBLE") - protected val DECIMAL = Keyword("DECIMAL") - protected val DATE = Keyword("DATE") - protected val TIMESTAMP = Keyword("TIMESTAMP") - protected val VARCHAR = Keyword("VARCHAR") - protected val ARRAY = Keyword("ARRAY") - protected val MAP = Keyword("MAP") - protected val STRUCT = Keyword("STRUCT") - - protected lazy val ddl: Parser[LogicalPlan] = createTable + protected lazy val ddl: Parser[LogicalPlan] = createTable | describeTable | refreshTable protected def start: Parser[LogicalPlan] = ddl /** - * `CREATE [TEMPORARY] TABLE avroTable + * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] * USING org.apache.spark.sql.avro * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` * or - * `CREATE [TEMPORARY] TABLE avroTable(intField int, stringField string...) + * `CREATE [TEMPORARY] TABLE avroTable(intField int, stringField string...) [IF NOT EXISTS] * USING org.apache.spark.sql.avro * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` + * or + * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` + * AS SELECT ... */ protected lazy val createTable: Parser[LogicalPlan] = - ( - (CREATE ~> TEMPORARY.? <~ TABLE) ~ ident - ~ (tableCols).? ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ { - case temp ~ tableName ~ columns ~ provider ~ opts => - val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) - CreateTableUsing(tableName, userSpecifiedSchema, provider, temp.isDefined, opts) + // TODO: Support database.table. + (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ ident ~ + tableCols.? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> restInput).? ^^ { + case temp ~ allowExisting ~ tableName ~ columns ~ provider ~ opts ~ query => + if (temp.isDefined && allowExisting.isDefined) { + throw new DDLException( + "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.") + } + + val options = opts.getOrElse(Map.empty[String, String]) + if (query.isDefined) { + if (columns.isDefined) { + throw new DDLException( + "a CREATE TABLE AS SELECT statement does not allow column definitions.") + } + // When IF NOT EXISTS clause appears in the query, the save mode will be ignore. + val mode = if (allowExisting.isDefined) { + SaveMode.Ignore + } else if (temp.isDefined) { + SaveMode.Overwrite + } else { + SaveMode.ErrorIfExists + } + + val queryPlan = parseQuery(query.get) + CreateTableUsingAsSelect(tableName, + provider, + temp.isDefined, + mode, + options, + queryPlan) + } else { + val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) + CreateTableUsing( + tableName, + userSpecifiedSchema, + provider, + temp.isDefined, + options, + allowExisting.isDefined, + managedIfNoPath = false) + } } - ) protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" + /* + * describe [extended] table avroTable + * This will display all columns of table `avroTable` includes column_name,column_type,comment + */ + protected lazy val describeTable: Parser[LogicalPlan] = + (DESCRIBE ~> opt(EXTENDED)) ~ (ident <~ ".").? ~ ident ^^ { + case e ~ db ~ tbl => + val tblIdentifier = db match { + case Some(dbName) => + Seq(dbName, tbl) + case None => + Seq(tbl) + } + DescribeCommand(UnresolvedRelation(tblIdentifier, None), e.isDefined) + } + + protected lazy val refreshTable: Parser[LogicalPlan] = + REFRESH ~> TABLE ~> (ident <~ ".").? ~ ident ^^ { + case maybeDatabaseName ~ tableName => + RefreshTable(maybeDatabaseName.getOrElse("default"), tableName) + } + protected lazy val options: Parser[Map[String, String]] = "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} - protected lazy val pair: Parser[(String, String)] = ident ~ stringLit ^^ { case k ~ v => (k,v) } + override implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch( + s"identifier matching regex ${regex}", { + case lexical.Identifier(str) if regex.unapplySeq(str).isDefined => str + case lexical.Keyword(str) if regex.unapplySeq(str).isDefined => str + } + ) + + protected lazy val optionName: Parser[String] = "[_a-zA-Z][a-zA-Z0-9]*".r ^^ { + case name => name + } + + protected lazy val pair: Parser[(String, String)] = + optionName ~ stringLit ^^ { case k ~ v => (k,v) } protected lazy val column: Parser[StructField] = - ident ~ dataType ^^ { case columnName ~ typ => - StructField(columnName, typ) - } + ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => + val meta = cm match { + case Some(comment) => + new MetadataBuilder().putString(COMMENT.str.toLowerCase, comment).build() + case None => Metadata.empty + } - protected lazy val primitiveType: Parser[DataType] = - STRING ^^^ StringType | - BINARY ^^^ BinaryType | - BOOLEAN ^^^ BooleanType | - TINYINT ^^^ ByteType | - SMALLINT ^^^ ShortType | - INT ^^^ IntegerType | - BIGINT ^^^ LongType | - FLOAT ^^^ FloatType | - DOUBLE ^^^ DoubleType | - fixedDecimalType | // decimal with precision/scale - DECIMAL ^^^ DecimalType.Unlimited | // decimal with no precision/scale - DATE ^^^ DateType | - TIMESTAMP ^^^ TimestampType | - VARCHAR ~ "(" ~ numericLit ~ ")" ^^^ StringType - - protected lazy val fixedDecimalType: Parser[DataType] = - (DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ { - case precision ~ scale => DecimalType(precision.toInt, scale.toInt) + StructField(columnName, typ, nullable = true, meta) } +} - protected lazy val arrayType: Parser[DataType] = - ARRAY ~> "<" ~> dataType <~ ">" ^^ { - case tpe => ArrayType(tpe) - } +private[sql] object ResolvedDataSource { - protected lazy val mapType: Parser[DataType] = - MAP ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ { - case t1 ~ _ ~ t2 => MapType(t1, t2) - } + private val builtinSources = Map( + "jdbc" -> classOf[org.apache.spark.sql.jdbc.DefaultSource], + "json" -> classOf[org.apache.spark.sql.json.DefaultSource], + "parquet" -> classOf[org.apache.spark.sql.parquet.DefaultSource] + ) - protected lazy val structField: Parser[StructField] = - ident ~ ":" ~ dataType ^^ { - case fieldName ~ _ ~ tpe => StructField(fieldName, tpe, nullable = true) + /** Given a provider name, look up the data source class definition. */ + def lookupDataSource(provider: String): Class[_] = { + if (builtinSources.contains(provider)) { + return builtinSources(provider) } - protected lazy val structType: Parser[DataType] = - (STRUCT ~> "<" ~> repsep(structField, ",") <~ ">" ^^ { - case fields => StructType(fields) - }) | - (STRUCT ~> "<>" ^^ { - case fields => StructType(Nil) - }) - - private[sql] lazy val dataType: Parser[DataType] = - arrayType | - mapType | - structType | - primitiveType -} - -object ResolvedDataSource { - def apply( - sqlContext: SQLContext, - userSpecifiedSchema: Option[StructType], - provider: String, - options: Map[String, String]): ResolvedDataSource = { val loader = Utils.getContextOrSparkClassLoader - val clazz: Class[_] = try loader.loadClass(provider) catch { + try { + loader.loadClass(provider) + } catch { case cnf: java.lang.ClassNotFoundException => - try loader.loadClass(provider + ".DefaultSource") catch { + try { + loader.loadClass(provider + ".DefaultSource") + } catch { case cnf: java.lang.ClassNotFoundException => sys.error(s"Failed to load class for data source: $provider") } } + } + /** Create a [[ResolvedDataSource]] for reading data in. */ + def apply( + sqlContext: SQLContext, + userSpecifiedSchema: Option[StructType], + provider: String, + options: Map[String, String]): ResolvedDataSource = { + val clazz: Class[_] = lookupDataSource(provider) + def className: String = clazz.getCanonicalName val relation = userSpecifiedSchema match { - case Some(schema: StructType) => { - clazz.newInstance match { - case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => - dataSource - .asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider] - .createRelation(sqlContext, new CaseInsensitiveMap(options), schema) - case _ => - sys.error(s"${clazz.getCanonicalName} does not allow user-specified schemas.") - } + case Some(schema: StructType) => clazz.newInstance() match { + case dataSource: SchemaRelationProvider => + dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema) + case dataSource: org.apache.spark.sql.sources.RelationProvider => + throw new AnalysisException(s"$className does not allow user-specified schemas.") + case _ => + throw new AnalysisException(s"$className is not a RelationProvider.") } - case None => { - clazz.newInstance match { - case dataSource: org.apache.spark.sql.sources.RelationProvider => - dataSource - .asInstanceOf[org.apache.spark.sql.sources.RelationProvider] - .createRelation(sqlContext, new CaseInsensitiveMap(options)) - case _ => - sys.error(s"A schema needs to be specified when using ${clazz.getCanonicalName}.") - } + + case None => clazz.newInstance() match { + case dataSource: RelationProvider => + dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) + case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => + throw new AnalysisException( + s"A schema needs to be specified when using $className.") + case _ => + throw new AnalysisException(s"$className is not a RelationProvider.") } } + new ResolvedDataSource(clazz, relation) + } + /** Create a [[ResolvedDataSource]] for saving the content of the given [[DataFrame]]. */ + def apply( + sqlContext: SQLContext, + provider: String, + mode: SaveMode, + options: Map[String, String], + data: DataFrame): ResolvedDataSource = { + val clazz: Class[_] = lookupDataSource(provider) + val relation = clazz.newInstance() match { + case dataSource: CreatableRelationProvider => + dataSource.createRelation(sqlContext, mode, options, data) + case _ => + sys.error(s"${clazz.getCanonicalName} does not allow create table as select.") + } new ResolvedDataSource(clazz, relation) } } private[sql] case class ResolvedDataSource(provider: Class[_], relation: BaseRelation) +/** + * Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command. + * @param table The table to be described. + * @param isExtended True if "DESCRIBE EXTENDED" is used. Otherwise, false. + * It is effective only when the table is a Hive table. + */ +private[sql] case class DescribeCommand( + table: LogicalPlan, + isExtended: Boolean) extends Command { + override val output = Seq( + // Column names are based on Hive. + AttributeReference("col_name", StringType, nullable = false, + new MetadataBuilder().putString("comment", "name of the column").build())(), + AttributeReference("data_type", StringType, nullable = false, + new MetadataBuilder().putString("comment", "data type of the column").build())(), + AttributeReference("comment", StringType, nullable = false, + new MetadataBuilder().putString("comment", "comment of the column").build())()) +} + +/** + * Used to represent the operation of create table using a data source. + * @param allowExisting If it is true, we will do nothing when the table already exists. + * If it is false, an exception will be thrown + */ private[sql] case class CreateTableUsing( tableName: String, userSpecifiedSchema: Option[StructType], provider: String, temporary: Boolean, - options: Map[String, String]) extends Command + options: Map[String, String], + allowExisting: Boolean, + managedIfNoPath: Boolean) extends Command + +/** + * A node used to support CTAS statements and saveAsTable for the data source API. + * This node is a [[UnaryNode]] instead of a [[Command]] because we want the analyzer + * can analyze the logical plan that will be used to populate the table. + * So, [[PreWriteCheck]] can detect cases that are not allowed. + */ +private[sql] case class CreateTableUsingAsSelect( + tableName: String, + provider: String, + temporary: Boolean, + mode: SaveMode, + options: Map[String, String], + child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = Seq.empty[Attribute] + // TODO: Override resolved after we support databaseName. + // override lazy val resolved = databaseName != None && childrenResolved +} -private [sql] case class CreateTempTableUsing( +private[sql] case class CreateTempTableUsing( tableName: String, userSpecifiedSchema: Option[StructType], provider: String, - options: Map[String, String]) extends RunnableCommand { + options: Map[String, String]) extends RunnableCommand { - def run(sqlContext: SQLContext) = { + def run(sqlContext: SQLContext): Seq[Row] = { val resolved = ResolvedDataSource(sqlContext, userSpecifiedSchema, provider, options) - new SchemaRDD(sqlContext, LogicalRelation(resolved.relation)).registerTempTable(tableName) + sqlContext.registerDataFrameAsTable( + DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) + Seq.empty + } +} + +private[sql] case class CreateTempTableUsingAsSelect( + tableName: String, + provider: String, + mode: SaveMode, + options: Map[String, String], + query: LogicalPlan) extends RunnableCommand { + + def run(sqlContext: SQLContext): Seq[Row] = { + val df = DataFrame(sqlContext, query) + val resolved = ResolvedDataSource(sqlContext, provider, mode, options, df) + sqlContext.registerDataFrameAsTable( + DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) + Seq.empty } } +private[sql] case class RefreshTable(databaseName: String, tableName: String) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + // Refresh the given table's metadata first. + sqlContext.catalog.refreshTable(databaseName, tableName) + + // If this table is cached as a InMemoryColumnarRelation, drop the original + // cached version and make the new version cached lazily. + val logicalPlan = sqlContext.catalog.lookupRelation(Seq(databaseName, tableName)) + // Use lookupCachedData directly since RefreshTable also takes databaseName. + val isCached = sqlContext.cacheManager.lookupCachedData(logicalPlan).nonEmpty + if (isCached) { + // Create a data frame to represent the table. + // TODO: Use uncacheTable once it supports database name. + val df = DataFrame(sqlContext, logicalPlan) + // Uncache the logicalPlan. + sqlContext.cacheManager.tryUncacheQuery(df, blocking = true) + // Cache it again. + sqlContext.cacheManager.cacheQuery(df, Some(tableName)) + } + + Seq.empty[Row] + } +} + /** * Builds a map in which keys are case insensitive */ -protected class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] +protected[sql] class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] with Serializable { val baseMap = map.map(kv => kv.copy(_1 = kv._1.toLowerCase)) @@ -245,5 +384,10 @@ protected class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, override def iterator: Iterator[(String, String)] = baseMap.iterator - override def -(key: String): Map[String, String] = baseMap - key.toLowerCase() + override def -(key: String): Map[String, String] = baseMap - key.toLowerCase } + +/** + * The exception thrown from the DDL parser. + */ +protected[sql] class DDLException(message: String) extends Exception(message) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 4a9fefc12b9a..791046e0079d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -17,11 +17,85 @@ package org.apache.spark.sql.sources +/** + * A filter predicate for data sources. + */ abstract class Filter +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * equal to `value`. + */ case class EqualTo(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * greater than `value`. + */ case class GreaterThan(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * greater than or equal to `value`. + */ case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * less than `value`. + */ case class LessThan(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a value + * less than or equal to `value`. + */ case class LessThanOrEqual(attribute: String, value: Any) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to one of the values in the array. + */ case class In(attribute: String, values: Array[Any]) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to null. + */ +case class IsNull(attribute: String) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to a non-null value. + */ +case class IsNotNull(attribute: String) extends Filter + +/** + * A filter that evaluates to `true` iff both `left` or `right` evaluate to `true`. + */ +case class And(left: Filter, right: Filter) extends Filter + +/** + * A filter that evaluates to `true` iff at least one of `left` or `right` evaluates to `true`. + */ +case class Or(left: Filter, right: Filter) extends Filter + +/** + * A filter that evaluates to `true` iff `child` is evaluated to `false`. + */ +case class Not(child: Filter) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to + * a string that starts with `value`. + */ +case class StringStartsWith(attribute: String, value: String) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to + * a string that starts with `value`. + */ +case class StringEndsWith(attribute: String, value: String) extends Filter + +/** + * A filter that evaluates to `true` iff the attribute evaluates to + * a string that contains the string `value`. + */ +case class StringContains(attribute: String, value: String) extends Filter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index cd82cc6ecb61..ca53dcdb92c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -14,11 +14,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.sql.sources import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{SaveMode, DataFrame, Row, SQLContext} import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute} import org.apache.spark.sql.types.StructType @@ -77,15 +78,36 @@ trait SchemaRelationProvider { schema: StructType): BaseRelation } +@DeveloperApi +trait CreatableRelationProvider { + /** + * Creates a relation with the given parameters based on the contents of the given + * DataFrame. The mode specifies the expected behavior of createRelation when + * data already exists. + * Right now, there are three modes, Append, Overwrite, and ErrorIfExists. + * Append mode means that when saving a DataFrame to a data source, if data already exists, + * contents of the DataFrame are expected to be appended to existing data. + * Overwrite mode means that when saving a DataFrame to a data source, if data already exists, + * existing data is expected to be overwritten by the contents of the DataFrame. + * ErrorIfExists mode means that when saving a DataFrame to a data source, + * if data already exists, an exception is expected to be thrown. + */ + def createRelation( + sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation +} + /** * ::DeveloperApi:: - * Represents a collection of tuples with a known schema. Classes that extend BaseRelation must - * be able to produce the schema of their data in the form of a [[StructType]] Concrete + * Represents a collection of tuples with a known schema. Classes that extend BaseRelation must + * be able to produce the schema of their data in the form of a [[StructType]]. Concrete * implementation should inherit from one of the descendant `Scan` classes, which define various * abstract methods for execution. * * BaseRelations must also define a equality function that only returns true when the two - * instances will return the same data. This equality function is used when determining when + * instances will return the same data. This equality function is used when determining when * it is safe to substitute cached results for a given relation. */ @DeveloperApi @@ -94,13 +116,26 @@ abstract class BaseRelation { def schema: StructType /** - * Returns an estimated size of this relation in bytes. This information is used by the planner + * Returns an estimated size of this relation in bytes. This information is used by the planner * to decided when it is safe to broadcast a relation and can be overridden by sources that * know the size ahead of time. By default, the system will assume that tables are too - * large to broadcast. This method will be called multiple times during query planning + * large to broadcast. This method will be called multiple times during query planning * and thus should not perform expensive operations for each invocation. + * + * Note that it is always better to overestimate size than underestimate, because underestimation + * could lead to execution plans that are suboptimal (i.e. broadcasting a very large table). */ - def sizeInBytes = sqlContext.conf.defaultSizeInBytes + def sizeInBytes: Long = sqlContext.conf.defaultSizeInBytes + + /** + * Whether does it need to convert the objects in Row to internal representation, for example: + * java.lang.String -> UTF8String + * java.lang.Decimal -> Decimal + * + * Note: The internal representation is not stable across releases and thus data sources outside + * of Spark SQL should leave this as true. + */ + def needConversion: Boolean = true } /** @@ -108,7 +143,7 @@ abstract class BaseRelation { * A BaseRelation that can produce all of its tuples as an RDD of Row objects. */ @DeveloperApi -abstract class TableScan extends BaseRelation { +trait TableScan { def buildScan(): RDD[Row] } @@ -118,7 +153,7 @@ abstract class TableScan extends BaseRelation { * containing all of its tuples as Row objects. */ @DeveloperApi -abstract class PrunedScan extends BaseRelation { +trait PrunedScan { def buildScan(requiredColumns: Array[String]): RDD[Row] } @@ -127,24 +162,48 @@ abstract class PrunedScan extends BaseRelation { * A BaseRelation that can eliminate unneeded columns and filter using selected * predicates before producing an RDD containing all matching tuples as Row objects. * + * The actual filter should be the conjunction of all `filters`, + * i.e. they should be "and" together. + * * The pushed down filters are currently purely an optimization as they will all be evaluated * again. This means it is safe to use them with methods that produce false positives such * as filtering partitions based on a bloom filter. */ @DeveloperApi -abstract class PrunedFilteredScan extends BaseRelation { +trait PrunedFilteredScan { def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] } +/** + * ::DeveloperApi:: + * A BaseRelation that can be used to insert data into it through the insert method. + * If overwrite in insert method is true, the old data in the relation should be overwritten with + * the new data. If overwrite in insert method is false, the new data should be appended. + * + * InsertableRelation has the following three assumptions. + * 1. It assumes that the data (Rows in the DataFrame) provided to the insert method + * exactly matches the ordinal of fields in the schema of the BaseRelation. + * 2. It assumes that the schema of this relation will not be changed. + * Even if the insert method updates the schema (e.g. a relation of JSON or Parquet data may have a + * schema update after an insert operation), the new schema will not be used. + * 3. It assumes that fields of the data provided in the insert method are nullable. + * If a data source needs to check the actual nullability of a field, it needs to do it in the + * insert method. + */ +@DeveloperApi +trait InsertableRelation { + def insert(data: DataFrame, overwrite: Boolean): Unit +} + /** * ::Experimental:: * An interface for experimenting with a more direct connection to the query planner. Compared to * [[PrunedFilteredScan]], this operator receives the raw expressions from the * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. Unlike the other APIs this - * interface is not designed to be binary compatible across releases and thus should only be used + * interface is NOT designed to be binary compatible across releases and thus should only be used * for experimentation. */ @Experimental -abstract class CatalystScan extends BaseRelation { +trait CatalystScan { def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala new file mode 100644 index 000000000000..6ed68d179edc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql.{SaveMode, AnalysisException} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, Catalog} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Alias} +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.DataType + +/** + * A rule to do pre-insert data type casting and field renaming. Before we insert into + * an [[InsertableRelation]], we will use this rule to make sure that + * the columns to be inserted have the correct data type and fields have the correct names. + */ +private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // Wait until children are resolved. + case p: LogicalPlan if !p.childrenResolved => p + + // We are inserting into an InsertableRelation. + case i @ InsertIntoTable( + l @ LogicalRelation(r: InsertableRelation), partition, child, overwrite, ifNotExists) => { + // First, make sure the data to be inserted have the same number of fields with the + // schema of the relation. + if (l.output.size != child.output.size) { + sys.error( + s"$l requires that the query in the SELECT clause of the INSERT INTO/OVERWRITE " + + s"statement generates the same number of columns as its schema.") + } + castAndRenameChildOutput(i, l.output, child) + } + } + + /** If necessary, cast data types and rename fields to the expected types and names. */ + def castAndRenameChildOutput( + insertInto: InsertIntoTable, + expectedOutput: Seq[Attribute], + child: LogicalPlan): InsertIntoTable = { + val newChildOutput = expectedOutput.zip(child.output).map { + case (expected, actual) => + val needCast = !expected.dataType.sameType(actual.dataType) + // We want to make sure the filed names in the data to be inserted exactly match + // names in the schema. + val needRename = expected.name != actual.name + (needCast, needRename) match { + case (true, _) => Alias(Cast(actual, expected.dataType), expected.name)() + case (false, true) => Alias(actual, expected.name)() + case (_, _) => actual + } + } + + if (newChildOutput == child.output) { + insertInto + } else { + insertInto.copy(child = Project(newChildOutput, child)) + } + } +} + +/** + * A rule to do various checks before inserting into or writing to a data source table. + */ +private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => Unit) { + def failAnalysis(msg: String): Unit = { throw new AnalysisException(msg) } + + def apply(plan: LogicalPlan): Unit = { + plan.foreach { + case i @ logical.InsertIntoTable( + l @ LogicalRelation(t: InsertableRelation), partition, query, overwrite, ifNotExists) => + // Right now, we do not support insert into a data source table with partition specs. + if (partition.nonEmpty) { + failAnalysis(s"Insert into a partition is not allowed because $l is not partitioned.") + } else { + // Get all input data source relations of the query. + val srcRelations = query.collect { + case LogicalRelation(src: BaseRelation) => src + } + if (srcRelations.contains(t)) { + failAnalysis( + "Cannot insert overwrite into table that is also being read from.") + } else { + // OK + } + } + + case i @ logical.InsertIntoTable( + l: LogicalRelation, partition, query, overwrite, ifNotExists) + if !l.isInstanceOf[InsertableRelation] => + // The relation in l is not an InsertableRelation. + failAnalysis(s"$l does not allow insertion.") + + case CreateTableUsingAsSelect(tableName, _, _, SaveMode.Overwrite, _, query) => + // When the SaveMode is Overwrite, we need to check if the table is an input table of + // the query. If so, we will throw an AnalysisException to let users know it is not allowed. + if (catalog.tableExists(Seq(tableName))) { + // Need to remove SubQuery operator. + EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) match { + // Only do the check if the table is a data source table + // (the relation is a BaseRelation). + case l @ LogicalRelation(dest: BaseRelation) => + // Get all input data source relations of the query. + val srcRelations = query.collect { + case LogicalRelation(src: BaseRelation) => src + } + if (srcRelations.contains(dest)) { + failAnalysis( + s"Cannot overwrite table $tableName that is also being read from.") + } else { + // OK + } + + case _ => // OK + } + } else { + // OK + } + + case _ => // OK + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index 006b16fbe07b..2fdd798b44bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types._ * @param y y coordinate */ @SQLUserDefinedType(udt = classOf[ExamplePointUDT]) -private[sql] class ExamplePoint(val x: Double, val y: Double) +private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable /** * User-defined type for [[ExamplePoint]]. @@ -37,7 +37,7 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] { override def sqlType: DataType = ArrayType(DoubleType, false) - override def pyUDT: String = "pyspark.tests.ExamplePointUDT" + override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT" override def serialize(obj: Any): Seq[Double] = { obj match { @@ -59,4 +59,6 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] { } override def userClass: Class[ExamplePoint] = classOf[ExamplePoint] + + private[spark] override def asNullable: ExamplePointUDT = this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala index f9c082216085..356a6100d2cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -20,28 +20,37 @@ package org.apache.spark.sql.test import scala.language.implicitConversions import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{SchemaRDD, SQLConf, SQLContext} +import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan /** A SQLContext that can be used for local testing. */ -object TestSQLContext +class LocalSQLContext extends SQLContext( new SparkContext( "local[2]", "TestSQLContext", new SparkConf().set("spark.sql.testkey", "true"))) { - /** Fewer partitions to speed up testing. */ - protected[sql] override lazy val conf: SQLConf = new SQLConf { - override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt + override protected[sql] def createSession(): SQLSession = { + new this.SQLSession() + } + + protected[sql] class SQLSession extends super.SQLSession { + protected[sql] override lazy val conf: SQLConf = new SQLConf { + /** Fewer partitions to speed up testing. */ + override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt + } } /** - * Turn a logical plan into a SchemaRDD. This should be removed once we have an easier way to - * construct SchemaRDD directly out of local data without relying on implicits. + * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier way to + * construct [[DataFrame]] directly out of local data without relying on implicits. */ - protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): SchemaRDD = { - new SchemaRDD(this, plan) + protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { + DataFrame(this, plan) } } + +object TestSQLContext extends LocalSQLContext + diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java similarity index 69% rename from sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java rename to sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index 9e96738ac095..c344a9b095c5 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -15,13 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java; +package test.org.apache.spark.sql; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import org.apache.spark.sql.test.TestSQLContext$; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -38,19 +39,18 @@ // see http://stackoverflow.com/questions/758570/. public class JavaApplySchemaSuite implements Serializable { private transient JavaSparkContext javaCtx; - private transient SQLContext javaSqlCtx; + private transient SQLContext sqlContext; @Before public void setUp() { - javaCtx = new JavaSparkContext("local", "JavaApplySchemaSuite"); - javaSqlCtx = new SQLContext(javaCtx); + sqlContext = TestSQLContext$.MODULE$; + javaCtx = new JavaSparkContext(sqlContext.sparkContext()); } @After public void tearDown() { - javaCtx.stop(); javaCtx = null; - javaSqlCtx = null; + sqlContext = null; } public static class Person implements Serializable { @@ -98,9 +98,9 @@ public Row call(Person person) throws Exception { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - SchemaRDD schemaRDD = javaSqlCtx.applySchema(rowRDD.rdd(), schema); - schemaRDD.registerTempTable("people"); - Row[] actual = javaSqlCtx.sql("SELECT * FROM people").collect(); + DataFrame df = sqlContext.applySchema(rowRDD, schema); + df.registerTempTable("people"); + Row[] actual = sqlContext.sql("SELECT * FROM people").collect(); List expected = new ArrayList(2); expected.add(RowFactory.create("Michael", 29)); @@ -109,6 +109,46 @@ public Row call(Person person) throws Exception { Assert.assertEquals(expected, Arrays.asList(actual)); } + @Test + public void dataFrameRDDOperations() { + List personList = new ArrayList(2); + Person person1 = new Person(); + person1.setName("Michael"); + person1.setAge(29); + personList.add(person1); + Person person2 = new Person(); + person2.setName("Yin"); + person2.setAge(28); + personList.add(person2); + + JavaRDD rowRDD = javaCtx.parallelize(personList).map( + new Function() { + public Row call(Person person) throws Exception { + return RowFactory.create(person.getName(), person.getAge()); + } + }); + + List fields = new ArrayList(2); + fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); + fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); + StructType schema = DataTypes.createStructType(fields); + + DataFrame df = sqlContext.applySchema(rowRDD, schema); + df.registerTempTable("people"); + List actual = sqlContext.sql("SELECT * FROM people").toJavaRDD().map(new Function() { + + public String call(Row row) { + return row.getString(0) + "_" + row.get(1).toString(); + } + }).collect(); + + List expected = new ArrayList(2); + expected.add("Michael_29"); + expected.add("Yin_28"); + + Assert.assertEquals(expected, actual); + } + @Test public void applySchemaToJSON() { JavaRDD jsonRDD = javaCtx.parallelize(Arrays.asList( @@ -122,7 +162,7 @@ public void applySchemaToJSON() { fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(), true)); fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true)); fields.add(DataTypes.createStructField("double", DataTypes.DoubleType, true)); - fields.add(DataTypes.createStructField("integer", DataTypes.IntegerType, true)); + fields.add(DataTypes.createStructField("integer", DataTypes.LongType, true)); fields.add(DataTypes.createStructField("long", DataTypes.LongType, true)); fields.add(DataTypes.createStructField("null", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("string", DataTypes.StringType, true)); @@ -147,18 +187,18 @@ public void applySchemaToJSON() { null, "this is another simple string.")); - SchemaRDD schemaRDD1 = javaSqlCtx.jsonRDD(jsonRDD.rdd()); - StructType actualSchema1 = schemaRDD1.schema(); + DataFrame df1 = sqlContext.jsonRDD(jsonRDD); + StructType actualSchema1 = df1.schema(); Assert.assertEquals(expectedSchema, actualSchema1); - schemaRDD1.registerTempTable("jsonTable1"); - List actual1 = javaSqlCtx.sql("select * from jsonTable1").collectAsList(); + df1.registerTempTable("jsonTable1"); + List actual1 = sqlContext.sql("select * from jsonTable1").collectAsList(); Assert.assertEquals(expectedResult, actual1); - SchemaRDD schemaRDD2 = javaSqlCtx.jsonRDD(jsonRDD.rdd(), expectedSchema); - StructType actualSchema2 = schemaRDD2.schema(); + DataFrame df2 = sqlContext.jsonRDD(jsonRDD, expectedSchema); + StructType actualSchema2 = df2.schema(); Assert.assertEquals(expectedSchema, actualSchema2); - schemaRDD2.registerTempTable("jsonTable2"); - List actual2 = javaSqlCtx.sql("select * from jsonTable2").collectAsList(); + df2.registerTempTable("jsonTable2"); + List actual2 = sqlContext.sql("select * from jsonTable2").collectAsList(); Assert.assertEquals(expectedResult, actual2); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java new file mode 100644 index 000000000000..e02c84872c62 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import com.google.common.collect.ImmutableMap; +import com.google.common.primitives.Ints; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.TestData$; +import org.apache.spark.sql.test.TestSQLContext; +import org.apache.spark.sql.test.TestSQLContext$; +import org.apache.spark.sql.types.*; +import org.junit.*; + +import scala.collection.JavaConversions; +import scala.collection.Seq; +import scala.collection.mutable.Buffer; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import static org.apache.spark.sql.functions.*; + +public class JavaDataFrameSuite { + private transient JavaSparkContext jsc; + private transient SQLContext context; + + @Before + public void setUp() { + // Trigger static initializer of TestData + TestData$.MODULE$.testData(); + jsc = new JavaSparkContext(TestSQLContext.sparkContext()); + context = TestSQLContext$.MODULE$; + } + + @After + public void tearDown() { + jsc = null; + context = null; + } + + @Test + public void testExecution() { + DataFrame df = context.table("testData").filter("key = 1"); + Assert.assertEquals(df.select("key").collect()[0].get(0), 1); + } + + /** + * See SPARK-5904. Abstract vararg methods defined in Scala do not work in Java. + */ + @Test + public void testVarargMethods() { + DataFrame df = context.table("testData"); + + df.toDF("key1", "value1"); + + df.select("key", "value"); + df.select(col("key"), col("value")); + df.selectExpr("key", "value + 1"); + + df.sort("key", "value"); + df.sort(col("key"), col("value")); + df.orderBy("key", "value"); + df.orderBy(col("key"), col("value")); + + df.groupBy("key", "value").agg(col("key"), col("value"), sum("value")); + df.groupBy(col("key"), col("value")).agg(col("key"), col("value"), sum("value")); + df.agg(first("key"), sum("value")); + + df.groupBy().avg("key"); + df.groupBy().mean("key"); + df.groupBy().max("key"); + df.groupBy().min("key"); + df.groupBy().sum("key"); + + // Varargs in column expressions + df.groupBy().agg(countDistinct("key", "value")); + df.groupBy().agg(countDistinct(col("key"), col("value"))); + df.select(coalesce(col("key"))); + } + + @Ignore + public void testShow() { + // This test case is intended ignored, but to make sure it compiles correctly + DataFrame df = context.table("testData"); + df.show(); + df.show(1000); + } + + public static class Bean implements Serializable { + private double a = 0.0; + private Integer[] b = new Integer[]{0, 1}; + private Map c = ImmutableMap.of("hello", new int[] { 1, 2 }); + private List d = Arrays.asList("floppy", "disk"); + + public double getA() { + return a; + } + + public Integer[] getB() { + return b; + } + + public Map getC() { + return c; + } + + public List getD() { + return d; + } + } + + @Test + public void testCreateDataFrameFromJavaBeans() { + Bean bean = new Bean(); + JavaRDD rdd = jsc.parallelize(Arrays.asList(bean)); + DataFrame df = context.createDataFrame(rdd, Bean.class); + StructType schema = df.schema(); + Assert.assertEquals(new StructField("a", DoubleType$.MODULE$, false, Metadata.empty()), + schema.apply("a")); + Assert.assertEquals( + new StructField("b", new ArrayType(IntegerType$.MODULE$, true), true, Metadata.empty()), + schema.apply("b")); + ArrayType valueType = new ArrayType(DataTypes.IntegerType, false); + MapType mapType = new MapType(DataTypes.StringType, valueType, true); + Assert.assertEquals( + new StructField("c", mapType, true, Metadata.empty()), + schema.apply("c")); + Assert.assertEquals( + new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()), + schema.apply("d")); + Row first = df.select("a", "b", "c", "d").first(); + Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); + // Now Java lists and maps are converetd to Scala Seq's and Map's. Once we get a Seq below, + // verify that it has the expected length, and contains expected elements. + Seq result = first.getAs(1); + Assert.assertEquals(bean.getB().length, result.length()); + for (int i = 0; i < result.length(); i++) { + Assert.assertEquals(bean.getB()[i], result.apply(i)); + } + Buffer outputBuffer = (Buffer) first.getJavaMap(2).get("hello"); + Assert.assertArrayEquals( + bean.getC().get("hello"), + Ints.toArray(JavaConversions.bufferAsJavaList(outputBuffer))); + Seq d = first.getAs(3); + Assert.assertEquals(bean.getD().size(), d.length()); + for (int i = 0; i < d.length(); i++) { + Assert.assertEquals(bean.getD().get(i), d.apply(i)); + } + } + +} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java similarity index 99% rename from sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java rename to sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java index fbfcd3f59d91..4ce1d1dddb26 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java; +package test.org.apache.spark.sql; import java.math.BigDecimal; import java.sql.Date; diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java similarity index 88% rename from sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java rename to sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 9ff40471a00a..79d92734ff37 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java; +package test.org.apache.spark.sql; import java.io.Serializable; @@ -23,28 +23,29 @@ import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.api.java.UDF1; +import org.apache.spark.sql.api.java.UDF2; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.DataTypes; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; // see http://stackoverflow.com/questions/758570/. -public class JavaAPISuite implements Serializable { +public class JavaUDFSuite implements Serializable { private transient JavaSparkContext sc; private transient SQLContext sqlContext; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaAPISuite"); - sqlContext = new SQLContext(sc); + sqlContext = TestSQLContext$.MODULE$; + sc = new JavaSparkContext(sqlContext.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; } @SuppressWarnings("unchecked") @@ -61,7 +62,7 @@ public Integer call(String str) throws Exception { } }, DataTypes.IntegerType); - Row result = sqlContext.sql("SELECT stringLengthTest('test')").first(); + Row result = sqlContext.sql("SELECT stringLengthTest('test')").head(); assert(result.getInt(0) == 4); } @@ -81,7 +82,7 @@ public Integer call(String str1, String str2) throws Exception { } }, DataTypes.IntegerType); - Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").first(); + Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head(); assert(result.getInt(0) == 9); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java new file mode 100644 index 000000000000..b76f7d421f64 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources; + +import java.io.File; +import java.io.IOException; +import java.util.*; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.test.TestSQLContext$; +import org.apache.spark.sql.*; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.Utils; + +public class JavaSaveLoadSuite { + + private transient JavaSparkContext sc; + private transient SQLContext sqlContext; + + String originalDefaultSource; + File path; + DataFrame df; + + private void checkAnswer(DataFrame actual, List expected) { + String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); + if (errorMessage != null) { + Assert.fail(errorMessage); + } + } + + @Before + public void setUp() throws IOException { + sqlContext = TestSQLContext$.MODULE$; + sc = new JavaSparkContext(sqlContext.sparkContext()); + + originalDefaultSource = sqlContext.conf().defaultDataSourceName(); + path = + Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile(); + if (path.exists()) { + path.delete(); + } + + List jsonObjects = new ArrayList(10); + for (int i = 0; i < 10; i++) { + jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); + } + JavaRDD rdd = sc.parallelize(jsonObjects); + df = sqlContext.jsonRDD(rdd); + df.registerTempTable("jsonTable"); + } + + @Test + public void saveAndLoad() { + Map options = new HashMap(); + options.put("path", path.toString()); + df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, options); + + DataFrame loadedDF = sqlContext.load("org.apache.spark.sql.json", options); + + checkAnswer(loadedDF, df.collectAsList()); + } + + @Test + public void saveAndLoadWithSchema() { + Map options = new HashMap(); + options.put("path", path.toString()); + df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, options); + + List fields = new ArrayList(); + fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); + StructType schema = DataTypes.createStructType(fields); + DataFrame loadedDF = sqlContext.load("org.apache.spark.sql.json", schema, options); + + checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList()); + } +} diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties index fbed0a782dd3..28e90b9520b2 100644 --- a/sql/core/src/test/resources/log4j.properties +++ b/sql/core/src/test/resources/log4j.properties @@ -39,6 +39,9 @@ log4j.appender.FA.Threshold = INFO log4j.additivity.parquet.hadoop.ParquetRecordReader=false log4j.logger.parquet.hadoop.ParquetRecordReader=OFF +log4j.additivity.parquet.hadoop.ParquetOutputCommitter=false +log4j.logger.parquet.hadoop.ParquetOutputCommitter=OFF + log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false log4j.logger.org.apache.hadoop.hive.serde2.lazy.LazyStruct=OFF diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index cfc037caff2a..0772e5e18742 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -17,10 +17,17 @@ package org.apache.spark.sql +import scala.concurrent.duration._ +import scala.language.{implicitConversions, postfixOps} + +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark.Accumulators import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.storage.{StorageLevel, RDDBlockId} +import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.apache.spark.storage.{RDDBlockId, StorageLevel} case class BigData(s: String) @@ -50,17 +57,17 @@ class CachedTableSuite extends QueryTest { } test("unpersist an uncached table will not raise exception") { - assert(None == lookupCachedData(testData)) - testData.unpersist(true) - assert(None == lookupCachedData(testData)) - testData.unpersist(false) - assert(None == lookupCachedData(testData)) + assert(None == cacheManager.lookupCachedData(testData)) + testData.unpersist(blocking = true) + assert(None == cacheManager.lookupCachedData(testData)) + testData.unpersist(blocking = false) + assert(None == cacheManager.lookupCachedData(testData)) testData.persist() - assert(None != lookupCachedData(testData)) - testData.unpersist(true) - assert(None == lookupCachedData(testData)) - testData.unpersist(false) - assert(None == lookupCachedData(testData)) + assert(None != cacheManager.lookupCachedData(testData)) + testData.unpersist(blocking = true) + assert(None == cacheManager.lookupCachedData(testData)) + testData.unpersist(blocking = false) + assert(None == cacheManager.lookupCachedData(testData)) } test("cache table as select") { @@ -86,7 +93,8 @@ class CachedTableSuite extends QueryTest { test("too big for memory") { val data = "*" * 10000 - sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).registerTempTable("bigData") + sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() + .registerTempTable("bigData") table("bigData").persist(StorageLevel.MEMORY_AND_DISK) assert(table("bigData").count() === 200000L) table("bigData").unpersist(blocking = true) @@ -190,7 +198,10 @@ class CachedTableSuite extends QueryTest { sql("UNCACHE TABLE testData") assert(!isCached("testData"), "Table 'testData' should not be cached") - assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + + eventually(timeout(10 seconds)) { + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + } } test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { @@ -203,7 +214,9 @@ class CachedTableSuite extends QueryTest { "Eagerly cached in-memory table should have already been materialized") uncacheTable("testCacheTable") - assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + eventually(timeout(10 seconds)) { + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + } } test("CACHE TABLE tableName AS SELECT ...") { @@ -216,7 +229,9 @@ class CachedTableSuite extends QueryTest { "Eagerly cached in-memory table should have already been materialized") uncacheTable("testCacheTable") - assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + eventually(timeout(10 seconds)) { + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + } } test("CACHE LAZY TABLE tableName") { @@ -234,7 +249,9 @@ class CachedTableSuite extends QueryTest { "Lazily cached in-memory table should have been materialized") uncacheTable("testData") - assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + eventually(timeout(10 seconds)) { + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + } } test("InMemoryRelation statistics") { @@ -265,4 +282,44 @@ class CachedTableSuite extends QueryTest { assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found")) assert(!isCached("t2")) } + + test("Clear all cache") { + sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") + sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") + cacheTable("t1") + cacheTable("t2") + clearCache() + assert(cacheManager.isEmpty) + + sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") + sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") + cacheTable("t1") + cacheTable("t2") + sql("Clear CACHE") + assert(cacheManager.isEmpty) + } + + test("Clear accumulators when uncacheTable to prevent memory leaking") { + sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") + sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") + + Accumulators.synchronized { + val accsSize = Accumulators.originals.size + cacheTable("t1") + cacheTable("t2") + assert((accsSize + 2) == Accumulators.originals.size) + } + + sql("SELECT * FROM t1").count() + sql("SELECT * FROM t2").count() + sql("SELECT * FROM t1").count() + sql("SELECT * FROM t2").count() + + Accumulators.synchronized { + val accsSize = Accumulators.originals.size + uncacheTable("t1") + uncacheTable("t2") + assert((accsSize - 2) == Accumulators.originals.size) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala new file mode 100644 index 000000000000..bc8fae100db6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -0,0 +1,334 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.apache.spark.sql.types._ + + +class ColumnExpressionSuite extends QueryTest { + import org.apache.spark.sql.TestData._ + + // TODO: Add test cases for bitwise operations. + + test("collect on column produced by a binary operator") { + val df = Seq((1, 2, 3)).toDF("a", "b", "c") + checkAnswer(df.select(df("a") + df("b")), Seq(Row(3))) + checkAnswer(df.select(df("a") + df("b").as("c")), Seq(Row(3))) + } + + test("star") { + checkAnswer(testData.select($"*"), testData.collect().toSeq) + } + + test("star qualified by data frame object") { + val df = testData.toDF + val goldAnswer = df.collect().toSeq + checkAnswer(df.select(df("*")), goldAnswer) + + val df1 = df.select(df("*"), lit("abcd").as("litCol")) + checkAnswer(df1.select(df("*")), goldAnswer) + } + + test("star qualified by table name") { + checkAnswer(testData.as("testData").select($"testData.*"), testData.collect().toSeq) + } + + test("+") { + checkAnswer( + testData2.select($"a" + 1), + testData2.collect().toSeq.map(r => Row(r.getInt(0) + 1))) + + checkAnswer( + testData2.select($"a" + $"b" + 2), + testData2.collect().toSeq.map(r => Row(r.getInt(0) + r.getInt(1) + 2))) + } + + test("-") { + checkAnswer( + testData2.select($"a" - 1), + testData2.collect().toSeq.map(r => Row(r.getInt(0) - 1))) + + checkAnswer( + testData2.select($"a" - $"b" - 2), + testData2.collect().toSeq.map(r => Row(r.getInt(0) - r.getInt(1) - 2))) + } + + test("*") { + checkAnswer( + testData2.select($"a" * 10), + testData2.collect().toSeq.map(r => Row(r.getInt(0) * 10))) + + checkAnswer( + testData2.select($"a" * $"b"), + testData2.collect().toSeq.map(r => Row(r.getInt(0) * r.getInt(1)))) + } + + test("/") { + checkAnswer( + testData2.select($"a" / 2), + testData2.collect().toSeq.map(r => Row(r.getInt(0).toDouble / 2))) + + checkAnswer( + testData2.select($"a" / $"b"), + testData2.collect().toSeq.map(r => Row(r.getInt(0).toDouble / r.getInt(1)))) + } + + + test("%") { + checkAnswer( + testData2.select($"a" % 2), + testData2.collect().toSeq.map(r => Row(r.getInt(0) % 2))) + + checkAnswer( + testData2.select($"a" % $"b"), + testData2.collect().toSeq.map(r => Row(r.getInt(0) % r.getInt(1)))) + } + + test("unary -") { + checkAnswer( + testData2.select(-$"a"), + testData2.collect().toSeq.map(r => Row(-r.getInt(0)))) + } + + test("unary !") { + checkAnswer( + complexData.select(!$"b"), + complexData.collect().toSeq.map(r => Row(!r.getBoolean(3)))) + } + + test("isNull") { + checkAnswer( + nullStrings.toDF.where($"s".isNull), + nullStrings.collect().toSeq.filter(r => r.getString(1) eq null)) + } + + test("isNotNull") { + checkAnswer( + nullStrings.toDF.where($"s".isNotNull), + nullStrings.collect().toSeq.filter(r => r.getString(1) ne null)) + } + + test("===") { + checkAnswer( + testData2.filter($"a" === 1), + testData2.collect().toSeq.filter(r => r.getInt(0) == 1)) + + checkAnswer( + testData2.filter($"a" === $"b"), + testData2.collect().toSeq.filter(r => r.getInt(0) == r.getInt(1))) + } + + test("<=>") { + checkAnswer( + testData2.filter($"a" === 1), + testData2.collect().toSeq.filter(r => r.getInt(0) == 1)) + + checkAnswer( + testData2.filter($"a" === $"b"), + testData2.collect().toSeq.filter(r => r.getInt(0) == r.getInt(1))) + } + + test("!==") { + val nullData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize( + Row(1, 1) :: + Row(1, 2) :: + Row(1, null) :: + Row(null, null) :: Nil), + StructType(Seq(StructField("a", IntegerType), StructField("b", IntegerType)))) + + checkAnswer( + nullData.filter($"b" <=> 1), + Row(1, 1) :: Nil) + + checkAnswer( + nullData.filter($"b" <=> null), + Row(1, null) :: Row(null, null) :: Nil) + + checkAnswer( + nullData.filter($"a" <=> $"b"), + Row(1, 1) :: Row(null, null) :: Nil) + } + + test(">") { + checkAnswer( + testData2.filter($"a" > 1), + testData2.collect().toSeq.filter(r => r.getInt(0) > 1)) + + checkAnswer( + testData2.filter($"a" > $"b"), + testData2.collect().toSeq.filter(r => r.getInt(0) > r.getInt(1))) + } + + test(">=") { + checkAnswer( + testData2.filter($"a" >= 1), + testData2.collect().toSeq.filter(r => r.getInt(0) >= 1)) + + checkAnswer( + testData2.filter($"a" >= $"b"), + testData2.collect().toSeq.filter(r => r.getInt(0) >= r.getInt(1))) + } + + test("<") { + checkAnswer( + testData2.filter($"a" < 2), + testData2.collect().toSeq.filter(r => r.getInt(0) < 2)) + + checkAnswer( + testData2.filter($"a" < $"b"), + testData2.collect().toSeq.filter(r => r.getInt(0) < r.getInt(1))) + } + + test("<=") { + checkAnswer( + testData2.filter($"a" <= 2), + testData2.collect().toSeq.filter(r => r.getInt(0) <= 2)) + + checkAnswer( + testData2.filter($"a" <= $"b"), + testData2.collect().toSeq.filter(r => r.getInt(0) <= r.getInt(1))) + } + + val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize( + Row(false, false) :: + Row(false, true) :: + Row(true, false) :: + Row(true, true) :: Nil), + StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType)))) + + test("&&") { + checkAnswer( + booleanData.filter($"a" && true), + Row(true, false) :: Row(true, true) :: Nil) + + checkAnswer( + booleanData.filter($"a" && false), + Nil) + + checkAnswer( + booleanData.filter($"a" && $"b"), + Row(true, true) :: Nil) + } + + test("||") { + checkAnswer( + booleanData.filter($"a" || true), + booleanData.collect()) + + checkAnswer( + booleanData.filter($"a" || false), + Row(true, false) :: Row(true, true) :: Nil) + + checkAnswer( + booleanData.filter($"a" || $"b"), + Row(false, true) :: Row(true, false) :: Row(true, true) :: Nil) + } + + test("sqrt") { + checkAnswer( + testData.select(sqrt('key)).orderBy('key.asc), + (1 to 100).map(n => Row(math.sqrt(n))) + ) + + checkAnswer( + testData.select(sqrt('value), 'key).orderBy('key.asc, 'value.asc), + (1 to 100).map(n => Row(math.sqrt(n), n)) + ) + + checkAnswer( + testData.select(sqrt(lit(null))), + (1 to 100).map(_ => Row(null)) + ) + } + + test("abs") { + checkAnswer( + testData.select(abs('key)).orderBy('key.asc), + (1 to 100).map(n => Row(n)) + ) + + checkAnswer( + negativeData.select(abs('key)).orderBy('key.desc), + (1 to 100).map(n => Row(n)) + ) + + checkAnswer( + testData.select(abs(lit(null))), + (1 to 100).map(_ => Row(null)) + ) + } + + test("upper") { + checkAnswer( + lowerCaseData.select(upper('l)), + ('a' to 'd').map(c => Row(c.toString.toUpperCase)) + ) + + checkAnswer( + testData.select(upper('value), 'key), + (1 to 100).map(n => Row(n.toString, n)) + ) + + checkAnswer( + testData.select(upper(lit(null))), + (1 to 100).map(n => Row(null)) + ) + } + + test("lower") { + checkAnswer( + upperCaseData.select(lower('L)), + ('A' to 'F').map(c => Row(c.toString.toLowerCase)) + ) + + checkAnswer( + testData.select(lower('value), 'key), + (1 to 100).map(n => Row(n.toString, n)) + ) + + checkAnswer( + testData.select(lower(lit(null))), + (1 to 100).map(n => Row(null)) + ) + } + + test("lift alias out of cast") { + compareExpressions( + col("1234").as("name").cast("int").expr, + col("1234").cast("int").as("name").expr) + } + + test("columns can be compared") { + assert('key.desc == 'key.desc) + assert('key.desc != 'key.asc) + } + + test("alias with metadata") { + val metadata = new MetadataBuilder() + .putString("originName", "value") + .build() + val schema = testData + .select($"*", col("value").as("abc", metadata)) + .schema + assert(schema("value").metadata === Metadata.empty) + assert(schema("abc").metadata === metadata) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala new file mode 100644 index 000000000000..2d2367d6e729 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -0,0 +1,55 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.TestSQLContext.{sparkContext => sc} +import org.apache.spark.sql.test.TestSQLContext.implicits._ + + +class DataFrameImplicitsSuite extends QueryTest { + + test("RDD of tuples") { + checkAnswer( + sc.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), + (1 to 10).map(i => Row(i, i.toString))) + } + + test("Seq of tuples") { + checkAnswer( + (1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), + (1 to 10).map(i => Row(i, i.toString))) + } + + test("RDD[Int]") { + checkAnswer( + sc.parallelize(1 to 10).toDF("intCol"), + (1 to 10).map(i => Row(i))) + } + + test("RDD[Long]") { + checkAnswer( + sc.parallelize(1L to 10L).toDF("longCol"), + (1L to 10L).map(i => Row(i))) + } + + test("RDD[String]") { + checkAnswer( + sc.parallelize(1 to 10).map(_.toString).toDF("stringCol"), + (1 to 10).map(i => Row(i.toString))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala new file mode 100644 index 000000000000..41b4f02e6a29 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.JavaConversions._ + +import org.apache.spark.sql.test.TestSQLContext.implicits._ + + +class DataFrameNaFunctionsSuite extends QueryTest { + + def createDF(): DataFrame = { + Seq[(String, java.lang.Integer, java.lang.Double)]( + ("Bob", 16, 176.5), + ("Alice", null, 164.3), + ("David", 60, null), + ("Amy", null, null), + (null, null, null)).toDF("name", "age", "height") + } + + test("drop") { + val input = createDF() + val rows = input.collect() + + checkAnswer( + input.na.drop("name" :: Nil), + rows(0) :: rows(1) :: rows(2) :: rows(3) :: Nil) + + checkAnswer( + input.na.drop("age" :: Nil), + rows(0) :: rows(2) :: Nil) + + checkAnswer( + input.na.drop("age" :: "height" :: Nil), + rows(0) :: Nil) + + checkAnswer( + input.na.drop(), + rows(0)) + + // dropna on an a dataframe with no column should return an empty data frame. + val empty = input.sqlContext.emptyDataFrame.select() + assert(empty.na.drop().count() === 0L) + + // Make sure the columns are properly named. + assert(input.na.drop().columns.toSeq === input.columns.toSeq) + } + + test("drop with how") { + val input = createDF() + val rows = input.collect() + + checkAnswer( + input.na.drop("all"), + rows(0) :: rows(1) :: rows(2) :: rows(3) :: Nil) + + checkAnswer( + input.na.drop("any"), + rows(0) :: Nil) + + checkAnswer( + input.na.drop("any", Seq("age", "height")), + rows(0) :: Nil) + + checkAnswer( + input.na.drop("all", Seq("age", "height")), + rows(0) :: rows(1) :: rows(2) :: Nil) + } + + test("drop with threshold") { + val input = createDF() + val rows = input.collect() + + checkAnswer( + input.na.drop(2, Seq("age", "height")), + rows(0) :: Nil) + + checkAnswer( + input.na.drop(3, Seq("name", "age", "height")), + rows(0)) + + // Make sure the columns are properly named. + assert(input.na.drop(2, Seq("age", "height")).columns.toSeq === input.columns.toSeq) + } + + test("fill") { + val input = createDF() + + val fillNumeric = input.na.fill(50.6) + checkAnswer( + fillNumeric, + Row("Bob", 16, 176.5) :: + Row("Alice", 50, 164.3) :: + Row("David", 60, 50.6) :: + Row("Amy", 50, 50.6) :: + Row(null, 50, 50.6) :: Nil) + + // Make sure the columns are properly named. + assert(fillNumeric.columns.toSeq === input.columns.toSeq) + + // string + checkAnswer( + input.na.fill("unknown").select("name"), + Row("Bob") :: Row("Alice") :: Row("David") :: Row("Amy") :: Row("unknown") :: Nil) + assert(input.na.fill("unknown").columns.toSeq === input.columns.toSeq) + + // fill double with subset columns + checkAnswer( + input.na.fill(50.6, "age" :: Nil), + Row("Bob", 16, 176.5) :: + Row("Alice", 50, 164.3) :: + Row("David", 60, null) :: + Row("Amy", 50, null) :: + Row(null, 50, null) :: Nil) + + // fill string with subset columns + checkAnswer( + Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", "col1" :: Nil), + Row("test", null)) + } + + test("fill with map") { + val df = Seq[(String, String, java.lang.Long, java.lang.Double)]( + (null, null, null, null)).toDF("a", "b", "c", "d") + checkAnswer( + df.na.fill(Map( + "a" -> "test", + "c" -> 1, + "d" -> 2.2 + )), + Row("test", null, 1, 2.2)) + + // Test Java version + checkAnswer( + df.na.fill(mapAsJavaMap(Map( + "a" -> "test", + "c" -> 1, + "d" -> 2.2 + ))), + Row("test", null, 1, 2.2)) + } + + test("replace") { + val input = createDF() + + // Replace two numeric columns: age and height + val out = input.na.replace(Seq("age", "height"), Map( + 16 -> 61, + 60 -> 6, + 164.3 -> 461.3 // Alice is really tall + )) + + checkAnswer( + out, + Row("Bob", 61, 176.5) :: + Row("Alice", null, 461.3) :: + Row("David", 6, null) :: + Row("Amy", null, null) :: + Row(null, null, null) :: Nil) + + // Replace only the age column + val out1 = input.na.replace("age", Map( + 16 -> 61, + 60 -> 6, + 164.3 -> 461.3 // Alice is really tall + )) + + checkAnswer( + out1, + Row("Bob", 61, 176.5) :: + Row("Alice", null, 164.3) :: + Row("David", 6, null) :: + Row("Amy", null, null) :: + Row(null, null, null) :: Nil) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala new file mode 100644 index 000000000000..5ec06d448e50 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -0,0 +1,583 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.postfixOps + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext} +import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery +import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.apache.spark.sql.test.TestSQLContext.sql + + +class DataFrameSuite extends QueryTest { + import org.apache.spark.sql.TestData._ + + test("analysis error should be eagerly reported") { + val oldSetting = TestSQLContext.conf.dataFrameEagerAnalysis + // Eager analysis. + TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") + + intercept[Exception] { testData.select('nonExistentName) } + intercept[Exception] { + testData.groupBy('key).agg(Map("nonExistentName" -> "sum")) + } + intercept[Exception] { + testData.groupBy("nonExistentName").agg(Map("key" -> "sum")) + } + intercept[Exception] { + testData.groupBy($"abcd").agg(Map("key" -> "sum")) + } + + // No more eager analysis once the flag is turned off + TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false") + testData.select('nonExistentName) + + // Set the flag back to original value before this test. + TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) + } + + test("dataframe toString") { + assert(testData.toString === "[key: int, value: string]") + assert(testData("key").toString === "key") + assert($"test".toString === "test") + } + + test("rename nested groupby") { + val df = Seq((1,(1,1))).toDF() + + checkAnswer( + df.groupBy("_1").agg(col("_1"), sum("_2._1")).toDF("key", "total"), + Row(1, 1) :: Nil) + } + + test("invalid plan toString, debug mode") { + val oldSetting = TestSQLContext.conf.dataFrameEagerAnalysis + TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") + + // Turn on debug mode so we can see invalid query plans. + import org.apache.spark.sql.execution.debug._ + TestSQLContext.debug() + + val badPlan = testData.select('badColumn) + + assert(badPlan.toString contains badPlan.queryExecution.toString, + "toString on bad query plans should include the query execution but was:\n" + + badPlan.toString) + + // Set the flag back to original value before this test. + TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) + } + + test("access complex data") { + assert(complexData.filter(complexData("a").getItem(0) === 2).count() == 1) + assert(complexData.filter(complexData("m").getItem("1") === 1).count() == 1) + assert(complexData.filter(complexData("s").getField("key") === 1).count() == 1) + } + + test("table scan") { + checkAnswer( + testData, + testData.collect().toSeq) + } + + test("empty data frame") { + assert(TestSQLContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) + assert(TestSQLContext.emptyDataFrame.count() === 0) + } + + test("head and take") { + assert(testData.take(2) === testData.collect().take(2)) + assert(testData.head(2) === testData.collect().take(2)) + assert(testData.head(2).head.schema === testData.schema) + } + + test("simple explode") { + val df = Seq(Tuple1("a b c"), Tuple1("d e")).toDF("words") + + checkAnswer( + df.explode("words", "word") { word: String => word.split(" ").toSeq }.select('word), + Row("a") :: Row("b") :: Row("c") :: Row("d") ::Row("e") :: Nil + ) + } + + test("join - join using") { + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") + val df2 = Seq(1, 2, 3).map(i => (i, (i + 1).toString)).toDF("int", "str") + + checkAnswer( + df.join(df2, "int"), + Row(1, "1", "2") :: Row(2, "2", "3") :: Row(3, "3", "4") :: Nil) + } + + test("join - join using self join") { + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") + + // self join + checkAnswer( + df.join(df, "int"), + Row(1, "1", "1") :: Row(2, "2", "2") :: Row(3, "3", "3") :: Nil) + } + + test("join - self join") { + val df1 = testData.select(testData("key")).as('df1) + val df2 = testData.select(testData("key")).as('df2) + + checkAnswer( + df1.join(df2, $"df1.key" === $"df2.key"), + sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq) + } + + test("join - using aliases after self join") { + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") + checkAnswer( + df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("x.str").count(), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) + + checkAnswer( + df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").count(), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) + } + + test("explode") { + val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDF("number", "letters") + val df2 = + df.explode('letters) { + case Row(letters: String) => letters.split(" ").map(Tuple1(_)).toSeq + } + + checkAnswer( + df2 + .select('_1 as 'letter, 'number) + .groupBy('letter) + .agg('letter, countDistinct('number)), + Row("a", 3) :: Row("b", 2) :: Row("c", 1) :: Nil + ) + } + + test("selectExpr") { + checkAnswer( + testData.selectExpr("abs(key)", "value"), + testData.collect().map(row => Row(math.abs(row.getInt(0)), row.getString(1))).toSeq) + } + + test("selectExpr with alias") { + checkAnswer( + testData.selectExpr("key as k").select("k"), + testData.select("key").collect().toSeq) + } + + test("filterExpr") { + checkAnswer( + testData.filter("key > 90"), + testData.collect().filter(_.getInt(0) > 90).toSeq) + } + + test("repartition") { + checkAnswer( + testData.select('key).repartition(10).select('key), + testData.select('key).collect().toSeq) + } + + test("coalesce") { + assert(testData.select('key).coalesce(1).rdd.partitions.size === 1) + + checkAnswer( + testData.select('key).coalesce(1).select('key), + testData.select('key).collect().toSeq) + } + + test("groupBy") { + checkAnswer( + testData2.groupBy("a").agg($"a", sum($"b")), + Seq(Row(1, 3), Row(2, 3), Row(3, 3)) + ) + checkAnswer( + testData2.groupBy("a").agg($"a", sum($"b").as("totB")).agg(sum('totB)), + Row(9) + ) + checkAnswer( + testData2.groupBy("a").agg(col("a"), count("*")), + Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil + ) + checkAnswer( + testData2.groupBy("a").agg(Map("*" -> "count")), + Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil + ) + checkAnswer( + testData2.groupBy("a").agg(Map("b" -> "sum")), + Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil + ) + + val df1 = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3, "d")) + .toDF("key", "value1", "value2", "rest") + + checkAnswer( + df1.groupBy("key").min(), + df1.groupBy("key").min("value1", "value2").collect() + ) + checkAnswer( + df1.groupBy("key").min("value2"), + Seq(Row("a", 0), Row("b", 4)) + ) + } + + test("agg without groups") { + checkAnswer( + testData2.agg(sum('b)), + Row(9) + ) + } + + test("convert $\"attribute name\" into unresolved attribute") { + checkAnswer( + testData.where($"key" === lit(1)).select($"value"), + Row("1")) + } + + test("convert Scala Symbol 'attrname into unresolved attribute") { + checkAnswer( + testData.where('key === lit(1)).select('value), + Row("1")) + } + + test("select *") { + checkAnswer( + testData.select($"*"), + testData.collect().toSeq) + } + + test("simple select") { + checkAnswer( + testData.where('key === lit(1)).select('value), + Row("1")) + } + + test("select with functions") { + checkAnswer( + testData.select(sum('value), avg('value), count(lit(1))), + Row(5050.0, 50.5, 100)) + + checkAnswer( + testData2.select('a + 'b, 'a < 'b), + Seq( + Row(2, false), + Row(3, true), + Row(3, false), + Row(4, false), + Row(4, false), + Row(5, false))) + + checkAnswer( + testData2.select(sumDistinct('a)), + Row(6)) + } + + test("global sorting") { + checkAnswer( + testData2.orderBy('a.asc, 'b.asc), + Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) + + checkAnswer( + testData2.orderBy(asc("a"), desc("b")), + Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) + + checkAnswer( + testData2.orderBy('a.asc, 'b.desc), + Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) + + checkAnswer( + testData2.orderBy('a.desc, 'b.desc), + Seq(Row(3,2), Row(3,1), Row(2,2), Row(2,1), Row(1,2), Row(1,1))) + + checkAnswer( + testData2.orderBy('a.desc, 'b.asc), + Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2))) + + checkAnswer( + arrayData.toDF().orderBy('data.getItem(0).asc), + arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq) + + checkAnswer( + arrayData.toDF().orderBy('data.getItem(0).desc), + arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq) + + checkAnswer( + arrayData.toDF().orderBy('data.getItem(1).asc), + arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq) + + checkAnswer( + arrayData.toDF().orderBy('data.getItem(1).desc), + arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq) + } + + test("limit") { + checkAnswer( + testData.limit(10), + testData.take(10).toSeq) + + checkAnswer( + arrayData.toDF().limit(1), + arrayData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq))) + + checkAnswer( + mapData.toDF().limit(1), + mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq))) + } + + test("average") { + checkAnswer( + testData2.agg(avg('a)), + Row(2.0)) + + checkAnswer( + testData2.agg(avg('a), sumDistinct('a)), // non-partial + Row(2.0, 6.0) :: Nil) + + checkAnswer( + decimalData.agg(avg('a)), + Row(new java.math.BigDecimal(2.0))) + checkAnswer( + decimalData.agg(avg('a), sumDistinct('a)), // non-partial + Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) + + checkAnswer( + decimalData.agg(avg('a cast DecimalType(10, 2))), + Row(new java.math.BigDecimal(2.0))) + // non-partial + checkAnswer( + decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), + Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) + } + + test("null average") { + checkAnswer( + testData3.agg(avg('b)), + Row(2.0)) + + checkAnswer( + testData3.agg(avg('b), countDistinct('b)), + Row(2.0, 1)) + + checkAnswer( + testData3.agg(avg('b), sumDistinct('b)), // non-partial + Row(2.0, 2.0)) + } + + test("zero average") { + checkAnswer( + emptyTableData.agg(avg('a)), + Row(null)) + + checkAnswer( + emptyTableData.agg(avg('a), sumDistinct('b)), // non-partial + Row(null, null)) + } + + test("count") { + assert(testData2.count() === testData2.map(_ => 1).count()) + + checkAnswer( + testData2.agg(count('a), sumDistinct('a)), // non-partial + Row(6, 6.0)) + } + + test("null count") { + checkAnswer( + testData3.groupBy('a).agg('a, count('b)), + Seq(Row(1,0), Row(2, 1)) + ) + + checkAnswer( + testData3.groupBy('a).agg('a, count('a + 'b)), + Seq(Row(1,0), Row(2, 1)) + ) + + checkAnswer( + testData3.agg(count('a), count('b), count(lit(1)), countDistinct('a), countDistinct('b)), + Row(2, 1, 2, 2, 1) + ) + + checkAnswer( + testData3.agg(count('b), countDistinct('b), sumDistinct('b)), // non-partial + Row(1, 1, 2) + ) + } + + test("zero count") { + assert(emptyTableData.count() === 0) + + checkAnswer( + emptyTableData.agg(count('a), sumDistinct('a)), // non-partial + Row(0, null)) + } + + test("zero sum") { + checkAnswer( + emptyTableData.agg(sum('a)), + Row(null)) + } + + test("zero sum distinct") { + checkAnswer( + emptyTableData.agg(sumDistinct('a)), + Row(null)) + } + + test("except") { + checkAnswer( + lowerCaseData.except(upperCaseData), + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.except(lowerCaseData), Nil) + checkAnswer(upperCaseData.except(upperCaseData), Nil) + } + + test("intersect") { + checkAnswer( + lowerCaseData.intersect(lowerCaseData), + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) + } + + test("udf") { + val foo = udf((a: Int, b: String) => a.toString + b) + + checkAnswer( + // SELECT *, foo(key, value) FROM testData + testData.select($"*", foo('key, 'value)).limit(3), + Row(1, "1", "11") :: Row(2, "2", "22") :: Row(3, "3", "33") :: Nil + ) + } + + test("call udf in SQLContext") { + val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") + val sqlctx = df.sqlContext + sqlctx.udf.register("simpleUdf", (v: Int) => v * v) + checkAnswer( + df.select($"id", callUdf("simpleUdf", $"value")), + Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil) + } + + test("withColumn") { + val df = testData.toDF().withColumn("newCol", col("key") + 1) + checkAnswer( + df, + testData.collect().map { case Row(key: Int, value: String) => + Row(key, value, key + 1) + }.toSeq) + assert(df.schema.map(_.name).toSeq === Seq("key", "value", "newCol")) + } + + test("replace column using withColumn") { + val df2 = TestSQLContext.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") + val df3 = df2.withColumn("x", df2("x") + 1) + checkAnswer( + df3.select("x"), + Row(2) :: Row(3) :: Row(4) :: Nil) + } + + test("withColumnRenamed") { + val df = testData.toDF().withColumn("newCol", col("key") + 1) + .withColumnRenamed("value", "valueRenamed") + checkAnswer( + df, + testData.collect().map { case Row(key: Int, value: String) => + Row(key, value, key + 1) + }.toSeq) + assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol")) + } + + test("describe") { + val describeTestData = Seq( + ("Bob", 16, 176), + ("Alice", 32, 164), + ("David", 60, 192), + ("Amy", 24, 180)).toDF("name", "age", "height") + + val describeResult = Seq( + Row("count", 4, 4), + Row("mean", 33.0, 178.0), + Row("stddev", 16.583123951777, 10.0), + Row("min", 16, 164), + Row("max", 60, 192)) + + val emptyDescribeResult = Seq( + Row("count", 0, 0), + Row("mean", null, null), + Row("stddev", null, null), + Row("min", null, null), + Row("max", null, null)) + + def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) + + val describeTwoCols = describeTestData.describe("age", "height") + assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "age", "height")) + checkAnswer(describeTwoCols, describeResult) + + val describeAllCols = describeTestData.describe() + assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "age", "height")) + checkAnswer(describeAllCols, describeResult) + + val describeOneCol = describeTestData.describe("age") + assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age")) + checkAnswer(describeOneCol, describeResult.map { case Row(s, d, _) => Row(s, d)} ) + + val describeNoCol = describeTestData.select("name").describe() + assert(getSchemaAsSeq(describeNoCol) === Seq("summary")) + checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _) => Row(s)} ) + + val emptyDescription = describeTestData.limit(0).describe() + assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "age", "height")) + checkAnswer(emptyDescription, emptyDescribeResult) + } + + test("apply on query results (SPARK-5462)") { + val df = testData.sqlContext.sql("select key from testData") + checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq) + } + + ignore("show") { + // This test case is intended ignored, but to make sure it compiles correctly + testData.select($"*").show() + testData.select($"*").show(1000) + } + + test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { + val rowRDD = TestSQLContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) + val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) + val df = TestSQLContext.createDataFrame(rowRDD, schema) + df.rdd.collect() + } + + test("SPARK-6899") { + val originalValue = TestSQLContext.conf.codegenEnabled + TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, "true") + checkAnswer( + decimalData.agg(avg('a)), + Row(new java.math.BigDecimal(2.0))) + TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala deleted file mode 100644 index afbfe214f1ce..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ /dev/null @@ -1,395 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types._ - -/* Implicits */ -import org.apache.spark.sql.catalyst.dsl._ -import org.apache.spark.sql.test.TestSQLContext._ - -import scala.language.postfixOps - -class DslQuerySuite extends QueryTest { - import org.apache.spark.sql.TestData._ - - test("table scan") { - checkAnswer( - testData, - testData.collect().toSeq) - } - - test("repartition") { - checkAnswer( - testData.select('key).repartition(10).select('key), - testData.select('key).collect().toSeq) - } - - test("agg") { - checkAnswer( - testData2.groupBy('a)('a, sum('b)), - Seq(Row(1,3), Row(2,3), Row(3,3)) - ) - checkAnswer( - testData2.groupBy('a)('a, sum('b) as 'totB).aggregate(sum('totB)), - Row(9) - ) - checkAnswer( - testData2.aggregate(sum('b)), - Row(9) - ) - } - - test("convert $\"attribute name\" into unresolved attribute") { - checkAnswer( - testData.where($"key" === 1).select($"value"), - Row("1")) - } - - test("convert Scala Symbol 'attrname into unresolved attribute") { - checkAnswer( - testData.where('key === 1).select('value), - Row("1")) - } - - test("select *") { - checkAnswer( - testData.select(Star(None)), - testData.collect().toSeq) - } - - test("simple select") { - checkAnswer( - testData.where('key === 1).select('value), - Row("1")) - } - - test("select with functions") { - checkAnswer( - testData.select(sum('value), avg('value), count(1)), - Row(5050.0, 50.5, 100)) - - checkAnswer( - testData2.select('a + 'b, 'a < 'b), - Seq( - Row(2, false), - Row(3, true), - Row(3, false), - Row(4, false), - Row(4, false), - Row(5, false))) - - checkAnswer( - testData2.select(sumDistinct('a)), - Row(6)) - } - - test("global sorting") { - checkAnswer( - testData2.orderBy('a.asc, 'b.asc), - Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) - - checkAnswer( - testData2.orderBy('a.asc, 'b.desc), - Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) - - checkAnswer( - testData2.orderBy('a.desc, 'b.desc), - Seq(Row(3,2), Row(3,1), Row(2,2), Row(2,1), Row(1,2), Row(1,1))) - - checkAnswer( - testData2.orderBy('a.desc, 'b.asc), - Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2))) - - checkAnswer( - arrayData.orderBy('data.getItem(0).asc), - arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq) - - checkAnswer( - arrayData.orderBy('data.getItem(0).desc), - arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq) - - checkAnswer( - arrayData.orderBy('data.getItem(1).asc), - arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq) - - checkAnswer( - arrayData.orderBy('data.getItem(1).desc), - arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq) - } - - test("partition wide sorting") { - // 2 partitions totally, and - // Partition #1 with values: - // (1, 1) - // (1, 2) - // (2, 1) - // Partition #2 with values: - // (2, 2) - // (3, 1) - // (3, 2) - checkAnswer( - testData2.sortBy('a.asc, 'b.asc), - Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) - - checkAnswer( - testData2.sortBy('a.asc, 'b.desc), - Seq(Row(1,2), Row(1,1), Row(2,1), Row(2,2), Row(3,2), Row(3,1))) - - checkAnswer( - testData2.sortBy('a.desc, 'b.desc), - Seq(Row(2,1), Row(1,2), Row(1,1), Row(3,2), Row(3,1), Row(2,2))) - - checkAnswer( - testData2.sortBy('a.desc, 'b.asc), - Seq(Row(2,1), Row(1,1), Row(1,2), Row(3,1), Row(3,2), Row(2,2))) - } - - test("limit") { - checkAnswer( - testData.limit(10), - testData.take(10).toSeq) - - checkAnswer( - arrayData.limit(1), - arrayData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq))) - - checkAnswer( - mapData.limit(1), - mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq))) - } - - test("SPARK-3395 limit distinct") { - val filtered = TestData.testData2 - .distinct() - .orderBy(SortOrder('a, Ascending), SortOrder('b, Ascending)) - .limit(1) - .registerTempTable("onerow") - checkAnswer( - sql("select * from onerow inner join testData2 on onerow.a = testData2.a"), - Row(1, 1, 1, 1) :: - Row(1, 1, 1, 2) :: Nil) - } - - test("SPARK-3858 generator qualifiers are discarded") { - checkAnswer( - arrayData.as('ad) - .generate(Explode("data" :: Nil, 'data), alias = Some("ex")) - .select("ex.data".attr), - Seq(1, 2, 3, 2, 3, 4).map(Row(_))) - } - - test("average") { - checkAnswer( - testData2.aggregate(avg('a)), - Row(2.0)) - - checkAnswer( - testData2.aggregate(avg('a), sumDistinct('a)), // non-partial - Row(2.0, 6.0) :: Nil) - - checkAnswer( - decimalData.aggregate(avg('a)), - Row(new java.math.BigDecimal(2.0))) - checkAnswer( - decimalData.aggregate(avg('a), sumDistinct('a)), // non-partial - Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) - - checkAnswer( - decimalData.aggregate(avg('a cast DecimalType(10, 2))), - Row(new java.math.BigDecimal(2.0))) - checkAnswer( - decimalData.aggregate(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial - Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) - } - - test("null average") { - checkAnswer( - testData3.aggregate(avg('b)), - Row(2.0)) - - checkAnswer( - testData3.aggregate(avg('b), countDistinct('b)), - Row(2.0, 1)) - - checkAnswer( - testData3.aggregate(avg('b), sumDistinct('b)), // non-partial - Row(2.0, 2.0)) - } - - test("zero average") { - checkAnswer( - emptyTableData.aggregate(avg('a)), - Row(null)) - - checkAnswer( - emptyTableData.aggregate(avg('a), sumDistinct('b)), // non-partial - Row(null, null)) - } - - test("count") { - assert(testData2.count() === testData2.map(_ => 1).count()) - - checkAnswer( - testData2.aggregate(count('a), sumDistinct('a)), // non-partial - Row(6, 6.0)) - } - - test("null count") { - checkAnswer( - testData3.groupBy('a)('a, count('b)), - Seq(Row(1,0), Row(2, 1)) - ) - - checkAnswer( - testData3.groupBy('a)('a, count('a + 'b)), - Seq(Row(1,0), Row(2, 1)) - ) - - checkAnswer( - testData3.aggregate(count('a), count('b), count(1), countDistinct('a), countDistinct('b)), - Row(2, 1, 2, 2, 1) - ) - - checkAnswer( - testData3.aggregate(count('b), countDistinct('b), sumDistinct('b)), // non-partial - Row(1, 1, 2) - ) - } - - test("zero count") { - assert(emptyTableData.count() === 0) - - checkAnswer( - emptyTableData.aggregate(count('a), sumDistinct('a)), // non-partial - Row(0, null)) - } - - test("zero sum") { - checkAnswer( - emptyTableData.aggregate(sum('a)), - Row(null)) - } - - test("zero sum distinct") { - checkAnswer( - emptyTableData.aggregate(sumDistinct('a)), - Row(null)) - } - - test("except") { - checkAnswer( - lowerCaseData.except(upperCaseData), - Row(1, "a") :: - Row(2, "b") :: - Row(3, "c") :: - Row(4, "d") :: Nil) - checkAnswer(lowerCaseData.except(lowerCaseData), Nil) - checkAnswer(upperCaseData.except(upperCaseData), Nil) - } - - test("intersect") { - checkAnswer( - lowerCaseData.intersect(lowerCaseData), - Row(1, "a") :: - Row(2, "b") :: - Row(3, "c") :: - Row(4, "d") :: Nil) - checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) - } - - test("udf") { - val foo = (a: Int, b: String) => a.toString + b - - checkAnswer( - // SELECT *, foo(key, value) FROM testData - testData.select(Star(None), foo.call('key, 'value)).limit(3), - Row(1, "1", "11") :: Row(2, "2", "22") :: Row(3, "3", "33") :: Nil - ) - } - - test("sqrt") { - checkAnswer( - testData.select(sqrt('key)).orderBy('key asc), - (1 to 100).map(n => Row(math.sqrt(n))) - ) - - checkAnswer( - testData.select(sqrt('value), 'key).orderBy('key asc, 'value asc), - (1 to 100).map(n => Row(math.sqrt(n), n)) - ) - - checkAnswer( - testData.select(sqrt(Literal(null))), - (1 to 100).map(_ => Row(null)) - ) - } - - test("abs") { - checkAnswer( - testData.select(abs('key)).orderBy('key asc), - (1 to 100).map(n => Row(n)) - ) - - checkAnswer( - negativeData.select(abs('key)).orderBy('key desc), - (1 to 100).map(n => Row(n)) - ) - - checkAnswer( - testData.select(abs(Literal(null))), - (1 to 100).map(_ => Row(null)) - ) - } - - test("upper") { - checkAnswer( - lowerCaseData.select(upper('l)), - ('a' to 'd').map(c => Row(c.toString.toUpperCase())) - ) - - checkAnswer( - testData.select(upper('value), 'key), - (1 to 100).map(n => Row(n.toString, n)) - ) - - checkAnswer( - testData.select(upper(Literal(null))), - (1 to 100).map(n => Row(null)) - ) - } - - test("lower") { - checkAnswer( - upperCaseData.select(lower('L)), - ('A' to 'F').map(c => Row(c.toString.toLowerCase())) - ) - - checkAnswer( - testData.select(lower('value), 'key), - (1 to 100).map(n => Row(n.toString, n)) - ) - - checkAnswer( - testData.select(lower(Literal(null))), - (1 to 100).map(n => Row(null)) - ) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index cd36da7751e8..037d392c1f92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -20,26 +20,28 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, RightOuter} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ + class JoinSuite extends QueryTest with BeforeAndAfterEach { // Ensures tables are loaded. TestData test("equi-join is hash-join") { - val x = testData2.as('x) - val y = testData2.as('y) - val join = x.join(y, Inner, Some("x.a".attr === "y.a".attr)).queryExecution.analyzed + val x = testData2.as("x") + val y = testData2.as("y") + val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan val planned = planner.HashJoin(join) assert(planned.size === 1) } def assertJoin(sqlString: String, c: Class[_]): Any = { - val rdd = sql(sqlString) - val physical = rdd.queryExecution.sparkPlan + val df = sql(sqlString) + val physical = df.queryExecution.sparkPlan val operators = physical.collect { case j: ShuffledHashJoin => j case j: HashOuterJoin => j @@ -49,6 +51,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j case j: BroadcastLeftSemiJoinHash => j + case j: SortMergeJoin => j } assert(operators.size === 1) @@ -58,8 +61,9 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("join operator selection") { - clearCache() + cacheManager.clearCache() + val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), @@ -89,33 +93,56 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + try { + conf.setConf("spark.sql.planner.sortMergeJoin", "true") + Seq( + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } finally { + conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + } } test("broadcasted hash join operator selection") { - clearCache() + cacheManager.clearCache() sql("CACHE TABLE testData") + val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin]) + ("SELECT * FROM testData join testData2 ON key = a where key = 2", + classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + try { + conf.setConf("spark.sql.planner.sortMergeJoin", "true") + Seq( + ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key = 2", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key = 2", + classOf[BroadcastHashJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } finally { + conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + } sql("UNCACHE TABLE testData") } test("multiple-key equi-join is hash-join") { - val x = testData2.as('x) - val y = testData2.as('y) - val join = x.join(y, Inner, - Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)).queryExecution.analyzed + val x = testData2.as("x") + val y = testData2.as("y") + val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan val planned = planner.HashJoin(join) assert(planned.size === 1) } test("inner join where, one match per row") { checkAnswer( - upperCaseData.join(lowerCaseData, Inner).where('n === 'N), + upperCaseData.join(lowerCaseData).where('n === 'N), Seq( Row(1, "A", 1, "a"), Row(2, "B", 2, "b"), @@ -126,7 +153,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("inner join ON, one match per row") { checkAnswer( - upperCaseData.join(lowerCaseData, Inner, Some('n === 'N)), + upperCaseData.join(lowerCaseData, $"n" === $"N"), Seq( Row(1, "A", 1, "a"), Row(2, "B", 2, "b"), @@ -136,10 +163,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("inner join, where, multiple matches") { - val x = testData2.where('a === 1).as('x) - val y = testData2.where('a === 1).as('y) + val x = testData2.where($"a" === 1).as("x") + val y = testData2.where($"a" === 1).as("y") checkAnswer( - x.join(y).where("x.a".attr === "y.a".attr), + x.join(y).where($"x.a" === $"y.a"), Row(1,1,1,1) :: Row(1,1,1,2) :: Row(1,2,1,1) :: @@ -148,22 +175,21 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("inner join, no matches") { - val x = testData2.where('a === 1).as('x) - val y = testData2.where('a === 2).as('y) + val x = testData2.where($"a" === 1).as("x") + val y = testData2.where($"a" === 2).as("y") checkAnswer( - x.join(y).where("x.a".attr === "y.a".attr), + x.join(y).where($"x.a" === $"y.a"), Nil) } test("big inner join, 4 matches per row") { val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData) - val bigDataX = bigData.as('x) - val bigDataY = bigData.as('y) + val bigDataX = bigData.as("x") + val bigDataY = bigData.as("y") checkAnswer( - bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr), - testData.flatMap( - row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) + bigDataX.join(bigDataY).where($"x.key" === $"y.key"), + testData.rdd.flatMap(row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } test("cartisian product join") { @@ -177,7 +203,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("left outer join") { checkAnswer( - upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)), + upperCaseData.join(lowerCaseData, $"n" === $"N", "left"), Row(1, "A", 1, "a") :: Row(2, "B", 2, "b") :: Row(3, "C", 3, "c") :: @@ -186,7 +212,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(6, "F", null, null) :: Nil) checkAnswer( - upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > 1, "left"), Row(1, "A", null, null) :: Row(2, "B", 2, "b") :: Row(3, "C", 3, "c") :: @@ -195,7 +221,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(6, "F", null, null) :: Nil) checkAnswer( - upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > 1, "left"), Row(1, "A", null, null) :: Row(2, "B", 2, "b") :: Row(3, "C", 3, "c") :: @@ -204,7 +230,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(6, "F", null, null) :: Nil) checkAnswer( - upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"l" > $"L", "left"), Row(1, "A", 1, "a") :: Row(2, "B", 2, "b") :: Row(3, "C", 3, "c") :: @@ -240,7 +266,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("right outer join") { checkAnswer( - lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)), + lowerCaseData.join(upperCaseData, $"n" === $"N", "right"), Row(1, "a", 1, "A") :: Row(2, "b", 2, "B") :: Row(3, "c", 3, "C") :: @@ -248,7 +274,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, null, 5, "E") :: Row(null, null, 6, "F") :: Nil) checkAnswer( - lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'n > 1)), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "right"), Row(null, null, 1, "A") :: Row(2, "b", 2, "B") :: Row(3, "c", 3, "C") :: @@ -256,7 +282,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, null, 5, "E") :: Row(null, null, 6, "F") :: Nil) checkAnswer( - lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'N > 1)), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "right"), Row(null, null, 1, "A") :: Row(2, "b", 2, "B") :: Row(3, "c", 3, "C") :: @@ -264,7 +290,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, null, 5, "E") :: Row(null, null, 6, "F") :: Nil) checkAnswer( - lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'l > 'L)), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "right"), Row(1, "a", 1, "A") :: Row(2, "b", 2, "B") :: Row(3, "c", 3, "C") :: @@ -306,7 +332,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val right = UnresolvedRelation(Seq("right"), None) checkAnswer( - left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)), + left.join(right, $"left.N" === $"right.N", "full"), Row(1, "A", null, null) :: Row(2, "B", null, null) :: Row(3, "C", 3, "C") :: @@ -315,7 +341,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, null, 6, "F") :: Nil) checkAnswer( - left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))), + left.join(right, ($"left.N" === $"right.N") && ($"left.N" !== 3), "full"), Row(1, "A", null, null) :: Row(2, "B", null, null) :: Row(3, "C", null, null) :: @@ -325,7 +351,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, null, 6, "F") :: Nil) checkAnswer( - left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))), + left.join(right, ($"left.N" === $"right.N") && ($"right.N" !== 3), "full"), Row(1, "A", null, null) :: Row(2, "B", null, null) :: Row(3, "C", null, null) :: @@ -385,7 +411,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("broadcasted left semi join operator selection") { - clearCache() + cacheManager.clearCache() sql("CACHE TABLE testData") val tmp = conf.autoBroadcastJoinThreshold @@ -410,8 +436,8 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("left semi join") { - val rdd = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") - checkAnswer(rdd, + val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") + checkAnswer(df, Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala new file mode 100644 index 000000000000..f9f41eb358bd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -0,0 +1,87 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} + +class ListTablesSuite extends QueryTest with BeforeAndAfter { + + import org.apache.spark.sql.test.TestSQLContext.implicits._ + + val df = + sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDF("key", "value") + + before { + df.registerTempTable("ListTablesSuiteTable") + } + + after { + catalog.unregisterTable(Seq("ListTablesSuiteTable")) + } + + test("get all tables") { + checkAnswer( + tables().filter("tableName = 'ListTablesSuiteTable'"), + Row("ListTablesSuiteTable", true)) + + checkAnswer( + sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), + Row("ListTablesSuiteTable", true)) + + catalog.unregisterTable(Seq("ListTablesSuiteTable")) + assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + } + + test("getting all Tables with a database name has no impact on returned table names") { + checkAnswer( + tables("DB").filter("tableName = 'ListTablesSuiteTable'"), + Row("ListTablesSuiteTable", true)) + + checkAnswer( + sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), + Row("ListTablesSuiteTable", true)) + + catalog.unregisterTable(Seq("ListTablesSuiteTable")) + assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + } + + test("query the returned DataFrame of tables") { + val expectedSchema = StructType( + StructField("tableName", StringType, false) :: + StructField("isTemporary", BooleanType, false) :: Nil) + + Seq(tables(), sql("SHOW TABLes")).foreach { + case tableDF => + assert(expectedSchema === tableDF.schema) + + tableDF.registerTempTable("tables") + checkAnswer( + sql("SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"), + Row(true, "ListTablesSuiteTable") + ) + checkAnswer( + tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), + Row("tables", true)) + dropTempTable("tables") + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 42a21c148df5..59f9508444f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -17,101 +17,143 @@ package org.apache.spark.sql +import java.util.{Locale, TimeZone} + +import scala.collection.JavaConversions._ + import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.columnar.InMemoryRelation class QueryTest extends PlanTest { + // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + // Add Locale setting + Locale.setDefault(Locale.US) + /** * Runs the plan and makes sure the answer contains all of the keywords, or the * none of keywords are listed in the answer - * @param rdd the [[SchemaRDD]] to be executed + * @param df the [[DataFrame]] to be executed * @param exists true for make sure the keywords are listed in the output, otherwise * to make sure none of the keyword are not listed in the output * @param keywords keyword in string array */ - def checkExistence(rdd: SchemaRDD, exists: Boolean, keywords: String*) { - val outputs = rdd.collect().map(_.mkString).mkString + def checkExistence(df: DataFrame, exists: Boolean, keywords: String*) { + val outputs = df.collect().map(_.mkString).mkString for (key <- keywords) { if (exists) { - assert(outputs.contains(key), s"Failed for $rdd ($key doens't exist in result)") + assert(outputs.contains(key), s"Failed for $df ($key doesn't exist in result)") } else { - assert(!outputs.contains(key), s"Failed for $rdd ($key existed in the result)") + assert(!outputs.contains(key), s"Failed for $df ($key existed in the result)") } } } /** * Runs the plan and makes sure the answer matches the expected result. - * @param rdd the [[SchemaRDD]] to be executed - * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ]. + * @param df the [[DataFrame]] to be executed + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. */ - protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = { - val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty + protected def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Unit = { + QueryTest.checkAnswer(df, expectedAnswer) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + protected def checkAnswer(df: DataFrame, expectedAnswer: Row): Unit = { + checkAnswer(df, Seq(expectedAnswer)) + } + + def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext) { + test(sqlString) { + checkAnswer(sqlContext.sql(sqlString), expectedAnswer) + } + } + + /** + * Asserts that a given [[DataFrame]] will be executed using the given number of cached results. + */ + def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { + val planWithCaching = query.queryExecution.withCachedData + val cachedData = planWithCaching collect { + case cached: InMemoryRelation => cached + } + + assert( + cachedData.size == numCachedTables, + s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + + planWithCaching) + } +} + +object QueryTest { + /** + * Runs the plan and makes sure the answer matches the expected result. + * If there was exception during the execution or the contents of the DataFrame does not + * match the expected result, an error message will be returned. Otherwise, a [[None]] will + * be returned. + * @param df the [[DataFrame]] to be executed + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = { + val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. val converted: Seq[Row] = answer.map { s => Row.fromSeq(s.toSeq.map { case d: java.math.BigDecimal => BigDecimal(d) + case b: Array[Byte] => b.toSeq case o => o }) } - if (!isSorted) converted.sortBy(_.toString) else converted + if (!isSorted) converted.sortBy(_.toString()) else converted } - val sparkAnswer = try rdd.collect().toSeq catch { + val sparkAnswer = try df.collect().toSeq catch { case e: Exception => - fail( + val errorMessage = s""" |Exception thrown while executing query: - |${rdd.queryExecution} + |${df.queryExecution} |== Exception == |$e |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} - """.stripMargin) + """.stripMargin + return Some(errorMessage) } if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { - fail(s""" + val errorMessage = + s""" |Results do not match for query: - |${rdd.logicalPlan} + |${df.logicalPlan} |== Analyzed Plan == - |${rdd.queryExecution.analyzed} + |${df.queryExecution.analyzed} |== Physical Plan == - |${rdd.queryExecution.executedPlan} + |${df.queryExecution.executedPlan} |== Results == |${sideBySide( - s"== Correct Answer - ${expectedAnswer.size} ==" +: - prepareAnswer(expectedAnswer).map(_.toString), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} - """.stripMargin) + s"== Correct Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString()), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} + """.stripMargin + return Some(errorMessage) } - } - protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = { - checkAnswer(rdd, Seq(expectedAnswer)) - } - - def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = { - test(sqlString) { - checkAnswer(sqlContext.sql(sqlString), expectedAnswer) - } + return None } - /** Asserts that a given SchemaRDD will be executed using the given number of cached results. */ - def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { - val planWithCaching = query.queryExecution.withCachedData - val cachedData = planWithCaching collect { - case cached: InMemoryRelation => cached + def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): String = { + checkAnswer(df, expectedAnswer.toSeq) match { + case Some(errorMessage) => errorMessage + case None => null } - - assert( - cachedData.size == numCachedTables, - s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + - planWithCaching) } - } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index f5b945f468da..fb3ba4bc1b90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql +import org.apache.spark.sql.execution.SparkSqlSerializer import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ class RowSuite extends FunSuite { @@ -27,7 +30,7 @@ class RowSuite extends FunSuite { test("create row") { val expected = new GenericMutableRow(4) expected.update(0, 2147483647) - expected.update(1, "this is a string") + expected.setString(1, "this is a string") expected.update(2, false) expected.update(3, null) val actual1 = Row(2147483647, "this is a string", false, null) @@ -50,4 +53,23 @@ class RowSuite extends FunSuite { row(0) = null assert(row.isNullAt(0)) } + + test("serialize w/ kryo") { + val row = Seq((1, Seq(1), Map(1 -> 1), BigDecimal(1))).toDF().first() + val serializer = new SparkSqlSerializer(TestSQLContext.sparkContext.getConf) + val instance = serializer.newInstance() + val ser = instance.serialize(row) + val de = instance.deserialize(ser).asInstanceOf[Row] + assert(de === row) + } + + test("get values by field name on Row created via .toDF") { + val row = Seq((1, Seq(1))).toDF("a", "b").first() + assert(row.getAs[Int]("a") === 1) + assert(row.getAs[Seq[Int]]("b") === Seq(1)) + + intercept[IllegalArgumentException]{ + row.getAs[Int]("c") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 03b44ca1d669..9e02e69fda3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,30 +17,51 @@ package org.apache.spark.sql -import java.util.TimeZone - import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types._ - -/* Implicits */ +import org.apache.spark.sql.execution.GeneratedAggregate +import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.{udf => _, _} +import org.apache.spark.sql.types._ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { // Make sure the tables are loaded. TestData - var origZone: TimeZone = _ - override protected def beforeAll() { - origZone = TimeZone.getDefault - TimeZone.setDefault(TimeZone.getTimeZone("UTC")) + import org.apache.spark.sql.test.TestSQLContext.implicits._ + val sqlCtx = TestSQLContext + + test("self join with aliases") { + Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str").registerTempTable("df") + + checkAnswer( + sql( + """ + |SELECT x.str, COUNT(*) + |FROM df x JOIN df y ON x.str = y.str + |GROUP BY x.str + """.stripMargin), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) } - override protected def afterAll() { - TimeZone.setDefault(origZone) + test("self join with alias in agg") { + Seq(1,2,3) + .map(i => (i, i.toString)) + .toDF("int", "str") + .groupBy("str") + .agg($"str", count("str").as("strCount")) + .registerTempTable("df") + + checkAnswer( + sql( + """ + |SELECT x.str, SUM(x.strCount) + |FROM df x JOIN df y ON x.str = y.str + |GROUP BY x.str + """.stripMargin), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) } test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") { @@ -78,14 +99,120 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql("SELECT ABS(2.5)"), Row(2.5)) } - + test("aggregation with codegen") { val originalValue = conf.codegenEnabled setConf(SQLConf.CODEGEN_ENABLED, "true") - sql("SELECT key FROM testData GROUP BY key").collect() + // Prepare a table that we can group some rows. + table("testData") + .unionAll(table("testData")) + .unionAll(table("testData")) + .registerTempTable("testData3x") + + def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { + val df = sql(sqlText) + // First, check if we have GeneratedAggregate. + var hasGeneratedAgg = false + df.queryExecution.executedPlan.foreach { + case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true + case _ => + } + if (!hasGeneratedAgg) { + fail( + s""" + |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan. + |${df.queryExecution.simpleString} + """.stripMargin) + } + // Then, check results. + checkAnswer(df, expectedResults) + } + + // Just to group rows. + testCodeGen( + "SELECT key FROM testData3x GROUP BY key", + (1 to 100).map(Row(_))) + // COUNT + testCodeGen( + "SELECT key, count(value) FROM testData3x GROUP BY key", + (1 to 100).map(i => Row(i, 3))) + testCodeGen( + "SELECT count(key) FROM testData3x", + Row(300) :: Nil) + // COUNT DISTINCT ON int + testCodeGen( + "SELECT value, count(distinct key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 1))) + testCodeGen( + "SELECT count(distinct key) FROM testData3x", + Row(100) :: Nil) + // SUM + testCodeGen( + "SELECT value, sum(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 3 * i))) + testCodeGen( + "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x", + Row(5050 * 3, 5050 * 3.0) :: Nil) + // AVERAGE + testCodeGen( + "SELECT value, avg(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT avg(key) FROM testData3x", + Row(50.5) :: Nil) + // MAX + testCodeGen( + "SELECT value, max(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT max(key) FROM testData3x", + Row(100) :: Nil) + // MIN + testCodeGen( + "SELECT value, min(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT min(key) FROM testData3x", + Row(1) :: Nil) + // Some combinations. + testCodeGen( + """ + |SELECT + | value, + | sum(key), + | max(key), + | min(key), + | avg(key), + | count(key), + | count(distinct key) + |FROM testData3x + |GROUP BY value + """.stripMargin, + (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1))) + testCodeGen( + "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x", + Row(100, 1, 50.5, 300, 100) :: Nil) + // Aggregate with Code generation handling all null values + testCodeGen( + "SELECT sum('a'), avg('a'), count(null) FROM testData", + Row(0, null, 0) :: Nil) + + dropTempTable("testData3x") setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) } + test("Add Parser of SQL COALESCE()") { + checkAnswer( + sql("""SELECT COALESCE(1, 2)"""), + Row(1)) + checkAnswer( + sql("SELECT COALESCE(null, 1, 1.5)"), + Row(1.toDouble)) + checkAnswer( + sql("SELECT COALESCE(null, null, null)"), + Row(null)) + } + test("SPARK-3176 Added Parser of SQL LAST()") { checkAnswer( sql("SELECT LAST(n) FROM lowerCaseData"), @@ -129,26 +256,29 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-3173 Timestamp support in the parser") { checkAnswer(sql( - "SELECT time FROM timestamps WHERE time=CAST('1970-01-01 00:00:00.001' AS TIMESTAMP)"), - Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))) + "SELECT time FROM timestamps WHERE time=CAST('1969-12-31 16:00:00.001' AS TIMESTAMP)"), + Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001"))) checkAnswer(sql( - "SELECT time FROM timestamps WHERE time='1970-01-01 00:00:00.001'"), - Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))) + "SELECT time FROM timestamps WHERE time='1969-12-31 16:00:00.001'"), + Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001"))) checkAnswer(sql( - "SELECT time FROM timestamps WHERE '1970-01-01 00:00:00.001'=time"), - Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001"))) + "SELECT time FROM timestamps WHERE '1969-12-31 16:00:00.001'=time"), + Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001"))) checkAnswer(sql( - """SELECT time FROM timestamps WHERE time<'1970-01-01 00:00:00.003' - AND time>'1970-01-01 00:00:00.001'"""), - Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002"))) + """SELECT time FROM timestamps WHERE time<'1969-12-31 16:00:00.003' + AND time>'1969-12-31 16:00:00.001'"""), + Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.002"))) checkAnswer(sql( - "SELECT time FROM timestamps WHERE time IN ('1970-01-01 00:00:00.001','1970-01-01 00:00:00.002')"), - Seq(Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.001")), - Row(java.sql.Timestamp.valueOf("1970-01-01 00:00:00.002")))) + """ + |SELECT time FROM timestamps + |WHERE time IN ('1969-12-31 16:00:00.001','1969-12-31 16:00:00.002') + """.stripMargin), + Seq(Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001")), + Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.002")))) checkAnswer(sql( "SELECT time FROM timestamps WHERE time='123'"), @@ -184,6 +314,15 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Seq(Row(1,3), Row(2,3), Row(3,3))) } + test("literal in agg grouping expressions") { + checkAnswer( + sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), + Seq(Row(1,2), Row(2,2), Row(3,2))) + checkAnswer( + sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), + Seq(Row(1,2), Row(2,2), Row(3,2))) + } + test("aggregates with nulls") { checkAnswer( sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"), @@ -203,7 +342,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row("1")) } - def sortTest() = { + def sortTest(): Unit = { checkAnswer( sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"), Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) @@ -259,6 +398,26 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { setConf(SQLConf.EXTERNAL_SORT, before.toString) } + test("SPARK-6927 sorting with codegen on") { + val externalbefore = conf.externalSortEnabled + val codegenbefore = conf.codegenEnabled + setConf(SQLConf.EXTERNAL_SORT, "false") + setConf(SQLConf.CODEGEN_ENABLED, "true") + sortTest() + setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) + setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + } + + test("SPARK-6927 external sorting with codegen on") { + val externalbefore = conf.externalSortEnabled + val codegenbefore = conf.codegenEnabled + setConf(SQLConf.CODEGEN_ENABLED, "true") + setConf(SQLConf.EXTERNAL_SORT, "true") + sortTest() + setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) + setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + } + test("limit") { checkAnswer( sql("SELECT * FROM testData LIMIT 10"), @@ -273,9 +432,39 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { mapData.collect().take(1).map(Row.fromTuple).toSeq) } + test("CTE feature") { + checkAnswer( + sql("with q1 as (select * from testData limit 10) select * from q1"), + testData.take(10).toSeq) + + checkAnswer( + sql(""" + |with q1 as (select * from testData where key= '5'), + |q2 as (select * from testData where key = '4') + |select * from q1 union all select * from q2""".stripMargin), + Row(5, "5") :: Row(4, "4") :: Nil) + + } + + test("Allow only a single WITH clause per query") { + intercept[RuntimeException] { + sql("with q1 as (select * from testData) with q2 as (select * from q1) select * from q2") + } + } + + test("date row") { + checkAnswer(sql( + """select cast("2015-01-28" as date) from testData limit 1"""), + Row(java.sql.Date.valueOf("2015-01-28")) + ) + } + test("from follow multiple brackets") { checkAnswer(sql( - "select key from ((select * from testData limit 1) union all (select * from testData limit 1)) x limit 1"), + """ + |select key from ((select * from testData limit 1) + | union all (select * from testData limit 1)) x limit 1 + """.stripMargin), Row(1) ) @@ -285,7 +474,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { ) checkAnswer(sql( - "select key from (select * from testData limit 1 union all select * from testData limit 1) x limit 1"), + """ + |select key from + | (select * from testData limit 1 union all select * from testData limit 1) x + | limit 1 + """.stripMargin), Row(1) ) } @@ -332,7 +525,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Seq(Row(1, 0), Row(2, 1))) checkAnswer( - sql("SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"), + sql( + """ + |SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3 + """.stripMargin), Row(2, 1, 2, 2, 1)) } @@ -381,8 +577,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("big inner join, 4 matches per row") { - - checkAnswer( sql( """ @@ -396,7 +590,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | SELECT * FROM testData UNION ALL | SELECT * FROM testData) y |WHERE x.key = y.key""".stripMargin), - testData.flatMap( + testData.rdd.flatMap( row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } @@ -571,7 +765,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { ("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Row(_))) // Column type mismatches where a coercion is not possible, in this case between integer // and array types, trigger a TreeNodeException. - intercept[TreeNodeException[_]] { + intercept[AnalysisException] { sql("SELECT data FROM arrayData UNION SELECT 1 FROM arrayData").collect() } } @@ -651,8 +845,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(values(0).toInt, values(1), values(2).toBoolean, v4) } - val schemaRDD1 = applySchema(rowRDD1, schema1) - schemaRDD1.registerTempTable("applySchema1") + val df1 = sqlCtx.createDataFrame(rowRDD1, schema1) + df1.registerTempTable("applySchema1") checkAnswer( sql("SELECT * FROM applySchema1"), Row(1, "A1", true, null) :: @@ -681,8 +875,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val schemaRDD2 = applySchema(rowRDD2, schema2) - schemaRDD2.registerTempTable("applySchema2") + val df2 = sqlCtx.createDataFrame(rowRDD2, schema2) + df2.registerTempTable("applySchema2") checkAnswer( sql("SELECT * FROM applySchema2"), Row(Row(1, true), Map("A1" -> null)) :: @@ -706,8 +900,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4)) } - val schemaRDD3 = applySchema(rowRDD3, schema2) - schemaRDD3.registerTempTable("applySchema3") + val df3 = sqlCtx.createDataFrame(rowRDD3, schema2) + df3.registerTempTable("applySchema3") checkAnswer( sql("SELECT f1.f11, f2['D4'] FROM applySchema3"), @@ -742,7 +936,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("metadata is propagated correctly") { - val person = sql("SELECT * FROM person") + val person: DataFrame = sql("SELECT * FROM person") val schema = person.schema val docKey = "doc" val docValue = "first name" @@ -751,14 +945,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { .build() val schemaWithMeta = new StructType(Array( schema("id"), schema("name").copy(metadata = metadata), schema("age"))) - val personWithMeta = applySchema(person, schemaWithMeta) - def validateMetadata(rdd: SchemaRDD): Unit = { + val personWithMeta = sqlCtx.createDataFrame(person.rdd, schemaWithMeta) + def validateMetadata(rdd: DataFrame): Unit = { assert(rdd.schema("name").metadata.getString(docKey) == docValue) } personWithMeta.registerTempTable("personWithMeta") - validateMetadata(personWithMeta.select('name)) - validateMetadata(personWithMeta.select("name".attr)) - validateMetadata(personWithMeta.select('id, 'name)) + validateMetadata(personWithMeta.select($"name")) + validateMetadata(personWithMeta.select($"name")) + validateMetadata(personWithMeta.select($"id", $"name")) validateMetadata(sql("SELECT * FROM personWithMeta")) validateMetadata(sql("SELECT id, name FROM personWithMeta")) validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId")) @@ -766,7 +960,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-3371 Renaming a function expression with group by gives error") { - udf.register("len", (s: String) => s.length) + TestSQLContext.udf.register("len", (s: String) => s.length) checkAnswer( sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), Row(1)) @@ -786,13 +980,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("throw errors for non-aggregate attributes with aggregation") { def checkAggregation(query: String, isInvalidQuery: Boolean = true) { - val logicalPlan = sql(query).queryExecution.logical - if (isInvalidQuery) { - val e = intercept[TreeNodeException[LogicalPlan]](sql(query).queryExecution.analyzed) - assert( - e.getMessage.startsWith("Expression not in GROUP BY"), - "Non-aggregate attribute(s) not detected\n" + logicalPlan) + val e = intercept[AnalysisException](sql(query).queryExecution.analyzed) + assert(e.getMessage contains "group by") } else { // Should not throw sql(query).queryExecution.analyzed @@ -800,7 +990,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } checkAggregation("SELECT key, COUNT(*) FROM testData") - checkAggregation("SELECT COUNT(key), COUNT(*) FROM testData", false) + checkAggregation("SELECT COUNT(key), COUNT(*) FROM testData", isInvalidQuery = false) checkAggregation("SELECT value, COUNT(*) FROM testData GROUP BY key") checkAggregation("SELECT COUNT(value), SUM(key) FROM testData GROUP BY key", false) @@ -951,9 +1141,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-3483 Special chars in column names") { - val data = sparkContext.parallelize(Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) + val data = sparkContext.parallelize( + Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) jsonRDD(data).registerTempTable("records") - sql("SELECT `key?number1` FROM records") + sql("SELECT `key?number1`, `key.number2` FROM records") } test("SPARK-3814 Support Bitwise & operator") { @@ -1019,10 +1210,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("Supporting relational operator '<=>' in Spark SQL") { val nullCheckData1 = TestData(1,"1") :: TestData(2,null) :: Nil val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) - rdd1.registerTempTable("nulldata1") + rdd1.toDF().registerTempTable("nulldata1") val nullCheckData2 = TestData(1,"1") :: TestData(2,null) :: Nil val rdd2 = sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) - rdd2.registerTempTable("nulldata2") + rdd2.toDF().registerTempTable("nulldata2") checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " + "nulldata2 on nulldata1.value <=> nulldata2.value"), (1 to 2).map(i => Row(i))) @@ -1031,7 +1222,34 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("Multi-column COUNT(DISTINCT ...)") { val data = TestData(1,"val_1") :: TestData(2,"val_2") :: Nil val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) - rdd.registerTempTable("distinctData") + rdd.toDF().registerTempTable("distinctData") checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2)) } + + test("SPARK-6145: ORDER BY test for nested fields") { + jsonRDD(sparkContext.makeRDD("""{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) + .registerTempTable("nestedOrder") + + checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1)) + checkAnswer(sql("SELECT a.b FROM nestedOrder ORDER BY a.b"), Row(1)) + checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.a.a"), Row(1)) + checkAnswer(sql("SELECT a.a.a FROM nestedOrder ORDER BY a.a.a"), Row(1)) + checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY c[0].d"), Row(1)) + checkAnswer(sql("SELECT c[0].d FROM nestedOrder ORDER BY c[0].d"), Row(1)) + } + + test("SPARK-6145: special cases") { + jsonRDD(sparkContext.makeRDD( + """{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t") + checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1)) + checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1)) + } + + test("SPARK-6898: complete support for special chars in column names") { + jsonRDD(sparkContext.makeRDD( + """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) + .registerTempTable("t") + + checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index a015884bae28..3fa00fd9d0cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -75,21 +75,25 @@ case class ComplexReflectData( dataField: Data) class ScalaReflectionRelationSuite extends FunSuite { + + import org.apache.spark.sql.test.TestSQLContext.implicits._ + test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, - new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3)) + new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3)) val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerTempTable("reflectData") + rdd.toDF().registerTempTable("reflectData") assert(sql("SELECT * FROM reflectData").collect().head === Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, - new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3))) + new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), + new Timestamp(12345), Seq(1,2,3))) } test("query case class RDD with nulls") { val data = NullReflectData(null, null, null, null, null, null, null) val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerTempTable("reflectNullData") + rdd.toDF().registerTempTable("reflectNullData") assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } @@ -97,15 +101,16 @@ class ScalaReflectionRelationSuite extends FunSuite { test("query case class RDD with Nones") { val data = OptionalReflectData(None, None, None, None, None, None, None) val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerTempTable("reflectOptionalData") + rdd.toDF().registerTempTable("reflectOptionalData") - assert(sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null))) + assert(sql("SELECT * FROM reflectOptionalData").collect().head === + Row.fromSeq(Seq.fill(7)(null))) } // Equality is broken for Arrays, so we test that separately. test("query binary data") { val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil) - rdd.registerTempTable("reflectBinary") + rdd.toDF().registerTempTable("reflectBinary") val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]] assert(result.toSeq === Seq[Byte](1)) @@ -124,7 +129,7 @@ class ScalaReflectionRelationSuite extends FunSuite { Map(10 -> Some(100L), 20 -> Some(200L), 30 -> None), Nested(None, "abc"))) val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerTempTable("reflectComplexData") + rdd.toDF().registerTempTable("reflectComplexData") assert(sql("SELECT * FROM reflectComplexData").collect().head === new GenericRow(Array[Any]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala new file mode 100644 index 000000000000..6f6d3c9c243d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.FunSuite + +import org.apache.spark.SparkConf +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.sql.test.TestSQLContext + +class SerializationSuite extends FunSuite { + + test("[SPARK-5235] SQLContext should be serializable") { + val sqlContext = new SQLContext(TestSQLContext.sparkContext) + new JavaSerializer(new SparkConf()).newInstance().serialize(sqlContext) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 808ed5288cfb..225b51bd73d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -20,20 +20,19 @@ package org.apache.spark.sql import java.sql.Timestamp import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.test._ -/* Implicits */ -import org.apache.spark.sql.test.TestSQLContext._ case class TestData(key: Int, value: String) object TestData { val testData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString))).toSchemaRDD + (1 to 100).map(i => TestData(i, i.toString))).toDF() testData.registerTempTable("testData") val negativeData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(-i, (-i).toString))).toSchemaRDD + (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() negativeData.registerTempTable("negativeData") case class LargeAndSmallInts(a: Int, b: Int) @@ -44,7 +43,7 @@ object TestData { LargeAndSmallInts(2147483645, 1) :: LargeAndSmallInts(2, 2) :: LargeAndSmallInts(2147483646, 1) :: - LargeAndSmallInts(3, 2) :: Nil).toSchemaRDD + LargeAndSmallInts(3, 2) :: Nil).toDF() largeAndSmallInts.registerTempTable("largeAndSmallInts") case class TestData2(a: Int, b: Int) @@ -55,7 +54,7 @@ object TestData { TestData2(2, 1) :: TestData2(2, 2) :: TestData2(3, 1) :: - TestData2(3, 2) :: Nil, 2).toSchemaRDD + TestData2(3, 2) :: Nil, 2).toDF() testData2.registerTempTable("testData2") case class DecimalData(a: BigDecimal, b: BigDecimal) @@ -67,7 +66,7 @@ object TestData { DecimalData(2, 1) :: DecimalData(2, 2) :: DecimalData(3, 1) :: - DecimalData(3, 2) :: Nil).toSchemaRDD + DecimalData(3, 2) :: Nil).toDF() decimalData.registerTempTable("decimalData") case class BinaryData(a: Array[Byte], b: Int) @@ -77,17 +76,17 @@ object TestData { BinaryData("22".getBytes(), 5) :: BinaryData("122".getBytes(), 3) :: BinaryData("121".getBytes(), 2) :: - BinaryData("123".getBytes(), 4) :: Nil).toSchemaRDD + BinaryData("123".getBytes(), 4) :: Nil).toDF() binaryData.registerTempTable("binaryData") case class TestData3(a: Int, b: Option[Int]) val testData3 = TestSQLContext.sparkContext.parallelize( TestData3(1, None) :: - TestData3(2, Some(2)) :: Nil).toSchemaRDD + TestData3(2, Some(2)) :: Nil).toDF() testData3.registerTempTable("testData3") - val emptyTableData = logical.LocalRelation('a.int, 'b.int) + val emptyTableData = logical.LocalRelation($"a".int, $"b".int) case class UpperCaseData(N: Int, L: String) val upperCaseData = @@ -97,7 +96,7 @@ object TestData { UpperCaseData(3, "C") :: UpperCaseData(4, "D") :: UpperCaseData(5, "E") :: - UpperCaseData(6, "F") :: Nil).toSchemaRDD + UpperCaseData(6, "F") :: Nil).toDF() upperCaseData.registerTempTable("upperCaseData") case class LowerCaseData(n: Int, l: String) @@ -106,7 +105,7 @@ object TestData { LowerCaseData(1, "a") :: LowerCaseData(2, "b") :: LowerCaseData(3, "c") :: - LowerCaseData(4, "d") :: Nil).toSchemaRDD + LowerCaseData(4, "d") :: Nil).toDF() lowerCaseData.registerTempTable("lowerCaseData") case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) @@ -114,7 +113,7 @@ object TestData { TestSQLContext.sparkContext.parallelize( ArrayData(Seq(1,2,3), Seq(Seq(1,2,3))) :: ArrayData(Seq(2,3,4), Seq(Seq(2,3,4))) :: Nil) - arrayData.registerTempTable("arrayData") + arrayData.toDF().registerTempTable("arrayData") case class MapData(data: scala.collection.Map[Int, String]) val mapData = @@ -124,18 +123,18 @@ object TestData { MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: MapData(Map(1 -> "a4", 2 -> "b4")) :: MapData(Map(1 -> "a5")) :: Nil) - mapData.registerTempTable("mapData") + mapData.toDF().registerTempTable("mapData") case class StringData(s: String) val repeatedData = TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) - repeatedData.registerTempTable("repeatedData") + repeatedData.toDF().registerTempTable("repeatedData") val nullableRepeatedData = TestSQLContext.sparkContext.parallelize( List.fill(2)(StringData(null)) ++ List.fill(2)(StringData("test"))) - nullableRepeatedData.registerTempTable("nullableRepeatedData") + nullableRepeatedData.toDF().registerTempTable("nullableRepeatedData") case class NullInts(a: Integer) val nullInts = @@ -144,7 +143,7 @@ object TestData { NullInts(2) :: NullInts(3) :: NullInts(null) :: Nil - ) + ).toDF() nullInts.registerTempTable("nullInts") val allNulls = @@ -152,7 +151,7 @@ object TestData { NullInts(null) :: NullInts(null) :: NullInts(null) :: - NullInts(null) :: Nil) + NullInts(null) :: Nil).toDF() allNulls.registerTempTable("allNulls") case class NullStrings(n: Int, s: String) @@ -160,11 +159,15 @@ object TestData { TestSQLContext.sparkContext.parallelize( NullStrings(1, "abc") :: NullStrings(2, "ABC") :: - NullStrings(3, null) :: Nil) + NullStrings(3, null) :: Nil).toDF() nullStrings.registerTempTable("nullStrings") case class TableName(tableName: String) - TestSQLContext.sparkContext.parallelize(TableName("test") :: Nil).registerTempTable("tableName") + TestSQLContext + .sparkContext + .parallelize(TableName("test") :: Nil) + .toDF() + .registerTempTable("tableName") val unparsedStrings = TestSQLContext.sparkContext.parallelize( @@ -177,29 +180,29 @@ object TestData { val timestamps = TestSQLContext.sparkContext.parallelize((1 to 3).map { i => TimestampField(new Timestamp(i)) }) - timestamps.registerTempTable("timestamps") + timestamps.toDF().registerTempTable("timestamps") case class IntField(i: Int) // An RDD with 4 elements and 8 partitions val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8) - withEmptyParts.registerTempTable("withEmptyParts") + withEmptyParts.toDF().registerTempTable("withEmptyParts") case class Person(id: Int, name: String, age: Int) case class Salary(personId: Int, salary: Double) val person = TestSQLContext.sparkContext.parallelize( Person(0, "mike", 30) :: - Person(1, "jim", 20) :: Nil) + Person(1, "jim", 20) :: Nil).toDF() person.registerTempTable("person") val salary = TestSQLContext.sparkContext.parallelize( Salary(0, 2000.0) :: - Salary(1, 1000.0) :: Nil) + Salary(1, 1000.0) :: Nil).toDF() salary.registerTempTable("salary") - case class ComplexData(m: Map[Int, String], s: TestData, a: Seq[Int], b: Boolean) + case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) val complexData = TestSQLContext.sparkContext.parallelize( - ComplexData(Map(1 -> "1"), TestData(1, "1"), Seq(1), true) - :: ComplexData(Map(2 -> "2"), TestData(2, "2"), Seq(2), false) - :: Nil).toSchemaRDD + ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1), true) + :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2), false) + :: Nil).toDF() complexData.registerTempTable("complexData") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 0c9812003124..d615542ab50a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.test._ /* Implicits */ import TestSQLContext._ +import TestSQLContext.implicits._ case class FunctionResult(f1: String, f2: String) @@ -28,25 +29,31 @@ class UDFSuite extends QueryTest { test("Simple UDF") { udf.register("strLenScala", (_: String).length) - assert(sql("SELECT strLenScala('test')").first().getInt(0) === 4) + assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) } test("ZeroArgument UDF") { udf.register("random0", () => { Math.random()}) - assert(sql("SELECT random0()").first().getDouble(0) >= 0.0) + assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { udf.register("strLenScala", (_: String).length + (_:Int)) - assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5) + assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } test("struct UDF") { udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) - val result= + val result = sql("SELECT returnStruct('test', 'test2') as ret") - .select("ret.f1".attr).first().getString(0) - assert(result == "test") + .select($"ret.f1").head().getString(0) + assert(result === "test") + } + + test("udf that is transformed") { + udf.register("makeStruct", (x: Int, y: Int) => (x, y)) + // 1 + 1 is constant folded causing a transformation. + assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index fbc8704f7837..2672e20deadc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -17,11 +17,22 @@ package org.apache.spark.sql +import java.io.File + +import org.apache.spark.util.Utils + import scala.beans.{BeanInfo, BeanProperty} +import com.clearspring.analytics.stream.cardinality.HyperLogLog + import org.apache.spark.rdd.RDD -import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql} +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ +import org.apache.spark.util.collection.OpenHashSet @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { @@ -55,25 +66,27 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { } } - override def userClass = classOf[MyDenseVector] + override def userClass: Class[MyDenseVector] = classOf[MyDenseVector] + + private[spark] override def asNullable: MyDenseVectorUDT = this } class UserDefinedTypeSuite extends QueryTest { val points = Seq( MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) - val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points) + val pointsRDD = sparkContext.parallelize(points).toDF() test("register user type: MyDenseVector for MyLabeledPoint") { - val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v } + val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v } val labelsArrays: Array[Double] = labels.collect() assert(labelsArrays.size === 2) assert(labelsArrays.contains(1.0)) assert(labelsArrays.contains(0.0)) val features: RDD[MyDenseVector] = - pointsRDD.select('features).map { case Row(v: MyDenseVector) => v } + pointsRDD.select('features).rdd.map { case Row(v: MyDenseVector) => v } val featuresArrays: Array[MyDenseVector] = features.collect() assert(featuresArrays.size === 2) assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0)))) @@ -81,10 +94,51 @@ class UserDefinedTypeSuite extends QueryTest { } test("UDTs and UDFs") { - udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) + TestSQLContext.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) pointsRDD.registerTempTable("points") checkAnswer( sql("SELECT testType(features) from points"), Seq(Row(true), Row(true))) } + + + test("UDTs with Parquet") { + val tempDir = Utils.createTempDir() + tempDir.delete() + pointsRDD.saveAsParquetFile(tempDir.getCanonicalPath) + } + + test("Repartition UDTs with Parquet") { + val tempDir = Utils.createTempDir() + tempDir.delete() + pointsRDD.repartition(1).saveAsParquetFile(tempDir.getCanonicalPath) + } + + // Tests to make sure that all operators correctly convert types on the way out. + test("Local UDTs") { + val df = Seq((1, new MyDenseVector(Array(0.1, 1.0)))).toDF("int", "vec") + df.collect()(0).getAs[MyDenseVector](1) + df.take(1)(0).getAs[MyDenseVector](1) + df.limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0) + df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0) + } + + test("HyperLogLogUDT") { + val hyperLogLogUDT = HyperLogLogUDT + val hyperLogLog = new HyperLogLog(0.4) + (1 to 10).foreach(i => hyperLogLog.offer(Row(i))) + + val actual = hyperLogLogUDT.deserialize(hyperLogLogUDT.serialize(hyperLogLog)) + assert(actual.cardinality() === hyperLogLog.cardinality()) + assert(java.util.Arrays.equals(actual.getBytes, hyperLogLog.getBytes)) + } + + test("OpenHashSetUDT") { + val openHashSetUDT = new OpenHashSetUDT(IntegerType) + val set = new OpenHashSet[Int] + (1 to 10).foreach(i => set.add(i)) + + val actual = openHashSetUDT.deserialize(openHashSetUDT.serialize(set)) + assert(actual.iterator.toSet === set.iterator.toSet) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index be2b34de077c..7cefcf44061c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -29,11 +29,12 @@ class ColumnStatsSuite extends FunSuite { testColumnStats(classOf[LongColumnStats], LONG, Row(Long.MaxValue, Long.MinValue, 0)) testColumnStats(classOf[FloatColumnStats], FLOAT, Row(Float.MaxValue, Float.MinValue, 0)) testColumnStats(classOf[DoubleColumnStats], DOUBLE, Row(Double.MaxValue, Double.MinValue, 0)) + testColumnStats(classOf[FixedDecimalColumnStats], FIXED_DECIMAL(15, 10), Row(null, null, 0)) testColumnStats(classOf[StringColumnStats], STRING, Row(null, null, 0)) - testColumnStats(classOf[DateColumnStats], DATE, Row(null, null, 0)) + testColumnStats(classOf[DateColumnStats], DATE, Row(Int.MaxValue, Int.MinValue, 0)) testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0)) - def testColumnStats[T <: NativeType, U <: ColumnStats]( + def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], columnType: NativeColumnType[T], initialStatistics: Row): Unit = { @@ -54,8 +55,8 @@ class ColumnStatsSuite extends FunSuite { val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(columnStats.gatherStats(_, 0)) - val values = rows.take(10).map(_(0).asInstanceOf[T#JvmType]) - val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#JvmType]] + val values = rows.take(10).map(_(0).asInstanceOf[T#InternalType]) + val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 87e608a8853d..1e105e259dce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -18,11 +18,14 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import java.sql.{Date, Timestamp} +import java.sql.Timestamp +import com.esotericsoftware.kryo.{Serializer, Kryo} +import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.spark.serializer.KryoRegistrator import org.scalatest.FunSuite -import org.apache.spark.Logging +import org.apache.spark.{SparkConf, Logging} import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer @@ -33,8 +36,9 @@ class ColumnTypeSuite extends FunSuite with Logging { test("defaultSize") { val checks = Map( - INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, BOOLEAN -> 1, - STRING -> 8, DATE -> 8, TIMESTAMP -> 12, BINARY -> 16, GENERIC -> 16) + INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, + FIXED_DECIMAL(15, 10) -> 8, BOOLEAN -> 1, STRING -> 8, DATE -> 4, TIMESTAMP -> 12, + BINARY -> 16, GENERIC -> 16) checks.foreach { case (columnType, expectedSize) => assertResult(expectedSize, s"Wrong defaultSize for $columnType") { @@ -56,22 +60,23 @@ class ColumnTypeSuite extends FunSuite with Logging { } } - checkActualSize(INT, Int.MaxValue, 4) - checkActualSize(SHORT, Short.MaxValue, 2) - checkActualSize(LONG, Long.MaxValue, 8) - checkActualSize(BYTE, Byte.MaxValue, 1) - checkActualSize(DOUBLE, Double.MaxValue, 8) - checkActualSize(FLOAT, Float.MaxValue, 4) - checkActualSize(BOOLEAN, true, 1) - checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length) - checkActualSize(DATE, new Date(0L), 8) + checkActualSize(INT, Int.MaxValue, 4) + checkActualSize(SHORT, Short.MaxValue, 2) + checkActualSize(LONG, Long.MaxValue, 8) + checkActualSize(BYTE, Byte.MaxValue, 1) + checkActualSize(DOUBLE, Double.MaxValue, 8) + checkActualSize(FLOAT, Float.MaxValue, 4) + checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8) + checkActualSize(BOOLEAN, true, 1) + checkActualSize(STRING, UTF8String("hello"), 4 + "hello".getBytes("utf-8").length) + checkActualSize(DATE, 0, 4) checkActualSize(TIMESTAMP, new Timestamp(0L), 12) val binary = Array.fill[Byte](4)(0: Byte) checkActualSize(BINARY, binary, 4 + 4) val generic = Map(1 -> "a") - checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 11) + checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8) } testNativeColumnType[BooleanType.type]( @@ -93,13 +98,21 @@ class ColumnTypeSuite extends FunSuite with Logging { testNativeColumnType[DoubleType.type](DOUBLE, _.putDouble(_), _.getDouble) + testNativeColumnType[DecimalType]( + FIXED_DECIMAL(15, 10), + (buffer: ByteBuffer, decimal: Decimal) => { + buffer.putLong(decimal.toUnscaledLong) + }, + (buffer: ByteBuffer) => { + Decimal(buffer.getLong(), 15, 10) + }) + testNativeColumnType[FloatType.type](FLOAT, _.putFloat(_), _.getFloat) testNativeColumnType[StringType.type]( STRING, - (buffer: ByteBuffer, string: String) => { - - val bytes = string.getBytes("utf-8") + (buffer: ByteBuffer, string: UTF8String) => { + val bytes = string.getBytes buffer.putInt(bytes.length) buffer.put(bytes) }, @@ -107,7 +120,7 @@ class ColumnTypeSuite extends FunSuite with Logging { val length = buffer.getInt() val bytes = new Array[Byte](length) buffer.get(bytes) - new String(bytes, "utf-8") + UTF8String(bytes) }) testColumnType[BinaryType.type, Array[Byte]]( @@ -148,12 +161,47 @@ class ColumnTypeSuite extends FunSuite with Logging { } } - def testNativeColumnType[T <: NativeType]( + test("CUSTOM") { + val conf = new SparkConf() + conf.set("spark.kryo.registrator", "org.apache.spark.sql.columnar.Registrator") + val serializer = new SparkSqlSerializer(conf).newInstance() + + val buffer = ByteBuffer.allocate(512) + val obj = CustomClass(Int.MaxValue,Long.MaxValue) + val serializedObj = serializer.serialize(obj).array() + + GENERIC.append(serializer.serialize(obj).array(), buffer) + buffer.rewind() + + val length = buffer.getInt + assert(length === serializedObj.length) + assert(13 == length) // id (1) + int (4) + long (8) + + val genericSerializedObj = SparkSqlSerializer.serialize(obj) + assert(length != genericSerializedObj.length) + assert(length < genericSerializedObj.length) + + assertResult(obj, "Custom deserialized object didn't equal the original object") { + val bytes = new Array[Byte](length) + buffer.get(bytes, 0, length) + serializer.deserialize(ByteBuffer.wrap(bytes)) + } + + buffer.rewind() + buffer.putInt(serializedObj.length).put(serializedObj) + + assertResult(obj, "Custom deserialized object didn't equal the original object") { + buffer.rewind() + serializer.deserialize(ByteBuffer.wrap(GENERIC.extract(buffer))) + } + } + + def testNativeColumnType[T <: AtomicType]( columnType: NativeColumnType[T], - putter: (ByteBuffer, T#JvmType) => Unit, - getter: (ByteBuffer) => T#JvmType): Unit = { + putter: (ByteBuffer, T#InternalType) => Unit, + getter: (ByteBuffer) => T#InternalType): Unit = { - testColumnType[T, T#JvmType](columnType, putter, getter) + testColumnType[T, T#InternalType](columnType, putter, getter) } def testColumnType[T <: DataType, JvmType]( @@ -206,4 +254,36 @@ class ColumnTypeSuite extends FunSuite with Logging { if (sb.nonEmpty) sb.setLength(sb.length - 1) sb.toString() } + + test("column type for decimal types with different precision") { + (1 to 18).foreach { i => + assertResult(FIXED_DECIMAL(i, 0)) { + ColumnType(DecimalType(i, 0)) + } + } + + assertResult(GENERIC) { + ColumnType(DecimalType(19, 0)) + } + } +} + +private[columnar] final case class CustomClass(a: Int, b: Long) + +private[columnar] object CustomerSerializer extends Serializer[CustomClass] { + override def write(kryo: Kryo, output: Output, t: CustomClass) { + output.writeInt(t.a) + output.writeLong(t.b) + } + override def read(kryo: Kryo, input: Input, aClass: Class[CustomClass]): CustomClass = { + val a = input.readInt() + val b = input.readLong() + CustomClass(a,b) + } +} + +private[columnar] final class Registrator extends KryoRegistrator { + override def registerClasses(kryo: Kryo) { + kryo.register(classOf[CustomClass], CustomerSerializer) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index f941465fa3e3..75d993e563e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -17,17 +17,17 @@ package org.apache.spark.sql.columnar +import java.sql.Timestamp + import scala.collection.immutable.HashSet import scala.util.Random -import java.sql.{Date, Timestamp} - import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.types.{DataType, NativeType} +import org.apache.spark.sql.types.{UTF8String, DataType, Decimal, AtomicType} object ColumnarTestUtils { - def makeNullRow(length: Int) = { + def makeNullRow(length: Int): GenericMutableRow = { val row = new GenericMutableRow(length) (0 until length).foreach(row.setNullAt) row @@ -41,16 +41,17 @@ object ColumnarTestUtils { } (columnType match { - case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte - case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort - case INT => Random.nextInt() - case LONG => Random.nextLong() - case FLOAT => Random.nextFloat() - case DOUBLE => Random.nextDouble() - case STRING => Random.nextString(Random.nextInt(32)) - case BOOLEAN => Random.nextBoolean() - case BINARY => randomBytes(Random.nextInt(32)) - case DATE => new Date(Random.nextLong()) + case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte + case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort + case INT => Random.nextInt() + case LONG => Random.nextLong() + case FLOAT => Random.nextFloat() + case DOUBLE => Random.nextDouble() + case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) + case STRING => UTF8String(Random.nextString(Random.nextInt(32))) + case BOOLEAN => Random.nextBoolean() + case BINARY => randomBytes(Random.nextInt(32)) + case DATE => Random.nextInt() case TIMESTAMP => val timestamp = new Timestamp(Random.nextLong()) timestamp.setNanos(Random.nextInt(999999999)) @@ -90,9 +91,9 @@ object ColumnarTestUtils { row } - def makeUniqueValuesAndSingleValueRows[T <: NativeType]( + def makeUniqueValuesAndSingleValueRows[T <: AtomicType]( columnType: NativeColumnType[T], - count: Int) = { + count: Int): (Seq[T#InternalType], Seq[GenericMutableRow]) = { val values = makeUniqueRandomValues(columnType, count) val rows = values.map { value => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index e61f3c39631d..56591d9dba29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -17,9 +17,13 @@ package org.apache.spark.sql.columnar +import java.sql.{Date, Timestamp} + import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.apache.spark.sql.types._ import org.apache.spark.sql.{QueryTest, TestData} import org.apache.spark.storage.StorageLevel.MEMORY_ONLY @@ -36,10 +40,11 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("default size avoids broadcast") { // TODO: Improve this test when we have better statistics - sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).registerTempTable("sizeTst") + sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) + .toDF().registerTempTable("sizeTst") cacheTable("sizeTst") assert( - table("sizeTst").queryExecution.logical.statistics.sizeInBytes > + table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > conf.autoBroadcastJoinThreshold) } @@ -114,4 +119,74 @@ class InMemoryColumnarQuerySuite extends QueryTest { complexData.count() complexData.unpersist() } + + test("decimal type") { + // Casting is required here because ScalaReflection can't capture decimal precision information. + val df = (1 to 10) + .map(i => Tuple1(Decimal(i, 15, 10))) + .toDF("dec") + .select($"dec" cast DecimalType(15, 10)) + + assert(df.schema.head.dataType === DecimalType(15, 10)) + + df.cache().registerTempTable("test_fixed_decimal") + checkAnswer( + sql("SELECT * FROM test_fixed_decimal"), + (1 to 10).map(i => Row(Decimal(i, 15, 10).toJavaBigDecimal))) + } + + test("test different data types") { + // Create the schema. + val struct = + StructType( + StructField("f1", FloatType, true) :: + StructField("f2", ArrayType(BooleanType), true) :: Nil) + val dataTypes = + Seq(StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), MapType(StringType, LongType), struct) + val fields = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, true) + } + val allColumns = fields.map(_.name).mkString(",") + val schema = StructType(fields) + + // Create a RDD for the schema + val rdd = + sparkContext.parallelize((1 to 100), 10).map { i => + Row( + s"str${i}: test cache.", + s"binary${i}: test cache.".getBytes("UTF-8"), + null, + i % 2 == 0, + i.toByte, + i.toShort, + i, + Long.MaxValue - i.toLong, + (i + 0.25).toFloat, + (i + 0.75), + BigDecimal(Long.MaxValue.toString + ".12345"), + new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"), + new Date(i), + new Timestamp(i), + (1 to i).toSeq, + (0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, + Row((i - 0.25).toFloat, (1 to i).toSeq)) + } + createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") + // Cache the table. + sql("cache table InMemoryCache_different_data_types") + // Make sure the table is indeed cached. + val tableScan = table("InMemoryCache_different_data_types").queryExecution.executedPlan + assert( + isCached("InMemoryCache_different_data_types"), + "InMemoryCache_different_data_types should be cached.") + // Issue a query and check the results. + checkAnswer( + sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"), + table("InMemoryCache_different_data_types").collect()) + dropTempTable("InMemoryCache_different_data_types") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index f95c895587f3..a0702144f942 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -31,7 +31,8 @@ class TestNullableColumnAccessor[T <: DataType, JvmType]( with NullableColumnAccessor object TestNullableColumnAccessor { - def apply[T <: DataType, JvmType](buffer: ByteBuffer, columnType: ColumnType[T, JvmType]) = { + def apply[T <: DataType, JvmType](buffer: ByteBuffer, columnType: ColumnType[T, JvmType]) + : TestNullableColumnAccessor[T, JvmType] = { // Skips the column type ID buffer.getInt() new TestNullableColumnAccessor(buffer, columnType) @@ -42,7 +43,8 @@ class NullableColumnAccessorSuite extends FunSuite { import ColumnarTestUtils._ Seq( - INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, DATE, TIMESTAMP + INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, FIXED_DECIMAL(15, 10), BINARY, GENERIC, + DATE, TIMESTAMP ).foreach { testNullableColumnAccessor(_) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index 80bd5c94570c..3a5605d2335d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -27,7 +27,8 @@ class TestNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T with NullableColumnBuilder object TestNullableColumnBuilder { - def apply[T <: DataType, JvmType](columnType: ColumnType[T, JvmType], initialSize: Int = 0) = { + def apply[T <: DataType, JvmType](columnType: ColumnType[T, JvmType], initialSize: Int = 0) + : TestNullableColumnBuilder[T, JvmType] = { val builder = new TestNullableColumnBuilder(columnType) builder.initialize(initialSize) builder @@ -38,7 +39,8 @@ class NullableColumnBuilderSuite extends FunSuite { import ColumnarTestUtils._ Seq( - INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, DATE, TIMESTAMP + INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, FIXED_DECIMAL(15, 10), BINARY, GENERIC, + DATE, TIMESTAMP ).foreach { testNullableColumnBuilder(_) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index c3a3f8ddc3eb..2a0b701cad7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.apache.spark.sql._ import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter { val originalColumnBatchSize = conf.columnBatchSize @@ -33,11 +34,13 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be val pruningData = sparkContext.makeRDD((1 to 100).map { key => val string = if (((key - 1) / 10) % 2 == 0) null else key.toString TestData(key, string) - }, 5) + }, 5).toDF() pruningData.registerTempTable("pruningData") // Enable in-memory partition pruning setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") + // Enable in-memory table scan accumulators + setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") } override protected def afterAll(): Unit = { @@ -104,14 +107,14 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be expectedQueryResult: => Seq[Int]): Unit = { test(query) { - val schemaRdd = sql(query) - val queryExecution = schemaRdd.queryExecution + val df = sql(query) + val queryExecution = df.queryExecution assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") { - schemaRdd.collect().map(_(0)).toArray + df.collect().map(_(0)).toArray } - val (readPartitions, readBatches) = schemaRdd.queryExecution.executedPlan.collect { + val (readPartitions, readBatches) = df.queryExecution.executedPlan.collect { case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value) }.head diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala index c82d9799359c..64b70552eb04 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala @@ -24,14 +24,14 @@ import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ -import org.apache.spark.sql.types.NativeType +import org.apache.spark.sql.types.AtomicType class DictionaryEncodingSuite extends FunSuite { testDictionaryEncoding(new IntColumnStats, INT) testDictionaryEncoding(new LongColumnStats, LONG) testDictionaryEncoding(new StringColumnStats, STRING) - def testDictionaryEncoding[T <: NativeType]( + def testDictionaryEncoding[T <: AtomicType]( columnStats: ColumnStats, columnType: NativeColumnType[T]) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala index 88011631ee4e..bfd99f143bed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala @@ -33,7 +33,7 @@ class IntegralDeltaSuite extends FunSuite { columnType: NativeColumnType[I], scheme: CompressionScheme) { - def skeleton(input: Seq[I#JvmType]) { + def skeleton(input: Seq[I#InternalType]) { // ------------- // Tests encoder // ------------- @@ -120,13 +120,13 @@ class IntegralDeltaSuite extends FunSuite { case LONG => Seq(2: Long, 1: Long, 2: Long, 130: Long) } - skeleton(input.map(_.asInstanceOf[I#JvmType])) + skeleton(input.map(_.asInstanceOf[I#InternalType])) } test(s"$scheme: long random series") { // Have to workaround with `Any` since no `ClassTag[I#JvmType]` available here. val input = Array.fill[Any](10000)(makeRandomValue(columnType)) - skeleton(input.map(_.asInstanceOf[I#JvmType])) + skeleton(input.map(_.asInstanceOf[I#InternalType])) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala index 08df1db37509..fde7a4595be0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala @@ -22,7 +22,7 @@ import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ -import org.apache.spark.sql.types.NativeType +import org.apache.spark.sql.types.AtomicType class RunLengthEncodingSuite extends FunSuite { testRunLengthEncoding(new NoopColumnStats, BOOLEAN) @@ -32,7 +32,7 @@ class RunLengthEncodingSuite extends FunSuite { testRunLengthEncoding(new LongColumnStats, LONG) testRunLengthEncoding(new StringColumnStats, STRING) - def testRunLengthEncoding[T <: NativeType]( + def testRunLengthEncoding[T <: AtomicType]( columnStats: ColumnStats, columnType: NativeColumnType[T]) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala index 0b18b4119268..5268dfe0aa03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.columnar.compression import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.types.NativeType +import org.apache.spark.sql.types.AtomicType -class TestCompressibleColumnBuilder[T <: NativeType]( +class TestCompressibleColumnBuilder[T <: AtomicType]( override val columnStats: ColumnStats, override val columnType: NativeColumnType[T], override val schemes: Seq[CompressionScheme]) @@ -32,10 +32,10 @@ class TestCompressibleColumnBuilder[T <: NativeType]( } object TestCompressibleColumnBuilder { - def apply[T <: NativeType]( + def apply[T <: AtomicType]( columnStats: ColumnStats, columnType: NativeColumnType[T], - scheme: CompressionScheme) = { + scheme: CompressionScheme): TestCompressibleColumnBuilder[T] = { val builder = new TestCompressibleColumnBuilder(columnStats, columnType, Seq(scheme)) builder.initialize(0, "", useCompression = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 67007b8c093c..523be56df65b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -20,14 +20,17 @@ package org.apache.spark.sql.execution import org.scalatest.FunSuite import org.apache.spark.sql.{SQLConf, execution} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.test.TestSQLContext.planner._ import org.apache.spark.sql.types._ + class PlannerSuite extends FunSuite { test("unions are collapsed") { val query = testData.unionAll(testData).unionAll(testData).logicalPlan @@ -40,7 +43,7 @@ class PlannerSuite extends FunSuite { } test("count is partially aggregated") { - val query = testData.groupBy('value)(Count('key)).queryExecution.analyzed + val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed val planned = HashAggregation(query).head val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } @@ -48,14 +51,14 @@ class PlannerSuite extends FunSuite { } test("count distinct is partially aggregated") { - val query = testData.groupBy('value)(CountDistinct('key :: Nil)).queryExecution.analyzed + val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed val planned = HashAggregation(query) assert(planned.nonEmpty) } test("mixed aggregates are partially aggregated") { val query = - testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).queryExecution.analyzed + testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed val planned = HashAggregation(query) assert(planned.nonEmpty) } @@ -69,7 +72,7 @@ class PlannerSuite extends FunSuite { val schema = StructType(fields) val row = Row.fromSeq(Seq.fill(fields.size)(null)) val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil) - applySchema(rowRDD, schema).registerTempTable("testLimit") + createDataFrame(rowRDD, schema).registerTempTable("testLimit") val planned = sql( """ @@ -128,9 +131,9 @@ class PlannerSuite extends FunSuite { testData.limit(3).registerTempTable("tiny") sql("CACHE TABLE tiny") - val a = testData.as('a) - val b = table("tiny").as('b) - val planned = a.join(b, Inner, Some("a.key".attr === "b.key".attr)).queryExecution.executedPlan + val a = testData.as("a") + val b = table("tiny").as("b") + val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala new file mode 100644 index 000000000000..27f063d73a9a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.sql.{Timestamp, Date} + +import org.scalatest.{FunSuite, BeforeAndAfterAll} + +import org.apache.spark.rdd.ShuffledRDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.ShuffleDependency +import org.apache.spark.sql.types._ +import org.apache.spark.sql.Row +import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest} + +class SparkSqlSerializer2DataTypeSuite extends FunSuite { + // Make sure that we will not use serializer2 for unsupported data types. + def checkSupported(dataType: DataType, isSupported: Boolean): Unit = { + val testName = + s"${if (dataType == null) null else dataType.toString} is " + + s"${if (isSupported) "supported" else "unsupported"}" + + test(testName) { + assert(SparkSqlSerializer2.support(Array(dataType)) === isSupported) + } + } + + checkSupported(null, isSupported = true) + checkSupported(NullType, isSupported = true) + checkSupported(BooleanType, isSupported = true) + checkSupported(ByteType, isSupported = true) + checkSupported(ShortType, isSupported = true) + checkSupported(IntegerType, isSupported = true) + checkSupported(LongType, isSupported = true) + checkSupported(FloatType, isSupported = true) + checkSupported(DoubleType, isSupported = true) + checkSupported(DateType, isSupported = true) + checkSupported(TimestampType, isSupported = true) + checkSupported(StringType, isSupported = true) + checkSupported(BinaryType, isSupported = true) + checkSupported(DecimalType(10, 5), isSupported = true) + checkSupported(DecimalType.Unlimited, isSupported = true) + + // For now, ArrayType, MapType, and StructType are not supported. + checkSupported(ArrayType(DoubleType, true), isSupported = false) + checkSupported(ArrayType(StringType, false), isSupported = false) + checkSupported(MapType(IntegerType, StringType, true), isSupported = false) + checkSupported(MapType(IntegerType, ArrayType(DoubleType), false), isSupported = false) + checkSupported(StructType(StructField("a", IntegerType, true) :: Nil), isSupported = false) + // UDTs are not supported right now. + checkSupported(new MyDenseVectorUDT, isSupported = false) +} + +abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll { + var allColumns: String = _ + val serializerClass: Class[Serializer] = + classOf[SparkSqlSerializer2].asInstanceOf[Class[Serializer]] + var numShufflePartitions: Int = _ + var useSerializer2: Boolean = _ + + override def beforeAll(): Unit = { + numShufflePartitions = conf.numShufflePartitions + useSerializer2 = conf.useSqlSerializer2 + + sql("set spark.sql.useSerializer2=true") + + val supportedTypes = + Seq(StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5), + DateType, TimestampType) + + val fields = supportedTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, true) + } + allColumns = fields.map(_.name).mkString(",") + val schema = StructType(fields) + + // Create a RDD with all data types supported by SparkSqlSerializer2. + val rdd = + sparkContext.parallelize((1 to 1000), 10).map { i => + Row( + s"str${i}: test serializer2.", + s"binary${i}: test serializer2.".getBytes("UTF-8"), + null, + i % 2 == 0, + i.toByte, + i.toShort, + i, + Long.MaxValue - i.toLong, + (i + 0.25).toFloat, + (i + 0.75), + BigDecimal(Long.MaxValue.toString + ".12345"), + new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"), + new Date(i), + new Timestamp(i)) + } + + createDataFrame(rdd, schema).registerTempTable("shuffle") + + super.beforeAll() + } + + override def afterAll(): Unit = { + dropTempTable("shuffle") + sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions") + sql(s"set spark.sql.useSerializer2=$useSerializer2") + super.afterAll() + } + + def checkSerializer[T <: Serializer]( + executedPlan: SparkPlan, + expectedSerializerClass: Class[T]): Unit = { + executedPlan.foreach { + case exchange: Exchange => + val shuffledRDD = exchange.execute().firstParent.asInstanceOf[ShuffledRDD[_, _, _]] + val dependency = shuffledRDD.getDependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + val serializerNotSetMessage = + s"Expected $expectedSerializerClass as the serializer of Exchange. " + + s"However, the serializer was not set." + val serializer = dependency.serializer.getOrElse(fail(serializerNotSetMessage)) + assert(serializer.getClass === expectedSerializerClass) + case _ => // Ignore other nodes. + } + } + + test("key schema and value schema are not nulls") { + val df = sql(s"SELECT DISTINCT ${allColumns} FROM shuffle") + checkSerializer(df.queryExecution.executedPlan, serializerClass) + checkAnswer( + df, + table("shuffle").collect()) + } + + test("value schema is null") { + val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0") + checkSerializer(df.queryExecution.executedPlan, serializerClass) + assert( + df.map(r => r.getString(0)).collect().toSeq === + table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq) + } + + test("no map output field") { + val df = sql(s"SELECT 1 + 1 FROM shuffle") + checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer]) + } +} + +/** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */ +class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite { + override def beforeAll(): Unit = { + super.beforeAll() + // Sort merge will not be triggered. + sql("set spark.sql.shuffle.partitions = 200") + } + + test("key schema is null") { + val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",") + val df = sql(s"SELECT $aggregations FROM shuffle") + checkSerializer(df.queryExecution.executedPlan, serializerClass) + checkAnswer( + df, + Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000)) + } +} + +/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */ +class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite { + + // We are expecting SparkSqlSerializer. + override val serializerClass: Class[Serializer] = + classOf[SparkSqlSerializer].asInstanceOf[Class[Serializer]] + + override def beforeAll(): Unit = { + super.beforeAll() + // To trigger the sort merge. + sql("set spark.sql.shuffle.partitions = 201") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala deleted file mode 100644 index 272c0d4cb233..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ - -/* Implicit conversions */ -import org.apache.spark.sql.test.TestSQLContext._ - -/** - * This is an example TGF that uses UnresolvedAttributes 'name and 'age to access specific columns - * from the input data. These will be replaced during analysis with specific AttributeReferences - * and then bound to specific ordinals during query planning. While TGFs could also access specific - * columns using hand-coded ordinals, doing so violates data independence. - * - * Note: this is only a rough example of how TGFs can be expressed, the final version will likely - * involve a lot more sugar for cleaner use in Scala/Java/etc. - */ -case class ExampleTGF(input: Seq[Expression] = Seq('name, 'age)) extends Generator { - def children = input - protected def makeOutput() = 'nameAndAge.string :: Nil - - val Seq(nameAttr, ageAttr) = input - - override def eval(input: Row): TraversableOnce[Row] = { - val name = nameAttr.eval(input) - val age = ageAttr.eval(input).asInstanceOf[Int] - - Iterator( - new GenericRow(Array[Any](s"$name is $age years old")), - new GenericRow(Array[Any](s"Next year, $name will be ${age + 1} years old"))) - } -} - -class TgfSuite extends QueryTest { - val inputData = - logical.LocalRelation('name.string, 'age.int).loadData( - ("michael", 29) :: Nil - ) - - test("simple tgf example") { - checkAnswer( - inputData.generate(ExampleTGF()), - Seq( - Row("michael is 29 years old"), - Row("Next year, michael will be 30 years old"))) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 87c28c334d22..358d8cf06e46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -23,11 +23,11 @@ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.test.TestSQLContext._ class DebuggingSuite extends FunSuite { - test("SchemaRDD.debug()") { + test("DataFrame.debug()") { testData.debug() } - test("SchemaRDD.typeCheck()") { + test("DataFrame.typeCheck()") { testData.typeCheck() } -} \ No newline at end of file +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala new file mode 100644 index 000000000000..db096af4535a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.math.BigDecimal +import java.sql.DriverManager +import java.util.{Calendar, GregorianCalendar, Properties} + +import org.apache.spark.sql.test._ +import org.h2.jdbc.JdbcSQLException +import org.scalatest.{FunSuite, BeforeAndAfter} +import TestSQLContext._ +import TestSQLContext.implicits._ + +class JDBCSuite extends FunSuite with BeforeAndAfter { + val url = "jdbc:h2:mem:testdb0" + val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass" + var conn: java.sql.Connection = null + + val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) + + before { + Class.forName("org.h2.Driver") + // Extra properties that will be specified for our database. We need these to test + // usage of parameters from OPTIONS clause in queries. + val properties = new Properties() + properties.setProperty("user", "testUser") + properties.setProperty("password", "testPass") + properties.setProperty("rowId", "false") + + conn = DriverManager.getConnection(url, properties) + conn.prepareStatement("create schema test").executeUpdate() + conn.prepareStatement( + "create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() + conn.prepareStatement("insert into test.people values ('fred', 1)").executeUpdate() + conn.prepareStatement("insert into test.people values ('mary', 2)").executeUpdate() + conn.prepareStatement( + "insert into test.people values ('joe ''foo'' \"bar\"', 3)").executeUpdate() + conn.commit() + + sql( + s""" + |CREATE TEMPORARY TABLE foobar + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + + sql( + s""" + |CREATE TEMPORARY TABLE parts + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass', + | partitionColumn 'THEID', lowerBound '1', upperBound '4', numPartitions '3') + """.stripMargin.replaceAll("\n", " ")) + + conn.prepareStatement("create table test.inttypes (a INT, b BOOLEAN, c TINYINT, " + + "d SMALLINT, e BIGINT)").executeUpdate() + conn.prepareStatement("insert into test.inttypes values (1, false, 3, 4, 1234567890123)" + ).executeUpdate() + conn.prepareStatement("insert into test.inttypes values (null, null, null, null, null)" + ).executeUpdate() + conn.commit() + sql( + s""" + |CREATE TEMPORARY TABLE inttypes + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable 'TEST.INTTYPES', user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + + conn.prepareStatement("create table test.strtypes (a BINARY(20), b VARCHAR(20), " + + "c VARCHAR_IGNORECASE(20), d CHAR(20), e BLOB, f CLOB)").executeUpdate() + val stmt = conn.prepareStatement("insert into test.strtypes values (?, ?, ?, ?, ?, ?)") + stmt.setBytes(1, testBytes) + stmt.setString(2, "Sensitive") + stmt.setString(3, "Insensitive") + stmt.setString(4, "Twenty-byte CHAR") + stmt.setBytes(5, testBytes) + stmt.setString(6, "I am a clob!") + stmt.executeUpdate() + sql( + s""" + |CREATE TEMPORARY TABLE strtypes + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable 'TEST.STRTYPES', user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + + conn.prepareStatement("create table test.timetypes (a TIME, b DATE, c TIMESTAMP)" + ).executeUpdate() + conn.prepareStatement("insert into test.timetypes values ('12:34:56', " + + "'1996-01-01', '2002-02-20 11:22:33.543543543')").executeUpdate() + conn.commit() + sql( + s""" + |CREATE TEMPORARY TABLE timetypes + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable 'TEST.TIMETYPES', user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + + + conn.prepareStatement("create table test.flttypes (a DOUBLE, b REAL, c DECIMAL(40, 20))" + ).executeUpdate() + conn.prepareStatement("insert into test.flttypes values (" + + "1.0000000000000002220446049250313080847263336181640625, " + + "1.00000011920928955078125, " + + "123456789012345.543215432154321)").executeUpdate() + conn.commit() + sql( + s""" + |CREATE TEMPORARY TABLE flttypes + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable 'TEST.FLTTYPES', user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + + // Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types. + } + + after { + conn.close() + } + + test("SELECT *") { + assert(sql("SELECT * FROM foobar").collect().size === 3) + } + + test("SELECT * WHERE (simple predicates)") { + assert(sql("SELECT * FROM foobar WHERE THEID < 1").collect().size === 0) + assert(sql("SELECT * FROM foobar WHERE THEID != 2").collect().size === 2) + assert(sql("SELECT * FROM foobar WHERE THEID = 1").collect().size === 1) + assert(sql("SELECT * FROM foobar WHERE NAME = 'fred'").collect().size === 1) + assert(sql("SELECT * FROM foobar WHERE NAME > 'fred'").collect().size === 2) + assert(sql("SELECT * FROM foobar WHERE NAME != 'fred'").collect().size === 2) + } + + test("SELECT * WHERE (quoted strings)") { + assert(sql("select * from foobar").where('NAME === "joe 'foo' \"bar\"").collect().size === 1) + } + + test("SELECT first field") { + val names = sql("SELECT NAME FROM foobar").collect().map(x => x.getString(0)).sortWith(_ < _) + assert(names.size === 3) + assert(names(0).equals("fred")) + assert(names(1).equals("joe 'foo' \"bar\"")) + assert(names(2).equals("mary")) + } + + test("SELECT second field") { + val ids = sql("SELECT THEID FROM foobar").collect().map(x => x.getInt(0)).sortWith(_ < _) + assert(ids.size === 3) + assert(ids(0) === 1) + assert(ids(1) === 2) + assert(ids(2) === 3) + } + + test("SELECT * partitioned") { + assert(sql("SELECT * FROM parts").collect().size == 3) + } + + test("SELECT WHERE (simple predicates) partitioned") { + assert(sql("SELECT * FROM parts WHERE THEID < 1").collect().size === 0) + assert(sql("SELECT * FROM parts WHERE THEID != 2").collect().size === 2) + assert(sql("SELECT THEID FROM parts WHERE THEID = 1").collect().size === 1) + } + + test("SELECT second field partitioned") { + val ids = sql("SELECT THEID FROM parts").collect().map(x => x.getInt(0)).sortWith(_ < _) + assert(ids.size === 3) + assert(ids(0) === 1) + assert(ids(1) === 2) + assert(ids(2) === 3) + } + + test("Basic API") { + assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE").collect().size === 3) + } + + test("Partitioning via JDBCPartitioningInfo API") { + assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3) + .collect.size === 3) + } + + test("Partitioning via list-of-where-clauses API") { + val parts = Array[String]("THEID < 2", "THEID >= 2") + assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts).collect().size === 3) + } + + test("H2 integral types") { + val rows = sql("SELECT * FROM inttypes WHERE A IS NOT NULL").collect() + assert(rows.size === 1) + assert(rows(0).getInt(0) === 1) + assert(rows(0).getBoolean(1) === false) + assert(rows(0).getInt(2) === 3) + assert(rows(0).getInt(3) === 4) + assert(rows(0).getLong(4) === 1234567890123L) + } + + test("H2 null entries") { + val rows = sql("SELECT * FROM inttypes WHERE A IS NULL").collect() + assert(rows.size === 1) + assert(rows(0).isNullAt(0)) + assert(rows(0).isNullAt(1)) + assert(rows(0).isNullAt(2)) + assert(rows(0).isNullAt(3)) + assert(rows(0).isNullAt(4)) + } + + test("H2 string types") { + val rows = sql("SELECT * FROM strtypes").collect() + assert(rows(0).getAs[Array[Byte]](0).sameElements(testBytes)) + assert(rows(0).getString(1).equals("Sensitive")) + assert(rows(0).getString(2).equals("Insensitive")) + assert(rows(0).getString(3).equals("Twenty-byte CHAR")) + assert(rows(0).getAs[Array[Byte]](4).sameElements(testBytes)) + assert(rows(0).getString(5).equals("I am a clob!")) + } + + test("H2 time types") { + val rows = sql("SELECT * FROM timetypes").collect() + val cal = new GregorianCalendar(java.util.Locale.ROOT) + cal.setTime(rows(0).getAs[java.sql.Timestamp](0)) + assert(cal.get(Calendar.HOUR_OF_DAY) === 12) + assert(cal.get(Calendar.MINUTE) === 34) + assert(cal.get(Calendar.SECOND) === 56) + cal.setTime(rows(0).getAs[java.sql.Timestamp](1)) + assert(cal.get(Calendar.YEAR) === 1996) + assert(cal.get(Calendar.MONTH) === 0) + assert(cal.get(Calendar.DAY_OF_MONTH) === 1) + cal.setTime(rows(0).getAs[java.sql.Timestamp](2)) + assert(cal.get(Calendar.YEAR) === 2002) + assert(cal.get(Calendar.MONTH) === 1) + assert(cal.get(Calendar.DAY_OF_MONTH) === 20) + assert(cal.get(Calendar.HOUR) === 11) + assert(cal.get(Calendar.MINUTE) === 22) + assert(cal.get(Calendar.SECOND) === 33) + assert(rows(0).getAs[java.sql.Timestamp](2).getNanos === 543543543) + } + + test("test DATE types") { + val rows = TestSQLContext.jdbc(urlWithUserAndPass, "TEST.TIMETYPES").collect() + val cachedRows = TestSQLContext.jdbc(urlWithUserAndPass, "TEST.TIMETYPES").cache().collect() + assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) + assert(cachedRows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) + } + + test("H2 floating-point types") { + val rows = sql("SELECT * FROM flttypes").collect() + assert(rows(0).getDouble(0) === 1.00000000000000022) // Yes, I meant ==. + assert(rows(0).getDouble(1) === 1.00000011920928955) // Yes, I meant ==. + assert(rows(0).getAs[BigDecimal](2) + .equals(new BigDecimal("123456789012345.54321543215432100000"))) + } + + test("SQL query as table name") { + sql( + s""" + |CREATE TEMPORARY TABLE hack + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable '(SELECT B, B*B FROM TEST.FLTTYPES)', + | user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + val rows = sql("SELECT * FROM hack").collect() + assert(rows(0).getDouble(0) === 1.00000011920928955) // Yes, I meant ==. + // For some reason, H2 computes this square incorrectly... + assert(math.abs(rows(0).getDouble(1) - 1.00000023841859331) < 1e-12) + } + + test("Pass extra properties via OPTIONS") { + // We set rowId to false during setup, which means that _ROWID_ column should be absent from + // all tables. If rowId is true (default), the query below doesn't throw an exception. + intercept[JdbcSQLException] { + sql( + s""" + |CREATE TEMPORARY TABLE abc + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable '(SELECT _ROWID_ FROM test.people)', + | user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala new file mode 100644 index 000000000000..ee5c7620d1a2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.DriverManager + +import org.scalatest.{BeforeAndAfter, FunSuite} + +import org.apache.spark.sql.Row +import org.apache.spark.sql.test._ +import org.apache.spark.sql.types._ + +class JDBCWriteSuite extends FunSuite with BeforeAndAfter { + val url = "jdbc:h2:mem:testdb2" + var conn: java.sql.Connection = null + + before { + Class.forName("org.h2.Driver") + conn = DriverManager.getConnection(url) + conn.prepareStatement("create schema test").executeUpdate() + } + + after { + conn.close() + } + + val sc = TestSQLContext.sparkContext + + val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222)) + val arr1x2 = Array[Row](Row.apply("fred", 3)) + val schema2 = StructType( + StructField("name", StringType) :: + StructField("id", IntegerType) :: Nil) + + val arr2x3 = Array[Row](Row.apply("dave", 42, 1), Row.apply("mary", 222, 2)) + val schema3 = StructType( + StructField("name", StringType) :: + StructField("id", IntegerType) :: + StructField("seq", IntegerType) :: Nil) + + test("Basic CREATE") { + val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) + + df.createJDBCTable(url, "TEST.BASICCREATETEST", false) + assert(2 == TestSQLContext.jdbc(url, "TEST.BASICCREATETEST").count) + assert(2 == TestSQLContext.jdbc(url, "TEST.BASICCREATETEST").collect()(0).length) + } + + test("CREATE with overwrite") { + val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3) + val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) + + df.createJDBCTable(url, "TEST.DROPTEST", false) + assert(2 == TestSQLContext.jdbc(url, "TEST.DROPTEST").count) + assert(3 == TestSQLContext.jdbc(url, "TEST.DROPTEST").collect()(0).length) + + df2.createJDBCTable(url, "TEST.DROPTEST", true) + assert(1 == TestSQLContext.jdbc(url, "TEST.DROPTEST").count) + assert(2 == TestSQLContext.jdbc(url, "TEST.DROPTEST").collect()(0).length) + } + + test("CREATE then INSERT to append") { + val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) + val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) + + df.createJDBCTable(url, "TEST.APPENDTEST", false) + df2.insertIntoJDBC(url, "TEST.APPENDTEST", false) + assert(3 == TestSQLContext.jdbc(url, "TEST.APPENDTEST").count) + assert(2 == TestSQLContext.jdbc(url, "TEST.APPENDTEST").collect()(0).length) + } + + test("CREATE then INSERT to truncate") { + val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) + val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) + + df.createJDBCTable(url, "TEST.TRUNCATETEST", false) + df2.insertIntoJDBC(url, "TEST.TRUNCATETEST", true) + assert(1 == TestSQLContext.jdbc(url, "TEST.TRUNCATETEST").count) + assert(2 == TestSQLContext.jdbc(url, "TEST.TRUNCATETEST").collect()(0).length) + } + + test("Incompatible INSERT to append") { + val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) + val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3) + + df.createJDBCTable(url, "TEST.INCOMPATIBLETEST", false) + intercept[org.apache.spark.SparkException] { + df2.insertIntoJDBC(url, "TEST.INCOMPATIBLETEST", true) + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 94d14acccbb1..fd0e2746dc04 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -19,16 +19,22 @@ package org.apache.spark.sql.json import java.sql.{Date, Timestamp} +import org.scalactic.Tolerance._ + import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.json.JsonRDD.{compatibleType, enforceCorrectType} +import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{QueryTest, Row, SQLConf} +import org.apache.spark.util.Utils class JsonSuite extends QueryTest { import org.apache.spark.sql.json.TestJsonData._ + TestJsonData test("Type promotion") { @@ -65,14 +71,15 @@ class JsonSuite extends QueryTest { checkTypePromotion(Timestamp.valueOf(strTime), enforceCorrectType(strTime, TimestampType)) val strDate = "2014-10-15" - checkTypePromotion(Date.valueOf(strDate), enforceCorrectType(strDate, DateType)) + checkTypePromotion( + DateUtils.fromJavaDate(Date.valueOf(strDate)), enforceCorrectType(strDate, DateType)) val ISO8601Time1 = "1970-01-01T01:00:01.0Z" checkTypePromotion(new Timestamp(3601000), enforceCorrectType(ISO8601Time1, TimestampType)) - checkTypePromotion(new Date(3601000), enforceCorrectType(ISO8601Time1, DateType)) + checkTypePromotion(DateUtils.millisToDays(3601000), enforceCorrectType(ISO8601Time1, DateType)) val ISO8601Time2 = "1970-01-01T02:00:01-01:00" checkTypePromotion(new Timestamp(10801000), enforceCorrectType(ISO8601Time2, TimestampType)) - checkTypePromotion(new Date(10801000), enforceCorrectType(ISO8601Time2, DateType)) + checkTypePromotion(DateUtils.millisToDays(10801000), enforceCorrectType(ISO8601Time2, DateType)) } test("Get compatible type") { @@ -193,7 +200,7 @@ class JsonSuite extends QueryTest { } test("Complex field and type inferring with null in sampling") { - val jsonSchemaRDD = jsonRDD(jsonNullStruct) + val jsonDF = jsonRDD(jsonNullStruct) val expectedSchema = StructType( StructField("headers", StructType( StructField("Charset", StringType, true) :: @@ -202,8 +209,8 @@ class JsonSuite extends QueryTest { StructField("ip", StringType, true) :: StructField("nullstr", StringType, true):: Nil) - assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerTempTable("jsonTable") + assert(expectedSchema === jsonDF.schema) + jsonDF.registerTempTable("jsonTable") checkAnswer( sql("select nullstr, headers.Host from jsonTable"), @@ -212,20 +219,20 @@ class JsonSuite extends QueryTest { } test("Primitive field and type inferring") { - val jsonSchemaRDD = jsonRDD(primitiveFieldAndType) + val jsonDF = jsonRDD(primitiveFieldAndType) val expectedSchema = StructType( StructField("bigInteger", DecimalType.Unlimited, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: - StructField("integer", IntegerType, true) :: + StructField("integer", LongType, true) :: StructField("long", LongType, true) :: StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) - assert(expectedSchema === jsonSchemaRDD.schema) + assert(expectedSchema === jsonDF.schema) - jsonSchemaRDD.registerTempTable("jsonTable") + jsonDF.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -240,33 +247,33 @@ class JsonSuite extends QueryTest { } test("Complex field and type inferring") { - val jsonSchemaRDD = jsonRDD(complexFieldAndType1) + val jsonDF = jsonRDD(complexFieldAndType1) val expectedSchema = StructType( - StructField("arrayOfArray1", ArrayType(ArrayType(StringType, false), false), true) :: - StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, false), false), true) :: - StructField("arrayOfBigInteger", ArrayType(DecimalType.Unlimited, false), true) :: - StructField("arrayOfBoolean", ArrayType(BooleanType, false), true) :: - StructField("arrayOfDouble", ArrayType(DoubleType, false), true) :: - StructField("arrayOfInteger", ArrayType(IntegerType, false), true) :: - StructField("arrayOfLong", ArrayType(LongType, false), true) :: + StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: + StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, true), true), true) :: + StructField("arrayOfBigInteger", ArrayType(DecimalType.Unlimited, true), true) :: + StructField("arrayOfBoolean", ArrayType(BooleanType, true), true) :: + StructField("arrayOfDouble", ArrayType(DoubleType, true), true) :: + StructField("arrayOfInteger", ArrayType(LongType, true), true) :: + StructField("arrayOfLong", ArrayType(LongType, true), true) :: StructField("arrayOfNull", ArrayType(StringType, true), true) :: - StructField("arrayOfString", ArrayType(StringType, false), true) :: + StructField("arrayOfString", ArrayType(StringType, true), true) :: StructField("arrayOfStruct", ArrayType( StructType( StructField("field1", BooleanType, true) :: StructField("field2", StringType, true) :: - StructField("field3", StringType, true) :: Nil), false), true) :: + StructField("field3", StringType, true) :: Nil), true), true) :: StructField("struct", StructType( StructField("field1", BooleanType, true) :: StructField("field2", DecimalType.Unlimited, true) :: Nil), true) :: StructField("structWithArrayFields", StructType( - StructField("field1", ArrayType(IntegerType, false), true) :: - StructField("field2", ArrayType(StringType, false), true) :: Nil), true) :: Nil) + StructField("field1", ArrayType(LongType, true), true) :: + StructField("field2", ArrayType(StringType, true), true) :: Nil), true) :: Nil) - assert(expectedSchema === jsonSchemaRDD.schema) + assert(expectedSchema === jsonDF.schema) - jsonSchemaRDD.registerTempTable("jsonTable") + jsonDF.registerTempTable("jsonTable") // Access elements of a primitive array. checkAnswer( @@ -338,26 +345,24 @@ class JsonSuite extends QueryTest { ) } - ignore("Complex field and type inferring (Ignored)") { - val jsonSchemaRDD = jsonRDD(complexFieldAndType1) - jsonSchemaRDD.registerTempTable("jsonTable") + test("GetField operation on complex data type") { + val jsonDF = jsonRDD(complexFieldAndType1) + jsonDF.registerTempTable("jsonTable") - // Right now, "field1" and "field2" are treated as aliases. We should fix it. checkAnswer( sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), Row(true, "str1") ) - // Right now, the analyzer cannot resolve arrayOfStruct.field1 and arrayOfStruct.field2. // Getting all values of a specific field from an array of structs. checkAnswer( sql("select arrayOfStruct.field1, arrayOfStruct.field2 from jsonTable"), - Row(Seq(true, false), Seq("str1", null)) + Row(Seq(true, false, null), Seq("str1", null, null)) ) } test("Type conflict in primitive field values") { - val jsonSchemaRDD = jsonRDD(primitiveFieldValueTypeConflict) + val jsonDF = jsonRDD(primitiveFieldValueTypeConflict) val expectedSchema = StructType( StructField("num_bool", StringType, true) :: @@ -367,16 +372,18 @@ class JsonSuite extends QueryTest { StructField("num_str", StringType, true) :: StructField("str_bool", StringType, true) :: Nil) - assert(expectedSchema === jsonSchemaRDD.schema) + assert(expectedSchema === jsonDF.schema) - jsonSchemaRDD.registerTempTable("jsonTable") + jsonDF.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), Row("true", 11L, null, 1.1, "13.1", "str1") :: Row("12", null, new java.math.BigDecimal("21474836470.9"), null, null, "true") :: - Row("false", 21474836470L, new java.math.BigDecimal("92233720368547758070"), 100, "str1", "false") :: - Row(null, 21474836570L, new java.math.BigDecimal("1.1"), 21474836470L, "92233720368547758070", null) :: Nil + Row("false", 21474836470L, + new java.math.BigDecimal("92233720368547758070"), 100, "str1", "false") :: + Row(null, 21474836570L, + new java.math.BigDecimal("1.1"), 21474836470L, "92233720368547758070", null) :: Nil ) // Number and Boolean conflict: resolve the type as number in this query. @@ -399,7 +406,8 @@ class JsonSuite extends QueryTest { // Widening to DecimalType checkAnswer( sql("select num_num_2 + 1.2 from jsonTable where num_num_2 > 1.1"), - Row(new java.math.BigDecimal("21474836472.1")) :: Row(new java.math.BigDecimal("92233720368547758071.2")) :: Nil + Row(new java.math.BigDecimal("21474836472.1")) :: + Row(new java.math.BigDecimal("92233720368547758071.2")) :: Nil ) // Widening to DoubleType @@ -428,8 +436,8 @@ class JsonSuite extends QueryTest { } ignore("Type conflict in primitive field values (Ignored)") { - val jsonSchemaRDD = jsonRDD(primitiveFieldValueTypeConflict) - jsonSchemaRDD.registerTempTable("jsonTable") + val jsonDF = jsonRDD(primitiveFieldValueTypeConflict) + jsonDF.registerTempTable("jsonTable") // Right now, the analyzer does not promote strings in a boolean expreesion. // Number and Boolean conflict: resolve the type as boolean in this query. @@ -462,9 +470,9 @@ class JsonSuite extends QueryTest { // We should directly cast num_str to DecimalType and also need to do the right type promotion // in the Project. checkAnswer( - jsonSchemaRDD. + jsonDF. where('num_str > BigDecimal("92233720368547758060")). - select('num_str + 1.2 as Symbol("num")), + select(('num_str + 1.2).as("num")), Row(new java.math.BigDecimal("92233720368547758061.2")) ) @@ -481,19 +489,19 @@ class JsonSuite extends QueryTest { } test("Type conflict in complex field values") { - val jsonSchemaRDD = jsonRDD(complexFieldValueTypeConflict) + val jsonDF = jsonRDD(complexFieldValueTypeConflict) val expectedSchema = StructType( - StructField("array", ArrayType(IntegerType, false), true) :: + StructField("array", ArrayType(LongType, true), true) :: StructField("num_struct", StringType, true) :: StructField("str_array", StringType, true) :: StructField("struct", StructType( StructField("field", StringType, true) :: Nil), true) :: StructField("struct_array", StringType, true) :: Nil) - assert(expectedSchema === jsonSchemaRDD.schema) + assert(expectedSchema === jsonDF.schema) - jsonSchemaRDD.registerTempTable("jsonTable") + jsonDF.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -505,17 +513,17 @@ class JsonSuite extends QueryTest { } test("Type conflict in array elements") { - val jsonSchemaRDD = jsonRDD(arrayElementTypeConflict) + val jsonDF = jsonRDD(arrayElementTypeConflict) val expectedSchema = StructType( StructField("array1", ArrayType(StringType, true), true) :: StructField("array2", ArrayType(StructType( - StructField("field", LongType, true) :: Nil), false), true) :: - StructField("array3", ArrayType(StringType, false), true) :: Nil) + StructField("field", LongType, true) :: Nil), true), true) :: + StructField("array3", ArrayType(StringType, true), true) :: Nil) - assert(expectedSchema === jsonSchemaRDD.schema) + assert(expectedSchema === jsonDF.schema) - jsonSchemaRDD.registerTempTable("jsonTable") + jsonDF.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -533,39 +541,67 @@ class JsonSuite extends QueryTest { } test("Handling missing fields") { - val jsonSchemaRDD = jsonRDD(missingFields) + val jsonDF = jsonRDD(missingFields) val expectedSchema = StructType( StructField("a", BooleanType, true) :: StructField("b", LongType, true) :: - StructField("c", ArrayType(IntegerType, false), true) :: + StructField("c", ArrayType(LongType, true), true) :: StructField("d", StructType( StructField("field", BooleanType, true) :: Nil), true) :: StructField("e", StringType, true) :: Nil) - assert(expectedSchema === jsonSchemaRDD.schema) + assert(expectedSchema === jsonDF.schema) - jsonSchemaRDD.registerTempTable("jsonTable") + jsonDF.registerTempTable("jsonTable") + } + + test("jsonFile should be based on JSONRelation") { + val dir = Utils.createTempDir() + dir.delete() + val path = dir.getCanonicalPath + sparkContext.parallelize(1 to 100).map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) + val jsonDF = jsonFile(path, 0.49) + + val analyzed = jsonDF.queryExecution.analyzed + assert( + analyzed.isInstanceOf[LogicalRelation], + "The DataFrame returned by jsonFile should be based on JSONRelation.") + val relation = analyzed.asInstanceOf[LogicalRelation].relation + assert( + relation.isInstanceOf[JSONRelation], + "The DataFrame returned by jsonFile should be based on JSONRelation.") + assert(relation.asInstanceOf[JSONRelation].path === path) + assert(relation.asInstanceOf[JSONRelation].samplingRatio === (0.49 +- 0.001)) + + val schema = StructType(StructField("a", LongType, true) :: Nil) + val logicalRelation = + jsonFile(path, schema).queryExecution.analyzed.asInstanceOf[LogicalRelation] + val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation] + assert(relationWithSchema.path === path) + assert(relationWithSchema.schema === schema) + assert(relationWithSchema.samplingRatio > 0.99) } test("Loading a JSON dataset from a text file") { - val file = getTempFilePath("json") - val path = file.toString + val dir = Utils.createTempDir() + dir.delete() + val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonSchemaRDD = jsonFile(path) + val jsonDF = jsonFile(path) val expectedSchema = StructType( StructField("bigInteger", DecimalType.Unlimited, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: - StructField("integer", IntegerType, true) :: + StructField("integer", LongType, true) :: StructField("long", LongType, true) :: StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) - assert(expectedSchema === jsonSchemaRDD.schema) + assert(expectedSchema === jsonDF.schema) - jsonSchemaRDD.registerTempTable("jsonTable") + jsonDF.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -580,8 +616,9 @@ class JsonSuite extends QueryTest { } test("Loading a JSON dataset from a text file with SQL") { - val file = getTempFilePath("json") - val path = file.toString + val dir = Utils.createTempDir() + dir.delete() + val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) sql( @@ -606,8 +643,9 @@ class JsonSuite extends QueryTest { } test("Applying schemas") { - val file = getTempFilePath("json") - val path = file.toString + val dir = Utils.createTempDir() + dir.delete() + val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) val schema = StructType( @@ -619,11 +657,11 @@ class JsonSuite extends QueryTest { StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) - val jsonSchemaRDD1 = jsonFile(path, schema) + val jsonDF1 = jsonFile(path, schema) - assert(schema === jsonSchemaRDD1.schema) + assert(schema === jsonDF1.schema) - jsonSchemaRDD1.registerTempTable("jsonTable1") + jsonDF1.registerTempTable("jsonTable1") checkAnswer( sql("select * from jsonTable1"), @@ -636,11 +674,11 @@ class JsonSuite extends QueryTest { "this is a simple string.") ) - val jsonSchemaRDD2 = jsonRDD(primitiveFieldAndType, schema) + val jsonDF2 = jsonRDD(primitiveFieldAndType, schema) - assert(schema === jsonSchemaRDD2.schema) + assert(schema === jsonDF2.schema) - jsonSchemaRDD2.registerTempTable("jsonTable2") + jsonDF2.registerTempTable("jsonTable2") checkAnswer( sql("select * from jsonTable2"), @@ -654,9 +692,65 @@ class JsonSuite extends QueryTest { ) } + test("Applying schemas with MapType") { + val schemaWithSimpleMap = StructType( + StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) + val jsonWithSimpleMap = jsonRDD(mapType1, schemaWithSimpleMap) + + jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap") + + checkAnswer( + sql("select map from jsonWithSimpleMap"), + Row(Map("a" -> 1)) :: + Row(Map("b" -> 2)) :: + Row(Map("c" -> 3)) :: + Row(Map("c" -> 1, "d" -> 4)) :: + Row(Map("e" -> null)) :: Nil + ) + + checkAnswer( + sql("select map['c'] from jsonWithSimpleMap"), + Row(null) :: + Row(null) :: + Row(3) :: + Row(1) :: + Row(null) :: Nil + ) + + val innerStruct = StructType( + StructField("field1", ArrayType(IntegerType, true), true) :: + StructField("field2", IntegerType, true) :: Nil) + val schemaWithComplexMap = StructType( + StructField("map", MapType(StringType, innerStruct, true), false) :: Nil) + + val jsonWithComplexMap = jsonRDD(mapType2, schemaWithComplexMap) + + jsonWithComplexMap.registerTempTable("jsonWithComplexMap") + + checkAnswer( + sql("select map from jsonWithComplexMap"), + Row(Map("a" -> Row(Seq(1, 2, 3, null), null))) :: + Row(Map("b" -> Row(null, 2))) :: + Row(Map("c" -> Row(Seq(), 4))) :: + Row(Map("c" -> Row(null, 3), "d" -> Row(Seq(null), null))) :: + Row(Map("e" -> null)) :: + Row(Map("f" -> Row(null, null))) :: Nil + ) + + checkAnswer( + sql("select map['a'].field1, map['c'].field2 from jsonWithComplexMap"), + Row(Seq(1, 2, 3, null), null) :: + Row(null, null) :: + Row(null, 4) :: + Row(null, 3) :: + Row(null, null) :: + Row(null, null) :: Nil + ) + } + test("SPARK-2096 Correctly parse dot notations") { - val jsonSchemaRDD = jsonRDD(complexFieldAndType2) - jsonSchemaRDD.registerTempTable("jsonTable") + val jsonDF = jsonRDD(complexFieldAndType2) + jsonDF.registerTempTable("jsonTable") checkAnswer( sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), @@ -673,8 +767,8 @@ class JsonSuite extends QueryTest { } test("SPARK-3390 Complex arrays") { - val jsonSchemaRDD = jsonRDD(complexFieldAndType2) - jsonSchemaRDD.registerTempTable("jsonTable") + val jsonDF = jsonRDD(complexFieldAndType2) + jsonDF.registerTempTable("jsonTable") checkAnswer( sql( @@ -696,8 +790,8 @@ class JsonSuite extends QueryTest { } test("SPARK-3308 Read top level JSON arrays") { - val jsonSchemaRDD = jsonRDD(jsonArray) - jsonSchemaRDD.registerTempTable("jsonTable") + val jsonDF = jsonRDD(jsonArray) + jsonDF.registerTempTable("jsonTable") checkAnswer( sql( @@ -717,8 +811,8 @@ class JsonSuite extends QueryTest { val oldColumnNameOfCorruptRecord = TestSQLContext.conf.columnNameOfCorruptRecord TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") - val jsonSchemaRDD = jsonRDD(corruptRecords) - jsonSchemaRDD.registerTempTable("jsonTable") + val jsonDF = jsonRDD(corruptRecords) + jsonDF.registerTempTable("jsonTable") val schema = StructType( StructField("_unparsed", StringType, true) :: @@ -726,7 +820,7 @@ class JsonSuite extends QueryTest { StructField("b", StringType, true) :: StructField("c", StringType, true) :: Nil) - assert(schema === jsonSchemaRDD.schema) + assert(schema === jsonDF.schema) // In HiveContext, backticks should be used to access columns starting with a underscore. checkAnswer( @@ -771,22 +865,22 @@ class JsonSuite extends QueryTest { } test("SPARK-4068: nulls in arrays") { - val jsonSchemaRDD = jsonRDD(nullsInArrays) - jsonSchemaRDD.registerTempTable("jsonTable") + val jsonDF = jsonRDD(nullsInArrays) + jsonDF.registerTempTable("jsonTable") val schema = StructType( StructField("field1", - ArrayType(ArrayType(ArrayType(ArrayType(StringType, false), false), true), false), true) :: + ArrayType(ArrayType(ArrayType(ArrayType(StringType, true), true), true), true), true) :: StructField("field2", ArrayType(ArrayType( - StructType(StructField("Test", IntegerType, true) :: Nil), false), true), true) :: + StructType(StructField("Test", LongType, true) :: Nil), true), true), true) :: StructField("field3", ArrayType(ArrayType( - StructType(StructField("Test", StringType, true) :: Nil), true), false), true) :: + StructType(StructField("Test", StringType, true) :: Nil), true), true), true) :: StructField("field4", - ArrayType(ArrayType(ArrayType(IntegerType, false), true), false), true) :: Nil) + ArrayType(ArrayType(ArrayType(LongType, true), true), true), true) :: Nil) - assert(schema === jsonSchemaRDD.schema) + assert(schema === jsonDF.schema) checkAnswer( sql( @@ -801,8 +895,7 @@ class JsonSuite extends QueryTest { ) } - test("SPARK-4228 SchemaRDD to JSON") - { + test("SPARK-4228 DataFrame to JSON") { val schema1 = StructType( StructField("f1", IntegerType, false) :: StructField("f2", StringType, false) :: @@ -818,12 +911,14 @@ class JsonSuite extends QueryTest { Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5) } - val schemaRDD1 = applySchema(rowRDD1, schema1) - schemaRDD1.registerTempTable("applySchema1") - val schemaRDD2 = schemaRDD1.toSchemaRDD - val result = schemaRDD2.toJSON.collect() - assert(result(0) == "{\"f1\":1,\"f2\":\"A1\",\"f3\":true,\"f4\":[\"1\",\" A1\",\" true\",\" null\"]}") - assert(result(3) == "{\"f1\":4,\"f2\":\"D4\",\"f3\":true,\"f4\":[\"4\",\" D4\",\" true\",\" 2147483644\"],\"f5\":2147483644}") + val df1 = createDataFrame(rowRDD1, schema1) + df1.registerTempTable("applySchema1") + val df2 = df1.toDF + val result = df2.toJSON.collect() + // scalastyle:off + assert(result(0) === "{\"f1\":1,\"f2\":\"A1\",\"f3\":true,\"f4\":[\"1\",\" A1\",\" true\",\" null\"]}") + assert(result(3) === "{\"f1\":4,\"f2\":\"D4\",\"f3\":true,\"f4\":[\"4\",\" D4\",\" true\",\" 2147483644\"],\"f5\":2147483644}") + // scalastyle:on val schema2 = StructType( StructField("f1", StructType( @@ -839,16 +934,16 @@ class JsonSuite extends QueryTest { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val schemaRDD3 = applySchema(rowRDD2, schema2) - schemaRDD3.registerTempTable("applySchema2") - val schemaRDD4 = schemaRDD3.toSchemaRDD - val result2 = schemaRDD4.toJSON.collect() + val df3 = createDataFrame(rowRDD2, schema2) + df3.registerTempTable("applySchema2") + val df4 = df3.toDF + val result2 = df4.toJSON.collect() - assert(result2(1) == "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") - assert(result2(3) == "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") + assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") + assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") - val jsonSchemaRDD = jsonRDD(primitiveFieldAndType) - val primTable = jsonRDD(jsonSchemaRDD.toJSON) + val jsonDF = jsonRDD(primitiveFieldAndType) + val primTable = jsonRDD(jsonDF.toJSON) primTable.registerTempTable("primativeTable") checkAnswer( sql("select * from primativeTable"), @@ -860,8 +955,8 @@ class JsonSuite extends QueryTest { "this is a simple string.") ) - val complexJsonSchemaRDD = jsonRDD(complexFieldAndType1) - val compTable = jsonRDD(complexJsonSchemaRDD.toJSON) + val complexJsonDF = jsonRDD(complexFieldAndType1) + val compTable = jsonRDD(complexJsonDF.toJSON) compTable.registerTempTable("complexTable") // Access elements of a primitive array. checkAnswer( @@ -877,7 +972,8 @@ class JsonSuite extends QueryTest { // Access elements of a BigInteger array (we use DecimalType internally). checkAnswer( - sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from complexTable"), + sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] " + + " from complexTable"), Row(new java.math.BigDecimal("922337203685477580700"), new java.math.BigDecimal("-922337203685477580800"), null) ) @@ -917,9 +1013,41 @@ class JsonSuite extends QueryTest { // Access elements of an array field of a struct. checkAnswer( - sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from complexTable"), + sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] " + + "from complexTable"), Row(5, null) ) + } + test("JSONRelation equality test") { + val relation1 = + JSONRelation("path", 1.0, Some(StructType(StructField("a", IntegerType, true) :: Nil)))(null) + val logicalRelation1 = LogicalRelation(relation1) + val relation2 = + JSONRelation("path", 0.5, Some(StructType(StructField("a", IntegerType, true) :: Nil)))( + org.apache.spark.sql.test.TestSQLContext) + val logicalRelation2 = LogicalRelation(relation2) + val relation3 = + JSONRelation("path", 1.0, Some(StructType(StructField("b", StringType, true) :: Nil)))(null) + val logicalRelation3 = LogicalRelation(relation3) + + assert(relation1 === relation2) + assert(logicalRelation1.sameResult(logicalRelation2), + s"$logicalRelation1 and $logicalRelation2 should be considered having the same result.") + + assert(relation1 !== relation3) + assert(!logicalRelation1.sameResult(logicalRelation3), + s"$logicalRelation1 and $logicalRelation3 should be considered not having the same result.") + + assert(relation2 !== relation3) + assert(!logicalRelation2.sameResult(logicalRelation3), + s"$logicalRelation2 and $logicalRelation3 should be considered not having the same result.") } + + test("SPARK-6245 JsonRDD.inferSchema on empty RDD") { + // This is really a test that it doesn't throw an exception + val emptySchema = JsonRDD.inferSchema(empty, 1.0, "") + assert(StructType(Seq()) === emptySchema) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index 3370b3c98b4b..47a97a49daab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -146,6 +146,23 @@ object TestJsonData { ]] }""" :: Nil) + val mapType1 = + TestSQLContext.sparkContext.parallelize( + """{"map": {"a": 1}}""" :: + """{"map": {"b": 2}}""" :: + """{"map": {"c": 3}}""" :: + """{"map": {"c": 1, "d": 4}}""" :: + """{"map": {"e": null}}""" :: Nil) + + val mapType2 = + TestSQLContext.sparkContext.parallelize( + """{"map": {"a": {"field1": [1, 2, 3, null]}}}""" :: + """{"map": {"b": {"field2": 2}}}""" :: + """{"map": {"c": {"field1": [], "field2": 4}}}""" :: + """{"map": {"c": {"field2": 3}, "d": {"field1": [null]}}}""" :: + """{"map": {"e": null}}""" :: + """{"map": {"f": {"field1": null}}}""" :: Nil) + val nullsInArrays = TestSQLContext.sparkContext.parallelize( """{"field1":[[null], [[["Test"]]]]}""" :: @@ -168,4 +185,7 @@ object TestJsonData { """{"a":{, b:3}""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """]""" :: Nil) + + val empty = + TestSQLContext.sparkContext.parallelize(Seq[String]()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 1e7d3e06fc19..10d0ede4dc0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -17,13 +17,17 @@ package org.apache.spark.sql.parquet +import org.scalatest.BeforeAndAfterAll import parquet.filter2.predicate.Operators._ import parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal, Predicate, Row} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{QueryTest, SQLConf, SchemaRDD} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf} /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -37,23 +41,33 @@ import org.apache.spark.sql.{QueryTest, SQLConf, SchemaRDD} * 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred * data type is nullable. */ -class ParquetFilterSuite extends QueryTest with ParquetTest { +class ParquetFilterSuiteBase extends QueryTest with ParquetTest { val sqlContext = TestSQLContext private def checkFilterPredicate( - rdd: SchemaRDD, + df: DataFrame, predicate: Predicate, filterClass: Class[_ <: FilterPredicate], - checker: (SchemaRDD, Seq[Row]) => Unit, + checker: (DataFrame, Seq[Row]) => Unit, expected: Seq[Row]): Unit = { val output = predicate.collect { case a: Attribute => a }.distinct withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { - val query = rdd.select(output: _*).where(predicate) + val query = df + .select(output.map(e => Column(e)): _*) + .where(Column(predicate)) - val maybeAnalyzedPredicate = query.queryExecution.executedPlan.collect { - case plan: ParquetTableScan => plan.columnPruningPred - }.flatten.reduceOption(_ && _) + val maybeAnalyzedPredicate = { + val forParquetTableScan = query.queryExecution.executedPlan.collect { + case plan: ParquetTableScan => plan.columnPruningPred + }.flatten.reduceOption(_ && _) + + val forParquetDataSource = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, LogicalRelation(_: ParquetRelation2)) => filters + }.flatten.reduceOption(_ && _) + + forParquetTableScan.orElse(forParquetDataSource) + } assert(maybeAnalyzedPredicate.isDefined) maybeAnalyzedPredicate.foreach { pred => @@ -71,168 +85,196 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { private def checkFilterPredicate (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) - (implicit rdd: SchemaRDD): Unit = { - checkFilterPredicate(rdd, predicate, filterClass, checkAnswer(_, _: Seq[Row]), expected) + (implicit df: DataFrame): Unit = { + checkFilterPredicate(df, predicate, filterClass, checkAnswer(_, _: Seq[Row]), expected) } private def checkFilterPredicate[T] (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: T) - (implicit rdd: SchemaRDD): Unit = { - checkFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) + (implicit df: DataFrame): Unit = { + checkFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df) + } + + private def checkBinaryFilterPredicate + (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) + (implicit df: DataFrame): Unit = { + def checkBinaryAnswer(df: DataFrame, expected: Seq[Row]) = { + assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).toSeq.sorted) { + df.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted + } + } + + checkFilterPredicate(df, predicate, filterClass, checkBinaryAnswer _, expected) + } + + private def checkBinaryFilterPredicate + (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Array[Byte]) + (implicit df: DataFrame): Unit = { + checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df) } test("filter pushdown - boolean") { - withParquetRDD((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + withParquetDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false))) - checkFilterPredicate('_1 === true, classOf[Eq [_]], true) + checkFilterPredicate('_1 === true, classOf[Eq[_]], true) checkFilterPredicate('_1 !== true, classOf[NotEq[_]], false) } } + test("filter pushdown - short") { + withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit df => + checkFilterPredicate(Cast('_1, IntegerType) === 1, classOf[Eq[_]], 1) + checkFilterPredicate( + Cast('_1, IntegerType) !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate(Cast('_1, IntegerType) < 2, classOf[Lt[_]], 1) + checkFilterPredicate(Cast('_1, IntegerType) > 3, classOf[Gt[_]], 4) + checkFilterPredicate(Cast('_1, IntegerType) <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate(Cast('_1, IntegerType) >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === Cast('_1, IntegerType), classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > Cast('_1, IntegerType), classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < Cast('_1, IntegerType), classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= Cast('_1, IntegerType), classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= Cast('_1, IntegerType), classOf[GtEq[_]], 4) + + checkFilterPredicate(!(Cast('_1, IntegerType) < 4), classOf[GtEq[_]], 4) + checkFilterPredicate( + Cast('_1, IntegerType) > 2 && Cast('_1, IntegerType) < 4, classOf[Operators.And], 3) + checkFilterPredicate( + Cast('_1, IntegerType) < 2 || Cast('_1, IntegerType) > 3, + classOf[Operators.Or], + Seq(Row(1), Row(4))) + } + } + test("filter pushdown - integer") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } test("filter pushdown - long") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit df => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } test("filter pushdown - float") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit df => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } test("filter pushdown - double") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit df => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } test("filter pushdown - string") { - withParquetRDD((1 to 4).map(i => Tuple1(i.toString))) { implicit rdd => + withParquetDataFrame((1 to 4).map(i => Tuple1(i.toString))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate( '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) - checkFilterPredicate('_1 === "1", classOf[Eq [_]], "1") - checkFilterPredicate('_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) + checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1") + checkFilterPredicate( + '_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) - checkFilterPredicate('_1 < "2", classOf[Lt [_]], "1") - checkFilterPredicate('_1 > "3", classOf[Gt [_]], "4") + checkFilterPredicate('_1 < "2", classOf[Lt[_]], "1") + checkFilterPredicate('_1 > "3", classOf[Gt[_]], "4") checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1") checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4") - checkFilterPredicate(Literal("1") === '_1, classOf[Eq [_]], "1") - checkFilterPredicate(Literal("2") > '_1, classOf[Lt [_]], "1") - checkFilterPredicate(Literal("3") < '_1, classOf[Gt [_]], "4") - checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") - checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") + checkFilterPredicate(Literal("1") === '_1, classOf[Eq[_]], "1") + checkFilterPredicate(Literal("2") > '_1, classOf[Lt[_]], "1") + checkFilterPredicate(Literal("3") < '_1, classOf[Gt[_]], "4") + checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") + checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") - checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") + checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") checkFilterPredicate('_1 > "2" && '_1 < "4", classOf[Operators.And], "3") - checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) - } - } - - def checkBinaryFilterPredicate - (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) - (implicit rdd: SchemaRDD): Unit = { - def checkBinaryAnswer(rdd: SchemaRDD, expected: Seq[Row]) = { - assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).toSeq.sorted) { - rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted - } + checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) } - - checkFilterPredicate(rdd, predicate, filterClass, checkBinaryAnswer _, expected) - } - - def checkBinaryFilterPredicate - (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Array[Byte]) - (implicit rdd: SchemaRDD): Unit = { - checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) } test("filter pushdown - binary") { @@ -240,25 +282,26 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { def b: Array[Byte] = int.toString.getBytes("UTF-8") } - withParquetRDD((1 to 4).map(i => Tuple1(i.b))) { implicit rdd => + withParquetDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df => + checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq[_]], 1.b) + checkBinaryFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkBinaryFilterPredicate( '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.b)).toSeq) - checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq [_]], 1.b) checkBinaryFilterPredicate( '_1 !== 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq) - checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt [_]], 1.b) - checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt [_]], 4.b) + checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt[_]], 1.b) + checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt[_]], 4.b) checkBinaryFilterPredicate('_1 <= 1.b, classOf[LtEq[_]], 1.b) checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b) - checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq [_]], 1.b) - checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt [_]], 1.b) - checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt [_]], 4.b) - checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) - checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) + checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq[_]], 1.b) + checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt[_]], 1.b) + checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt[_]], 4.b) + checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) + checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b) checkBinaryFilterPredicate('_1 > 2.b && '_1 < 4.b, classOf[Operators.And], 3.b) @@ -267,3 +310,66 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { } } } + +class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { + val originalConf = sqlContext.conf.parquetUseDataSourceApi + + override protected def beforeAll(): Unit = { + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + } + + override protected def afterAll(): Unit = { + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } + + test("SPARK-6554: don't push down predicates which reference partition columns") { + import sqlContext.implicits._ + + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/part=1" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").saveAsParquetFile(path) + + // If the "part = 1" filter gets pushed down, this query will throw an exception since + // "part" is not a valid column in the actual Parquet file + checkAnswer( + sqlContext.parquetFile(path).filter("part = 1"), + (1 to 3).map(i => Row(i, i.toString, 1))) + } + } + } +} + +class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { + val originalConf = sqlContext.conf.parquetUseDataSourceApi + + override protected def beforeAll(): Unit = { + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + } + + override protected def afterAll(): Unit = { + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } + + test("SPARK-6742: don't push down predicates which reference partition columns") { + import sqlContext.implicits._ + + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/part=1" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").saveAsParquetFile(path) + + // If the "part = 1" filter gets pushed down, this query will throw an exception since + // "part" is not a valid column in the actual Parquet file + val df = DataFrame(sqlContext, org.apache.spark.sql.parquet.ParquetRelation( + path, + Some(sqlContext.sparkContext.hadoopConfiguration), sqlContext, + Seq(AttributeReference("part", IntegerType, false)()) )) + + checkAnswer( + df.filter("a = 1 or part = 1"), + (1 to 3).map(i => Row(1, i, i.toString))) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index a57e4e85a35e..97c0f439acf1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -21,23 +21,25 @@ import scala.collection.JavaConversions._ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.scalatest.BeforeAndAfterAll import parquet.example.data.simple.SimpleGroup import parquet.example.data.{Group, GroupWriter} import parquet.hadoop.api.WriteSupport import parquet.hadoop.api.WriteSupport.WriteContext -import parquet.hadoop.metadata.CompressionCodecName -import parquet.hadoop.{ParquetFileWriter, ParquetWriter} +import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData, CompressionCodecName} +import parquet.hadoop.{Footer, ParquetFileWriter, ParquetWriter} import parquet.io.api.RecordConsumer import parquet.schema.{MessageType, MessageTypeParser} -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.types.DecimalType -import org.apache.spark.sql.{QueryTest, SQLConf, SchemaRDD} +import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf, SaveMode} // Write support class for nested groups: ParquetWriter initializes GroupWriteSupport // with an empty configuration (it is after all not intended to be used in this way?) @@ -62,14 +64,16 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS /** * A test suite that tests basic Parquet I/O. */ -class ParquetIOSuite extends QueryTest with ParquetTest { +class ParquetIOSuiteBase extends QueryTest with ParquetTest { val sqlContext = TestSQLContext + import sqlContext.implicits.localSeqToDataFrameHolder + /** * Writes `data` to a Parquet file, reads it back and check file contents. */ protected def checkParquetFile[T <: Product : ClassTag: TypeTag](data: Seq[T]): Unit = { - withParquetRDD(data)(r => checkAnswer(r, data.map(Row.fromTuple))) + withParquetDataFrame(data)(r => checkAnswer(r, data.map(Row.fromTuple))) } test("basic data types (without binary)") { @@ -81,9 +85,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest { test("raw binary") { val data = (1 to 4).map(i => Tuple1(Array.fill(3)(i.toByte))) - withParquetRDD(data) { rdd => + withParquetDataFrame(data) { df => assertResult(data.map(_._1.mkString(",")).sorted) { - rdd.collect().map(_.getAs[Array[Byte]](0).mkString(",")).sorted + df.collect().map(_.getAs[Array[Byte]](0).mkString(",")).sorted } } } @@ -97,11 +101,14 @@ class ParquetIOSuite extends QueryTest with ParquetTest { } test("fixed-length decimals") { - def makeDecimalRDD(decimal: DecimalType): SchemaRDD = + + def makeDecimalRDD(decimal: DecimalType): DataFrame = sparkContext .parallelize(0 to 1000) .map(i => Tuple1(i / 100.0)) - .select('_1 cast decimal) + .toDF() + // Parquet doesn't allow column names with spaces, have to add an alias here + .select($"_1" cast decimal as "dec") for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { withTempPath { dir => @@ -128,6 +135,21 @@ class ParquetIOSuite extends QueryTest with ParquetTest { } } + test("date type") { + def makeDateRDD(): DataFrame = + sparkContext + .parallelize(0 to 1000) + .map(i => Tuple1(DateUtils.toJavaDate(i))) + .toDF() + .select($"_1") + + withTempPath { dir => + val data = makeDateRDD() + data.saveAsParquetFile(dir.getCanonicalPath) + checkAnswer(parquetFile(dir.getCanonicalPath), data.collect().toSeq) + } + } + test("map") { val data = (1 to 4).map(i => Tuple1(Map(i -> s"val_$i"))) checkParquetFile(data) @@ -140,9 +162,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest { test("struct") { val data = (1 to 4).map(i => Tuple1((i, s"val_$i"))) - withParquetRDD(data) { rdd => + withParquetDataFrame(data) { df => // Structs are converted to `Row`s - checkAnswer(rdd, data.map { case Tuple1(struct) => + checkAnswer(df, data.map { case Tuple1(struct) => Row(Row(struct.productIterator.toSeq: _*)) }) } @@ -150,9 +172,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest { test("nested struct with array of array as field") { val data = (1 to 4).map(i => Tuple1((i, Seq(Seq(s"val_$i"))))) - withParquetRDD(data) { rdd => + withParquetDataFrame(data) { df => // Structs are converted to `Row`s - checkAnswer(rdd, data.map { case Tuple1(struct) => + checkAnswer(df, data.map { case Tuple1(struct) => Row(Row(struct.productIterator.toSeq: _*)) }) } @@ -160,8 +182,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest { test("nested map with struct as value type") { val data = (1 to 4).map(i => Tuple1(Map(i -> (i, s"val_$i")))) - withParquetRDD(data) { rdd => - checkAnswer(rdd, data.map { case Tuple1(m) => + withParquetDataFrame(data) { df => + checkAnswer(df, data.map { case Tuple1(m) => Row(m.mapValues(struct => Row(struct.productIterator.toSeq: _*))) }) } @@ -175,8 +197,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest { null.asInstanceOf[java.lang.Float], null.asInstanceOf[java.lang.Double]) - withParquetRDD(allNulls :: Nil) { rdd => - val rows = rdd.collect() + withParquetDataFrame(allNulls :: Nil) { df => + val rows = df.collect() assert(rows.size === 1) assert(rows.head === Row(Seq.fill(5)(null): _*)) } @@ -188,15 +210,15 @@ class ParquetIOSuite extends QueryTest with ParquetTest { None.asInstanceOf[Option[Long]], None.asInstanceOf[Option[String]]) - withParquetRDD(allNones :: Nil) { rdd => - val rows = rdd.collect() + withParquetDataFrame(allNones :: Nil) { df => + val rows = df.collect() assert(rows.size === 1) assert(rows.head === Row(Seq.fill(3)(null): _*)) } } test("compression codec") { - def compressionCodecFor(path: String) = { + def compressionCodecFor(path: String): String = { val codecs = ParquetTypesConverter .readMetaData(new Path(path), Some(configuration)) .getBlocks @@ -284,4 +306,115 @@ class ParquetIOSuite extends QueryTest with ParquetTest { expectedSchema.checkContains(actualSchema) } } + + test("save - overwrite") { + withParquetFile((1 to 10).map(i => (i, i.toString))) { file => + val newData = (11 to 20).map(i => (i, i.toString)) + newData.toDF().save("org.apache.spark.sql.parquet", SaveMode.Overwrite, Map("path" -> file)) + checkAnswer(parquetFile(file), newData.map(Row.fromTuple)) + } + } + + test("save - ignore") { + val data = (1 to 10).map(i => (i, i.toString)) + withParquetFile(data) { file => + val newData = (11 to 20).map(i => (i, i.toString)) + newData.toDF().save("org.apache.spark.sql.parquet", SaveMode.Ignore, Map("path" -> file)) + checkAnswer(parquetFile(file), data.map(Row.fromTuple)) + } + } + + test("save - throw") { + val data = (1 to 10).map(i => (i, i.toString)) + withParquetFile(data) { file => + val newData = (11 to 20).map(i => (i, i.toString)) + val errorMessage = intercept[Throwable] { + newData.toDF().save( + "org.apache.spark.sql.parquet", SaveMode.ErrorIfExists, Map("path" -> file)) + }.getMessage + assert(errorMessage.contains("already exists")) + } + } + + test("save - append") { + val data = (1 to 10).map(i => (i, i.toString)) + withParquetFile(data) { file => + val newData = (11 to 20).map(i => (i, i.toString)) + newData.toDF().save("org.apache.spark.sql.parquet", SaveMode.Append, Map("path" -> file)) + checkAnswer(parquetFile(file), (data ++ newData).map(Row.fromTuple)) + } + } + + test("SPARK-6315 regression test") { + // Spark 1.1 and prior versions write Spark schema as case class string into Parquet metadata. + // This has been deprecated by JSON format since 1.2. Notice that, 1.3 further refactored data + // types API, and made StructType.fields an array. This makes the result of StructType.toString + // different from prior versions: there's no "Seq" wrapping the fields part in the string now. + val sparkSchema = + "StructType(Seq(StructField(a,BooleanType,false),StructField(b,IntegerType,false)))" + + // The Parquet schema is intentionally made different from the Spark schema. Because the new + // Parquet data source simply falls back to the Parquet schema once it fails to parse the Spark + // schema. By making these two different, we are able to assert the old style case class string + // is parsed successfully. + val parquetSchema = MessageTypeParser.parseMessageType( + """message root { + | required int32 c; + |} + """.stripMargin) + + withTempPath { location => + val extraMetadata = Map(RowReadSupport.SPARK_METADATA_KEY -> sparkSchema.toString) + val fileMetadata = new FileMetaData(parquetSchema, extraMetadata, "Spark") + val path = new Path(location.getCanonicalPath) + + ParquetFileWriter.writeMetadataFile( + sparkContext.hadoopConfiguration, + path, + new Footer(path, new ParquetMetadata(fileMetadata, Nil)) :: Nil) + + assertResult(parquetFile(path.toString).schema) { + StructType( + StructField("a", BooleanType, nullable = false) :: + StructField("b", IntegerType, nullable = false) :: + Nil) + } + } + } +} + +class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { + val originalConf = sqlContext.conf.parquetUseDataSourceApi + + override protected def beforeAll(): Unit = { + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + } + + override protected def afterAll(): Unit = { + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } + + test("SPARK-6330 regression test") { + // In 1.3.0, save to fs other than file: without configuring core-site.xml would get: + // IllegalArgumentException: Wrong FS: hdfs://..., expected: file:/// + intercept[java.io.FileNotFoundException] { + sqlContext.parquetFile("file:///nonexistent") + } + val errorMessage = intercept[Throwable] { + sqlContext.parquetFile("hdfs://nonexistent") + }.toString + assert(errorMessage.contains("UnknownHostException")) + } +} + +class ParquetDataSourceOffIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { + val originalConf = sqlContext.conf.parquetUseDataSourceApi + + override protected def beforeAll(): Unit = { + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + } + + override protected def afterAll(): Unit = { + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala new file mode 100644 index 000000000000..b7561ce7298c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -0,0 +1,343 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.parquet + +import scala.collection.mutable.ArrayBuffer + +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.parquet.ParquetRelation2._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{QueryTest, Row, SQLContext} + +// The data where the partitioning key exists only in the directory structure. +case class ParquetData(intField: Int, stringField: String) + +// The data that also includes the partitioning key +case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) + +class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { + override val sqlContext: SQLContext = TestSQLContext + + import sqlContext._ + import sqlContext.implicits._ + + val defaultPartitionName = "__NULL__" + + test("column type inference") { + def check(raw: String, literal: Literal): Unit = { + assert(inferPartitionColumnValue(raw, defaultPartitionName) === literal) + } + + check("10", Literal.create(10, IntegerType)) + check("1000000000000000", Literal.create(1000000000000000L, LongType)) + check("1.5", Literal.create(1.5, FloatType)) + check("hello", Literal.create("hello", StringType)) + check(defaultPartitionName, Literal.create(null, NullType)) + } + + test("parse partition") { + def check(path: String, expected: PartitionValues): Unit = { + assert(expected === parsePartition(new Path(path), defaultPartitionName)) + } + + def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = { + val message = intercept[T] { + parsePartition(new Path(path), defaultPartitionName) + }.getMessage + + assert(message.contains(expected)) + } + + check( + "file:///", + PartitionValues( + ArrayBuffer.empty[String], + ArrayBuffer.empty[Literal])) + + check( + "file://path/a=10", + PartitionValues( + ArrayBuffer("a"), + ArrayBuffer(Literal.create(10, IntegerType)))) + + check( + "file://path/a=10/b=hello/c=1.5", + PartitionValues( + ArrayBuffer("a", "b", "c"), + ArrayBuffer( + Literal.create(10, IntegerType), + Literal.create("hello", StringType), + Literal.create(1.5, FloatType)))) + + check( + "file://path/a=10/b_hello/c=1.5", + PartitionValues( + ArrayBuffer("c"), + ArrayBuffer(Literal.create(1.5, FloatType)))) + + checkThrows[AssertionError]("file://path/=10", "Empty partition column name") + checkThrows[AssertionError]("file://path/a=", "Empty partition column value") + } + + test("parse partitions") { + def check(paths: Seq[String], spec: PartitionSpec): Unit = { + assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName) === spec) + } + + check(Seq( + "hdfs://host:9000/path/a=10/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", IntegerType), + StructField("b", StringType))), + Seq(Partition(Row(10, "hello"), "hdfs://host:9000/path/a=10/b=hello")))) + + check(Seq( + "hdfs://host:9000/path/a=10/b=20", + "hdfs://host:9000/path/a=10.5/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", FloatType), + StructField("b", StringType))), + Seq( + Partition(Row(10, "20"), "hdfs://host:9000/path/a=10/b=20"), + Partition(Row(10.5, "hello"), "hdfs://host:9000/path/a=10.5/b=hello")))) + + check(Seq( + s"hdfs://host:9000/path/a=10/b=20", + s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", IntegerType), + StructField("b", StringType))), + Seq( + Partition(Row(10, "20"), s"hdfs://host:9000/path/a=10/b=20"), + Partition(Row(null, "hello"), s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello")))) + + check(Seq( + s"hdfs://host:9000/path/a=10/b=$defaultPartitionName", + s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName"), + PartitionSpec( + StructType(Seq( + StructField("a", FloatType), + StructField("b", StringType))), + Seq( + Partition(Row(10, null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), + Partition(Row(10.5, null), s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName")))) + } + + test("read partitioned table - normal case") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } { + makeParquetFile( + (1 to 10).map(i => ParquetData(i, i.toString)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + parquetFile(base.getCanonicalPath).registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } yield Row(i, i.toString, pi, ps)) + + checkAnswer( + sql("SELECT intField, pi FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + _ <- Seq("foo", "bar") + } yield Row(i, pi)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi = 1"), + for { + i <- 1 to 10 + ps <- Seq("foo", "bar") + } yield Row(i, i.toString, 1, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps = 'foo'"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, i.toString, pi, "foo")) + } + } + } + + test("read partitioned table - partition key included in Parquet file") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } { + makeParquetFile( + (1 to 10).map(i => ParquetDataWithKey(i, pi, i.toString, ps)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + parquetFile(base.getCanonicalPath).registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } yield Row(i, pi, i.toString, ps)) + + checkAnswer( + sql("SELECT intField, pi FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + _ <- Seq("foo", "bar") + } yield Row(i, pi)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi = 1"), + for { + i <- 1 to 10 + ps <- Seq("foo", "bar") + } yield Row(i, 1, i.toString, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps = 'foo'"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, pi, i.toString, "foo")) + } + } + } + + test("read partitioned table - with nulls") { + withTempDir { base => + for { + // Must be `Integer` rather than `Int` here. `null.asInstanceOf[Int]` results in a zero... + pi <- Seq(1, null.asInstanceOf[Integer]) + ps <- Seq("foo", null.asInstanceOf[String]) + } { + makeParquetFile( + (1 to 10).map(i => ParquetData(i, i.toString)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + val parquetRelation = load( + "org.apache.spark.sql.parquet", + Map( + "path" -> base.getCanonicalPath, + ParquetRelation2.DEFAULT_PARTITION_NAME -> defaultPartitionName)) + + parquetRelation.registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, null.asInstanceOf[Integer]) + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, i.toString, pi, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi IS NULL"), + for { + i <- 1 to 10 + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, i.toString, null, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps IS NULL"), + for { + i <- 1 to 10 + pi <- Seq(1, null.asInstanceOf[Integer]) + } yield Row(i, i.toString, pi, null)) + } + } + } + + test("read partitioned table - with nulls and partition keys are included in Parquet file") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", null.asInstanceOf[String]) + } { + makeParquetFile( + (1 to 10).map(i => ParquetDataWithKey(i, pi, i.toString, ps)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + val parquetRelation = load( + "org.apache.spark.sql.parquet", + Map( + "path" -> base.getCanonicalPath, + ParquetRelation2.DEFAULT_PARTITION_NAME -> defaultPartitionName)) + + parquetRelation.registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, pi, i.toString, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps IS NULL"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, pi, i.toString, null)) + } + } + } + + test("read partitioned table - merging compatible schemas") { + withTempDir { base => + makeParquetFile( + (1 to 10).map(i => Tuple1(i)).toDF("intField"), + makePartitionDir(base, defaultPartitionName, "pi" -> 1)) + + makeParquetFile( + (1 to 10).map(i => (i, i.toString)).toDF("intField", "stringField"), + makePartitionDir(base, defaultPartitionName, "pi" -> 2)) + + load(base.getCanonicalPath, "org.apache.spark.sql.parquet").registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + (1 to 10).map(i => Row(i, null, 1)) ++ (1 to 10).map(i => Row(i, i.toString, 2))) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 1263ff818ea1..b98ba09ccfc2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.parquet -import org.apache.spark.sql.QueryTest +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.{SQLConf, QueryTest} import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ @@ -25,21 +27,34 @@ import org.apache.spark.sql.test.TestSQLContext._ /** * A test suite that tests various Parquet queries. */ -class ParquetQuerySuite extends QueryTest with ParquetTest { +class ParquetQuerySuiteBase extends QueryTest with ParquetTest { val sqlContext = TestSQLContext - test("simple projection") { + test("simple select queries") { withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { - checkAnswer(sql("SELECT _1 FROM t"), (0 until 10).map(Row.apply(_))) + checkAnswer(sql("SELECT _1 FROM t where t._1 > 5"), (6 until 10).map(Row.apply(_))) + checkAnswer(sql("SELECT _1 FROM t as tmp where tmp._1 < 5"), (0 until 5).map(Row.apply(_))) } } test("appending") { val data = (0 until 10).map(i => (i, i.toString)) + createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { - sql("INSERT INTO t SELECT * FROM t") + sql("INSERT INTO TABLE t SELECT * FROM tmp") checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) } + catalog.unregisterTable(Seq("tmp")) + } + + test("overwriting") { + val data = (0 until 10).map(i => (i, i.toString)) + createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + withParquetTable(data, "t") { + sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") + checkAnswer(table("t"), data.map(Row.fromTuple)) + } + catalog.unregisterTable(Seq("tmp")) } test("self-join") { @@ -53,8 +68,8 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") val queryOutput = selfJoin.queryExecution.analyzed.output - assertResult(4, s"Field count mismatches")(queryOutput.size) - assertResult(2, s"Duplicated expression ID in query plan:\n $selfJoin") { + assertResult(4, "Field count mismatches")(queryOutput.size) + assertResult(2, "Duplicated expression ID in query plan:\n $selfJoin") { queryOutput.filter(_.name == "_1").map(_.exprId).size } @@ -63,7 +78,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { } test("nested data - struct with array field") { - val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) + val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) withParquetTable(data, "t") { checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { case Tuple1((_, Seq(string))) => Row(string) @@ -72,7 +87,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { } test("nested data - array of struct") { - val data = (1 to 10).map(i => Tuple1(Seq(i -> s"val_$i"))) + val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) withParquetTable(data, "t") { checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { case Tuple1(Seq((_, string))) => Row(string) @@ -82,7 +97,42 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { withParquetTable((1 to 10).map(Tuple1.apply), "t") { - checkAnswer(sql(s"SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) + checkAnswer(sql("SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) } } + + test("SPARK-5309 strings stored using dictionary compression in parquet") { + withParquetTable((0 until 1000).map(i => ("same", "run_" + i /100, 1)), "t") { + + checkAnswer(sql("SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"), + (0 until 10).map(i => Row("same", "run_" + i, 100))) + + checkAnswer(sql("SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"), + List(Row("same", "run_5", 100))) + } + } +} + +class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { + val originalConf = sqlContext.conf.parquetUseDataSourceApi + + override protected def beforeAll(): Unit = { + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + } + + override protected def afterAll(): Unit = { + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } +} + +class ParquetDataSourceOffQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { + val originalConf = sqlContext.conf.parquetUseDataSourceApi + + override protected def beforeAll(): Unit = { + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + } + + override protected def afterAll(): Unit = { + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index 64274950b868..c964b6d98455 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -25,6 +25,7 @@ import parquet.schema.MessageTypeParser import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types._ class ParquetSchemaSuite extends FunSuite with ParquetTest { val sqlContext = TestSQLContext @@ -33,9 +34,10 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { * Checks whether the reflected Parquet message type for product type `T` conforms `messageType`. */ private def testSchema[T <: Product: ClassTag: TypeTag]( - testName: String, messageType: String): Unit = { + testName: String, messageType: String, isThriftDerived: Boolean = false): Unit = { test(testName) { - val actual = ParquetTypesConverter.convertFromAttributes(ScalaReflection.attributesFor[T]) + val actual = ParquetTypesConverter.convertFromAttributes( + ScalaReflection.attributesFor[T], isThriftDerived) val expected = MessageTypeParser.parseMessageType(messageType) actual.checkContains(expected) expected.checkContains(actual) @@ -55,7 +57,7 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { |} """.stripMargin) - testSchema[(Byte, Short, Int, Long)]( + testSchema[(Byte, Short, Int, Long, java.sql.Date)]( "logical integral types", """ |message root { @@ -63,6 +65,7 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { | required int32 _2 (INT_16); | required int32 _3 (INT_32); | required int64 _4 (INT_64); + | optional int32 _5 (DATE); |} """.stripMargin) @@ -146,6 +149,29 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { |} """.stripMargin) + // Test for SPARK-4520 -- ensure that thrift generated parquet schema is generated + // as expected from attributes + testSchema[(Array[Byte], Array[Byte], Array[Byte], Seq[Int], Map[Array[Byte], Seq[Int]])]( + "thrift generated parquet schema", + """ + |message root { + | optional binary _1 (UTF8); + | optional binary _2 (UTF8); + | optional binary _3 (UTF8); + | optional group _4 (LIST) { + | repeated int32 _4_tuple; + | } + | optional group _5 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required binary key (UTF8); + | optional group value (LIST) { + | repeated int32 value_tuple; + | } + | } + | } + |} + """.stripMargin, isThriftDerived = true) + test("DataType string parser compatibility") { // This is the generated string from previous versions of the Spark SQL, using the following: // val schema = StructType(List( @@ -154,10 +180,12 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { val caseClassString = "StructType(List(StructField(c1,IntegerType,false), StructField(c2,BinaryType,true)))" + // scalastyle:off val jsonString = """ |{"type":"struct","fields":[{"name":"c1","type":"integer","nullable":false,"metadata":{}},{"name":"c2","type":"binary","nullable":true,"metadata":{}}]} """.stripMargin + // scalastyle:on val fromCaseClassString = ParquetTypesConverter.convertFromString(caseClassString) val fromJson = ParquetTypesConverter.convertFromString(jsonString) @@ -168,4 +196,86 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { assert(a.nullable === b.nullable) } } + + test("merge with metastore schema") { + // Field type conflict resolution + assertResult( + StructType(Seq( + StructField("lowerCase", StringType), + StructField("UPPERCase", DoubleType, nullable = false)))) { + + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("lowercase", StringType), + StructField("uppercase", DoubleType, nullable = false))), + + StructType(Seq( + StructField("lowerCase", BinaryType), + StructField("UPPERCase", IntegerType, nullable = true)))) + } + + // MetaStore schema is subset of parquet schema + assertResult( + StructType(Seq( + StructField("UPPERCase", DoubleType, nullable = false)))) { + + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("uppercase", DoubleType, nullable = false))), + + StructType(Seq( + StructField("lowerCase", BinaryType), + StructField("UPPERCase", IntegerType, nullable = true)))) + } + + // Metastore schema contains additional non-nullable fields. + assert(intercept[Throwable] { + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("uppercase", DoubleType, nullable = false), + StructField("lowerCase", BinaryType, nullable = false))), + + StructType(Seq( + StructField("UPPERCase", IntegerType, nullable = true)))) + }.getMessage.contains("detected conflicting schemas")) + + // Conflicting non-nullable field names + intercept[Throwable] { + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq(StructField("lower", StringType, nullable = false))), + StructType(Seq(StructField("lowerCase", BinaryType)))) + } + } + + test("merge missing nullable fields from Metastore schema") { + // Standard case: Metastore schema contains additional nullable fields not present + // in the Parquet file schema. + assertResult( + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = true)))) { + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("firstfield", StringType, nullable = true), + StructField("secondfield", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = true))), + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true)))) + } + + // Merge should fail if the Metastore contains any additional fields that are not + // nullable. + assert(intercept[Throwable] { + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("firstfield", StringType, nullable = true), + StructField("secondfield", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = false))), + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true)))) + }.getMessage.contains("detected conflicting schemas")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala new file mode 100644 index 000000000000..20a23b3bd6aa --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.io.{IOException, File} + +import org.apache.spark.sql.AnalysisException +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.util.Utils + +class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { + + import caseInsensisitiveContext._ + + var path: File = null + + override def beforeAll(): Unit = { + path = Utils.createTempDir() + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) + jsonRDD(rdd).registerTempTable("jt") + } + + override def afterAll(): Unit = { + dropTempTable("jt") + } + + after { + Utils.deleteRecursively(path) + } + + test("CREATE TEMPORARY TABLE AS SELECT") { + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT a, b FROM jt + """.stripMargin) + + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + sql("SELECT a, b FROM jt").collect()) + + dropTempTable("jsonTable") + } + + test("CREATE TEMPORARY TABLE AS SELECT based on the file without write permission") { + val childPath = new File(path.toString, "child") + path.mkdir() + childPath.createNewFile() + path.setWritable(false) + + val e = intercept[IOException] { + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT a, b FROM jt + """.stripMargin) + sql("SELECT a, b FROM jsonTable").collect() + } + assert(e.getMessage().contains("Unable to clear output directory")) + + path.setWritable(true) + } + + test("create a table, drop it and create another one with the same name") { + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT a, b FROM jt + """.stripMargin) + + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + sql("SELECT a, b FROM jt").collect()) + + val message = intercept[DDLException]{ + sql( + s""" + |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT a * 4 FROM jt + """.stripMargin) + }.getMessage + assert( + message.contains(s"a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause."), + "CREATE TEMPORARY TABLE IF NOT EXISTS should not be allowed.") + + // Overwrite the temporary table. + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT a * 4 FROM jt + """.stripMargin) + checkAnswer( + sql("SELECT * FROM jsonTable"), + sql("SELECT a * 4 FROM jt").collect()) + + dropTempTable("jsonTable") + // Explicitly delete the data. + if (path.exists()) Utils.deleteRecursively(path) + + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT b FROM jt + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + sql("SELECT b FROM jt").collect()) + + dropTempTable("jsonTable") + } + + test("CREATE TEMPORARY TABLE AS SELECT with IF NOT EXISTS is not allowed") { + val message = intercept[DDLException]{ + sql( + s""" + |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT b FROM jt + """.stripMargin) + }.getMessage + assert( + message.contains("a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause."), + "CREATE TEMPORARY TABLE IF NOT EXISTS should not be allowed.") + } + + test("a CTAS statement with column definitions is not allowed") { + intercept[DDLException]{ + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable (a int, b string) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT a, b FROM jt + """.stripMargin) + } + } + + test("it is not allowed to write to a table while querying it.") { + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT a, b FROM jt + """.stripMargin) + + val message = intercept[AnalysisException] { + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT a, b FROM jsonTable + """.stripMargin) + }.getMessage + assert( + message.contains("Cannot overwrite table "), + "Writing to a table while querying it should not be allowed.") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala new file mode 100644 index 000000000000..ca25751b9583 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -0,0 +1,102 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql._ +import org.apache.spark.sql.types._ + +class DDLScanSource extends RelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + SimpleDDLScan(parameters("from").toInt, parameters("TO").toInt, parameters("Table"))(sqlContext) + } +} + +case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlContext: SQLContext) + extends BaseRelation with TableScan { + + override def schema: StructType = + StructType(Seq( + StructField("intType", IntegerType, nullable = false, + new MetadataBuilder().putString("comment", s"test comment $table").build()), + StructField("stringType", StringType, nullable = false), + StructField("dateType", DateType, nullable = false), + StructField("timestampType", TimestampType, nullable = false), + StructField("doubleType", DoubleType, nullable = false), + StructField("bigintType", LongType, nullable = false), + StructField("tinyintType", ByteType, nullable = false), + StructField("decimalType", DecimalType.Unlimited, nullable = false), + StructField("fixedDecimalType", DecimalType(5,1), nullable = false), + StructField("binaryType", BinaryType, nullable = false), + StructField("booleanType", BooleanType, nullable = false), + StructField("smallIntType", ShortType, nullable = false), + StructField("floatType", FloatType, nullable = false), + StructField("mapType", MapType(StringType, StringType)), + StructField("arrayType", ArrayType(StringType)), + StructField("structType", + StructType(StructField("f1",StringType) :: + (StructField("f2",IntegerType)) :: Nil + ) + ) + )) + + + override def buildScan(): RDD[Row] = { + sqlContext.sparkContext.parallelize(from to to).map(e => Row(s"people$e", e * 2)) + } +} + +class DDLTestSuite extends DataSourceTest { + import caseInsensisitiveContext._ + + before { + sql( + """ + |CREATE TEMPORARY TABLE ddlPeople + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + } + + sqlTest( + "describe ddlPeople", + Seq( + Row("intType", "int", "test comment test1"), + Row("stringType", "string", ""), + Row("dateType", "date", ""), + Row("timestampType", "timestamp", ""), + Row("doubleType", "double", ""), + Row("bigintType", "bigint", ""), + Row("tinyintType", "tinyint", ""), + Row("decimalType", "decimal(10,0)", ""), + Row("fixedDecimalType", "decimal(5,1)", ""), + Row("binaryType", "binary", ""), + Row("booleanType", "boolean", ""), + Row("smallIntType", "smallint", ""), + Row("floatType", "float", ""), + Row("mapType", "map", ""), + Row("arrayType", "array", ""), + Row("structType", "struct", "") + )) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 9626252e742e..33c67355967d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -28,7 +28,15 @@ abstract class DataSourceTest extends QueryTest with BeforeAndAfter { implicit val caseInsensisitiveContext = new SQLContext(TestSQLContext.sparkContext) { @transient override protected[sql] lazy val analyzer: Analyzer = - new Analyzer(catalog, functionRegistry, caseSensitive = false) + new Analyzer(catalog, functionRegistry, caseSensitive = false) { + override val extendedResolutionRules = + PreInsertCastAndRename :: + Nil + + override val extendedCheckRules = Seq( + sources.PreWriteCheck(catalog) + ) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 390538d35a34..cb5e5147ff18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.sources import scala.language.existentials +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.types._ @@ -32,31 +33,57 @@ class FilteredScanSource extends RelationProvider { } case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) - extends PrunedFilteredScan { + extends BaseRelation + with PrunedFilteredScan { - override def schema = + override def schema: StructType = StructType( StructField("a", IntegerType, nullable = false) :: - StructField("b", IntegerType, nullable = false) :: Nil) + StructField("b", IntegerType, nullable = false) :: + StructField("c", StringType, nullable = false) :: Nil) - override def buildScan(requiredColumns: Array[String], filters: Array[Filter]) = { + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { val rowBuilders = requiredColumns.map { case "a" => (i: Int) => Seq(i) case "b" => (i: Int) => Seq(i * 2) + case "c" => (i: Int) => + val c = (i - 1 + 'a').toChar.toString + Seq(c * 5 + c.toUpperCase() * 5) } FiltersPushed.list = filters - val filterFunctions = filters.collect { + // Predicate test on integer column + def translateFilterOnA(filter: Filter): Int => Boolean = filter match { case EqualTo("a", v) => (a: Int) => a == v case LessThan("a", v: Int) => (a: Int) => a < v case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v case GreaterThan("a", v: Int) => (a: Int) => a > v case GreaterThanOrEqual("a", v: Int) => (a: Int) => a >= v case In("a", values) => (a: Int) => values.map(_.asInstanceOf[Int]).toSet.contains(a) + case IsNull("a") => (a: Int) => false // Int can't be null + case IsNotNull("a") => (a: Int) => true + case Not(pred) => (a: Int) => !translateFilterOnA(pred)(a) + case And(left, right) => (a: Int) => + translateFilterOnA(left)(a) && translateFilterOnA(right)(a) + case Or(left, right) => (a: Int) => + translateFilterOnA(left)(a) || translateFilterOnA(right)(a) + case _ => (a: Int) => true } - def eval(a: Int) = !filterFunctions.map(_(a)).contains(false) + // Predicate test on string column + def translateFilterOnC(filter: Filter): String => Boolean = filter match { + case StringStartsWith("c", v) => _.startsWith(v) + case StringEndsWith("c", v) => _.endsWith(v) + case StringContains("c", v) => _.contains(v) + case _ => (c: String) => true + } + + def eval(a: Int) = { + val c = (a - 1 + 'a').toChar.toString * 5 + (a - 1 + 'a').toChar.toString.toUpperCase() * 5 + !filters.map(translateFilterOnA(_)(a)).contains(false) && + !filters.map(translateFilterOnC(_)(c)).contains(false) + } sqlContext.sparkContext.parallelize(from to to).filter(eval).map(i => Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty))) @@ -86,7 +113,8 @@ class FilteredScanSuite extends DataSourceTest { sqlTest( "SELECT * FROM oneToTenFiltered", - (1 to 10).map(i => Row(i, i * 2)).toSeq) + (1 to 10).map(i => Row(i, i * 2, (i - 1 + 'a').toChar.toString * 5 + + (i - 1 + 'a').toChar.toString.toUpperCase() * 5)).toSeq) sqlTest( "SELECT a, b FROM oneToTenFiltered", @@ -121,20 +149,52 @@ class FilteredScanSuite extends DataSourceTest { (2 to 10 by 2).map(i => Row(i, i)).toSeq) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE a = 1", - Seq(1).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b FROM oneToTenFiltered WHERE a = 1", + Seq(1).map(i => Row(i, i * 2))) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", - Seq(1,3,5).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b FROM oneToTenFiltered WHERE a IN (1,3,5)", + Seq(1,3,5).map(i => Row(i, i * 2))) + + sqlTest( + "SELECT a, b FROM oneToTenFiltered WHERE A = 1", + Seq(1).map(i => Row(i, i * 2))) + + sqlTest( + "SELECT a, b FROM oneToTenFiltered WHERE b = 2", + Seq(1).map(i => Row(i, i * 2))) + + sqlTest( + "SELECT a, b FROM oneToTenFiltered WHERE a IS NULL", + Seq.empty[Row]) + + sqlTest( + "SELECT a, b FROM oneToTenFiltered WHERE a IS NOT NULL", + (1 to 10).map(i => Row(i, i * 2)).toSeq) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE A = 1", - Seq(1).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b FROM oneToTenFiltered WHERE a < 5 AND a > 1", + (2 to 4).map(i => Row(i, i * 2)).toSeq) sqlTest( - "SELECT * FROM oneToTenFiltered WHERE b = 2", - Seq(1).map(i => Row(i, i * 2)).toSeq) + "SELECT a, b FROM oneToTenFiltered WHERE a < 3 OR a > 8", + Seq(1, 2, 9, 10).map(i => Row(i, i * 2))) + + sqlTest( + "SELECT a, b FROM oneToTenFiltered WHERE NOT (a < 6)", + (6 to 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT a, b, c FROM oneToTenFiltered WHERE c like 'c%'", + Seq(Row(3, 3 * 2, "c" * 5 + "C" * 5))) + + sqlTest( + "SELECT a, b, c FROM oneToTenFiltered WHERE c like '%D'", + Seq(Row(4, 4 * 2, "d" * 5 + "D" * 5))) + + sqlTest( + "SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", + Seq(Row(5, 5 * 2, "e" * 5 + "E" * 5))) testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1) testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1) @@ -162,6 +222,19 @@ class FilteredScanSuite extends DataSourceTest { testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0) testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 5 AND a > 1", 3) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8", 4) + testPushDown("SELECT * FROM oneToTenFiltered WHERE NOT (a < 6)", 5) + + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'c%'", 1) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'C%'", 0) + + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%D'", 1) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%d'", 0) + + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", 1) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%Ee%'", 0) + def testPushDown(sqlString: String, expectedCount: Int): Unit = { test(s"PushDown Returns $expectedCount: $sqlString") { val queryExecution = sql(sqlString).queryExecution diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala new file mode 100644 index 000000000000..80efe9728fbc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.io.File + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.util.Utils + +class InsertSuite extends DataSourceTest with BeforeAndAfterAll { + + import caseInsensisitiveContext._ + + var path: File = null + + override def beforeAll: Unit = { + path = Utils.createTempDir() + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) + jsonRDD(rdd).registerTempTable("jt") + sql( + s""" + |CREATE TEMPORARY TABLE jsonTable (a int, b string) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${path.toString}' + |) + """.stripMargin) + } + + override def afterAll: Unit = { + dropTempTable("jsonTable") + dropTempTable("jt") + Utils.deleteRecursively(path) + } + + test("Simple INSERT OVERWRITE a JSONRelation") { + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) + + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i, s"str$i")) + ) + } + + test("PreInsert casting and renaming") { + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a * 2, a * 4 FROM jt + """.stripMargin) + + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i * 2, s"${i * 4}")) + ) + + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a * 4 AS A, a * 6 as c FROM jt + """.stripMargin) + + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i * 4, s"${i * 6}")) + ) + } + + test("SELECT clause generating a different number of columns is not allowed.") { + val message = intercept[RuntimeException] { + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a FROM jt + """.stripMargin) + }.getMessage + assert( + message.contains("generates the same number of columns as its schema"), + "SELECT clause generating a different number of columns should not be not allowed." + ) + } + + test("INSERT OVERWRITE a JSONRelation multiple times") { + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) + + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) + + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) + + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i, s"str$i")) + ) + } + + test("INSERT INTO not supported for JSONRelation for now") { + intercept[RuntimeException]{ + sql( + s""" + |INSERT INTO TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) + } + } + + test("it is not allowed to write to a table while querying it.") { + val message = intercept[AnalysisException] { + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jsonTable + """.stripMargin) + }.getMessage + assert( + message.contains("Cannot insert overwrite into table that is also being read from."), + "INSERT OVERWRITE to a table while querying it should not be allowed.") + } + + test("Caching") { + // Cached Query Execution + cacheTable("jsonTable") + assertCached(sql("SELECT * FROM jsonTable")) + checkAnswer( + sql("SELECT * FROM jsonTable"), + (1 to 10).map(i => Row(i, s"str$i"))) + + assertCached(sql("SELECT a FROM jsonTable")) + checkAnswer( + sql("SELECT a FROM jsonTable"), + (1 to 10).map(Row(_)).toSeq) + + assertCached(sql("SELECT a FROM jsonTable WHERE a < 5")) + checkAnswer( + sql("SELECT a FROM jsonTable WHERE a < 5"), + (1 to 4).map(Row(_)).toSeq) + + assertCached(sql("SELECT a * 2 FROM jsonTable")) + checkAnswer( + sql("SELECT a * 2 FROM jsonTable"), + (1 to 10).map(i => Row(i * 2)).toSeq) + + assertCached(sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2) + checkAnswer( + sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), + (2 to 10).map(i => Row(i, i - 1)).toSeq) + + // Insert overwrite and keep the same schema. + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a * 2, b FROM jt + """.stripMargin) + // jsonTable should be recached. + assertCached(sql("SELECT * FROM jsonTable")) + // The cached data is the new data. + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + sql("SELECT a * 2, b FROM jt").collect()) + + // Verify uncaching + uncacheTable("jsonTable") + assertCached(sql("SELECT * FROM jsonTable"), 0) + } + + test("it's not allowed to insert into a relation that is not an InsertableRelation") { + sql( + """ + |CREATE TEMPORARY TABLE oneToTen + |USING org.apache.spark.sql.sources.SimpleScanSource + |OPTIONS ( + | From '1', + | To '10' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM oneToTen"), + (1 to 10).map(Row(_)).toSeq + ) + + val message = intercept[AnalysisException] { + sql( + s""" + |INSERT OVERWRITE TABLE oneToTen SELECT CAST(a AS INT) FROM jt + """.stripMargin) + }.getMessage + assert( + message.contains("does not allow insertion."), + "It is not allowed to insert into a table that is not an InsertableRelation." + ) + + dropTempTable("oneToTen") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index 7900b3e8948d..6a1ddf2f8e98 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.sources +import scala.language.existentials + +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.types._ @@ -29,14 +32,15 @@ class PrunedScanSource extends RelationProvider { } case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) - extends PrunedScan { + extends BaseRelation + with PrunedScan { - override def schema = + override def schema: StructType = StructType( StructField("a", IntegerType, nullable = false) :: StructField("b", IntegerType, nullable = false) :: Nil) - override def buildScan(requiredColumns: Array[String]) = { + override def buildScan(requiredColumns: Array[String]): RDD[Row] = { val rowBuilders = requiredColumns.map { case "a" => (i: Int) => Seq(i) case "b" => (i: Int) => Seq(i * 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala new file mode 100644 index 000000000000..8331a14c9295 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -0,0 +1,34 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.scalatest.FunSuite + +class ResolvedDataSourceSuite extends FunSuite { + + test("builtin sources") { + assert(ResolvedDataSource.lookupDataSource("jdbc") === + classOf[org.apache.spark.sql.jdbc.DefaultSource]) + + assert(ResolvedDataSource.lookupDataSource("json") === + classOf[org.apache.spark.sql.json.DefaultSource]) + + assert(ResolvedDataSource.lookupDataSource("parquet") === + classOf[org.apache.spark.sql.parquet.DefaultSource]) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala new file mode 100644 index 000000000000..cb287ba85c1f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.io.File + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.{SaveMode, SQLConf, DataFrame} +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { + + import caseInsensisitiveContext._ + + var originalDefaultSource: String = null + + var path: File = null + + var df: DataFrame = null + + override def beforeAll(): Unit = { + originalDefaultSource = conf.defaultDataSourceName + + path = Utils.createTempDir() + path.delete() + + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) + df = jsonRDD(rdd) + df.registerTempTable("jsonTable") + } + + override def afterAll(): Unit = { + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + } + + after { + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + Utils.deleteRecursively(path) + } + + def checkLoad(): Unit = { + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + checkAnswer(load(path.toString), df.collect()) + + // Test if we can pick up the data source name passed in load. + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + checkAnswer(load(path.toString, "org.apache.spark.sql.json"), df.collect()) + checkAnswer(load("org.apache.spark.sql.json", Map("path" -> path.toString)), df.collect()) + val schema = StructType(StructField("b", StringType, true) :: Nil) + checkAnswer( + load("org.apache.spark.sql.json", schema, Map("path" -> path.toString)), + sql("SELECT b FROM jsonTable").collect()) + } + + test("save with path and load") { + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + df.save(path.toString) + checkLoad() + } + + test("save with path and datasource, and load") { + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + df.save(path.toString, "org.apache.spark.sql.json") + checkLoad() + } + + test("save with data source and options, and load") { + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, Map("path" -> path.toString)) + checkLoad() + } + + test("save and save again") { + df.save(path.toString, "org.apache.spark.sql.json") + + var message = intercept[RuntimeException] { + df.save(path.toString, "org.apache.spark.sql.json") + }.getMessage + + assert( + message.contains("already exists"), + "We should complain that the path already exists.") + + if (path.exists()) Utils.deleteRecursively(path) + + df.save(path.toString, "org.apache.spark.sql.json") + checkLoad() + + df.save("org.apache.spark.sql.json", SaveMode.Overwrite, Map("path" -> path.toString)) + checkLoad() + + message = intercept[RuntimeException] { + df.save("org.apache.spark.sql.json", SaveMode.Append, Map("path" -> path.toString)) + }.getMessage + + assert( + message.contains("Append mode is not supported"), + "We should complain that 'Append mode is not supported' for JSON source.") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index b1e0919b7aed..3b47b8adf313 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.sources import java.sql.{Timestamp, Date} +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.types._ @@ -33,12 +34,12 @@ class SimpleScanSource extends RelationProvider { } case class SimpleScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) - extends TableScan { + extends BaseRelation with TableScan { - override def schema = + override def schema: StructType = StructType(StructField("i", IntegerType, nullable = false) :: Nil) - override def buildScan() = sqlContext.sparkContext.parallelize(from to to).map(Row(_)) + override def buildScan(): RDD[Row] = sqlContext.sparkContext.parallelize(from to to).map(Row(_)) } class AllDataTypesScanSource extends SchemaRelationProvider { @@ -51,14 +52,15 @@ class AllDataTypesScanSource extends SchemaRelationProvider { } case class AllDataTypesScan( - from: Int, - to: Int, - userSpecifiedSchema: StructType)(@transient val sqlContext: SQLContext) - extends TableScan { + from: Int, + to: Int, + userSpecifiedSchema: StructType)(@transient val sqlContext: SQLContext) + extends BaseRelation + with TableScan { - override def schema = userSpecifiedSchema + override def schema: StructType = userSpecifiedSchema - override def buildScan() = { + override def buildScan(): RDD[Row] = { sqlContext.sparkContext.parallelize(from to to).map { i => Row( s"str_$i", @@ -72,7 +74,7 @@ case class AllDataTypesScan( i.toDouble, new java.math.BigDecimal(i), new java.math.BigDecimal(i), - new Date((i + 1) * 8640000), + new Date(1970, 1, 1), new Timestamp(20000 + i), s"varchar_$i", Seq(i, i + 1), @@ -80,7 +82,7 @@ case class AllDataTypesScan( Map(i -> i.toString), Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), Row(i, i.toString), - Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000))))) + Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date(1970, 1, i + 1))))) } } } @@ -101,7 +103,7 @@ class TableScanSuite extends DataSourceTest { i.toDouble, new java.math.BigDecimal(i), new java.math.BigDecimal(i), - new Date((i + 1) * 8640000), + new Date(1970, 1, 1), new Timestamp(20000 + i), s"varchar_$i", Seq(i, i + 1), @@ -109,7 +111,7 @@ class TableScanSuite extends DataSourceTest { Map(i -> i.toString), Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), Row(i, i.toString), - Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000))))) + Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date(1970, 1, i + 1))))) }.toSeq before { @@ -264,7 +266,7 @@ class TableScanSuite extends DataSourceTest { sqlTest( "SELECT structFieldComplex.Value.`value_(2)` FROM tableWithSchema", - (1 to 10).map(i => Row(Seq(new Date((i + 2) * 8640000)))).toSeq) + (1 to 10).map(i => Row(Seq(new Date(1970, 1, i + 1)))).toSeq) test("Caching") { // Cached Query Execution @@ -344,4 +346,24 @@ class TableScanSuite extends DataSourceTest { } assert(schemaNeeded.getMessage.contains("A schema needs to be specified when using")) } + + test("SPARK-5196 schema field with comment") { + sql( + """ + |CREATE TEMPORARY TABLE student(name string comment "SN", age int comment "SA", grade int) + |USING org.apache.spark.sql.sources.AllDataTypesScanSource + |OPTIONS ( + | from '1', + | to '10' + |) + """.stripMargin) + + val planned = sql("SELECT * FROM student").queryExecution.executedPlan + val comments = planned.schema.fields.map { field => + if (field.metadata.contains("comment")) field.metadata.getString("comment") + else "NO_COMMENT" + }.mkString(",") + + assert(comments === "SN,SA,NO_COMMENT") + } } diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 123a1f629ab1..f38c796241df 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -21,8 +21,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,10 @@ spark-hive_${scala.binary.version} ${project.version} + + com.google.guava + guava + ${hive.group} hive-cli diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 6e07df18b0e1..832596fc8bee 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -28,6 +28,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.scheduler.{SparkListenerApplicationEnd, SparkListener} +import org.apache.spark.util.Utils /** * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a @@ -57,13 +58,7 @@ object HiveThriftServer2 extends Logging { logInfo("Starting SparkContext") SparkSQLEnv.init() - Runtime.getRuntime.addShutdownHook( - new Thread() { - override def run() { - SparkSQLEnv.stop() - } - } - ) + Utils.addShutdownHook { () => SparkSQLEnv.stop() } try { val server = new HiveThriftServer2(SparkSQLEnv.hiveContext) @@ -98,16 +93,14 @@ private[hive] class HiveThriftServer2(hiveContext: HiveContext) setSuperField(this, "cliService", sparkSqlCliService) addService(sparkSqlCliService) - if (isHTTPTransportMode(hiveConf)) { - val thriftCliService = new ThriftHttpCLIService(sparkSqlCliService) - setSuperField(this, "thriftCLIService", thriftCliService) - addService(thriftCliService) + val thriftCliService = if (isHTTPTransportMode(hiveConf)) { + new ThriftHttpCLIService(sparkSqlCliService) } else { - val thriftCliService = new ThriftBinaryCLIService(sparkSqlCliService) - setSuperField(this, "thriftCLIService", thriftCliService) - addService(thriftCliService) + new ThriftBinaryCLIService(sparkSqlCliService) } + setSuperField(this, "thriftCLIService", thriftCliService) + addService(thriftCliService) initCompositeService(hiveConf) } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala old mode 100755 new mode 100644 index 7385952861ee..7e307bb4ad1e --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -23,6 +23,7 @@ import java.io._ import java.util.{ArrayList => JArrayList} import jline.{ConsoleReader, History} + import org.apache.commons.lang.StringUtils import org.apache.commons.logging.LogFactory import org.apache.hadoop.conf.Configuration @@ -32,14 +33,14 @@ import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.processors.{SetProcessor, CommandProcessor, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor, CommandProcessorFactory} import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.shims.ShimLoader import org.apache.thrift.transport.TSocket import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveShim -import org.apache.spark.sql.hive.thriftserver.HiveThriftServerShim +import org.apache.spark.util.Utils private[hive] object SparkSQLCLIDriver { private var prompt = "spark-sql" @@ -101,13 +102,7 @@ private[hive] object SparkSQLCLIDriver { SessionState.start(sessionState) // Clean up after we exit - Runtime.getRuntime.addShutdownHook( - new Thread() { - override def run() { - SparkSQLEnv.stop() - } - } - ) + Utils.addShutdownHook { () => SparkSQLEnv.stop() } // "-h" option has been passed, so connect to Hive thrift server. if (sessionState.getHost != null) { @@ -145,6 +140,9 @@ private[hive] object SparkSQLCLIDriver { case e: UnsupportedEncodingException => System.exit(3) } + // use the specified database if specified + cli.processSelectDatabase(sessionState); + // Execute -i init files (always in silent mode) cli.processInitFiles(sessionState) @@ -194,28 +192,29 @@ private[hive] object SparkSQLCLIDriver { val currentDB = ReflectionUtils.invokeStatic(classOf[CliDriver], "getFormattedDb", classOf[HiveConf] -> conf, classOf[CliSessionState] -> sessionState) - def promptWithCurrentDB = s"$prompt$currentDB" - def continuedPromptWithDBSpaces = continuedPrompt + ReflectionUtils.invokeStatic( + def promptWithCurrentDB: String = s"$prompt$currentDB" + def continuedPromptWithDBSpaces: String = continuedPrompt + ReflectionUtils.invokeStatic( classOf[CliDriver], "spacesForString", classOf[String] -> currentDB) var currentPrompt = promptWithCurrentDB var line = reader.readLine(currentPrompt + "> ") while (line != null) { - if (prefix.nonEmpty) { - prefix += '\n' - } + if (!line.startsWith("--")) { + if (prefix.nonEmpty) { + prefix += '\n' + } - if (line.trim().endsWith(";") && !line.trim().endsWith("\\;")) { - line = prefix + line - ret = cli.processLine(line, true) - prefix = "" - currentPrompt = promptWithCurrentDB - } else { - prefix = prefix + line - currentPrompt = continuedPromptWithDBSpaces + if (line.trim().endsWith(";") && !line.trim().endsWith("\\;")) { + line = prefix + line + ret = cli.processLine(line, true) + prefix = "" + currentPrompt = promptWithCurrentDB + } else { + prefix = prefix + line + currentPrompt = continuedPromptWithDBSpaces + } } - line = reader.readLine(currentPrompt + "> ") } @@ -263,7 +262,8 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { val proc: CommandProcessor = HiveShim.getCommandProcessor(Array(tokens(0)), hconf) if (proc != null) { - if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor]) { + if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor] || + proc.isInstanceOf[AddResourceProcessor]) { val driver = new SparkSQLDriver driver.init() @@ -292,9 +292,13 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { } } + var counter = 0 try { while (!out.checkError() && driver.getResults(res)) { - res.foreach(out.println) + res.foreach{ l => + counter += 1 + out.println(l) + } res.clear() } } catch { @@ -311,7 +315,11 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { ret = cret } - console.printInfo(s"Time taken: $timeTaken seconds", null) + var responseMsg = s"Time taken: $timeTaken seconds" + if (counter != 0) { + responseMsg += s", Fetched $counter row(s)" + } + console.printInfo(responseMsg , null) // Destroy the driver to release all the locks. driver.destroy() } else { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 158c22515972..97b46a01ba5b 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -22,6 +22,7 @@ import scala.collection.JavaConversions._ import org.apache.spark.scheduler.StatsReportListener import org.apache.spark.sql.hive.{HiveShim, HiveContext} import org.apache.spark.{Logging, SparkConf, SparkContext} +import org.apache.spark.util.Utils /** A singleton object for the master program. The slaves should not access this. */ private[hive] object SparkSQLEnv extends Logging { @@ -37,7 +38,7 @@ private[hive] object SparkSQLEnv extends Logging { val maybeKryoReferenceTracking = sparkConf.getOption("spark.kryo.referenceTracking") sparkConf - .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}") + .setAppName(s"SparkSQL::${Utils.localHostName()}") .set("spark.sql.hive.version", HiveShim.version) .set( "spark.serializer", diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala deleted file mode 100644 index 89e9ede7261c..000000000000 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.thriftserver - -import java.util.concurrent.Executors - -import org.apache.commons.logging.Log -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hive.service.cli.session.SessionManager - -import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ -import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager -import org.apache.hive.service.cli.SessionHandle - -private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) - extends SessionManager - with ReflectedCompositeService { - - private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) - - override def init(hiveConf: HiveConf) { - setSuperField(this, "hiveConf", hiveConf) - - val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) - setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) - getAncestorField[Log](this, 3, "LOG").info( - s"HiveServer2: Async execution pool size $backgroundPoolSize") - - setSuperField(this, "operationManager", sparkSqlOperationManager) - addService(sparkSqlOperationManager) - - initCompositeService(hiveConf) - } - - override def closeSession(sessionHandle: SessionHandle) { - super.closeSession(sessionHandle) - sparkSqlOperationManager.sessionToActivePool -= sessionHandle - } -} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 60953576d0e3..b070fa8eaa46 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -1,13 +1,12 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -26,20 +25,31 @@ import scala.concurrent.{Await, Promise} import scala.sys.process.{Process, ProcessLogger} import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.util.Utils + +class CliSuite extends FunSuite with BeforeAndAfter with Logging { + val warehousePath = Utils.createTempDir() + val metastorePath = Utils.createTempDir() + + before { + warehousePath.delete() + metastorePath.delete() + } + + after { + warehousePath.delete() + metastorePath.delete() + } -class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { def runCliWithin( timeout: FiniteDuration, extraArgs: Seq[String] = Seq.empty)( - queriesAndExpectedAnswers: (String, String)*) { + queriesAndExpectedAnswers: (String, String)*): Unit = { val (queries, expectedAnswers) = queriesAndExpectedAnswers.unzip - val warehousePath = getTempFilePath("warehouse") - val metastorePath = getTempFilePath("metastore") val cliScript = "../../bin/spark-sql".split("/").mkString(File.separator) val command = { @@ -94,8 +104,6 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { """.stripMargin, cause) throw cause } finally { - warehousePath.delete() - metastorePath.delete() process.destroy() } } @@ -121,6 +129,26 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { } test("Single command with -e") { - runCliWithin(1.minute, Seq("-e", "SHOW TABLES;"))("" -> "OK") + runCliWithin(1.minute, Seq("-e", "SHOW DATABASES;"))("" -> "OK") + } + + test("Single command with --database") { + runCliWithin(1.minute)( + "CREATE DATABASE hive_test_db;" + -> "OK", + "USE hive_test_db;" + -> "OK", + "CREATE TABLE hive_test(key INT, val STRING);" + -> "OK", + "SHOW TABLES;" + -> "Time taken: " + ) + + runCliWithin(1.minute, Seq("--database", "hive_test_db", "-e", "SHOW TABLES;"))( + "" + -> "OK", + "" + -> "hive_test" + ) } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala deleted file mode 100644 index b52a51d11e4a..000000000000 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ /dev/null @@ -1,387 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.thriftserver - -import java.io.File -import java.net.ServerSocket -import java.sql.{Date, DriverManager, Statement} -import java.util.concurrent.TimeoutException - -import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer -import scala.concurrent.duration._ -import scala.concurrent.{Await, Promise} -import scala.sys.process.{Process, ProcessLogger} -import scala.util.Try - -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hive.jdbc.HiveDriver -import org.apache.hive.service.auth.PlainSaslHelper -import org.apache.hive.service.cli.GetInfoType -import org.apache.hive.service.cli.thrift.TCLIService.Client -import org.apache.hive.service.cli.thrift._ -import org.apache.thrift.protocol.TBinaryProtocol -import org.apache.thrift.transport.TSocket -import org.scalatest.FunSuite - -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.util.getTempFilePath -import org.apache.spark.sql.hive.HiveShim - -/** - * Tests for the HiveThriftServer2 using JDBC. - * - * NOTE: SPARK_PREPEND_CLASSES is explicitly disabled in this test suite. Assembly jar must be - * rebuilt after changing HiveThriftServer2 related code. - */ -class HiveThriftServer2Suite extends FunSuite with Logging { - Class.forName(classOf[HiveDriver].getCanonicalName) - - object TestData { - def getTestDataFilePath(name: String) = { - Thread.currentThread().getContextClassLoader.getResource(s"data/files/$name") - } - - val smallKv = getTestDataFilePath("small_kv.txt") - val smallKvWithNull = getTestDataFilePath("small_kv_with_null.txt") - } - - def randomListeningPort = { - // Let the system to choose a random available port to avoid collision with other parallel - // builds. - val socket = new ServerSocket(0) - val port = socket.getLocalPort - socket.close() - port - } - - def withJdbcStatement( - serverStartTimeout: FiniteDuration = 1.minute, - httpMode: Boolean = false)( - f: Statement => Unit) { - val port = randomListeningPort - - startThriftServer(port, serverStartTimeout, httpMode) { - val jdbcUri = if (httpMode) { - s"jdbc:hive2://${"localhost"}:$port/" + - "default?hive.server2.transport.mode=http;hive.server2.thrift.http.path=cliservice" - } else { - s"jdbc:hive2://${"localhost"}:$port/" - } - - val user = System.getProperty("user.name") - val connection = DriverManager.getConnection(jdbcUri, user, "") - val statement = connection.createStatement() - - try { - f(statement) - } finally { - statement.close() - connection.close() - } - } - } - - def withCLIServiceClient( - serverStartTimeout: FiniteDuration = 1.minute)( - f: ThriftCLIServiceClient => Unit) { - val port = randomListeningPort - - startThriftServer(port) { - // Transport creation logics below mimics HiveConnection.createBinaryTransport - val rawTransport = new TSocket("localhost", port) - val user = System.getProperty("user.name") - val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport) - val protocol = new TBinaryProtocol(transport) - val client = new ThriftCLIServiceClient(new Client(protocol)) - - transport.open() - - try { - f(client) - } finally { - transport.close() - } - } - } - - def startThriftServer( - port: Int, - serverStartTimeout: FiniteDuration = 1.minute, - httpMode: Boolean = false)( - f: => Unit) { - val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator) - val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator) - - val warehousePath = getTempFilePath("warehouse") - val metastorePath = getTempFilePath("metastore") - val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" - - val command = - if (httpMode) { - s"""$startScript - | --master local - | --hiveconf hive.root.logger=INFO,console - | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri - | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath - | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost - | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=http - | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT}=$port - | --driver-class-path ${sys.props("java.class.path")} - | --conf spark.ui.enabled=false - """.stripMargin.split("\\s+").toSeq - } else { - s"""$startScript - | --master local - | --hiveconf hive.root.logger=INFO,console - | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri - | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath - | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost - | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$port - | --driver-class-path ${sys.props("java.class.path")} - | --conf spark.ui.enabled=false - """.stripMargin.split("\\s+").toSeq - } - - val serverRunning = Promise[Unit]() - val buffer = new ArrayBuffer[String]() - val LOGGING_MARK = - s"starting ${HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$")}, logging to " - var logTailingProcess: Process = null - var logFilePath: String = null - - def captureLogOutput(line: String): Unit = { - buffer += line - if (line.contains("ThriftBinaryCLIService listening on") || - line.contains("Started ThriftHttpCLIService in http")) { - serverRunning.success(()) - } - } - - def captureThriftServerOutput(source: String)(line: String): Unit = { - if (line.startsWith(LOGGING_MARK)) { - logFilePath = line.drop(LOGGING_MARK.length).trim - // Ensure that the log file is created so that the `tail' command won't fail - Try(new File(logFilePath).createNewFile()) - logTailingProcess = Process(s"/usr/bin/env tail -f $logFilePath") - .run(ProcessLogger(captureLogOutput, _ => ())) - } - } - - val env = Seq( - // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths - "SPARK_TESTING" -> "0") - - Process(command, None, env: _*).run(ProcessLogger( - captureThriftServerOutput("stdout"), - captureThriftServerOutput("stderr"))) - - try { - Await.result(serverRunning.future, serverStartTimeout) - f - } catch { - case cause: Exception => - cause match { - case _: TimeoutException => - logError(s"Failed to start Hive Thrift server within $serverStartTimeout", cause) - case _ => - } - logError( - s""" - |===================================== - |HiveThriftServer2Suite failure output - |===================================== - |HiveThriftServer2 command line: ${command.mkString(" ")} - |Binding port: $port - |System user: ${System.getProperty("user.name")} - | - |${buffer.mkString("\n")} - |========================================= - |End HiveThriftServer2Suite failure output - |========================================= - """.stripMargin, cause) - throw cause - } finally { - warehousePath.delete() - metastorePath.delete() - Process(stopScript, None, env: _*).run().exitValue() - // The `spark-daemon.sh' script uses kill, which is not synchronous, have to wait for a while. - Thread.sleep(3.seconds.toMillis) - Option(logTailingProcess).map(_.destroy()) - Option(logFilePath).map(new File(_).delete()) - } - } - - test("Test JDBC query execution") { - withJdbcStatement() { statement => - val queries = Seq( - "SET spark.sql.shuffle.partitions=3", - "DROP TABLE IF EXISTS test", - "CREATE TABLE test(key INT, val STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test", - "CACHE TABLE test") - - queries.foreach(statement.execute) - - assertResult(5, "Row count mismatch") { - val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test") - resultSet.next() - resultSet.getInt(1) - } - } - } - - test("Test JDBC query execution in Http Mode") { - withJdbcStatement(httpMode = true) { statement => - val queries = Seq( - "SET spark.sql.shuffle.partitions=3", - "DROP TABLE IF EXISTS test", - "CREATE TABLE test(key INT, val STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test", - "CACHE TABLE test") - - queries.foreach(statement.execute) - - assertResult(5, "Row count mismatch") { - val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test") - resultSet.next() - resultSet.getInt(1) - } - } - } - - test("SPARK-3004 regression: result set containing NULL") { - withJdbcStatement() { statement => - val queries = Seq( - "DROP TABLE IF EXISTS test_null", - "CREATE TABLE test_null(key INT, val STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKvWithNull}' OVERWRITE INTO TABLE test_null") - - queries.foreach(statement.execute) - - val resultSet = statement.executeQuery("SELECT * FROM test_null WHERE key IS NULL") - - (0 until 5).foreach { _ => - resultSet.next() - assert(resultSet.getInt(1) === 0) - assert(resultSet.wasNull()) - } - - assert(!resultSet.next()) - } - } - - test("GetInfo Thrift API") { - withCLIServiceClient() { client => - val user = System.getProperty("user.name") - val sessionHandle = client.openSession(user, "") - - assertResult("Spark SQL", "Wrong GetInfo(CLI_DBMS_NAME) result") { - client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_NAME).getStringValue - } - - assertResult("Spark SQL", "Wrong GetInfo(CLI_SERVER_NAME) result") { - client.getInfo(sessionHandle, GetInfoType.CLI_SERVER_NAME).getStringValue - } - - assertResult(true, "Spark version shouldn't be \"Unknown\"") { - val version = client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_VER).getStringValue - logInfo(s"Spark version: $version") - version != "Unknown" - } - } - } - - test("Checks Hive version") { - withJdbcStatement() { statement => - val resultSet = statement.executeQuery("SET spark.sql.hive.version") - resultSet.next() - assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") - } - } - - test("Checks Hive version in Http Mode") { - withJdbcStatement(httpMode = true) { statement => - val resultSet = statement.executeQuery("SET spark.sql.hive.version") - resultSet.next() - assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") - } - } - - test("SPARK-4292 regression: result set iterator issue") { - withJdbcStatement() { statement => - val queries = Seq( - "DROP TABLE IF EXISTS test_4292", - "CREATE TABLE test_4292(key INT, val STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_4292") - - queries.foreach(statement.execute) - - val resultSet = statement.executeQuery("SELECT key FROM test_4292") - - Seq(238, 86, 311, 27, 165).foreach { key => - resultSet.next() - assert(resultSet.getInt(1) === key) - } - - statement.executeQuery("DROP TABLE IF EXISTS test_4292") - } - } - - test("SPARK-4309 regression: Date type support") { - withJdbcStatement() { statement => - val queries = Seq( - "DROP TABLE IF EXISTS test_date", - "CREATE TABLE test_date(key INT, value STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_date") - - queries.foreach(statement.execute) - - assertResult(Date.valueOf("2011-01-01")) { - val resultSet = statement.executeQuery( - "SELECT CAST('2011-01-01' as date) FROM test_date LIMIT 1") - resultSet.next() - resultSet.getDate(1) - } - } - } - - test("SPARK-4407 regression: Complex type support") { - withJdbcStatement() { statement => - val queries = Seq( - "DROP TABLE IF EXISTS test_map", - "CREATE TABLE test_map(key INT, value STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map") - - queries.foreach(statement.execute) - - assertResult("""{238:"val_238"}""") { - val resultSet = statement.executeQuery("SELECT MAP(key, value) FROM test_map LIMIT 1") - resultSet.next() - resultSet.getString(1) - } - - assertResult("""["238","val_238"]""") { - val resultSet = statement.executeQuery( - "SELECT ARRAY(CAST(key AS STRING), value) FROM test_map LIMIT 1") - resultSet.next() - resultSet.getString(1) - } - } - } -} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala new file mode 100644 index 000000000000..4cf95e7bdfb2 --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -0,0 +1,561 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.io.File +import java.net.URL +import java.sql.{Date, DriverManager, Statement} + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ +import scala.concurrent.{Await, Promise} +import scala.sys.process.{Process, ProcessLogger} +import scala.util.{Random, Try} + +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hive.jdbc.HiveDriver +import org.apache.hive.service.auth.PlainSaslHelper +import org.apache.hive.service.cli.GetInfoType +import org.apache.hive.service.cli.thrift.TCLIService.Client +import org.apache.hive.service.cli.thrift.ThriftCLIServiceClient +import org.apache.thrift.protocol.TBinaryProtocol +import org.apache.thrift.transport.TSocket +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.Logging +import org.apache.spark.sql.hive.HiveShim +import org.apache.spark.util.Utils + +object TestData { + def getTestDataFilePath(name: String): URL = { + Thread.currentThread().getContextClassLoader.getResource(s"data/files/$name") + } + + val smallKv = getTestDataFilePath("small_kv.txt") + val smallKvWithNull = getTestDataFilePath("small_kv_with_null.txt") +} + +class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { + override def mode: ServerMode.Value = ServerMode.binary + + private def withCLIServiceClient(f: ThriftCLIServiceClient => Unit): Unit = { + // Transport creation logics below mimics HiveConnection.createBinaryTransport + val rawTransport = new TSocket("localhost", serverPort) + val user = System.getProperty("user.name") + val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport) + val protocol = new TBinaryProtocol(transport) + val client = new ThriftCLIServiceClient(new Client(protocol)) + + transport.open() + try f(client) finally transport.close() + } + + test("GetInfo Thrift API") { + withCLIServiceClient { client => + val user = System.getProperty("user.name") + val sessionHandle = client.openSession(user, "") + + assertResult("Spark SQL", "Wrong GetInfo(CLI_DBMS_NAME) result") { + client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_NAME).getStringValue + } + + assertResult("Spark SQL", "Wrong GetInfo(CLI_SERVER_NAME) result") { + client.getInfo(sessionHandle, GetInfoType.CLI_SERVER_NAME).getStringValue + } + + assertResult(true, "Spark version shouldn't be \"Unknown\"") { + val version = client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_VER).getStringValue + logInfo(s"Spark version: $version") + version != "Unknown" + } + } + } + + test("JDBC query execution") { + withJdbcStatement { statement => + val queries = Seq( + "SET spark.sql.shuffle.partitions=3", + "DROP TABLE IF EXISTS test", + "CREATE TABLE test(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test", + "CACHE TABLE test") + + queries.foreach(statement.execute) + + assertResult(5, "Row count mismatch") { + val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test") + resultSet.next() + resultSet.getInt(1) + } + } + } + + test("Checks Hive version") { + withJdbcStatement { statement => + val resultSet = statement.executeQuery("SET spark.sql.hive.version") + resultSet.next() + assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") + } + } + + test("SPARK-3004 regression: result set containing NULL") { + withJdbcStatement { statement => + val queries = Seq( + "DROP TABLE IF EXISTS test_null", + "CREATE TABLE test_null(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKvWithNull}' OVERWRITE INTO TABLE test_null") + + queries.foreach(statement.execute) + + val resultSet = statement.executeQuery("SELECT * FROM test_null WHERE key IS NULL") + + (0 until 5).foreach { _ => + resultSet.next() + assert(resultSet.getInt(1) === 0) + assert(resultSet.wasNull()) + } + + assert(!resultSet.next()) + } + } + + test("SPARK-4292 regression: result set iterator issue") { + withJdbcStatement { statement => + val queries = Seq( + "DROP TABLE IF EXISTS test_4292", + "CREATE TABLE test_4292(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_4292") + + queries.foreach(statement.execute) + + val resultSet = statement.executeQuery("SELECT key FROM test_4292") + + Seq(238, 86, 311, 27, 165).foreach { key => + resultSet.next() + assert(resultSet.getInt(1) === key) + } + + statement.executeQuery("DROP TABLE IF EXISTS test_4292") + } + } + + test("SPARK-4309 regression: Date type support") { + withJdbcStatement { statement => + val queries = Seq( + "DROP TABLE IF EXISTS test_date", + "CREATE TABLE test_date(key INT, value STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_date") + + queries.foreach(statement.execute) + + assertResult(Date.valueOf("2011-01-01")) { + val resultSet = statement.executeQuery( + "SELECT CAST('2011-01-01' as date) FROM test_date LIMIT 1") + resultSet.next() + resultSet.getDate(1) + } + } + } + + test("SPARK-4407 regression: Complex type support") { + withJdbcStatement { statement => + val queries = Seq( + "DROP TABLE IF EXISTS test_map", + "CREATE TABLE test_map(key INT, value STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map") + + queries.foreach(statement.execute) + + assertResult("""{238:"val_238"}""") { + val resultSet = statement.executeQuery("SELECT MAP(key, value) FROM test_map LIMIT 1") + resultSet.next() + resultSet.getString(1) + } + + assertResult("""["238","val_238"]""") { + val resultSet = statement.executeQuery( + "SELECT ARRAY(CAST(key AS STRING), value) FROM test_map LIMIT 1") + resultSet.next() + resultSet.getString(1) + } + } + } + + test("test multiple session") { + import org.apache.spark.sql.SQLConf + var defaultV1: String = null + var defaultV2: String = null + + withMultipleConnectionJdbcStatement( + // create table + { statement => + + val queries = Seq( + "DROP TABLE IF EXISTS test_map", + "CREATE TABLE test_map(key INT, value STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map", + "CACHE TABLE test_table AS SELECT key FROM test_map ORDER BY key DESC") + + queries.foreach(statement.execute) + + val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC") + val buf1 = new collection.mutable.ArrayBuffer[Int]() + while (rs1.next()) { + buf1 += rs1.getInt(1) + } + rs1.close() + + val rs2 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") + val buf2 = new collection.mutable.ArrayBuffer[Int]() + while (rs2.next()) { + buf2 += rs2.getInt(1) + } + rs2.close() + + assert(buf1 === buf2) + }, + + // first session, we get the default value of the session status + { statement => + + val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS}") + rs1.next() + defaultV1 = rs1.getString(1) + assert(defaultV1 != "200") + rs1.close() + + val rs2 = statement.executeQuery("SET hive.cli.print.header") + rs2.next() + + defaultV2 = rs2.getString(1) + assert(defaultV1 != "true") + rs2.close() + }, + + // second session, we update the session status + { statement => + + val queries = Seq( + s"SET ${SQLConf.SHUFFLE_PARTITIONS}=291", + "SET hive.cli.print.header=true" + ) + + queries.map(statement.execute) + val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS}") + rs1.next() + assert("spark.sql.shuffle.partitions=291" === rs1.getString(1)) + rs1.close() + + val rs2 = statement.executeQuery("SET hive.cli.print.header") + rs2.next() + assert("hive.cli.print.header=true" === rs2.getString(1)) + rs2.close() + }, + + // third session, we get the latest session status, supposed to be the + // default value + { statement => + + val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS}") + rs1.next() + assert(defaultV1 === rs1.getString(1)) + rs1.close() + + val rs2 = statement.executeQuery("SET hive.cli.print.header") + rs2.next() + assert(defaultV2 === rs2.getString(1)) + rs2.close() + }, + + // accessing the cached data in another session + { statement => + + val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC") + val buf1 = new collection.mutable.ArrayBuffer[Int]() + while (rs1.next()) { + buf1 += rs1.getInt(1) + } + rs1.close() + + val rs2 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") + val buf2 = new collection.mutable.ArrayBuffer[Int]() + while (rs2.next()) { + buf2 += rs2.getInt(1) + } + rs2.close() + + assert(buf1 === buf2) + statement.executeQuery("UNCACHE TABLE test_table") + + // TODO need to figure out how to determine if the data loaded from cache + val rs3 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") + val buf3 = new collection.mutable.ArrayBuffer[Int]() + while (rs3.next()) { + buf3 += rs3.getInt(1) + } + rs3.close() + + assert(buf1 === buf3) + }, + + // accessing the uncached table + { statement => + + // TODO need to figure out how to determine if the data loaded from cache + val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC") + val buf1 = new collection.mutable.ArrayBuffer[Int]() + while (rs1.next()) { + buf1 += rs1.getInt(1) + } + rs1.close() + + val rs2 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") + val buf2 = new collection.mutable.ArrayBuffer[Int]() + while (rs2.next()) { + buf2 += rs2.getInt(1) + } + rs2.close() + + assert(buf1 === buf2) + } + ) + } +} + +class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { + override def mode: ServerMode.Value = ServerMode.http + + test("JDBC query execution") { + withJdbcStatement { statement => + val queries = Seq( + "SET spark.sql.shuffle.partitions=3", + "DROP TABLE IF EXISTS test", + "CREATE TABLE test(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test", + "CACHE TABLE test") + + queries.foreach(statement.execute) + + assertResult(5, "Row count mismatch") { + val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test") + resultSet.next() + resultSet.getInt(1) + } + } + } + + test("Checks Hive version") { + withJdbcStatement { statement => + val resultSet = statement.executeQuery("SET spark.sql.hive.version") + resultSet.next() + assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") + } + } +} + +object ServerMode extends Enumeration { + val binary, http = Value +} + +abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { + Class.forName(classOf[HiveDriver].getCanonicalName) + + private def jdbcUri = if (mode == ServerMode.http) { + s"""jdbc:hive2://localhost:$serverPort/ + |default? + |hive.server2.transport.mode=http; + |hive.server2.thrift.http.path=cliservice + """.stripMargin.split("\n").mkString.trim + } else { + s"jdbc:hive2://localhost:$serverPort/" + } + + def withMultipleConnectionJdbcStatement(fs: (Statement => Unit)*) { + val user = System.getProperty("user.name") + val connections = fs.map { _ => DriverManager.getConnection(jdbcUri, user, "") } + val statements = connections.map(_.createStatement()) + + try { + statements.zip(fs).map { case (s, f) => f(s) } + } finally { + statements.map(_.close()) + connections.map(_.close()) + } + } + + def withJdbcStatement(f: Statement => Unit) { + withMultipleConnectionJdbcStatement(f) + } +} + +abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll with Logging { + def mode: ServerMode.Value + + private val CLASS_NAME = HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$") + private val LOG_FILE_MARK = s"starting $CLASS_NAME, logging to " + + private val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator) + private val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator) + + private var listeningPort: Int = _ + protected def serverPort: Int = listeningPort + + protected def user = System.getProperty("user.name") + + private var warehousePath: File = _ + private var metastorePath: File = _ + private def metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" + + private val pidDir: File = Utils.createTempDir("thriftserver-pid") + private var logPath: File = _ + private var logTailingProcess: Process = _ + private var diagnosisBuffer: ArrayBuffer[String] = ArrayBuffer.empty[String] + + private def serverStartCommand(port: Int) = { + val portConf = if (mode == ServerMode.binary) { + ConfVars.HIVE_SERVER2_THRIFT_PORT + } else { + ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT + } + + s"""$startScript + | --master local + | --hiveconf hive.root.logger=INFO,console + | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri + | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath + | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost + | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode + | --hiveconf $portConf=$port + | --driver-class-path ${sys.props("java.class.path")} + | --conf spark.ui.enabled=false + """.stripMargin.split("\\s+").toSeq + } + + private def startThriftServer(port: Int, attempt: Int) = { + warehousePath = Utils.createTempDir() + warehousePath.delete() + metastorePath = Utils.createTempDir() + metastorePath.delete() + logPath = null + logTailingProcess = null + + val command = serverStartCommand(port) + + diagnosisBuffer ++= + s""" + |### Attempt $attempt ### + |HiveThriftServer2 command line: $command + |Listening port: $port + |System user: $user + """.stripMargin.split("\n") + + logInfo(s"Trying to start HiveThriftServer2: port=$port, mode=$mode, attempt=$attempt") + + val env = Seq( + // Disables SPARK_TESTING to exclude log4j.properties in test directories. + "SPARK_TESTING" -> "0", + // Points SPARK_PID_DIR to SPARK_HOME, otherwise only 1 Thrift server instance can be started + // at a time, which is not Jenkins friendly. + "SPARK_PID_DIR" -> pidDir.getCanonicalPath) + + logPath = Process(command, None, env: _*).lines.collectFirst { + case line if line.contains(LOG_FILE_MARK) => new File(line.drop(LOG_FILE_MARK.length)) + }.getOrElse { + throw new RuntimeException("Failed to find HiveThriftServer2 log file.") + } + + val serverStarted = Promise[Unit]() + + // Ensures that the following "tail" command won't fail. + logPath.createNewFile() + logTailingProcess = + // Using "-n +0" to make sure all lines in the log file are checked. + Process(s"/usr/bin/env tail -n +0 -f ${logPath.getCanonicalPath}").run(ProcessLogger( + (line: String) => { + diagnosisBuffer += line + + if (line.contains("ThriftBinaryCLIService listening on") || + line.contains("Started ThriftHttpCLIService in http")) { + serverStarted.trySuccess(()) + } else if (line.contains("HiveServer2 is stopped")) { + // This log line appears when the server fails to start and terminates gracefully (e.g. + // because of port contention). + serverStarted.tryFailure(new RuntimeException("Failed to start HiveThriftServer2")) + } + })) + + Await.result(serverStarted.future, 2.minute) + } + + private def stopThriftServer(): Unit = { + // The `spark-daemon.sh' script uses kill, which is not synchronous, have to wait for a while. + Process(stopScript, None, "SPARK_PID_DIR" -> pidDir.getCanonicalPath).run().exitValue() + Thread.sleep(3.seconds.toMillis) + + warehousePath.delete() + warehousePath = null + + metastorePath.delete() + metastorePath = null + + Option(logPath).foreach(_.delete()) + logPath = null + + Option(logTailingProcess).foreach(_.destroy()) + logTailingProcess = null + } + + private def dumpLogs(): Unit = { + logError( + s""" + |===================================== + |HiveThriftServer2Suite failure output + |===================================== + |${diagnosisBuffer.mkString("\n")} + |========================================= + |End HiveThriftServer2Suite failure output + |========================================= + """.stripMargin) + } + + override protected def beforeAll(): Unit = { + // Chooses a random port between 10000 and 19999 + listeningPort = 10000 + Random.nextInt(10000) + diagnosisBuffer.clear() + + // Retries up to 3 times with different port numbers if the server fails to start + (1 to 3).foldLeft(Try(startThriftServer(listeningPort, 0))) { case (started, attempt) => + started.orElse { + listeningPort += 1 + stopThriftServer() + Try(startThriftServer(listeningPort, attempt)) + } + }.recover { + case cause: Throwable => + dumpLogs() + throw cause + }.get + + logInfo(s"HiveThriftServer2 started successfully") + } + + override protected def afterAll(): Unit = { + stopThriftServer() + logInfo("HiveThriftServer2 stopped") + } +} diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala index 166c56b9dfe2..95a6e86d0546 100644 --- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala +++ b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala @@ -18,8 +18,15 @@ package org.apache.spark.sql.hive.thriftserver import java.sql.{Date, Timestamp} +import java.util.concurrent.Executors import java.util.{ArrayList => JArrayList, Map => JMap} +import org.apache.commons.logging.Log +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hive.service.cli.thrift.TProtocolVersion +import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager + import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, Map => SMap} @@ -29,10 +36,10 @@ import org.apache.hadoop.hive.shims.ShimLoader import org.apache.hadoop.security.UserGroupInformation import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.ExecuteStatementOperation -import org.apache.hive.service.cli.session.HiveSession +import org.apache.hive.service.cli.session.{SessionManager, HiveSession} import org.apache.spark.Logging -import org.apache.spark.sql.{SQLConf, SchemaRDD, Row => SparkRow} +import org.apache.spark.sql.{DataFrame, SQLConf, Row => SparkRow} import org.apache.spark.sql.execution.SetCommand import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} @@ -71,7 +78,7 @@ private[hive] class SparkExecuteStatementOperation( sessionToActivePool: SMap[SessionHandle, String]) extends ExecuteStatementOperation(parentSession, statement, confOverlay) with Logging { - private var result: SchemaRDD = _ + private var result: DataFrame = _ private var iter: Iterator[SparkRow] = _ private var dataTypes: Array[DataType] = _ @@ -185,6 +192,10 @@ private[hive] class SparkExecuteStatementOperation( def run(): Unit = { logInfo(s"Running query '$statement'") setState(OperationState.RUNNING) + hiveContext.sparkContext.setJobDescription(statement) + sessionToActivePool.get(parentSession.getSessionHandle).foreach { pool => + hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) + } try { result = hiveContext.sql(statement) logDebug(result.queryExecution.toString()) @@ -194,15 +205,11 @@ private[hive] class SparkExecuteStatementOperation( logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") case _ => } - hiveContext.sparkContext.setJobDescription(statement) - sessionToActivePool.get(parentSession.getSessionHandle).foreach { pool => - hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) - } iter = { val useIncrementalCollect = hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean if (useIncrementalCollect) { - result.toLocalIterator + result.rdd.toLocalIterator } else { result.collect().iterator } @@ -220,3 +227,42 @@ private[hive] class SparkExecuteStatementOperation( setState(OperationState.FINISHED) } } + +private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) + extends SessionManager + with ReflectedCompositeService { + + private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) + + override def init(hiveConf: HiveConf) { + setSuperField(this, "hiveConf", hiveConf) + + val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) + setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) + getAncestorField[Log](this, 3, "LOG").info( + s"HiveServer2: Async execution pool size $backgroundPoolSize") + + setSuperField(this, "operationManager", sparkSqlOperationManager) + addService(sparkSqlOperationManager) + + initCompositeService(hiveConf) + } + + override def openSession( + username: String, + passwd: String, + sessionConf: java.util.Map[String, String], + withImpersonation: Boolean, + delegationToken: String): SessionHandle = { + hiveContext.openSession() + + super.openSession(username, passwd, sessionConf, withImpersonation, delegationToken) + } + + override def closeSession(sessionHandle: SessionHandle) { + super.closeSession(sessionHandle) + sparkSqlOperationManager.sessionToActivePool -= sessionHandle + + hiveContext.detachSession() + } +} diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala index eaf7a1ddd499..178eb1af7cdc 100644 --- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala +++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala @@ -18,8 +18,15 @@ package org.apache.spark.sql.hive.thriftserver import java.sql.{Date, Timestamp} +import java.util.concurrent.Executors import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} +import org.apache.commons.logging.Log +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hive.service.cli.thrift.TProtocolVersion +import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager + import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, Map => SMap} @@ -27,10 +34,10 @@ import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.security.UserGroupInformation import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.ExecuteStatementOperation -import org.apache.hive.service.cli.session.HiveSession +import org.apache.hive.service.cli.session.{SessionManager, HiveSession} import org.apache.spark.Logging -import org.apache.spark.sql.{Row => SparkRow, SQLConf, SchemaRDD} +import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} import org.apache.spark.sql.execution.SetCommand import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} @@ -72,7 +79,7 @@ private[hive] class SparkExecuteStatementOperation( // NOTE: `runInBackground` is set to `false` intentionally to disable asynchronous execution extends ExecuteStatementOperation(parentSession, statement, confOverlay, false) with Logging { - private var result: SchemaRDD = _ + private var result: DataFrame = _ private var iter: Iterator[SparkRow] = _ private var dataTypes: Array[DataType] = _ @@ -156,6 +163,10 @@ private[hive] class SparkExecuteStatementOperation( def run(): Unit = { logInfo(s"Running query '$statement'") setState(OperationState.RUNNING) + hiveContext.sparkContext.setJobDescription(statement) + sessionToActivePool.get(parentSession.getSessionHandle).foreach { pool => + hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) + } try { result = hiveContext.sql(statement) logDebug(result.queryExecution.toString()) @@ -165,15 +176,11 @@ private[hive] class SparkExecuteStatementOperation( logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") case _ => } - hiveContext.sparkContext.setJobDescription(statement) - sessionToActivePool.get(parentSession.getSessionHandle).foreach { pool => - hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) - } iter = { val useIncrementalCollect = hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean if (useIncrementalCollect) { - result.toLocalIterator + result.rdd.toLocalIterator } else { result.collect().iterator } @@ -191,3 +198,43 @@ private[hive] class SparkExecuteStatementOperation( setState(OperationState.FINISHED) } } + +private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) + extends SessionManager + with ReflectedCompositeService { + + private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) + + override def init(hiveConf: HiveConf) { + setSuperField(this, "hiveConf", hiveConf) + + val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) + setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) + getAncestorField[Log](this, 3, "LOG").info( + s"HiveServer2: Async execution pool size $backgroundPoolSize") + + setSuperField(this, "operationManager", sparkSqlOperationManager) + addService(sparkSqlOperationManager) + + initCompositeService(hiveConf) + } + + override def openSession( + protocol: TProtocolVersion, + username: String, + passwd: String, + sessionConf: java.util.Map[String, String], + withImpersonation: Boolean, + delegationToken: String): SessionHandle = { + hiveContext.openSession() + + super.openSession(protocol, username, passwd, sessionConf, withImpersonation, delegationToken) + } + + override def closeSession(sessionHandle: SessionHandle) { + super.closeSession(sessionHandle) + sparkSqlOperationManager.sessionToActivePool -= sessionHandle + + hiveContext.detachSession() + } +} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 0d934620aca0..81ee48ef4152 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -225,13 +225,22 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Needs constant object inspectors "udf_round", + // the table src(key INT, value STRING) is not the same as HIVE unittest. In Hive + // is src(key STRING, value STRING), and in the reflect.q, it failed in + // Integer.valueOf, which expect the first argument passed as STRING type not INT. + "udf_reflect", + // Sort with Limit clause causes failure. "ctas", "ctas_hadoop20", // timestamp in array, the output format of Hive contains double quotes, while // Spark SQL doesn't - "udf_sort_array" + "udf_sort_array", + + // It has a bug and it has been fixed by + // https://issues.apache.org/jira/browse/HIVE-7673 (in Hive 0.14 and trunk). + "input46" ) ++ HiveShim.compatibilityBlackList /** @@ -357,6 +366,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "database_drop", "database_location", "database_properties", + "date_1", "date_2", "date_3", "date_4", @@ -517,10 +527,12 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "inputddl2", "inputddl3", "inputddl4", + "inputddl5", "inputddl6", "inputddl7", "inputddl8", "insert1", + "insert1_overwrite_partitions", "insert2_overwrite_partitions", "insert_compressed", "join0", @@ -625,6 +637,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "mapreduce8", "merge1", "merge2", + "merge4", "mergejoins", "multiMapJoin1", "multiMapJoin2", @@ -638,6 +651,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "nonblock_op_deduplicate", "notable_alias1", "notable_alias2", + "nullformatCTAS", "nullgroup", "nullgroup2", "nullgroup3", @@ -717,6 +731,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "select_unquote_and", "select_unquote_not", "select_unquote_or", + "semicolon", "semijoin", "serde_regex", "serde_reported_schema", @@ -786,6 +801,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udaf_covar_pop", "udaf_covar_samp", "udaf_histogram_numeric", + "udaf_number_format", "udf2", "udf5", "udf6", @@ -883,6 +899,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_power", "udf_radians", "udf_rand", + "udf_reflect2", "udf_regexp", "udf_regexp_extract", "udf_regexp_replace", diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala new file mode 100644 index 000000000000..65d070bd3cbd --- /dev/null +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.hive.test.TestHive + +/** + * Runs the test cases that are included in the hive distribution with sort merge join is true. + */ +class SortMergeCompatibilitySuite extends HiveCompatibilitySuite { + override def beforeAll() { + super.beforeAll() + TestHive.setConf(SQLConf.SORTMERGE_JOIN, "true") + } + + override def afterAll() { + TestHive.setConf(SQLConf.SORTMERGE_JOIN, "false") + super.afterAll() + } + + override def whiteList = Seq( + "auto_join0", + "auto_join1", + "auto_join10", + "auto_join11", + "auto_join12", + "auto_join13", + "auto_join14", + "auto_join14_hadoop20", + "auto_join15", + "auto_join17", + "auto_join18", + "auto_join19", + "auto_join2", + "auto_join20", + "auto_join21", + "auto_join22", + "auto_join23", + "auto_join24", + "auto_join25", + "auto_join26", + "auto_join27", + "auto_join28", + "auto_join3", + "auto_join30", + "auto_join31", + "auto_join32", + "auto_join4", + "auto_join5", + "auto_join6", + "auto_join7", + "auto_join8", + "auto_join9", + "auto_join_filters", + "auto_join_nulls", + "auto_join_reordering_values", + "auto_smb_mapjoin_14", + "auto_sortmerge_join_1", + "auto_sortmerge_join_10", + "auto_sortmerge_join_11", + "auto_sortmerge_join_12", + "auto_sortmerge_join_13", + "auto_sortmerge_join_14", + "auto_sortmerge_join_15", + "auto_sortmerge_join_16", + "auto_sortmerge_join_2", + "auto_sortmerge_join_3", + "auto_sortmerge_join_4", + "auto_sortmerge_join_5", + "auto_sortmerge_join_6", + "auto_sortmerge_join_7", + "auto_sortmerge_join_8", + "auto_sortmerge_join_9", + "correlationoptimizer1", + "correlationoptimizer10", + "correlationoptimizer11", + "correlationoptimizer13", + "correlationoptimizer14", + "correlationoptimizer15", + "correlationoptimizer2", + "correlationoptimizer3", + "correlationoptimizer4", + "correlationoptimizer6", + "correlationoptimizer7", + "correlationoptimizer8", + "correlationoptimizer9", + "join0", + "join1", + "join10", + "join11", + "join12", + "join13", + "join14", + "join14_hadoop20", + "join15", + "join16", + "join17", + "join18", + "join19", + "join2", + "join20", + "join21", + "join22", + "join23", + "join24", + "join25", + "join26", + "join27", + "join28", + "join29", + "join3", + "join30", + "join31", + "join32", + "join32_lessSize", + "join33", + "join34", + "join35", + "join36", + "join37", + "join38", + "join39", + "join4", + "join40", + "join41", + "join5", + "join6", + "join7", + "join8", + "join9", + "join_1to1", + "join_array", + "join_casesensitive", + "join_empty", + "join_filters", + "join_hive_626", + "join_map_ppr", + "join_nulls", + "join_nullsafe", + "join_rc", + "join_reorder2", + "join_reorder3", + "join_reorder4", + "join_star" + ) +} diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 58b0722464be..21dce8d8a565 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -21,8 +21,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT ../../pom.xml @@ -59,6 +59,11 @@ ${hive.group} hive-exec + + org.apache.httpcomponents + httpclient + ${commons.httpclient.version} + org.codehaus.jackson jackson-mapper-asl @@ -84,6 +89,25 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + org.apache.spark + spark-sql_${scala.binary.version} + test-jar + ${project.version} + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + test-jar + ${project.version} + test + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 9d2cfd8e0d66..78c46b522f06 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -18,30 +18,28 @@ package org.apache.spark.sql.hive import java.io.{BufferedReader, InputStreamReader, PrintStream} -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import scala.collection.JavaConversions._ import scala.language.implicitConversions -import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.Table -import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.parse.VariableSubstitution +import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateAnalysisOperators, OverrideCatalog, OverrideFunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateSubQueries, OverrideCatalog, OverrideFunctionRegistry} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand, QueryExecutionException} -import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DescribeHiveTableCommand} -import org.apache.spark.sql.sources.DataSourceStrategy +import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, QueryExecutionException, SetCommand} +import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} +import org.apache.spark.sql.sources.{DDLParser, DataSourceStrategy} import org.apache.spark.sql.types._ /** @@ -51,10 +49,6 @@ import org.apache.spark.sql.types._ class HiveContext(sc: SparkContext) extends SQLContext(sc) { self => - protected[sql] override lazy val conf: SQLConf = new SQLConf { - override def dialect: String = getConf(SQLConf.DIALECT, "hiveql") - } - /** * When true, enables an experimental feature where metastore tables that use the parquet SerDe * are automatically converted to use the Spark SQL parquet table scan, instead of the Hive @@ -63,32 +57,50 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected[sql] def convertMetastoreParquet: Boolean = getConf("spark.sql.hive.convertMetastoreParquet", "true") == "true" + /** + * When true, also tries to merge possibly different but compatible Parquet schemas in different + * Parquet data files. + * + * This configuration is only effective when "spark.sql.hive.convertMetastoreParquet" is true. + */ + protected[sql] def convertMetastoreParquetWithSchemaMerging: Boolean = + getConf("spark.sql.hive.convertMetastoreParquet.mergeSchema", "false") == "true" + + /** + * When true, a table created by a Hive CTAS statement (no USING clause) will be + * converted to a data source table, using the data source set by spark.sql.sources.default. + * The table in CTAS statement will be converted when it meets any of the following conditions: + * - The CTAS does not specify any of a SerDe (ROW FORMAT SERDE), a File Format (STORED AS), or + * a Storage Hanlder (STORED BY), and the value of hive.default.fileformat in hive-site.xml + * is either TextFile or SequenceFile. + * - The CTAS statement specifies TextFile (STORED AS TEXTFILE) as the file format and no SerDe + * is specified (no ROW FORMAT SERDE clause). + * - The CTAS statement specifies SequenceFile (STORED AS SEQUENCEFILE) as the file format + * and no SerDe is specified (no ROW FORMAT SERDE clause). + */ + protected[sql] def convertCTAS: Boolean = + getConf("spark.sql.hive.convertCTAS", "false").toBoolean + override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = - new this.QueryExecution { val logical = plan } + new this.QueryExecution(plan) - override def sql(sqlText: String): SchemaRDD = { + @transient + protected[sql] val ddlParserWithHiveQL = new DDLParser(HiveQl.parseSql(_)) + + override def sql(sqlText: String): DataFrame = { val substituted = new VariableSubstitution().substitute(hiveconf, sqlText) // TODO: Create a framework for registering parsers instead of just hardcoding if statements. if (conf.dialect == "sql") { super.sql(substituted) } else if (conf.dialect == "hiveql") { - new SchemaRDD(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(substituted))) + // 这里需要重构,直接使用 DataFrame(this, parseSql(sqlText)) 即可 + val ddlPlan = ddlParserWithHiveQL.parse(sqlText, exceptionOnError = false) + DataFrame(this, ddlPlan.getOrElse(HiveQl.parseSql(substituted))) } else { - sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'") + sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'") } } - /** - * Creates a table using the schema of the given class. - * - * @param tableName The name of the table to create. - * @param allowExisting When false, an exception will be thrown if the table already exists. - * @tparam A A case class that is used to describe the schema of the table to be created. - */ - def createTable[A <: Product : TypeTag](tableName: String, allowExisting: Boolean = true) { - catalog.createTable("default", tableName, ScalaReflection.attributesFor[A], allowExisting) - } - /** * Invalidate and refresh all the cached the metadata of the given table. For performance reasons, * Spark SQL or the external data source library it uses might cache certain metadata about a @@ -114,7 +126,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { */ @Experimental def analyze(tableName: String) { - val relation = EliminateAnalysisOperators(catalog.lookupRelation(Seq(tableName))) + val relation = EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) relation match { case relation: MetastoreRelation => @@ -170,18 +182,19 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val tableFullName = relation.hiveQlTable.getDbName + "." + relation.hiveQlTable.getTableName - catalog.client.alterTable(tableFullName, new Table(hiveTTable)) + catalog.synchronized { + catalog.client.alterTable(tableFullName, new Table(hiveTTable)) + } } case otherRelation => - throw new NotImplementedError( - s"Analyze has only implemented for Hive tables, " + - s"but $tableName is a ${otherRelation.nodeName}") + throw new UnsupportedOperationException( + s"Analyze only works for Hive tables, but $tableName is a ${otherRelation.nodeName}") } } // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. @transient - protected lazy val outputBuffer = new java.io.OutputStream { + protected lazy val outputBuffer = new java.io.OutputStream { var pos: Int = 0 var buffer = new Array[Int](10240) def write(i: Int): Unit = { @@ -189,7 +202,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { pos = (pos + 1) % buffer.size } - override def toString = { + override def toString: String = { val (end, start) = buffer.splitAt(pos) val input = new java.io.InputStream { val iterator = (start ++ end).iterator @@ -208,30 +221,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } } - /** - * SQLConf and HiveConf contracts: - * - * 1. reuse existing started SessionState if any - * 2. when the Hive session is first initialized, params in HiveConf will get picked up by the - * SQLConf. Additionally, any properties set by set() or a SET command inside sql() will be - * set in the SQLConf *as well as* in the HiveConf. - */ - @transient protected[hive] lazy val (hiveconf, sessionState) = - Option(SessionState.get()) - .orElse { - val newState = new SessionState(new HiveConf(classOf[SessionState])) - // Only starts newly created `SessionState` instance. Any existing `SessionState` instance - // returned by `SessionState.get()` must be the most recently started one. - SessionState.start(newState) - Some(newState) - } - .map { state => - setConf(state.getConf.getAllProperties) - if (state.out == null) state.out = new PrintStream(outputBuffer, true, "UTF-8") - if (state.err == null) state.err = new PrintStream(outputBuffer, true, "UTF-8") - (state.getConf, state) - } - .get + protected[hive] def sessionState = tlSession.get().asInstanceOf[this.SQLSession].sessionState + + protected[hive] def hiveconf = tlSession.get().asInstanceOf[this.SQLSession].hiveconf override def setConf(key: String, value: String): Unit = { super.setConf(key, value) @@ -245,19 +237,61 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // Note that HiveUDFs will be overridden by functions registered in this context. @transient override protected[sql] lazy val functionRegistry = - new HiveFunctionRegistry with OverrideFunctionRegistry + new HiveFunctionRegistry with OverrideFunctionRegistry { + def caseSensitive: Boolean = false + } /* An analyzer that uses the Hive metastore. */ @transient override protected[sql] lazy val analyzer = new Analyzer(catalog, functionRegistry, caseSensitive = false) { - override val extendedRules = + override val extendedResolutionRules = + catalog.ParquetConversions :: catalog.CreateTables :: catalog.PreInsertionCasts :: ExtractPythonUdfs :: + sources.PreInsertCastAndRename :: Nil } + override protected[sql] def createSession(): SQLSession = { + new this.SQLSession() + } + + protected[hive] class SQLSession extends super.SQLSession { + protected[sql] override lazy val conf: SQLConf = new SQLConf { + override def dialect: String = getConf(SQLConf.DIALECT, "hiveql") + } + + protected[hive] lazy val hiveconf: HiveConf = { + setConf(sessionState.getConf.getAllProperties) + sessionState.getConf + } + + /** + * SQLConf and HiveConf contracts: + * + * 1. reuse existing started SessionState if any + * 2. when the Hive session is first initialized, params in HiveConf will get picked up by the + * SQLConf. Additionally, any properties set by set() or a SET command inside sql() will be + * set in the SQLConf *as well as* in the HiveConf. + */ + protected[hive] lazy val sessionState: SessionState = { + var state = SessionState.get() + if (state == null) { + state = new SessionState(new HiveConf(classOf[SessionState])) + SessionState.start(state) + } + if (state.out == null) { + state.out = new PrintStream(outputBuffer, true, "UTF-8") + } + if (state.err == null) { + state.err = new PrintStream(outputBuffer, true, "UTF-8") + } + state + } + } + /** * Runs the specified SQL query using Hive. */ @@ -352,7 +386,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override protected[sql] val planner = hivePlanner /** Extends QueryExecution with hive specific features. */ - protected[sql] abstract class QueryExecution extends super.QueryExecution { + protected[sql] class QueryExecution(logicalPlan: LogicalPlan) + extends super.QueryExecution(logicalPlan) { + // Like what we do in runHive, makes sure the session represented by the + // `sessionState` field is activated. + if (SessionState.get() != sessionState) { + SessionState.start(sessionState) + } /** * Returns the result as a hive compatible sequence of strings. For native commands, the @@ -408,7 +448,7 @@ private object HiveContext { toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) }.toSeq.sorted.mkString("{", ",", "}") case (null, _) => "NULL" - case (d: Date, DateType) => new DateWritable(d).toString + case (d: Int, DateType) => new DateWritable(d).toString case (t: Timestamp, TimestampType) => new TimestampWritable(t).toString case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8") case (decimal: java.math.BigDecimal, DecimalType()) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 82dba99900df..74ae984f3486 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -34,7 +34,7 @@ import scala.collection.JavaConversions._ * 1. The Underlying data type in catalyst and in Hive * In catalyst: * Primitive => - * java.lang.String + * UTF8String * int / scala.Int * boolean / scala.Boolean * float / scala.Float @@ -239,9 +239,10 @@ private[hive] trait HiveInspectors { */ def unwrap(data: Any, oi: ObjectInspector): Any = oi match { case coi: ConstantObjectInspector if coi.getWritableConstantValue == null => null - case poi: WritableConstantStringObjectInspector => poi.getWritableConstantValue.toString + case poi: WritableConstantStringObjectInspector => + UTF8String(poi.getWritableConstantValue.toString) case poi: WritableConstantHiveVarcharObjectInspector => - poi.getWritableConstantValue.getHiveVarchar.getValue + UTF8String(poi.getWritableConstantValue.getHiveVarchar.getValue) case poi: WritableConstantHiveDecimalObjectInspector => HiveShim.toCatalystDecimal( PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector, @@ -267,7 +268,8 @@ private[hive] trait HiveInspectors { val temp = new Array[Byte](writable.getLength) System.arraycopy(writable.getBytes, 0, temp, 0, temp.length) temp - case poi: WritableConstantDateObjectInspector => poi.getWritableConstantValue.get() + case poi: WritableConstantDateObjectInspector => + DateUtils.fromJavaDate(poi.getWritableConstantValue.get()) case mi: StandardConstantMapObjectInspector => // take the value from the map inspector object, rather than the input data mi.getWritableConstantValue.map { case (k, v) => @@ -283,10 +285,13 @@ private[hive] trait HiveInspectors { case pi: PrimitiveObjectInspector => pi match { // We think HiveVarchar is also a String case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() => - hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue - case hvoi: HiveVarcharObjectInspector => hvoi.getPrimitiveJavaObject(data).getValue + UTF8String(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue) + case hvoi: HiveVarcharObjectInspector => + UTF8String(hvoi.getPrimitiveJavaObject(data).getValue) case x: StringObjectInspector if x.preferWritable() => - x.getPrimitiveWritableObject(data).toString + UTF8String(x.getPrimitiveWritableObject(data).toString) + case x: StringObjectInspector => + UTF8String(x.getPrimitiveJavaObject(data)) case x: IntObjectInspector if x.preferWritable() => x.get(data) case x: BooleanObjectInspector if x.preferWritable() => x.get(data) case x: FloatObjectInspector if x.preferWritable() => x.get(data) @@ -304,7 +309,8 @@ private[hive] trait HiveInspectors { System.arraycopy(bw.getBytes(), 0, result, 0, bw.getLength()) result case x: DateObjectInspector if x.preferWritable() => - x.getPrimitiveWritableObject(data).get() + DateUtils.fromJavaDate(x.getPrimitiveWritableObject(data).get()) + case x: DateObjectInspector => DateUtils.fromJavaDate(x.getPrimitiveJavaObject(data)) // org.apache.hadoop.hive.serde2.io.TimestampWritable.set will reset current time object // if next timestamp is null, so Timestamp object is cloned case x: TimestampObjectInspector if x.preferWritable() => @@ -338,11 +344,16 @@ private[hive] trait HiveInspectors { */ protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match { case _: JavaHiveVarcharObjectInspector => - (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size) + (o: Any) => + val s = o.asInstanceOf[UTF8String].toString + new HiveVarchar(s, s.size) case _: JavaHiveDecimalObjectInspector => (o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toJavaBigDecimal) + case _: JavaDateObjectInspector => + (o: Any) => DateUtils.toJavaDate(o.asInstanceOf[Int]) + case soi: StandardStructObjectInspector => val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) (o: Any) => { @@ -404,7 +415,7 @@ private[hive] trait HiveInspectors { case x: PrimitiveObjectInspector => x match { // TODO we don't support the HiveVarcharObjectInspector yet. case _: StringObjectInspector if x.preferWritable() => HiveShim.getStringWritable(a) - case _: StringObjectInspector => a.asInstanceOf[java.lang.String] + case _: StringObjectInspector => a.asInstanceOf[UTF8String].toString() case _: IntObjectInspector if x.preferWritable() => HiveShim.getIntWritable(a) case _: IntObjectInspector => a.asInstanceOf[java.lang.Integer] case _: BooleanObjectInspector if x.preferWritable() => HiveShim.getBooleanWritable(a) @@ -426,7 +437,7 @@ private[hive] trait HiveInspectors { case _: BinaryObjectInspector if x.preferWritable() => HiveShim.getBinaryWritable(a) case _: BinaryObjectInspector => a.asInstanceOf[Array[Byte]] case _: DateObjectInspector if x.preferWritable() => HiveShim.getDateWritable(a) - case _: DateObjectInspector => a.asInstanceOf[java.sql.Date] + case _: DateObjectInspector => DateUtils.toJavaDate(a.asInstanceOf[Int]) case _: TimestampObjectInspector if x.preferWritable() => HiveShim.getTimestampWritable(a) case _: TimestampObjectInspector => a.asInstanceOf[java.sql.Timestamp] } @@ -588,7 +599,7 @@ private[hive] trait HiveInspectors { case Literal(_, dt) => sys.error(s"Hive doesn't support the constant type [$dt].") // ideally, we don't test the foldable here(but in optimizer), however, some of the // Hive UDF / UDAF requires its argument to be constant objectinspector, we do it eagerly. - case _ if expr.foldable => toInspector(Literal(expr.eval(), expr.dataType)) + case _ if expr.foldable => toInspector(Literal.create(expr.eval(), expr.dataType)) // For those non constant expression, map to object inspector according to its data type case _ => toInspector(expr.dataType) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 1a49f09bd998..f1c0bd92aa23 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -20,26 +20,27 @@ package org.apache.spark.sql.hive import java.io.IOException import java.util.{List => JList} -import com.google.common.cache.{LoadingCache, CacheLoader, CacheBuilder} - -import org.apache.hadoop.util.ReflectionUtils -import org.apache.hadoop.hive.metastore.TableType -import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition, FieldSchema} -import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table, HiveException} -import org.apache.hadoop.hive.ql.metadata.InvalidTableException +import com.google.common.base.Objects +import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} +import org.apache.hadoop.hive.metastore.api.{FieldSchema, Partition => TPartition, Table => TTable} +import org.apache.hadoop.hive.metastore.{TableType, Warehouse} +import org.apache.hadoop.hive.ql.metadata._ import org.apache.hadoop.hive.ql.plan.CreateTableDesc import org.apache.hadoop.hive.serde.serdeConstants -import org.apache.hadoop.hive.serde2.{Deserializer, SerDeException} import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe +import org.apache.hadoop.hive.serde2.{Deserializer, SerDeException} +import org.apache.hadoop.util.ReflectionUtils import org.apache.spark.Logging -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.analysis.{Catalog, OverrideCatalog} +import org.apache.spark.sql.{SaveMode, AnalysisException, SQLContext} +import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NoSuchTableException, Catalog, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.sources.{DDLParser, LogicalRelation, ResolvedDataSource} +import org.apache.spark.sql.parquet.{ParquetRelation2, Partition => ParquetPartition, PartitionSpec} +import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, DDLParser, LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -52,10 +53,13 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with /** Connection to hive metastore. Usages should lock on `this`. */ protected[hive] val client = Hive.get(hive.hiveconf) + /** Usages should lock on `this`. */ + protected[hive] lazy val hiveWarehouse = new Warehouse(hive.hiveconf) + // TODO: Use this everywhere instead of tuples or databaseName, tableName,. /** A fully qualified identifier for a table (i.e., database.tableName) */ case class QualifiedTableName(database: String, name: String) { - def toLowerCase = QualifiedTableName(database.toLowerCase, name.toLowerCase) + def toLowerCase: QualifiedTableName = QualifiedTableName(database.toLowerCase, name.toLowerCase) } /** A cache of Spark SQL data source tables that have been accessed. */ @@ -63,14 +67,36 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val cacheLoader = new CacheLoader[QualifiedTableName, LogicalPlan]() { override def load(in: QualifiedTableName): LogicalPlan = { logDebug(s"Creating new cached data source for $in") - val table = client.getTable(in.database, in.name) - val schemaString = table.getProperty("spark.sql.sources.schema") - val userSpecifiedSchema = - if (schemaString == null) { - None - } else { - Some(DataType.fromJson(schemaString).asInstanceOf[StructType]) + val table = HiveMetastoreCatalog.this.synchronized { + client.getTable(in.database, in.name) + } + + def schemaStringFromParts: Option[String] = { + Option(table.getProperty("spark.sql.sources.schema.numParts")).map { numParts => + val parts = (0 until numParts.toInt).map { index => + val part = table.getProperty(s"spark.sql.sources.schema.part.${index}") + if (part == null) { + throw new AnalysisException( + s"Could not read schema from the metastore because it is corrupted " + + s"(missing part ${index} of the schema).") + } + + part + } + // Stick all parts back to a single schema string. + parts.mkString } + } + + // Originally, we used spark.sql.sources.schema to store the schema of a data source table. + // After SPARK-6024, we removed this flag. + // Although we are not using spark.sql.sources.schema any more, we need to still support. + val schemaString = + Option(table.getProperty("spark.sql.sources.schema")).orElse(schemaStringFromParts) + + val userSpecifiedSchema = + schemaString.map(s => DataType.fromJson(s).asInstanceOf[StructType]) + // It does not appear that the ql client for the metastore has a way to enumerate all the // SerDe properties directly... val options = table.getTTable.getSd.getSerdeInfo.getParameters.toMap @@ -89,8 +115,16 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with CacheBuilder.newBuilder().maximumSize(1000).build(cacheLoader) } - def refreshTable(databaseName: String, tableName: String): Unit = { - cachedDataSourceTables.refresh(QualifiedTableName(databaseName, tableName).toLowerCase) + override def refreshTable(databaseName: String, tableName: String): Unit = { + // refreshTable does not eagerly reload the cache. It just invalidate the cache. + // Next time when we use the table, it will be populated in the cache. + // Since we also cache ParquetRealtions converted from Hive Parquet tables and + // adding converted ParquetRealtions into the cache is not defined in the load function + // of the cache (instead, we add the cache entry in convertToParquetRelation), + // it is better at here to invalidate the cache to avoid confusing waring logs from the + // cache loader (e.g. cannot find data source provider, which is only defined for + // data source table.). + invalidateTable(databaseName, tableName) } def invalidateTable(databaseName: String, tableName: String): Unit = { @@ -99,22 +133,39 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val caseSensitive: Boolean = false + /** + * Creates a data source table (a table created with USING clause) in Hive's metastore. + * Returns true when the table has been created. Otherwise, false. + */ def createDataSourceTable( tableName: String, userSpecifiedSchema: Option[StructType], provider: String, - options: Map[String, String]) = { + options: Map[String, String], + isExternal: Boolean): Unit = { val (dbName, tblName) = processDatabaseAndTableName("default", tableName) val tbl = new Table(dbName, tblName) tbl.setProperty("spark.sql.sources.provider", provider) if (userSpecifiedSchema.isDefined) { - tbl.setProperty("spark.sql.sources.schema", userSpecifiedSchema.get.json) + val threshold = hive.conf.schemaStringLengthThreshold + val schemaJsonString = userSpecifiedSchema.get.json + // Split the JSON string. + val parts = schemaJsonString.grouped(threshold).toSeq + tbl.setProperty("spark.sql.sources.schema.numParts", parts.size.toString) + parts.zipWithIndex.foreach { case (part, index) => + tbl.setProperty(s"spark.sql.sources.schema.part.${index}", part) + } } options.foreach { case (key, value) => tbl.setSerdeParam(key, value) } - tbl.setProperty("EXTERNAL", "TRUE") - tbl.setTableType(TableType.EXTERNAL_TABLE) + if (isExternal) { + tbl.setProperty("EXTERNAL", "TRUE") + tbl.setTableType(TableType.EXTERNAL_TABLE) + } else { + tbl.setProperty("EXTERNAL", "FALSE") + tbl.setTableType(TableType.MANAGED_TABLE) + } // create the table synchronized { @@ -122,29 +173,48 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } } - def tableExists(tableIdentifier: Seq[String]): Boolean = { + def hiveDefaultTableFilePath(tableName: String): String = synchronized { + val currentDatabase = client.getDatabase(hive.sessionState.getCurrentDatabase) + + hiveWarehouse.getTablePath(currentDatabase, tableName).toString + } + + def tableExists(tableIdentifier: Seq[String]): Boolean = synchronized { val tableIdent = processTableIdentifier(tableIdentifier) - val databaseName = tableIdent.lift(tableIdent.size - 2).getOrElse( - hive.sessionState.getCurrentDatabase) + val databaseName = + tableIdent + .lift(tableIdent.size - 2) + .getOrElse(hive.sessionState.getCurrentDatabase) val tblName = tableIdent.last - try { - client.getTable(databaseName, tblName) != null - } catch { - case ie: InvalidTableException => false - } + client.getTable(databaseName, tblName, false) != null } def lookupRelation( tableIdentifier: Seq[String], - alias: Option[String]): LogicalPlan = synchronized { + alias: Option[String]): LogicalPlan = { val tableIdent = processTableIdentifier(tableIdentifier) val databaseName = tableIdent.lift(tableIdent.size - 2).getOrElse( hive.sessionState.getCurrentDatabase) val tblName = tableIdent.last - val table = client.getTable(databaseName, tblName) + val table = try { + synchronized { + client.getTable(databaseName, tblName) + } + } catch { + case te: org.apache.hadoop.hive.ql.metadata.InvalidTableException => + throw new NoSuchTableException + } if (table.getProperty("spark.sql.sources.provider") != null) { - cachedDataSourceTables(QualifiedTableName(databaseName, tblName).toLowerCase) + val dataSourceTable = + cachedDataSourceTables(QualifiedTableName(databaseName, tblName).toLowerCase) + // Then, if alias is specified, wrap the table with a Subquery using the alias. + // Othersie, wrap the table with a Subquery using the table name. + val withAlias = + alias.map(a => Subquery(a, dataSourceTable)).getOrElse( + Subquery(tableIdent.last, dataSourceTable)) + + withAlias } else if (table.isView) { // if the unresolved relation is from hive view // parse the text into logic node. @@ -152,16 +222,111 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } else { val partitions: Seq[Partition] = if (table.isPartitioned) { - HiveShim.getAllPartitionsOf(client, table).toSeq + synchronized { + HiveShim.getAllPartitionsOf(client, table).toSeq + } } else { Nil } - // Since HiveQL is case insensitive for table names we make them all lowercase. - MetastoreRelation( - databaseName, tblName, alias)( - table.getTTable, partitions.map(part => part.getTPartition))(hive) + MetastoreRelation(databaseName, tblName, alias)( + table.getTTable, partitions.map(part => part.getTPartition))(hive) + } + } + + private def convertToParquetRelation(metastoreRelation: MetastoreRelation): LogicalRelation = { + val metastoreSchema = StructType.fromAttributes(metastoreRelation.output) + val mergeSchema = hive.convertMetastoreParquetWithSchemaMerging + + // NOTE: Instead of passing Metastore schema directly to `ParquetRelation2`, we have to + // serialize the Metastore schema to JSON and pass it as a data source option because of the + // evil case insensitivity issue, which is reconciled within `ParquetRelation2`. + val parquetOptions = Map( + ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json, + ParquetRelation2.MERGE_SCHEMA -> mergeSchema.toString) + val tableIdentifier = + QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) + + def getCached( + tableIdentifier: QualifiedTableName, + pathsInMetastore: Seq[String], + schemaInMetastore: StructType, + partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = { + cachedDataSourceTables.getIfPresent(tableIdentifier) match { + case null => None // Cache miss + case logical@LogicalRelation(parquetRelation: ParquetRelation2) => + // If we have the same paths, same schema, and same partition spec, + // we will use the cached Parquet Relation. + val useCached = + parquetRelation.paths.toSet == pathsInMetastore.toSet && + logical.schema.sameType(metastoreSchema) && + parquetRelation.maybePartitionSpec == partitionSpecInMetastore + + if (useCached) { + Some(logical) + } else { + // If the cached relation is not updated, we invalidate it right away. + cachedDataSourceTables.invalidate(tableIdentifier) + None + } + case other => + logWarning( + s"${metastoreRelation.databaseName}.${metastoreRelation.tableName} should be stored " + + s"as Parquet. However, we are getting a ${other} from the metastore cache. " + + s"This cached entry will be invalidated.") + cachedDataSourceTables.invalidate(tableIdentifier) + None + } + } + + val result = if (metastoreRelation.hiveQlTable.isPartitioned) { + val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys) + val partitionColumnDataTypes = partitionSchema.map(_.dataType) + val partitions = metastoreRelation.hiveQlPartitions.map { p => + val location = p.getLocation + val values = Row.fromSeq(p.getValues.zip(partitionColumnDataTypes).map { + case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null) + }) + ParquetPartition(values, location) + } + val partitionSpec = PartitionSpec(partitionSchema, partitions) + val paths = partitions.map(_.path) + + val cached = getCached(tableIdentifier, paths, metastoreSchema, Some(partitionSpec)) + val parquetRelation = cached.getOrElse { + val created = + LogicalRelation(ParquetRelation2(paths, parquetOptions, None, Some(partitionSpec))(hive)) + cachedDataSourceTables.put(tableIdentifier, created) + created + } + + parquetRelation + } else { + val paths = Seq(metastoreRelation.hiveQlTable.getDataLocation.toString) + + val cached = getCached(tableIdentifier, paths, metastoreSchema, None) + val parquetRelation = cached.getOrElse { + val created = + LogicalRelation(ParquetRelation2(paths, parquetOptions)(hive)) + cachedDataSourceTables.put(tableIdentifier, created) + created + } + + parquetRelation } + + result.newInstance() + } + + override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = synchronized { + val dbName = if (!caseSensitive) { + if (databaseName.isDefined) Some(databaseName.get.toLowerCase) else None + } else { + databaseName + } + val db = dbName.getOrElse(hive.sessionState.getCurrentDatabase) + + client.getAllTables(db).map(tableName => (tableName, false)) } /** @@ -191,7 +356,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val hiveSchema: JList[FieldSchema] = if (schema == null || schema.isEmpty) { crtTbl.getCols } else { - schema.map(attr => new FieldSchema(attr.name, toMetastoreType(attr.dataType), "")) + schema.map(attr => new FieldSchema(attr.name, toMetastoreType(attr.dataType), null)) } tbl.setFields(hiveSchema) @@ -222,9 +387,9 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with logInfo(s"Default to LazySimpleSerDe for table $dbName.$tblName") tbl.setSerializationLib(classOf[LazySimpleSerDe].getName()) - import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat import org.apache.hadoop.io.Text + import org.apache.hadoop.mapred.TextInputFormat tbl.setInputFormatClass(classOf[TextInputFormat]) tbl.setOutputFormatClass(classOf[HiveIgnoreKeyTextOutputFormat[Text, Text]]) @@ -265,6 +430,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with if (crtTbl != null && crtTbl.getLineDelim() != null) { tbl.setSerdeParam(serdeConstants.LINE_DELIM, crtTbl.getLineDelim()) } + HiveShim.setTblNullFormat(crtTbl, tbl) if (crtTbl != null && crtTbl.getSerdeProps() != null) { val iter = crtTbl.getSerdeProps().entrySet().iterator() @@ -346,13 +512,87 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } } + /** + * When scanning or writing to non-partitioned Metastore Parquet tables, convert them to Parquet + * data source relations for better performance. + * + * This rule can be considered as [[HiveStrategies.ParquetConversion]] done right. + */ + object ParquetConversions extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!plan.resolved) { + return plan + } + + // Collects all `MetastoreRelation`s which should be replaced + val toBeReplaced = plan.collect { + // Write path + case InsertIntoTable(relation: MetastoreRelation, _, _, _, _) + // Inserting into partitioned table is not supported in Parquet data source (yet). + if !relation.hiveQlTable.isPartitioned && + hive.convertMetastoreParquet && + hive.conf.parquetUseDataSourceApi && + relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => + val parquetRelation = convertToParquetRelation(relation) + val attributedRewrites = relation.output.zip(parquetRelation.output) + (relation, parquetRelation, attributedRewrites) + + // Write path + case InsertIntoHiveTable(relation: MetastoreRelation, _, _, _, _) + // Inserting into partitioned table is not supported in Parquet data source (yet). + if !relation.hiveQlTable.isPartitioned && + hive.convertMetastoreParquet && + hive.conf.parquetUseDataSourceApi && + relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => + val parquetRelation = convertToParquetRelation(relation) + val attributedRewrites = relation.output.zip(parquetRelation.output) + (relation, parquetRelation, attributedRewrites) + + // Read path + case p @ PhysicalOperation(_, _, relation: MetastoreRelation) + if hive.convertMetastoreParquet && + hive.conf.parquetUseDataSourceApi && + relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => + val parquetRelation = convertToParquetRelation(relation) + val attributedRewrites = relation.output.zip(parquetRelation.output) + (relation, parquetRelation, attributedRewrites) + } + + val relationMap = toBeReplaced.map(r => (r._1, r._2)).toMap + val attributedRewrites = AttributeMap(toBeReplaced.map(_._3).fold(Nil)(_ ++: _)) + + // Replaces all `MetastoreRelation`s with corresponding `ParquetRelation2`s, and fixes + // attribute IDs referenced in other nodes. + plan.transformUp { + case r: MetastoreRelation if relationMap.contains(r) => + val parquetRelation = relationMap(r) + val alias = r.alias.getOrElse(r.tableName) + Subquery(alias, parquetRelation) + + case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) + if relationMap.contains(r) => + val parquetRelation = relationMap(r) + InsertIntoTable(parquetRelation, partition, child, overwrite, ifNotExists) + + case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) + if relationMap.contains(r) => + val parquetRelation = relationMap(r) + InsertIntoTable(parquetRelation, partition, child, overwrite, ifNotExists) + + case other => other.transformExpressions { + case a: Attribute if a.resolved => attributedRewrites.getOrElse(a, a) + } + } + } + } + /** * Creates any tables required for query execution. * For example, because of a CREATE TABLE X AS statement. */ object CreateTables extends Rule[LogicalPlan] { import org.apache.hadoop.hive.ql.Context - import org.apache.hadoop.hive.ql.parse.{QB, ASTNode, SemanticAnalyzer} + import org.apache.hadoop.hive.ql.parse.{ASTNode, QB, SemanticAnalyzer} def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Wait until children are resolved. @@ -383,24 +623,69 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Some(sa.getQB().getTableDesc) } - execution.CreateTableAsSelect( - databaseName, - tableName, - child, - allowExisting, - desc) + // Check if the query specifies file format or storage handler. + val hasStorageSpec = desc match { + case Some(crtTbl) => + crtTbl != null && (crtTbl.getSerName != null || crtTbl.getStorageHandler != null) + case None => false + } + + if (hive.convertCTAS && !hasStorageSpec) { + // Do the conversion when spark.sql.hive.convertCTAS is true and the query + // does not specify any storage format (file format and storage handler). + if (dbName.isDefined) { + throw new AnalysisException( + "Cannot specify database name in a CTAS statement " + + "when spark.sql.hive.convertCTAS is set to true.") + } + + val mode = if (allowExisting) SaveMode.Ignore else SaveMode.ErrorIfExists + CreateTableUsingAsSelect( + tblName, + hive.conf.defaultDataSourceName, + temporary = false, + mode, + options = Map.empty[String, String], + child + ) + } else { + execution.CreateTableAsSelect( + databaseName, + tableName, + child, + allowExisting, + desc) + } case p: LogicalPlan if p.resolved => p case p @ CreateTableAsSelect(db, tableName, child, allowExisting, None) => val (dbName, tblName) = processDatabaseAndTableName(db, tableName) - val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase) - execution.CreateTableAsSelect( - databaseName, - tableName, - child, - allowExisting, - None) + if (hive.convertCTAS) { + if (dbName.isDefined) { + throw new AnalysisException( + "Cannot specify database name in a CTAS statement " + + "when spark.sql.hive.convertCTAS is set to true.") + } + + val mode = if (allowExisting) SaveMode.Ignore else SaveMode.ErrorIfExists + CreateTableUsingAsSelect( + tblName, + hive.conf.defaultDataSourceName, + temporary = false, + mode, + options = Map.empty[String, String], + child + ) + } else { + val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase) + execution.CreateTableAsSelect( + databaseName, + tableName, + child, + allowExisting, + None) + } } } @@ -413,11 +698,12 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with // Wait until children are resolved. case p: LogicalPlan if !p.childrenResolved => p - case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) => + case p @ InsertIntoTable(table: MetastoreRelation, _, child, _, _) => castChildOutput(p, table, child) } - def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) = { + def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) + : LogicalPlan = { val childOutputDataTypes = child.output.map(_.dataType) val tableOutputDataTypes = (table.attributes ++ table.partitionKeys).take(child.output.length).map(_.dataType) @@ -426,10 +712,10 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with p } else if (childOutputDataTypes.size == tableOutputDataTypes.size && childOutputDataTypes.zip(tableOutputDataTypes) - .forall { case (left, right) => DataType.equalsIgnoreNullability(left, right) }) { + .forall { case (left, right) => left.sameType(right) }) { // If both types ignoring nullability of ArrayType, MapType, StructType are the same, // use InsertIntoHiveTable instead of InsertIntoTable. - InsertIntoHiveTable(p.table, p.partition, p.child, p.overwrite) + InsertIntoHiveTable(p.table, p.partition, p.child, p.overwrite, p.ifNotExists) } else { // Only do the casting when child output data types differ from table output data types. val castedChildOutput = child.output.zip(table.output).map { @@ -455,7 +741,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with */ override def unregisterTable(tableIdentifier: Seq[String]): Unit = ??? - override def unregisterAllTables() = {} + override def unregisterAllTables(): Unit = {} } /** @@ -467,15 +753,15 @@ private[hive] case class InsertIntoHiveTable( table: LogicalPlan, partition: Map[String, Option[String]], child: LogicalPlan, - overwrite: Boolean) + overwrite: Boolean, + ifNotExists: Boolean) extends LogicalPlan { - override def children = child :: Nil - override def output = child.output + override def children: Seq[LogicalPlan] = child :: Nil + override def output: Seq[Attribute] = child.output - override lazy val resolved = childrenResolved && child.output.zip(table.output).forall { - case (childAttr, tableAttr) => - DataType.equalsIgnoreNullability(childAttr.dataType, tableAttr.dataType) + override lazy val resolved: Boolean = childrenResolved && child.output.zip(table.output).forall { + case (childAttr, tableAttr) => childAttr.dataType.sameType(tableAttr.dataType) } } @@ -483,23 +769,36 @@ private[hive] case class MetastoreRelation (databaseName: String, tableName: String, alias: Option[String]) (val table: TTable, val partitions: Seq[TPartition]) (@transient sqlContext: SQLContext) - extends LeafNode { + extends LeafNode with MultiInstanceRelation { self: Product => + override def equals(other: scala.Any): Boolean = other match { + case relation: MetastoreRelation => + databaseName == relation.databaseName && + tableName == relation.tableName && + alias == relation.alias && + output == relation.output + case _ => false + } + + override def hashCode(): Int = { + Objects.hashCode(databaseName, tableName, alias, output) + } + // TODO: Can we use org.apache.hadoop.hive.ql.metadata.Table as the type of table and // use org.apache.hadoop.hive.ql.metadata.Partition as the type of elements of partitions. // Right now, using org.apache.hadoop.hive.ql.metadata.Table and // org.apache.hadoop.hive.ql.metadata.Partition will cause a NotSerializableException // which indicates the SerDe we used is not Serializable. - @transient val hiveQlTable = new Table(table) + @transient val hiveQlTable: Table = new Table(table) - @transient val hiveQlPartitions = partitions.map { p => + @transient val hiveQlPartitions: Seq[Partition] = partitions.map { p => new Partition(hiveQlTable, p) } - @transient override lazy val statistics = Statistics( + @transient override lazy val statistics: Statistics = Statistics( sizeInBytes = { val totalSize = hiveQlTable.getParameters.get(HiveShim.getStatsSetupConstTotalSize) val rawDataSize = hiveQlTable.getParameters.get(HiveShim.getStatsSetupConstRawDataSize) @@ -519,6 +818,15 @@ private[hive] case class MetastoreRelation } ) + /** Only compare database and tablename, not alias. */ + override def sameResult(plan: LogicalPlan): Boolean = { + plan match { + case mr: MetastoreRelation => + mr.databaseName == databaseName && mr.tableName == tableName + case _ => false + } + } + val tableDesc = HiveShim.getTableDesc( Class.forName( hiveQlTable.getSerializationLib, @@ -534,9 +842,9 @@ private[hive] case class MetastoreRelation ) implicit class SchemaAttribute(f: FieldSchema) { - def toAttribute = AttributeReference( + def toAttribute: AttributeReference = AttributeReference( f.getName, - sqlContext.ddlParser.parseType(f.getType), + HiveMetastoreTypes.toDataType(f.getType), // Since data can be dumped in randomly with no validation, everything is nullable. nullable = true )(qualifiers = Seq(alias.getOrElse(tableName))) @@ -555,14 +863,15 @@ private[hive] case class MetastoreRelation /** An attribute map for determining the ordinal for non-partition columns. */ val columnOrdinals = AttributeMap(attributes.zipWithIndex) + + override def newInstance(): MetastoreRelation = { + MetastoreRelation(databaseName, tableName, alias)(table, partitions)(sqlContext) + } } -object HiveMetastoreTypes { - protected val ddlParser = new DDLParser - def toDataType(metastoreType: String): DataType = synchronized { - ddlParser.parseType(metastoreType) - } +private[hive] object HiveMetastoreTypes { + def toDataType(metastoreType: String): DataType = DataTypeParser(metastoreType) def toMetastoreType(dt: DataType): String = dt match { case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 5e29e57d9358..d46c25f76051 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -19,22 +19,30 @@ package org.apache.spark.sql.hive import java.sql.Date + +import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo} + +import scala.collection.mutable.ArrayBuffer + import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Context import org.apache.hadoop.hive.ql.lib.Node import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.parse._ import org.apache.hadoop.hive.ql.plan.PlanUtils -import org.apache.spark.sql.SparkSQLParser +import org.apache.spark.sql.{AnalysisException, SparkSQLParser} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.execution.ExplainCommand -import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable} +import org.apache.spark.sql.sources.DescribeCommand +import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable, HiveScriptIOSchema} import org.apache.spark.sql.types._ +import org.apache.spark.util.random.RandomSampler /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -46,53 +54,11 @@ import scala.collection.JavaConversions._ */ private[hive] case object NativePlaceholder extends Command -/** - * Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command. - * @param table The table to be described. - * @param isExtended True if "DESCRIBE EXTENDED" is used. Otherwise, false. - * It is effective only when the table is a Hive table. - */ -case class DescribeCommand( - table: LogicalPlan, - isExtended: Boolean) extends Command { - override def output = Seq( - // Column names are based on Hive. - AttributeReference("col_name", StringType, nullable = false)(), - AttributeReference("data_type", StringType, nullable = false)(), - AttributeReference("comment", StringType, nullable = false)()) -} - /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ +// 命名问题,不够直观,应该要体现他是一个转换器的作用 private[hive] object HiveQl { - protected val nativeCommands = Seq( - "TOK_DESCFUNCTION", - "TOK_DESCDATABASE", - "TOK_SHOW_CREATETABLE", - "TOK_SHOWCOLUMNS", - "TOK_SHOW_TABLESTATUS", - "TOK_SHOWDATABASES", - "TOK_SHOWFUNCTIONS", - "TOK_SHOWINDEXES", - "TOK_SHOWINDEXES", - "TOK_SHOWPARTITIONS", - "TOK_SHOWTABLES", - "TOK_SHOW_TBLPROPERTIES", - - "TOK_LOCKTABLE", - "TOK_SHOWLOCKS", - "TOK_UNLOCKTABLE", - - "TOK_CREATEROLE", - "TOK_DROPROLE", - "TOK_GRANT", - "TOK_GRANT_ROLE", - "TOK_REVOKE", - "TOK_SHOW_GRANT", - "TOK_SHOW_ROLE_GRANT", - - "TOK_CREATEFUNCTION", - "TOK_DROPFUNCTION", - + protected val nativeCommands = Seq( // 应该把这个单独抽出来放到一个object里面,便于维护 + "TOK_ALTERDATABASE_OWNER", "TOK_ALTERDATABASE_PROPERTIES", "TOK_ALTERINDEX_PROPERTIES", "TOK_ALTERINDEX_REBUILD", @@ -110,37 +76,77 @@ private[hive] object HiveQl { "TOK_ALTERTABLE_SKEWED", "TOK_ALTERTABLE_TOUCH", "TOK_ALTERTABLE_UNARCHIVE", - "TOK_CREATEDATABASE", - "TOK_CREATEFUNCTION", - "TOK_CREATEINDEX", - "TOK_DROPDATABASE", - "TOK_DROPINDEX", - "TOK_MSCK", - "TOK_ALTERVIEW_ADDPARTS", "TOK_ALTERVIEW_AS", "TOK_ALTERVIEW_DROPPARTS", "TOK_ALTERVIEW_PROPERTIES", "TOK_ALTERVIEW_RENAME", + + "TOK_CREATEDATABASE", + "TOK_CREATEFUNCTION", + "TOK_CREATEINDEX", + "TOK_CREATEROLE", "TOK_CREATEVIEW", + + "TOK_DESCDATABASE", + "TOK_DESCFUNCTION", + + "TOK_DROPDATABASE", + "TOK_DROPFUNCTION", + "TOK_DROPINDEX", + "TOK_DROPROLE", + "TOK_DROPTABLE_PROPERTIES", "TOK_DROPVIEW", - + "TOK_DROPVIEW_PROPERTIES", + "TOK_EXPORT", + + "TOK_GRANT", + "TOK_GRANT_ROLE", + "TOK_IMPORT", + "TOK_LOAD", - - "TOK_SWITCHDATABASE" + + "TOK_LOCKTABLE", + + "TOK_MSCK", + + "TOK_REVOKE", + + "TOK_SHOW_COMPACTIONS", + "TOK_SHOW_CREATETABLE", + "TOK_SHOW_GRANT", + "TOK_SHOW_ROLE_GRANT", + "TOK_SHOW_ROLE_PRINCIPALS", + "TOK_SHOW_ROLES", + "TOK_SHOW_SET_ROLE", + "TOK_SHOW_TABLESTATUS", + "TOK_SHOW_TBLPROPERTIES", + "TOK_SHOW_TRANSACTIONS", + "TOK_SHOWCOLUMNS", + "TOK_SHOWDATABASES", + "TOK_SHOWFUNCTIONS", + "TOK_SHOWINDEXES", + "TOK_SHOWLOCKS", + "TOK_SHOWPARTITIONS", + + "TOK_SWITCHDATABASE", + + "TOK_UNLOCKTABLE" ) // Commands that we do not need to explain. - protected val noExplainCommands = Seq( + protected val noExplainCommands = Seq( // 同上,应该抽出去 "TOK_DESCTABLE", + "TOK_SHOWTABLES", "TOK_TRUNCATETABLE" // truncate table" is a NativeCommand, does not need to explain. ) ++ nativeCommands + // 移到hivecontext 里面去, 放在这里不合适,需要重构 protected val hqlParser = { val fallback = new ExtendedHiveQlParser - new SparkSQLParser(fallback(_)) + new SparkSQLParser(fallback.parse(_)) } /** @@ -152,7 +158,7 @@ private[hive] object HiveQl { * have clean copy semantics. Therefore, users of this class should take care when * copying/modifying trees that might be used elsewhere. */ - implicit class TransformableNode(n: ASTNode) { + implicit class TransformableNode(n: ASTNode) { // 这个比较有意思, 对 astnode 进行transform操作,必须使用隐式转换吗? /** * Returns a copy of this node where `rule` has been recursively applied to it and all of its * children. When `rule` does not apply to a given node it is left unchanged. @@ -201,8 +207,8 @@ private[hive] object HiveQl { * Right now this function only checks the name, type, text and children of the node * for equality. */ - def checkEquals(other: ASTNode) { - def check(field: String, f: ASTNode => Any) = if (f(n) != f(other)) { + def checkEquals(other: ASTNode): Unit = { + def check(field: String, f: ASTNode => Any): Unit = if (f(n) != f(other)) { sys.error(s"$field does not match for trees. " + s"'${f(n)}' != '${f(other)}' left: ${dumpTree(n)}, right: ${dumpTree(other)}") } @@ -214,17 +220,11 @@ private[hive] object HiveQl { val leftChildren = nilIfEmpty(n.getChildren).asInstanceOf[Seq[ASTNode]] val rightChildren = nilIfEmpty(other.getChildren).asInstanceOf[Seq[ASTNode]] leftChildren zip rightChildren foreach { - case (l,r) => l checkEquals r + case (l, r) => l checkEquals r } } } - class ParseException(sql: String, cause: Throwable) - extends Exception(s"Failed to parse: $sql", cause) - - class SemanticException(msg: String) - extends Exception(s"Error in semantic analysis: $msg") - /** * Returns the AST for the given SQL string. */ @@ -242,10 +242,13 @@ private[hive] object HiveQl { /** Returns a LogicalPlan for a given HiveQL string. */ - def parseSql(sql: String): LogicalPlan = hqlParser(sql) + // 重构时删除 + def parseSql(sql: String): LogicalPlan = hqlParser.parse(sql) + + val errorRegEx = "line (\\d+):(\\d+) (.*)".r /** Creates LogicalPlan for a given HiveQL string. */ - def createPlan(sql: String) = { + def createPlan(sql: String): LogicalPlan = { try { val tree = getAst(sql) if (nativeCommands contains tree.getText) { @@ -257,19 +260,28 @@ private[hive] object HiveQl { } } } catch { - case e: Exception => throw new ParseException(sql, e) - case e: NotImplementedError => sys.error( - s""" - |Unsupported language features in query: $sql - |${dumpTree(getAst(sql))} - |$e - |${e.getStackTrace.head} - """.stripMargin) + case pe: org.apache.hadoop.hive.ql.parse.ParseException => + pe.getMessage match { + case errorRegEx(line, start, message) => + throw new AnalysisException(message, Some(line.toInt), Some(start.toInt)) + case otherMessage => + throw new AnalysisException(otherMessage) + } + case e: Exception => + throw new AnalysisException(e.getMessage) + case e: NotImplementedError => + throw new AnalysisException( + s""" + |Unsupported language features in query: $sql + |${dumpTree(getAst(sql))} + |$e + |${e.getStackTrace.head} + """.stripMargin) } } /** Creates LogicalPlan for a given VIEW */ - def createPlanForView(view: Table, alias: Option[String]) = alias match { + def createPlanForView(view: Table, alias: Option[String]): Subquery = alias match { // because hive use things like `_c0` to build the expanded text // currently we cannot support view from "create view v1(c1) as ..." case None => Subquery(view.getTableName, createPlan(view.getViewExpandedText)) @@ -300,6 +312,7 @@ private[hive] object HiveQl { /** @return matches of the form (tokenName, children). */ def unapply(t: Any): Option[(String, Seq[ASTNode])] = t match { case t: ASTNode => + CurrentOrigin.setPosition(t.getLine, t.getCharPositionInLine) Some((t.getText, Option(t.getChildren).map(_.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]])) case _ => None @@ -322,7 +335,7 @@ private[hive] object HiveQl { clauses } - def getClause(clauseName: String, nodeList: Seq[Node]) = + def getClause(clauseName: String, nodeList: Seq[Node]): Node = getClauseOption(clauseName, nodeList).getOrElse(sys.error( s"Expected clause $clauseName missing from ${nodeList.map(dumpTree(_)).mkString("\n")}")) @@ -474,23 +487,21 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // Just fake explain for any of the native commands. case Token("TOK_EXPLAIN", explainArgs) if noExplainCommands.contains(explainArgs.head.getText) => - ExplainCommand(NoRelation, Seq(AttributeReference("plan", StringType, nullable = false)())) + ExplainCommand(OneRowRelation) case Token("TOK_EXPLAIN", explainArgs) if "TOK_CREATETABLE" == explainArgs.head.getText => val Some(crtTbl) :: _ :: extended :: Nil = getClauses(Seq("TOK_CREATETABLE", "FORMATTED", "EXTENDED"), explainArgs) ExplainCommand( nodeToPlan(crtTbl), - Seq(AttributeReference("plan", StringType,nullable = false)()), - extended != None) + extended = extended.isDefined) case Token("TOK_EXPLAIN", explainArgs) => // Ignore FORMATTED if present. val Some(query) :: _ :: extended :: Nil = getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs) ExplainCommand( nodeToPlan(query), - Seq(AttributeReference("plan", StringType, nullable = false)()), - extended != None) + extended = extended.isDefined) case Token("TOK_DESCTABLE", describeArgs) => // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL @@ -509,15 +520,14 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // TODO: Actually, a user may mean tableName.columnName. Need to resolve this issue. val tableIdent = extractTableIdent(nameParts.head) DescribeCommand( - UnresolvedRelation(tableIdent, None), extended.isDefined) + UnresolvedRelation(tableIdent, None), isExtended = extended.isDefined) case Token(".", dbName :: tableName :: colName :: Nil) => // It is describing a column with the format like "describe db.table column". NativePlaceholder case tableName => // It is describing a table with the format like "describe table". DescribeCommand( - UnresolvedRelation(Seq(tableName.getText), None), - extended.isDefined) + UnresolvedRelation(Seq(tableName.getText), None), isExtended = extended.isDefined) } } // All other cases. @@ -552,6 +562,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C "TOK_TBLTEXTFILE", // Stored as TextFile "TOK_TBLRCFILE", // Stored as RCFile "TOK_TBLORCFILE", // Stored as ORC File + "TOK_TBLPARQUETFILE", // Stored as PARQUET "TOK_TABLEFILEFORMAT", // User-provided InputFormat and OutputFormat "TOK_STORAGEHANDLER", // Storage handler "TOK_TABLELOCATION", @@ -568,9 +579,26 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_TRUNCATETABLE", Token("TOK_TABLE_PARTITION",table)::Nil) => NativePlaceholder - case Token("TOK_QUERY", - Token("TOK_FROM", fromClause :: Nil) :: - insertClauses) => + case Token("TOK_QUERY", queryArgs) + if Seq("TOK_FROM", "TOK_INSERT").contains(queryArgs.head.getText) => + + val (fromClause: Option[ASTNode], insertClauses, cteRelations) = + queryArgs match { + case Token("TOK_FROM", args: Seq[ASTNode]) :: insertClauses => + // check if has CTE + insertClauses.last match { + case Token("TOK_CTE", cteClauses) => + val cteRelations = cteClauses.map(node => { + val relation = nodeToRelation(node).asInstanceOf[Subquery] + (relation.alias, relation) + }).toMap + (Some(args.head), insertClauses.init, Some(cteRelations)) + + case _ => (Some(args.head), insertClauses, None) + } + + case Token("TOK_INSERT", _) :: Nil => (None, queryArgs, None) + } // Return one query for each insert clause. val queries = insertClauses.map { case Token("TOK_INSERT", singleInsert) => @@ -611,8 +639,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C "TOK_LATERAL_VIEW"), singleInsert) } - - val relations = nodeToRelation(fromClause) + + val relations = fromClause match { + case Some(f) => nodeToRelation(f) + case None => OneRowRelation + } + val withWhere = whereClause.map { whereNode => val Seq(whereExpr) = whereNode.getChildren.toSeq Filter(nodeToExpr(whereExpr), relations) @@ -627,29 +659,65 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_SELEXPR", Token("TOK_TRANSFORM", Token("TOK_EXPLIST", inputExprs) :: - Token("TOK_SERDE", Nil) :: + Token("TOK_SERDE", inputSerdeClause) :: Token("TOK_RECORDWRITER", writerClause) :: // TODO: Need to support other types of (in/out)put Token(script, Nil) :: - Token("TOK_SERDE", serdeClause) :: + Token("TOK_SERDE", outputSerdeClause) :: Token("TOK_RECORDREADER", readerClause) :: - outputClause :: Nil) :: Nil) => - - val output = outputClause match { - case Token("TOK_ALIASLIST", aliases) => - aliases.map { case Token(name, Nil) => AttributeReference(name, StringType)() } - case Token("TOK_TABCOLLIST", attributes) => - attributes.map { case Token("TOK_TABCOL", Token(name, Nil) :: dataType :: Nil) => - AttributeReference(name, nodeToDataType(dataType))() } + outputClause) :: Nil) => + + val (output, schemaLess) = outputClause match { + case Token("TOK_ALIASLIST", aliases) :: Nil => + (aliases.map { case Token(name, Nil) => AttributeReference(name, StringType)() }, + false) + case Token("TOK_TABCOLLIST", attributes) :: Nil => + (attributes.map { case Token("TOK_TABCOL", Token(name, Nil) :: dataType :: Nil) => + AttributeReference(name, nodeToDataType(dataType))() }, false) + case Nil => + (List(AttributeReference("key", StringType)(), + AttributeReference("value", StringType)()), true) } + + def matchSerDe(clause: Seq[ASTNode]) + : (Seq[(String, String)], String, Seq[(String, String)]) = clause match { + case Token("TOK_SERDEPROPS", propsClause) :: Nil => + val rowFormat = propsClause.map { + case Token(name, Token(value, Nil) :: Nil) => (name, value) + } + (rowFormat, "", Nil) + + case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil => + (Nil, serdeClass, Nil) + + case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: + Token("TOK_TABLEPROPERTIES", + Token("TOK_TABLEPROPLIST", propsClause) :: Nil) :: Nil) :: Nil => + val serdeProps = propsClause.map { + case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) => + (name, value) + } + (Nil, serdeClass, serdeProps) + + case Nil => (Nil, "", Nil) + } + + val (inRowFormat, inSerdeClass, inSerdeProps) = matchSerDe(inputSerdeClause) + val (outRowFormat, outSerdeClass, outSerdeProps) = matchSerDe(outputSerdeClause) + val unescapedScript = BaseSemanticAnalyzer.unescapeSQLString(script) + val schema = HiveScriptIOSchema( + inRowFormat, outRowFormat, + inSerdeClass, outSerdeClass, + inSerdeProps, outSerdeProps, schemaLess) + Some( logical.ScriptTransformation( inputExprs.map(nodeToExpr), unescapedScript, output, - withWhere)) + withWhere, schema)) case _ => None } @@ -660,12 +728,14 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val alias = getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText - Generate( - nodesToGenerator(clauses), - join = true, - outer = false, - Some(alias.toLowerCase), - withWhere) + val (generator, attributes) = nodesToGenerator(clauses) + Generate( + generator, + join = true, + outer = false, + Some(alias.toLowerCase), + attributes.map(UnresolvedAttribute(_)), + withWhere) }.getOrElse(withWhere) // The projection of the query can either be a normal projection, an aggregation @@ -744,7 +814,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } // If there are multiple INSERTS just UNION them together into on query. - queries.reduceLeft(Union) + val query = queries.reduceLeft(Union) + + // return With plan if there is CTE + cteRelations.map(With(query, _)).getOrElse(query) case Token("TOK_UNION", left :: right :: Nil) => Union(nodeToPlan(left), nodeToPlan(right)) @@ -765,12 +838,14 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val alias = getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText - Generate( - nodesToGenerator(clauses), - join = true, - outer = isOuter.nonEmpty, - Some(alias.toLowerCase), - nodeToRelation(relationClause)) + val (generator, attributes) = nodesToGenerator(clauses) + Generate( + generator, + join = true, + outer = isOuter.nonEmpty, + Some(alias.toLowerCase), + attributes.map(UnresolvedAttribute(_)), + nodeToRelation(relationClause)) /* All relations, possibly with aliases or sampling clauses. */ case Token("TOK_TABREF", clauses) => @@ -808,7 +883,15 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_TABLESPLITSAMPLE", Token("TOK_PERCENT", Nil) :: Token(fraction, Nil) :: Nil) => - Sample(fraction.toDouble, withReplacement = false, (math.random * 1000).toInt, relation) + // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling + // function takes X PERCENT as the input and the range of X is [0, 100], we need to + // adjust the fraction. + require( + fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon) + && fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon), + s"Sampling fraction ($fraction) must be on interval [0, 100]") + Sample(fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt, + relation) case Token("TOK_TABLEBUCKETSAMPLE", Token(numerator, Nil) :: Token(denominator, Nil) :: Nil) => @@ -926,21 +1009,48 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C cleanIdentifier(key.toLowerCase) -> None }.toMap).getOrElse(Map.empty) - InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite) + InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, false) + + case Token(destinationToken(), + Token("TOK_TAB", + tableArgs) :: + Token("TOK_IFNOTEXISTS", + ifNotExists) :: Nil) => + val Some(tableNameParts) :: partitionClause :: Nil = + getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) + + val tableIdent = extractTableIdent(tableNameParts) + + val partitionKeys = partitionClause.map(_.getChildren.map { + // Parse partitions. We also make keys case insensitive. + case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => + cleanIdentifier(key.toLowerCase) -> Some(PlanUtils.stripQuotes(value)) + case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) => + cleanIdentifier(key.toLowerCase) -> None + }.toMap).getOrElse(Map.empty) + + InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, true) case a: ASTNode => throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") } protected def selExprNodeToExpr(node: Node): Option[Expression] = node match { - case Token("TOK_SELEXPR", - e :: Nil) => + case Token("TOK_SELEXPR", e :: Nil) => Some(nodeToExpr(e)) - case Token("TOK_SELEXPR", - e :: Token(alias, Nil) :: Nil) => + case Token("TOK_SELEXPR", e :: Token(alias, Nil) :: Nil) => Some(Alias(nodeToExpr(e), cleanIdentifier(alias))()) + case Token("TOK_SELEXPR", e :: aliasChildren) => + var aliasNames = ArrayBuffer[String]() + aliasChildren.foreach { _ match { + case Token(name, Nil) => aliasNames += cleanIdentifier(name) + case _ => + } + } + Some(MultiAlias(nodeToExpr(e), aliasNames)) + /* Hints are ignored */ case Token("TOK_HINTLIST", _) => None @@ -965,6 +1075,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C /* Case insensitive matches */ val ARRAY = "(?i)ARRAY".r + val COALESCE = "(?i)COALESCE".r val COUNT = "(?i)COUNT".r val AVG = "(?i)AVG".r val SUM = "(?i)SUM".r @@ -997,16 +1108,16 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token(".", qualifier :: Token(attr, Nil) :: Nil) => nodeToExpr(qualifier) match { case UnresolvedAttribute(qualifierName) => - UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr)) - case other => GetField(other, attr) + UnresolvedAttribute(qualifierName :+ cleanIdentifier(attr)) + case other => UnresolvedGetField(other, attr) } /* Stars (*) */ - case Token("TOK_ALLCOLREF", Nil) => Star(None) + case Token("TOK_ALLCOLREF", Nil) => UnresolvedStar(None) // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only // has a single child which is tableName. case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", Token(name, Nil) :: Nil) :: Nil) => - Star(Some(name)) + UnresolvedStar(Some(name)) /* Aggregate Functions */ case Token("TOK_FUNCTION", Token(AVG(), Nil) :: arg :: Nil) => Average(nodeToExpr(arg)) @@ -1057,6 +1168,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C Cast(nodeToExpr(arg), DateType) /* Arithmetic */ + case Token("+", child :: Nil) => nodeToExpr(child) case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child)) case Token("~", child :: Nil) => BitwiseNot(nodeToExpr(child)) case Token("+", left :: right:: Nil) => Add(nodeToExpr(left), nodeToExpr(right)) @@ -1137,20 +1249,21 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C CreateArray(children.map(nodeToExpr)) case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: Nil) => - Substring(nodeToExpr(string), nodeToExpr(pos), Literal(Integer.MAX_VALUE, IntegerType)) + Substring(nodeToExpr(string), nodeToExpr(pos), Literal.create(Integer.MAX_VALUE, IntegerType)) case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) => Substring(nodeToExpr(string), nodeToExpr(pos), nodeToExpr(length)) + case Token("TOK_FUNCTION", Token(COALESCE(), Nil) :: list) => Coalesce(list.map(nodeToExpr)) /* UDFs - Must be last otherwise will preempt built in functions */ case Token("TOK_FUNCTION", Token(name, Nil) :: args) => UnresolvedFunction(name, args.map(nodeToExpr)) case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) => - UnresolvedFunction(name, Star(None) :: Nil) + UnresolvedFunction(name, UnresolvedStar(None) :: Nil) /* Literals */ - case Token("TOK_NULL", Nil) => Literal(null, NullType) - case Token(TRUE(), Nil) => Literal(true, BooleanType) - case Token(FALSE(), Nil) => Literal(false, BooleanType) + case Token("TOK_NULL", Nil) => Literal.create(null, NullType) + case Token(TRUE(), Nil) => Literal.create(true, BooleanType) + case Token(FALSE(), Nil) => Literal.create(false, BooleanType) case Token("TOK_STRINGLITERALSEQUENCE", strings) => Literal(strings.map(s => BaseSemanticAnalyzer.unescapeSQLString(s.getText)).mkString) @@ -1161,21 +1274,21 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C try { if (ast.getText.endsWith("L")) { // Literal bigint. - v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toLong, LongType) + v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toLong, LongType) } else if (ast.getText.endsWith("S")) { // Literal smallint. - v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toShort, ShortType) + v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toShort, ShortType) } else if (ast.getText.endsWith("Y")) { // Literal tinyint. - v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toByte, ByteType) + v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toByte, ByteType) } else if (ast.getText.endsWith("BD") || ast.getText.endsWith("D")) { // Literal decimal val strVal = ast.getText.stripSuffix("D").stripSuffix("B") v = Literal(Decimal(strVal)) } else { - v = Literal(ast.getText.toDouble, DoubleType) - v = Literal(ast.getText.toLong, LongType) - v = Literal(ast.getText.toInt, IntegerType) + v = Literal.create(ast.getText.toDouble, DoubleType) + v = Literal.create(ast.getText.toLong, LongType) + v = Literal.create(ast.getText.toInt, IntegerType) } } catch { case nfe: NumberFormatException => // Do nothing @@ -1193,6 +1306,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case ast: ASTNode if ast.getType == HiveParser.TOK_DATELITERAL => Literal(Date.valueOf(ast.getText.substring(1, ast.getText.length - 1))) + case ast: ASTNode if ast.getType == HiveParser.TOK_CHARSETLITERAL => + Literal(BaseSemanticAnalyzer.charSetString(ast.getChild(0).getText, ast.getChild(1).getText)) + case a: ASTNode => throw new NotImplementedError( s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText} : @@ -1202,7 +1318,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val explode = "(?i)explode".r - def nodesToGenerator(nodes: Seq[Node]): Generator = { + def nodesToGenerator(nodes: Seq[Node]): (Generator, Seq[String]) = { val function = nodes.head val attributes = nodes.flatMap { @@ -1212,13 +1328,17 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C function match { case Token("TOK_FUNCTION", Token(explode(), Nil) :: child :: Nil) => - Explode(attributes, nodeToExpr(child)) + (Explode(nodeToExpr(child)), attributes) case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => - HiveGenericUdtf( - new HiveFunctionWrapper(functionName), - attributes, - children.map(nodeToExpr)) + val functionInfo: FunctionInfo = + Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse( + sys.error(s"Couldn't find function $functionName")) + val functionClassName = functionInfo.getFunctionClass.getName + + (HiveGenericUdtf( + new HiveFunctionWrapper(functionClassName), + children.map(nodeToExpr)), attributes) case a: ASTNode => throw new NotImplementedError( @@ -1231,7 +1351,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C def dumpTree(node: Node, builder: StringBuilder = new StringBuilder, indent: Int = 0) : StringBuilder = { node match { - case a: ASTNode => builder.append((" " * indent) + a.getText + "\n") + case a: ASTNode => builder.append( + (" " * indent) + a.getText + " " + + a.getLine + ", " + + a.getTokenStartIndex + "," + + a.getTokenStopIndex + ", " + + a.getCharPositionInLine + "\n") case other => sys.error(s"Non ASTNode encountered: $other") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 6952b126cf89..be9249a8b1f4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -20,19 +20,18 @@ package org.apache.spark.sql.hive import scala.collection.JavaConversions._ import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.{SQLContext, SchemaRDD, Strategy} +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate +import org.apache.spark.sql.catalyst.expressions.{Row, _} import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.hive +import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand, _} import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.sources.CreateTableUsing +import org.apache.spark.sql.sources.{CreateTableUsing, CreateTableUsingAsSelect, DescribeCommand} import org.apache.spark.sql.types.StringType @@ -55,16 +54,15 @@ private[hive] trait HiveStrategies { */ @Experimental object ParquetConversion extends Strategy { - implicit class LogicalPlanHacks(s: SchemaRDD) { - def lowerCase = - new SchemaRDD(s.sqlContext, s.logicalPlan) + implicit class LogicalPlanHacks(s: DataFrame) { + def lowerCase: DataFrame = DataFrame(s.sqlContext, s.logicalPlan) - def addPartitioningAttributes(attrs: Seq[Attribute]) = { + def addPartitioningAttributes(attrs: Seq[Attribute]): DataFrame = { // Don't add the partitioning key if its already present in the data. if (attrs.map(_.name).toSet.subsetOf(s.logicalPlan.output.map(_.name).toSet)) { s } else { - new SchemaRDD( + DataFrame( s.sqlContext, s.logicalPlan transform { case p: ParquetRelation => p.copy(partitioningAttributes = attrs) @@ -74,7 +72,7 @@ private[hive] trait HiveStrategies { } implicit class PhysicalPlanHacks(originalPlan: SparkPlan) { - def fakeOutput(newOutput: Seq[Attribute]) = + def fakeOutput(newOutput: Seq[Attribute]): OutputFaker = OutputFaker( originalPlan.output.map(a => newOutput.find(a.name.toLowerCase == _.name.toLowerCase) @@ -86,7 +84,8 @@ private[hive] trait HiveStrategies { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projectList, predicates, relation: MetastoreRelation) if relation.tableDesc.getSerdeClassName.contains("Parquet") && - hiveContext.convertMetastoreParquet => + hiveContext.convertMetastoreParquet && + !hiveContext.conf.parquetUseDataSourceApi => // Filter out all predicates that only deal with partition keys val partitionsKeys = AttributeSet(relation.partitionKeys) @@ -97,13 +96,13 @@ private[hive] trait HiveStrategies { // We are going to throw the predicates and projection back at the whole optimization // sequence so lets unresolve all the attributes, allowing them to be rebound to the // matching parquet attributes. - val unresolvedOtherPredicates = otherPredicates.map(_ transform { + val unresolvedOtherPredicates = Column(otherPredicates.map(_ transform { case a: AttributeReference => UnresolvedAttribute(a.name) - }).reduceOption(And).getOrElse(Literal(true)) + }).reduceOption(And).getOrElse(Literal(true))) - val unresolvedProjection = projectList.map(_ transform { + val unresolvedProjection: Seq[Column] = projectList.map(_ transform { case a: AttributeReference => UnresolvedAttribute(a.name) - }) + }).map(Column(_)) try { if (relation.hiveQlTable.isPartitioned) { @@ -120,30 +119,37 @@ private[hive] trait HiveStrategies { val inputData = new GenericMutableRow(relation.partitionKeys.size) val pruningCondition = if (codegenEnabled) { - GeneratePredicate(castedPredicate) + GeneratePredicate.generate(castedPredicate) } else { - InterpretedPredicate(castedPredicate) + InterpretedPredicate.create(castedPredicate) } val partitions = relation.hiveQlPartitions.filter { part => val partitionValues = part.getValues var i = 0 while (i < partitionValues.size()) { - inputData(i) = partitionValues(i) + inputData(i) = CatalystTypeConverters.convertToCatalyst(partitionValues(i)) i += 1 } pruningCondition(inputData) } - hiveContext - .parquetFile(partitions.map(_.getLocation).mkString(",")) - .addPartitioningAttributes(relation.partitionKeys) - .lowerCase - .where(unresolvedOtherPredicates) - .select(unresolvedProjection: _*) - .queryExecution - .executedPlan - .fakeOutput(projectList.map(_.toAttribute)) :: Nil + val partitionLocations = partitions.map(_.getLocation) + + if (partitionLocations.isEmpty) { + PhysicalRDD(plan.output, sparkContext.emptyRDD[Row]) :: Nil + } else { + hiveContext + .parquetFile(partitionLocations: _*) + .addPartitioningAttributes(relation.partitionKeys) + .lowerCase + .where(unresolvedOtherPredicates) + .select(unresolvedProjection: _*) + .queryExecution + .executedPlan + .fakeOutput(projectList.map(_.toAttribute)) :: Nil + } + } else { hiveContext .parquetFile(relation.hiveQlTable.getDataLocation.toString) @@ -167,20 +173,22 @@ private[hive] trait HiveStrategies { object Scripts extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.ScriptTransformation(input, script, output, child) => - ScriptTransformation(input, script, output, planLater(child))(hiveContext) :: Nil + case logical.ScriptTransformation(input, script, output, child, schema: HiveScriptIOSchema) => + ScriptTransformation(input, script, output, planLater(child), schema)(hiveContext) :: Nil case _ => Nil } } object DataSinks extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) => + case logical.InsertIntoTable( + table: MetastoreRelation, partition, child, overwrite, ifNotExists) => execution.InsertIntoHiveTable( - table, partition, planLater(child), overwrite) :: Nil - case hive.InsertIntoHiveTable(table: MetastoreRelation, partition, child, overwrite) => + table, partition, planLater(child), overwrite, ifNotExists) :: Nil + case hive.InsertIntoHiveTable( + table: MetastoreRelation, partition, child, overwrite, ifNotExists) => execution.InsertIntoHiveTable( - table, partition, planLater(child), overwrite) :: Nil + table, partition, planLater(child), overwrite, ifNotExists) :: Nil case _ => Nil } } @@ -212,9 +220,16 @@ private[hive] trait HiveStrategies { object HiveDDLStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case CreateTableUsing(tableName, userSpecifiedSchema, provider, false, options) => + case CreateTableUsing( + tableName, userSpecifiedSchema, provider, false, opts, allowExisting, managedIfNoPath) => ExecutedCommand( - CreateMetastoreDataSource(tableName, userSpecifiedSchema, provider, options)) :: Nil + CreateMetastoreDataSource( + tableName, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath)) :: Nil + + case CreateTableUsingAsSelect(tableName, provider, false, mode, opts, query) => + val cmd = + CreateMetastoreDataSourceAsSelect(tableName, provider, mode, opts, query) + ExecutedCommand(cmd) :: Nil case _ => Nil } @@ -228,8 +243,11 @@ private[hive] trait HiveStrategies { case t: MetastoreRelation => ExecutedCommand( DescribeHiveTableCommand(t, describe.output, describe.isExtended)) :: Nil + case o: LogicalPlan => - ExecutedCommand(RunnableDescribeCommand(planLater(o), describe.output)) :: Nil + val resultPlan = context.executePlan(o).executedPlan + ExecutedCommand(RunnableDescribeCommand( + resultPlan, describe.output, describe.isExtended)) :: Nil } case _ => Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index c368715f7c6f..e556c74ffb01 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} import org.apache.hadoop.hive.serde2.Deserializer -import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, StructObjectInspector} import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} @@ -34,6 +34,8 @@ import org.apache.spark.SerializableWritable import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.DateUtils +import org.apache.spark.util.Utils /** * A trait for subclasses that handle table scans. @@ -75,7 +77,9 @@ class HadoopTableReader( override def makeRDDForTable(hiveTable: HiveTable): RDD[Row] = makeRDDForTable( hiveTable, - relation.tableDesc.getDeserializerClass.asInstanceOf[Class[Deserializer]], + Class.forName( + relation.tableDesc.getSerdeClassName, true, sc.sessionState.getConf.getClassLoader) + .asInstanceOf[Class[Deserializer]], filterOpt = None) /** @@ -115,7 +119,7 @@ class HadoopTableReader( val hconf = broadcastedHiveConf.value.value val deserializer = deserializerClass.newInstance() deserializer.initialize(hconf, tableDesc.getProperties) - HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow) + HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow, deserializer) } deserializedHadoopRDD @@ -141,7 +145,46 @@ class HadoopTableReader( partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]], filterOpt: Option[PathFilter]): RDD[Row] = { - val hivePartitionRDDs = partitionToDeserializer.map { case (partition, partDeserializer) => + + // SPARK-5068:get FileStatus and do the filtering locally when the path is not exists + def verifyPartitionPath( + partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]]): + Map[HivePartition, Class[_ <: Deserializer]] = { + if (!sc.conf.verifyPartitionPath) { + partitionToDeserializer + } else { + var existPathSet = collection.mutable.Set[String]() + var pathPatternSet = collection.mutable.Set[String]() + partitionToDeserializer.filter { + case (partition, partDeserializer) => + def updateExistPathSetByPathPattern(pathPatternStr: String) { + val pathPattern = new Path(pathPatternStr) + val fs = pathPattern.getFileSystem(sc.hiveconf) + val matches = fs.globStatus(pathPattern) + matches.foreach(fileStatus => existPathSet += fileStatus.getPath.toString) + } + // convert /demo/data/year/month/day to /demo/data/*/*/*/ + def getPathPatternByPath(parNum: Int, tempPath: Path): String = { + var path = tempPath + for (i <- (1 to parNum)) path = path.getParent + val tails = (1 to parNum).map(_ => "*").mkString("/", "/", "/") + path.toString + tails + } + + val partPath = HiveShim.getDataLocationPath(partition) + val partNum = Utilities.getPartitionDesc(partition).getPartSpec.size(); + var pathPatternStr = getPathPatternByPath(partNum, partPath) + if (!pathPatternSet.contains(pathPatternStr)) { + pathPatternSet += pathPatternStr + updateExistPathSetByPathPattern(pathPatternStr) + } + existPathSet.contains(partPath.toString) + } + } + } + + val hivePartitionRDDs = verifyPartitionPath(partitionToDeserializer) + .map { case (partition, partDeserializer) => val partDesc = Utilities.getPartitionDesc(partition) val partPath = HiveShim.getDataLocationPath(partition) val inputPathStr = applyFilterIfNeeded(partPath, filterOpt) @@ -174,7 +217,7 @@ class HadoopTableReader( relation.partitionKeys.contains(attr) } - def fillPartitionKeys(rawPartValues: Array[String], row: MutableRow) = { + def fillPartitionKeys(rawPartValues: Array[String], row: MutableRow): Unit = { partitionKeyAttrs.foreach { case (attr, ordinal) => val partOrdinal = relation.partitionKeys.indexOf(attr) row(ordinal) = Cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null) @@ -188,9 +231,13 @@ class HadoopTableReader( val hconf = broadcastedHiveConf.value.value val deserializer = localDeserializer.newInstance() deserializer.initialize(hconf, partProps) + // get the table deserializer + val tableSerDe = tableDesc.getDeserializerClass.newInstance() + tableSerDe.initialize(hconf, tableDesc.getProperties) // fill the non partition key attributes - HadoopTableReader.fillObject(iter, deserializer, nonPartitionKeyAttrs, mutableRow) + HadoopTableReader.fillObject(iter, deserializer, nonPartitionKeyAttrs, + mutableRow, tableSerDe) } }.toSeq @@ -247,7 +294,7 @@ private[hive] object HadoopTableReader extends HiveInspectors { * instantiate a HadoopRDD. */ def initializeLocalJobConfFunc(path: String, tableDesc: TableDesc)(jobConf: JobConf) { - FileInputFormat.setInputPaths(jobConf, path) + FileInputFormat.setInputPaths(jobConf, Seq[Path](new Path(path)): _*) if (tableDesc != null) { PlanUtils.configureInputJobPropertiesForStorageHandler(tableDesc) Utilities.copyTableJobPropertiesToConf(tableDesc, jobConf) @@ -260,25 +307,36 @@ private[hive] object HadoopTableReader extends HiveInspectors { * Transform all given raw `Writable`s into `Row`s. * * @param iterator Iterator of all `Writable`s to be transformed - * @param deserializer The `Deserializer` associated with the input `Writable` + * @param rawDeser The `Deserializer` associated with the input `Writable` * @param nonPartitionKeyAttrs Attributes that should be filled together with their corresponding * positions in the output schema * @param mutableRow A reusable `MutableRow` that should be filled + * @param tableDeser Table Deserializer * @return An `Iterator[Row]` transformed from `iterator` */ def fillObject( iterator: Iterator[Writable], - deserializer: Deserializer, + rawDeser: Deserializer, nonPartitionKeyAttrs: Seq[(Attribute, Int)], - mutableRow: MutableRow): Iterator[Row] = { + mutableRow: MutableRow, + tableDeser: Deserializer): Iterator[Row] = { + + val soi = if (rawDeser.getObjectInspector.equals(tableDeser.getObjectInspector)) { + rawDeser.getObjectInspector.asInstanceOf[StructObjectInspector] + } else { + HiveShim.getConvertedOI( + rawDeser.getObjectInspector, + tableDeser.getObjectInspector).asInstanceOf[StructObjectInspector] + } - val soi = deserializer.getObjectInspector().asInstanceOf[StructObjectInspector] val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { case (attr, ordinal) => soi.getStructFieldRef(attr.name) -> ordinal }.unzip - // Builds specific unwrappers ahead of time according to object inspector types to avoid pattern - // matching and branching costs per row. + /** + * Builds specific unwrappers ahead of time according to object inspector + * types to avoid pattern matching and branching costs per row. + */ val unwrappers: Seq[(Any, MutableRow, Int) => Unit] = fieldRefs.map { _.getFieldObjectInspector match { case oi: BooleanObjectInspector => @@ -306,7 +364,7 @@ private[hive] object HadoopTableReader extends HiveInspectors { row.update(ordinal, oi.getPrimitiveJavaObject(value).clone()) case oi: DateObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => - row.update(ordinal, oi.getPrimitiveJavaObject(value)) + row.update(ordinal, DateUtils.fromJavaDate(oi.getPrimitiveJavaObject(value))) case oi: BinaryObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, oi.getPrimitiveJavaObject(value)) @@ -315,9 +373,11 @@ private[hive] object HadoopTableReader extends HiveInspectors { } } + val converter = ObjectInspectorConverters.getConverter(rawDeser.getObjectInspector, soi) + // Map each tuple to a row object iterator.map { value => - val raw = deserializer.deserialize(value) + val raw = converter.convert(rawDeser.deserialize(value)) var i = 0 while (i < fieldRefs.length) { val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index a547babcebff..76a1965f3cb2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.MetastoreRelation /** - * :: Experimental :: * Create table and insert the query result into it. * @param database the database name of the new relation * @param tableName the table name of the new relation @@ -38,7 +37,7 @@ import org.apache.spark.sql.hive.MetastoreRelation * @param desc the CreateTableDesc, which may contains serde, storage handler etc. */ -@Experimental +private[hive] case class CreateTableAsSelect( database: String, tableName: String, @@ -46,7 +45,7 @@ case class CreateTableAsSelect( allowExisting: Boolean, desc: Option[CreateTableDesc]) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] lazy val metastoreRelation: MetastoreRelation = { // Create Hive Table @@ -68,7 +67,7 @@ case class CreateTableAsSelect( new org.apache.hadoop.hive.metastore.api.AlreadyExistsException(s"$database.$tableName") } } else { - hiveContext.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true)).toRdd + hiveContext.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true, false)).toRdd } Seq.empty[Row] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala index bfacc51ef57a..6fce69b58b85 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala @@ -30,16 +30,14 @@ import org.apache.spark.sql.SQLContext /** * Implementation for "describe [extended] table". - * - * :: DeveloperApi :: */ -@DeveloperApi +private[hive] case class DescribeHiveTableCommand( table: MetastoreRelation, override val output: Seq[Attribute], isExtended: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { // Trying to mimic the format of Hive's output. But not exactly the same. var results: Seq[(String, String, String)] = Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala index 781a2e9164c8..60a9bb630d0d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala @@ -17,22 +17,18 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row} import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types.StringType -/** - * :: DeveloperApi :: - */ -@DeveloperApi +private[hive] case class HiveNativeCommand(sql: String) extends RunnableCommand { - override def output = + override def output: Seq[AttributeReference] = Seq(AttributeReference("result", StringType, nullable = false)()) - override def run(sqlContext: SQLContext) = + override def run(sqlContext: SQLContext): Seq[Row] = sqlContext.asInstanceOf[HiveContext].runSqlHive(sql).map(Row(_)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index b56175fe7637..0a5f19eee710 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -26,21 +26,20 @@ import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive._ import org.apache.spark.sql.types.{BooleanType, DataType} /** - * :: DeveloperApi :: * The Hive table scan operator. Column and partition pruning are both handled. * * @param requestedAttributes Attributes to be fetched from the Hive table. * @param relation The Hive table be be scanned. * @param partitionPruningPred An optional partition pruning predicate for partitioned table. */ -@DeveloperApi +private[hive] case class HiveTableScan( requestedAttributes: Seq[Attribute], relation: MetastoreRelation, @@ -130,11 +129,11 @@ case class HiveTableScan( } } - override def execute() = if (!relation.hiveQlTable.isPartitioned) { + override def execute(): RDD[Row] = if (!relation.hiveQlTable.isPartitioned) { hadoopReader.makeRDDForTable(relation.hiveQlTable) } else { hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions)) } - override def output = attributes + override def output: Seq[Attribute] = attributes } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 42bc8a0b6793..89995a91b1a9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -32,29 +32,26 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.Object import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.{ ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.{SerializableWritable, SparkException, TaskContext} -/** - * :: DeveloperApi :: - */ -@DeveloperApi +private[hive] case class InsertIntoHiveTable( table: MetastoreRelation, partition: Map[String, Option[String]], child: SparkPlan, - overwrite: Boolean) extends UnaryNode with HiveInspectors { + overwrite: Boolean, + ifNotExists: Boolean) extends UnaryNode with HiveInspectors { @transient val sc: HiveContext = sqlContext.asInstanceOf[HiveContext] @transient lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass @transient private lazy val hiveContext = new Context(sc.hiveconf) - @transient private lazy val db = Hive.get(sc.hiveconf) + @transient private lazy val catalog = sc.catalog private def newSerializer(tableDesc: TableDesc): Serializer = { val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer] @@ -62,7 +59,7 @@ case class InsertIntoHiveTable( serializer } - def output = child.output + def output: Seq[Attribute] = child.output def saveAsHiveFile( rdd: RDD[Row], @@ -76,7 +73,6 @@ case class InsertIntoHiveTable( val outputFileFormatClassName = fileSinkConf.getTableInfo.getOutputFileFormatClassName assert(outputFileFormatClassName != null, "Output format class not set") conf.value.set("mapred.output.format.class", outputFileFormatClassName) - conf.value.setOutputCommitter(classOf[FileOutputCommitter]) FileOutputFormat.setOutputPath( conf.value, @@ -204,42 +200,59 @@ case class InsertIntoHiveTable( orderedPartitionSpec.put(entry.getName,partitionSpec.get(entry.getName).getOrElse("")) } val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols, partitionSpec) - db.validatePartitionNameCharacters(partVals) + catalog.synchronized { + catalog.client.validatePartitionNameCharacters(partVals) + } // inheritTableSpecs is set to true. It should be set to false for a IMPORT query // which is currently considered as a Hive native command. val inheritTableSpecs = true // TODO: Correctly set isSkewedStoreAsSubdir. val isSkewedStoreAsSubdir = false if (numDynamicPartitions > 0) { - db.loadDynamicPartitions( - outputPath, - qualifiedTableName, - orderedPartitionSpec, - overwrite, - numDynamicPartitions, - holdDDLTime, - isSkewedStoreAsSubdir - ) + catalog.synchronized { + catalog.client.loadDynamicPartitions( + outputPath, + qualifiedTableName, + orderedPartitionSpec, + overwrite, + numDynamicPartitions, + holdDDLTime, + isSkewedStoreAsSubdir) + } } else { - db.loadPartition( + // scalastyle:off + // ifNotExists is only valid with static partition, refer to + // https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DML#LanguageManualDML-InsertingdataintoHiveTablesfromqueries + // scalastyle:on + val oldPart = catalog.synchronized { + catalog.client.getPartition( + catalog.client.getTable(qualifiedTableName), partitionSpec, false) + } + if (oldPart == null || !ifNotExists) { + catalog.synchronized { + catalog.client.loadPartition( + outputPath, + qualifiedTableName, + orderedPartitionSpec, + overwrite, + holdDDLTime, + inheritTableSpecs, + isSkewedStoreAsSubdir) + } + } + } + } else { + catalog.synchronized { + catalog.client.loadTable( outputPath, qualifiedTableName, - orderedPartitionSpec, overwrite, - holdDDLTime, - inheritTableSpecs, - isSkewedStoreAsSubdir) + holdDDLTime) } - } else { - db.loadTable( - outputPath, - qualifiedTableName, - overwrite, - holdDDLTime) } // Invalidate the cache. - sqlContext.invalidateCache(table) + sqlContext.cacheManager.invalidateCache(table) // It would be nice to just return the childRdd unchanged so insert operations could be chained, // however for now we return an empty list to simplify compatibility checks with hive, which diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 0c8f676e9c5c..3eddda3b28c6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -17,35 +17,44 @@ package org.apache.spark.sql.hive.execution -import java.io.{BufferedReader, InputStreamReader} +import java.io.{BufferedReader, DataInputStream, DataOutputStream, EOFException, InputStreamReader} +import java.util.Properties -import org.apache.spark.annotation.DeveloperApi +import scala.collection.JavaConversions._ + +import org.apache.hadoop.hive.serde.serdeConstants +import org.apache.hadoop.hive.serde2.AbstractSerDe +import org.apache.hadoop.hive.serde2.objectinspector._ + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.execution._ -import org.apache.spark.sql.hive.HiveContext - -/* Implicit conversions */ -import scala.collection.JavaConversions._ +import org.apache.spark.sql.hive.HiveShim._ +import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} +import org.apache.spark.sql.types.DataType +import org.apache.spark.util.Utils /** - * :: DeveloperApi :: * Transforms the input by forking and running the specified script. * * @param input the set of expression that should be passed to the script. * @param script the command that should be executed. * @param output the attributes that are produced by the script. */ -@DeveloperApi +private[hive] case class ScriptTransformation( input: Seq[Expression], script: String, output: Seq[Attribute], - child: SparkPlan)(@transient sc: HiveContext) + child: SparkPlan, + ioschema: HiveScriptIOSchema)(@transient sc: HiveContext) extends UnaryNode { - override def otherCopyArgs = sc :: Nil + override def otherCopyArgs: Seq[HiveContext] = sc :: Nil - def execute() = { + def execute(): RDD[Row] = { child.execute().mapPartitions { iter => val cmd = List("/bin/bash", "-c", script) val builder = new ProcessBuilder(cmd) @@ -53,28 +62,211 @@ case class ScriptTransformation( val inputStream = proc.getInputStream val outputStream = proc.getOutputStream val reader = new BufferedReader(new InputStreamReader(inputStream)) + + val (outputSerde, outputSoi) = ioschema.initOutputSerDe(output) + + val iterator: Iterator[Row] = new Iterator[Row] with HiveInspectors { + var cacheRow: Row = null + var curLine: String = null + var eof: Boolean = false - // TODO: This should be exposed as an iterator instead of reading in all the data at once. - val outputLines = collection.mutable.ArrayBuffer[Row]() - val readerThread = new Thread("Transform OutputReader") { - override def run() { - var curLine = reader.readLine() - while (curLine != null) { - // TODO: Use SerDe - outputLines += new GenericRow(curLine.split("\t").asInstanceOf[Array[Any]]) + override def hasNext: Boolean = { + if (outputSerde == null) { + if (curLine == null) { + curLine = reader.readLine() + curLine != null + } else { + true + } + } else { + !eof + } + } + + def deserialize(): Row = { + if (cacheRow != null) return cacheRow + + val mutableRow = new SpecificMutableRow(output.map(_.dataType)) + try { + val dataInputStream = new DataInputStream(inputStream) + val writable = outputSerde.getSerializedClass().newInstance + writable.readFields(dataInputStream) + + val raw = outputSerde.deserialize(writable) + val dataList = outputSoi.getStructFieldsDataAsList(raw) + val fieldList = outputSoi.getAllStructFieldRefs() + + var i = 0 + dataList.foreach( element => { + if (element == null) { + mutableRow.setNullAt(i) + } else { + mutableRow(i) = unwrap(element, fieldList(i).getFieldObjectInspector) + } + i += 1 + }) + return mutableRow + } catch { + case e: EOFException => + eof = true + return null + } + } + + override def next(): Row = { + if (!hasNext) { + throw new NoSuchElementException + } + + if (outputSerde == null) { + val prevLine = curLine curLine = reader.readLine() + if (!ioschema.schemaLess) { + new GenericRow(CatalystTypeConverters.convertToCatalyst( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"))) + .asInstanceOf[Array[Any]]) + } else { + new GenericRow(CatalystTypeConverters.convertToCatalyst( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2)) + .asInstanceOf[Array[Any]]) + } + } else { + val ret = deserialize() + if (!eof) { + cacheRow = null + cacheRow = deserialize() + } + ret } } } - readerThread.start() + + val (inputSerde, inputSoi) = ioschema.initInputSerDe(input) + val dataOutputStream = new DataOutputStream(outputStream) val outputProjection = new InterpretedProjection(input, child.output) - iter - .map(outputProjection) - // TODO: Use SerDe - .map(_.mkString("", "\t", "\n").getBytes("utf-8")).foreach(outputStream.write) - outputStream.close() - readerThread.join() - outputLines.toIterator + + // Put the write(output to the pipeline) into a single thread + // and keep the collector as remain in the main thread. + // otherwise it will causes deadlock if the data size greater than + // the pipeline / buffer capacity. + new Thread(new Runnable() { + override def run(): Unit = { + iter + .map(outputProjection) + .foreach { row => + if (inputSerde == null) { + val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), + ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") + + outputStream.write(data) + } else { + val writable = inputSerde.serialize(row.asInstanceOf[GenericRow].values, inputSoi) + prepareWritable(writable).write(dataOutputStream) + } + } + outputStream.close() + } + }).start() + + iterator + } + } +} + +/** + * The wrapper class of Hive input and output schema properties + */ +private[hive] +case class HiveScriptIOSchema ( + inputRowFormat: Seq[(String, String)], + outputRowFormat: Seq[(String, String)], + inputSerdeClass: String, + outputSerdeClass: String, + inputSerdeProps: Seq[(String, String)], + outputSerdeProps: Seq[(String, String)], + schemaLess: Boolean) extends ScriptInputOutputSchema with HiveInspectors { + + val defaultFormat = Map(("TOK_TABLEROWFORMATFIELD", "\t"), + ("TOK_TABLEROWFORMATLINES", "\n")) + + val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k)) + val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k)) + + + def initInputSerDe(input: Seq[Expression]): (AbstractSerDe, ObjectInspector) = { + val (columns, columnTypes) = parseAttrs(input) + val serde = initSerDe(inputSerdeClass, columns, columnTypes, inputSerdeProps) + (serde, initInputSoi(serde, columns, columnTypes)) + } + + def initOutputSerDe(output: Seq[Attribute]): (AbstractSerDe, StructObjectInspector) = { + val (columns, columnTypes) = parseAttrs(output) + val serde = initSerDe(outputSerdeClass, columns, columnTypes, outputSerdeProps) + (serde, initOutputputSoi(serde)) + } + + def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = { + + val columns = attrs.map { + case aref: AttributeReference => aref.name + case e: NamedExpression => e.name + case _ => null + } + + val columnTypes = attrs.map { + case aref: AttributeReference => aref.dataType + case e: NamedExpression => e.dataType + case _ => null + } + + (columns, columnTypes) + } + + def initSerDe(serdeClassName: String, columns: Seq[String], + columnTypes: Seq[DataType], serdeProps: Seq[(String, String)]): AbstractSerDe = { + + val serde: AbstractSerDe = if (serdeClassName != "") { + val trimed_class = serdeClassName.split("'")(1) + Utils.classForName(trimed_class) + .newInstance.asInstanceOf[AbstractSerDe] + } else { + null + } + + if (serde != null) { + val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",") + + var propsMap = serdeProps.map(kv => { + (kv._1.split("'")(1), kv._2.split("'")(1)) + }).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) + propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) + + val properties = new Properties() + properties.putAll(propsMap) + serde.initialize(null, properties) + } + + serde + } + + def initInputSoi(inputSerde: AbstractSerDe, columns: Seq[String], columnTypes: Seq[DataType]) + : ObjectInspector = { + + if (inputSerde != null) { + val fieldObjectInspectors = columnTypes.map(toInspector(_)) + ObjectInspectorFactory + .getStandardStructObjectInspector(columns, fieldObjectInspectors) + .asInstanceOf[ObjectInspector] + } else { + null + } + } + + def initOutputputSoi(outputSerde: AbstractSerDe): StructObjectInspector = { + if (outputSerde != null) { + outputSerde.getObjectInspector().asInstanceOf[StructObjectInspector] + } else { + null } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 91f9da35abee..a40a1e53117c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -17,51 +17,54 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ /** - * :: DeveloperApi :: * Analyzes the given table in the current database to generate statistics, which will be * used in query optimizations. * * Right now, it only supports Hive tables and it only updates the size of a Hive table * in the Hive metastore. */ -@DeveloperApi +private[hive] case class AnalyzeTable(tableName: String) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { sqlContext.asInstanceOf[HiveContext].analyze(tableName) Seq.empty[Row] } } /** - * :: DeveloperApi :: * Drops a table from the metastore and removes it if it is cached. */ -@DeveloperApi +private[hive] case class DropTable( tableName: String, ifExists: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] val ifExistsClause = if (ifExists) "IF EXISTS " else "" try { - hiveContext.tryUncacheQuery(hiveContext.table(tableName)) + hiveContext.cacheManager.tryUncacheQuery(hiveContext.table(tableName)) } catch { - // This table's metadata is not in + // This table's metadata is not in Hive metastore (e.g. the table does not exist). case _: org.apache.hadoop.hive.ql.metadata.InvalidTableException => - // Other exceptions can be caused by users providing wrong parameters in OPTIONS + case _: org.apache.spark.sql.catalyst.analysis.NoSuchTableException => + // Other Throwables can be caused by users providing wrong parameters in OPTIONS // (e.g. invalid paths). We catch it and log a warning message. - // Users should be able to drop such kinds of tables regardless if there is an exception. - case e: Exception => log.warn(s"${e.getMessage}") + // Users should be able to drop such kinds of tables regardless if there is an error. + case e: Throwable => log.warn(s"${e.getMessage}", e) } hiveContext.invalidateTable(tableName) hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") @@ -70,27 +73,27 @@ case class DropTable( } } -/** - * :: DeveloperApi :: - */ -@DeveloperApi +private[hive] case class AddJar(path: String) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override val output: Seq[Attribute] = { + val schema = StructType( + StructField("result", IntegerType, false) :: Nil) + schema.toAttributes + } + + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] hiveContext.runSqlHive(s"ADD JAR $path") hiveContext.sparkContext.addJar(path) - Seq.empty[Row] + Seq(Row(0)) } } -/** - * :: DeveloperApi :: - */ -@DeveloperApi +private[hive] case class AddFile(path: String) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] hiveContext.runSqlHive(s"ADD FILE $path") hiveContext.sparkContext.addFile(path) @@ -98,16 +101,142 @@ case class AddFile(path: String) extends RunnableCommand { } } +private[hive] case class CreateMetastoreDataSource( tableName: String, userSpecifiedSchema: Option[StructType], provider: String, - options: Map[String, String]) extends RunnableCommand { + options: Map[String, String], + allowExisting: Boolean, + managedIfNoPath: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext) = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] - hiveContext.catalog.createDataSourceTable(tableName, userSpecifiedSchema, provider, options) + if (hiveContext.catalog.tableExists(tableName :: Nil)) { + if (allowExisting) { + return Seq.empty[Row] + } else { + throw new AnalysisException(s"Table $tableName already exists.") + } + } + + var isExternal = true + val optionsWithPath = + if (!options.contains("path") && managedIfNoPath) { + isExternal = false + options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableName)) + } else { + options + } + + hiveContext.catalog.createDataSourceTable( + tableName, + userSpecifiedSchema, + provider, + optionsWithPath, + isExternal) + + Seq.empty[Row] + } +} + +private[hive] +case class CreateMetastoreDataSourceAsSelect( + tableName: String, + provider: String, + mode: SaveMode, + options: Map[String, String], + query: LogicalPlan) extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val hiveContext = sqlContext.asInstanceOf[HiveContext] + var createMetastoreTable = false + var isExternal = true + val optionsWithPath = + if (!options.contains("path")) { + isExternal = false + options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableName)) + } else { + options + } + + var existingSchema = None: Option[StructType] + if (sqlContext.catalog.tableExists(Seq(tableName))) { + // Check if we need to throw an exception or just return. + mode match { + case SaveMode.ErrorIfExists => + throw new AnalysisException(s"Table $tableName already exists. " + + s"If you are using saveAsTable, you can set SaveMode to SaveMode.Append to " + + s"insert data into the table or set SaveMode to SaveMode.Overwrite to overwrite" + + s"the existing data. " + + s"Or, if you are using SQL CREATE TABLE, you need to drop $tableName first.") + case SaveMode.Ignore => + // Since the table already exists and the save mode is Ignore, we will just return. + return Seq.empty[Row] + case SaveMode.Append => + // Check if the specified data source match the data source of the existing table. + val resolved = + ResolvedDataSource(sqlContext, Some(query.schema), provider, optionsWithPath) + val createdRelation = LogicalRelation(resolved.relation) + EliminateSubQueries(sqlContext.table(tableName).logicalPlan) match { + case l @ LogicalRelation(i: InsertableRelation) => + if (i != createdRelation.relation) { + val errorDescription = + s"Cannot append to table $tableName because the resolved relation does not " + + s"match the existing relation of $tableName. " + + s"You can use insertInto($tableName, false) to append this DataFrame to the " + + s"table $tableName and using its data source and options." + val errorMessage = + s""" + |$errorDescription + |== Relations == + |${sideBySide( + s"== Expected Relation ==" :: + l.toString :: Nil, + s"== Actual Relation ==" :: + createdRelation.toString :: Nil).mkString("\n")} + """.stripMargin + throw new AnalysisException(errorMessage) + } + existingSchema = Some(l.schema) + case o => + throw new AnalysisException(s"Saving data in ${o.toString} is not supported.") + } + case SaveMode.Overwrite => + hiveContext.sql(s"DROP TABLE IF EXISTS $tableName") + // Need to create the table again. + createMetastoreTable = true + } + } else { + // The table does not exist. We need to create it in metastore. + createMetastoreTable = true + } + + val data = DataFrame(hiveContext, query) + val df = existingSchema match { + // If we are inserting into an existing table, just use the existing schema. + case Some(schema) => sqlContext.createDataFrame(data.queryExecution.toRdd, schema) + case None => data + } + + // Create the relation based on the data of df. + val resolved = ResolvedDataSource(sqlContext, provider, mode, optionsWithPath, df) + + if (createMetastoreTable) { + // We will use the schema of resolved.relation as the schema of the table (instead of + // the schema of df). It is important since the nullability may be changed by the relation + // provider (for example, see org.apache.spark.sql.parquet.DefaultSource). + hiveContext.catalog.createDataSourceTable( + tableName, + Some(resolved.relation.schema), + provider, + optionsWithPath, + isExternal) + } + + // Refresh the cache of the table in the catalog. + hiveContext.refreshTable(tableName) Seq.empty[Row] } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 76d214037219..4b6f0ad75f54 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -33,8 +33,11 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{Generate, Project, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils.getContextOrSparkClassLoader +import org.apache.spark.sql.catalyst.analysis.MultiAlias +import org.apache.spark.sql.catalyst.errors.TreeNodeException /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -42,7 +45,7 @@ import scala.collection.JavaConversions._ private[hive] abstract class HiveFunctionRegistry extends analysis.FunctionRegistry with HiveInspectors { - def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name) + def getFunctionInfo(name: String): FunctionInfo = FunctionRegistry.getFunctionInfo(name) def lookupFunction(name: String, children: Seq[Expression]): Expression = { // We only look it up to see if it exists, but do not include it in the HiveUDF since it is @@ -63,7 +66,7 @@ private[hive] abstract class HiveFunctionRegistry } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { HiveUdaf(new HiveFunctionWrapper(functionClassName), children) } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), Nil, children) + HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), children) } else { sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") } @@ -75,7 +78,7 @@ private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, childre type EvaluatedType = Any type UDFType = UDF - def nullable = true + override def nullable: Boolean = true @transient lazy val function = funcWrapper.createFunction[UDFType]() @@ -93,7 +96,7 @@ private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, childre udfType != null && udfType.deterministic() } - override def foldable = isUDFDeterministic && children.forall(_.foldable) + override def foldable: Boolean = isUDFDeterministic && children.forall(_.foldable) // Create parameter converters @transient @@ -107,7 +110,7 @@ private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, childre method.getGenericReturnType(), ObjectInspectorOptions.JAVA) @transient - protected lazy val cached = new Array[AnyRef](children.length) + protected lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) // TODO: Finish input output types. override def eval(input: Row): Any = { @@ -117,17 +120,19 @@ private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, childre returnInspector) } - override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } } // Adapter from Catalyst ExpressionResult to Hive DeferredObject private[hive] class DeferredObjectAdapter(oi: ObjectInspector) extends DeferredObject with HiveInspectors { private var func: () => Any = _ - def set(func: () => Any) { + def set(func: () => Any): Unit = { this.func = func } - override def prepare(i: Int) = {} + override def prepare(i: Int): Unit = {} override def get(): AnyRef = wrap(func(), oi) } @@ -136,7 +141,7 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr type UDFType = GenericUDF type EvaluatedType = Any - def nullable = true + override def nullable: Boolean = true @transient lazy val function = funcWrapper.createFunction[UDFType]() @@ -155,7 +160,7 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr (udfType != null && udfType.deterministic()) } - override def foldable = + override def foldable: Boolean = isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] @transient @@ -179,7 +184,9 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr unwrap(function.evaluate(deferedObjects), returnInspector) } - override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } } private[hive] case class HiveGenericUdaf( @@ -206,9 +213,11 @@ private[hive] case class HiveGenericUdaf( def nullable: Boolean = true - override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } - def newInstance() = new HiveUdafFunction(funcWrapper, children, this) + def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this) } /** It is used as a wrapper for the hive functions which uses UDAF interface */ @@ -237,10 +246,11 @@ private[hive] case class HiveUdaf( def nullable: Boolean = true - override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } - def newInstance() = - new HiveUdafFunction(funcWrapper, children, this, true) + def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this, true) } /** @@ -256,7 +266,6 @@ private[hive] case class HiveUdaf( */ private[hive] case class HiveGenericUdtf( funcWrapper: HiveFunctionWrapper, - aliasNames: Seq[String], children: Seq[Expression]) extends Generator with HiveInspectors { @@ -272,23 +281,8 @@ private[hive] case class HiveGenericUdtf( @transient protected lazy val udtInput = new Array[AnyRef](children.length) - protected lazy val outputDataTypes = outputInspector.getAllStructFieldRefs.map { - field => inspectorToDataType(field.getFieldObjectInspector) - } - - override protected def makeOutput() = { - // Use column names when given, otherwise _c1, _c2, ... _cn. - if (aliasNames.size == outputDataTypes.size) { - aliasNames.zip(outputDataTypes).map { - case (attrName, attrDataType) => - AttributeReference(attrName, attrDataType, nullable = true)() - } - } else { - outputDataTypes.zipWithIndex.map { - case (attrDataType, i) => - AttributeReference(s"_c$i", attrDataType, nullable = true)() - } - } + lazy val elementTypes = outputInspector.getAllStructFieldRefs.map { + field => (inspectorToDataType(field.getFieldObjectInspector), true) } override def eval(input: Row): TraversableOnce[Row] = { @@ -311,14 +305,16 @@ private[hive] case class HiveGenericUdtf( collected += unwrap(input, outputInspector).asInstanceOf[Row] } - def collectRows() = { + def collectRows(): Seq[Row] = { val toCollect = collected collected = new ArrayBuffer[Row] toCollect } } - override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } } private[hive] case class HiveUdafFunction( @@ -347,9 +343,8 @@ private[hive] case class HiveUdafFunction( private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) - // Cast required to avoid type inference selecting a deprecated Hive API. private val buffer = - function.getNewAggregationBuffer.asInstanceOf[GenericUDAFEvaluator.AbstractAggregationBuffer] + function.getNewAggregationBuffer override def eval(input: Row): Any = unwrap(function.evaluate(buffer), returnInspector) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index aae175e426ad..8398da268174 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.hive -import java.io.IOException import java.text.NumberFormat import java.util.Date @@ -30,6 +29,7 @@ import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred._ +import org.apache.hadoop.hive.common.FileUtils import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.Row @@ -117,19 +117,7 @@ private[hive] class SparkHiveWriterContainer( } protected def commit() { - if (committer.needsTaskCommit(taskContext)) { - try { - committer.commitTask(taskContext) - logInfo (taID + ": Committed") - } catch { - case e: IOException => - logError("Error committing the output of task: " + taID.value, e) - committer.abortTask(taskContext) - throw e - } - } else { - logInfo("No need to commit output of task: " + taID.value) - } + SparkHadoopMapRedUtil.commitTask(committer, taskContext, jobID, splitID, attemptID) } private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { @@ -212,11 +200,16 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( .zip(row.toSeq.takeRight(dynamicPartColNames.length)) .map { case (col, rawVal) => val string = if (rawVal == null) null else String.valueOf(rawVal) - s"/$col=${if (string == null || string.isEmpty) defaultPartName else string}" - } - .mkString - - def newWriter = { + val colString = + if (string == null || string.isEmpty) { + defaultPartName + } else { + FileUtils.escapePathName(string) + } + s"/$col=$colString" + }.mkString + + def newWriter(): FileSinkOperator.RecordWriter = { val newFileSinkDesc = new FileSinkDesc( fileSinkConf.getDirName + dynamicPartPath, fileSinkConf.getTableInfo, @@ -240,6 +233,6 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( Reporter.NULL) } - writers.getOrElseUpdate(dynamicPartPath, newWriter) + writers.getOrElseUpdate(dynamicPartPath, newWriter()) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/package-info.java b/sql/hive/src/main/scala/org/apache/spark/sql/hive/package-info.java index 8b29fa7d1a8f..4b23fbf6e736 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/package-info.java +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/package-info.java @@ -15,4 +15,4 @@ * limitations under the License. */ -package org.apache.spark.sql.hive; \ No newline at end of file +package org.apache.spark.sql.hive; diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/package.scala index a6c8ed4f7e86..db074361ef03 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/package.scala @@ -17,4 +17,14 @@ package org.apache.spark.sql +/** + * Support for running Spark SQL queries using functionality from Apache Hive (does not require an + * existing Hive installation). Supported Hive features include: + * - Using HiveQL to express queries. + * - Reading metadata from the Hive Metastore using HiveSerDes. + * - Hive UDFs, UDAs, UDTs + * + * Users that would like access to this functionality should create a + * [[hive.HiveContext HiveContext]] instead of a [[SQLContext]]. + */ package object hive diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala deleted file mode 100644 index 2a16c9d1a27c..000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.parquet - -import java.util.Properties - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category -import org.apache.hadoop.hive.serde2.{SerDeStats, SerDe} -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector -import org.apache.hadoop.io.Writable - -/** - * A placeholder that allows Spark SQL users to create metastore tables that are stored as - * parquet files. It is only intended to pass the checks that the serde is valid and exists - * when a CREATE TABLE is run. The actual work of decoding will be done by ParquetTableScan - * when "spark.sql.hive.convertMetastoreParquet" is set to true. - */ -@deprecated("No code should depend on FakeParquetHiveSerDe as it is only intended as a " + - "placeholder in the Hive MetaStore", "1.2.0") -class FakeParquetSerDe extends SerDe { - override def getObjectInspector: ObjectInspector = new ObjectInspector { - override def getCategory: Category = Category.PRIMITIVE - - override def getTypeName: String = "string" - } - - override def deserialize(p1: Writable): AnyRef = throwError - - override def initialize(p1: Configuration, p2: Properties): Unit = {} - - override def getSerializedClass: Class[_ <: Writable] = throwError - - override def getSerDeStats: SerDeStats = throwError - - override def serialize(p1: scala.Any, p2: ObjectInspector): Writable = throwError - - private def throwError = - sys.error( - "spark.sql.hive.convertMetastoreParquet must be set to true to use FakeParquetSerDe") -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala similarity index 89% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 47431cef03e1..9f17bca083d1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -20,26 +20,25 @@ package org.apache.spark.sql.hive.test import java.io.File import java.util.{Set => JavaSet} -import scala.collection.mutable -import scala.language.implicitConversions - import org.apache.hadoop.hive.ql.exec.FunctionRegistry import org.apache.hadoop.hive.ql.io.avro.{AvroContainerInputFormat, AvroContainerOutputFormat} import org.apache.hadoop.hive.ql.metadata.Table +import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.serde2.RegexSerDe import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.hive.serde2.avro.AvroSerDe - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.util.Utils +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.CacheTableCommand import org.apache.spark.sql.hive._ -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.execution.HiveNativeCommand +import org.apache.spark.util.Utils +import org.apache.spark.{SparkConf, SparkContext} + +import scala.collection.mutable +import scala.language.implicitConversions /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -68,22 +67,21 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { System.clearProperty("spark.hostPort") CommandProcessorFactory.clean(hiveconf) - lazy val warehousePath = getTempFilePath("sparkHiveWarehouse").getCanonicalPath - lazy val metastorePath = getTempFilePath("sparkHiveMetastore").getCanonicalPath + hiveconf.set("hive.plan.serialization.format", "javaXML") + + lazy val warehousePath = Utils.createTempDir() + lazy val metastorePath = Utils.createTempDir() /** Sets up the system initially or after a RESET command */ protected def configure(): Unit = { + warehousePath.delete() + metastorePath.delete() setConf("javax.jdo.option.ConnectionURL", s"jdbc:derby:;databaseName=$metastorePath;create=true") - setConf("hive.metastore.warehouse.dir", warehousePath) - Utils.registerShutdownDeleteDir(new File(warehousePath)) - Utils.registerShutdownDeleteDir(new File(metastorePath)) + setConf("hive.metastore.warehouse.dir", warehousePath.toString) } - val testTempDir = File.createTempFile("testTempFiles", "spark.hive.tmp") - testTempDir.delete() - testTempDir.mkdir() - Utils.registerShutdownDeleteDir(testTempDir) + val testTempDir = Utils.createTempDir() // For some hive test case which contain ${system:test.tmp.dir} System.setProperty("test.tmp.dir", testTempDir.getCanonicalPath) @@ -99,12 +97,18 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { override def runSqlHive(sql: String): Seq[String] = super.runSqlHive(rewritePaths(sql)) override def executePlan(plan: LogicalPlan): this.QueryExecution = - new this.QueryExecution { val logical = plan } + new this.QueryExecution(plan) + + override protected[sql] def createSession(): SQLSession = { + new this.SQLSession() + } - /** Fewer partitions to speed up testing. */ - protected[sql] override lazy val conf: SQLConf = new SQLConf { - override def numShufflePartitions: Int = getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt - override def dialect: String = getConf(SQLConf.DIALECT, "hiveql") + protected[hive] class SQLSession extends super.SQLSession { + /** Fewer partitions to speed up testing. */ + protected[sql] override lazy val conf: SQLConf = new SQLConf { + override def numShufflePartitions: Int = getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt + override def dialect: String = getConf(SQLConf.DIALECT, "hiveql") + } } /** @@ -150,16 +154,22 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { val describedTable = "DESCRIBE (\\w+)".r - protected[hive] class HiveQLQueryExecution(hql: String) extends this.QueryExecution { - lazy val logical = HiveQl.parseSql(hql) - def hiveExec() = runSqlHive(hql) - override def toString = hql + "\n" + super.toString + val vs = new VariableSubstitution() + + // we should substitute variables in hql to pass the text to parseSql() as a parameter. + // Hive parser need substituted text. HiveContext.sql() does this but return a DataFrame, + // while we need a logicalPlan so we cannot reuse that. + protected[hive] class HiveQLQueryExecution(hql: String) + extends this.QueryExecution(HiveQl.parseSql(vs.substitute(hiveconf, hql))) { + def hiveExec(): Seq[String] = runSqlHive(hql) + override def toString: String = hql + "\n" + super.toString } /** * Override QueryExecution with special debug workflow. */ - abstract class QueryExecution extends super.QueryExecution { + class QueryExecution(logicalPlan: LogicalPlan) + extends super.QueryExecution(logicalPlan) { override lazy val analyzed = { val describedTables = logical match { case HiveNativeCommand(describedTable(tbl)) => tbl :: Nil @@ -175,14 +185,16 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(loadTestTable) // Proceed with analysis. - analyzer(logical) + analyzer.execute(logical) } } case class TestTable(name: String, commands: (()=>Unit)*) protected[hive] implicit class SqlCmd(sql: String) { - def cmd = () => new HiveQLQueryExecution(sql).stringResult(): Unit + def cmd: () => Unit = { + () => new HiveQLQueryExecution(sql).stringResult(): Unit + } } /** @@ -190,10 +202,14 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { * demand when a query are run against it. */ lazy val testTables = new mutable.HashMap[String, TestTable]() - def registerTestTable(testTable: TestTable) = testTables += (testTable.name -> testTable) + + def registerTestTable(testTable: TestTable): Unit = { + testTables += (testTable.name -> testTable) + } // The test tables that are defined in the Hive QTestUtil. // /itests/util/src/main/java/org/apache/hadoop/hive/ql/QTestUtil.java + // https://github.com/apache/hive/blob/branch-0.13/data/scripts/q_test_init.sql val hiveQTestUtilTables = Seq( TestTable("src", "CREATE TABLE src (key INT, value STRING)".cmd, @@ -221,11 +237,10 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { } }), TestTable("src_thrift", () => { - import org.apache.thrift.protocol.TBinaryProtocol - import org.apache.hadoop.hive.serde2.thrift.test.Complex import org.apache.hadoop.hive.serde2.thrift.ThriftDeserializer - import org.apache.hadoop.mapred.SequenceFileInputFormat - import org.apache.hadoop.mapred.SequenceFileOutputFormat + import org.apache.hadoop.hive.serde2.thrift.test.Complex + import org.apache.hadoop.mapred.{SequenceFileInputFormat, SequenceFileOutputFormat} + import org.apache.thrift.protocol.TBinaryProtocol val srcThrift = new Table("default", "src_thrift") srcThrift.setFields(Nil) @@ -247,12 +262,6 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { |WITH SERDEPROPERTIES ('field.delim'='\\t') """.stripMargin.cmd, "INSERT OVERWRITE TABLE serdeins SELECT * FROM src".cmd), - TestTable("sales", - s"""CREATE TABLE IF NOT EXISTS sales (key STRING, value INT) - |ROW FORMAT SERDE '${classOf[RegexSerDe].getCanonicalName}' - |WITH SERDEPROPERTIES ("input.regex" = "([^ ]*)\t([^ ]*)") - """.stripMargin.cmd, - s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/sales.txt")}' INTO TABLE sales".cmd), TestTable("episodes", s"""CREATE TABLE episodes (title STRING, air_date STRING, doctor INT) |ROW FORMAT SERDE '${classOf[AvroSerDe].getCanonicalName}' @@ -395,7 +404,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) } - clearCache() + cacheManager.clearCache() loadedTables.clear() catalog.cachedDataSourceTables.invalidateAll() catalog.client.getAllTables("default").foreach { t => diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java new file mode 100644 index 000000000000..53ddecf57958 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hive; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.spark.sql.SaveMode; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.QueryTest$; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.hive.test.TestHive$; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.Utils; + +public class JavaMetastoreDataSourcesSuite { + private transient JavaSparkContext sc; + private transient HiveContext sqlContext; + + String originalDefaultSource; + File path; + Path hiveManagedPath; + FileSystem fs; + DataFrame df; + + private void checkAnswer(DataFrame actual, List expected) { + String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); + if (errorMessage != null) { + Assert.fail(errorMessage); + } + } + + @Before + public void setUp() throws IOException { + sqlContext = TestHive$.MODULE$; + sc = new JavaSparkContext(sqlContext.sparkContext()); + + originalDefaultSource = sqlContext.conf().defaultDataSourceName(); + path = + Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile(); + if (path.exists()) { + path.delete(); + } + hiveManagedPath = new Path(sqlContext.catalog().hiveDefaultTableFilePath("javaSavedTable")); + fs = hiveManagedPath.getFileSystem(sc.hadoopConfiguration()); + if (fs.exists(hiveManagedPath)){ + fs.delete(hiveManagedPath, true); + } + + List jsonObjects = new ArrayList(10); + for (int i = 0; i < 10; i++) { + jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); + } + JavaRDD rdd = sc.parallelize(jsonObjects); + df = sqlContext.jsonRDD(rdd); + df.registerTempTable("jsonTable"); + } + + @After + public void tearDown() throws IOException { + // Clean up tables. + sqlContext.sql("DROP TABLE IF EXISTS javaSavedTable"); + sqlContext.sql("DROP TABLE IF EXISTS externalTable"); + } + + @Test + public void saveExternalTableAndQueryIt() { + Map options = new HashMap(); + options.put("path", path.toString()); + df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options); + + checkAnswer( + sqlContext.sql("SELECT * FROM javaSavedTable"), + df.collectAsList()); + + DataFrame loadedDF = + sqlContext.createExternalTable("externalTable", "org.apache.spark.sql.json", options); + + checkAnswer(loadedDF, df.collectAsList()); + checkAnswer( + sqlContext.sql("SELECT * FROM externalTable"), + df.collectAsList()); + } + + @Test + public void saveExternalTableWithSchemaAndQueryIt() { + Map options = new HashMap(); + options.put("path", path.toString()); + df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options); + + checkAnswer( + sqlContext.sql("SELECT * FROM javaSavedTable"), + df.collectAsList()); + + List fields = new ArrayList(); + fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); + StructType schema = DataTypes.createStructType(fields); + DataFrame loadedDF = + sqlContext.createExternalTable("externalTable", "org.apache.spark.sql.json", schema, options); + + checkAnswer( + loadedDF, + sqlContext.sql("SELECT b FROM javaSavedTable").collectAsList()); + checkAnswer( + sqlContext.sql("SELECT * FROM externalTable"), + sqlContext.sql("SELECT b FROM javaSavedTable").collectAsList()); + } + + @Test + public void saveTableAndQueryIt() { + Map options = new HashMap(); + df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options); + + checkAnswer( + sqlContext.sql("SELECT * FROM javaSavedTable"), + df.collectAsList()); + } +} diff --git a/sql/hive/src/test/resources/golden/CTE feature #1-0-eedabbfe6ba8799f7b7782fb47a82768 b/sql/hive/src/test/resources/golden/CTE feature #1-0-eedabbfe6ba8799f7b7782fb47a82768 new file mode 100644 index 000000000000..f6ba75da254c --- /dev/null +++ b/sql/hive/src/test/resources/golden/CTE feature #1-0-eedabbfe6ba8799f7b7782fb47a82768 @@ -0,0 +1,3 @@ +5 +5 +5 diff --git a/sql/hive/src/test/resources/golden/CTE feature #2-0-aa03d104251f97e36bc52279cb9931c9 b/sql/hive/src/test/resources/golden/CTE feature #2-0-aa03d104251f97e36bc52279cb9931c9 new file mode 100644 index 000000000000..ca7b591095e2 --- /dev/null +++ b/sql/hive/src/test/resources/golden/CTE feature #2-0-aa03d104251f97e36bc52279cb9931c9 @@ -0,0 +1,4 @@ +val_4 +val_5 +val_5 +val_5 diff --git a/sql/hive/src/test/resources/golden/CTE feature #3-0-b5d4bf3c0ee92b2fda0ca24f422383f2 b/sql/hive/src/test/resources/golden/CTE feature #3-0-b5d4bf3c0ee92b2fda0ca24f422383f2 new file mode 100644 index 000000000000..b8626c4cff28 --- /dev/null +++ b/sql/hive/src/test/resources/golden/CTE feature #3-0-b5d4bf3c0ee92b2fda0ca24f422383f2 @@ -0,0 +1 @@ +4 diff --git a/sql/hive/src/test/resources/golden/Date cast-0-a7cd69b80c77a771a2c955db666be53d b/sql/hive/src/test/resources/golden/Date cast-0-a7cd69b80c77a771a2c955db666be53d new file mode 100644 index 000000000000..98da82fa8938 --- /dev/null +++ b/sql/hive/src/test/resources/golden/Date cast-0-a7cd69b80c77a771a2c955db666be53d @@ -0,0 +1 @@ +1970-01-01 1970-01-01 1969-12-31 16:00:00 1969-12-31 16:00:00 1970-01-01 00:00:00 diff --git a/sql/hive/src/test/resources/golden/Date comparison test 1-0-bde89be08a12361073ff658fef768b7e b/sql/hive/src/test/resources/golden/Date comparison test 1-0-bde89be08a12361073ff658fef768b7e new file mode 100644 index 000000000000..27ba77ddaf61 --- /dev/null +++ b/sql/hive/src/test/resources/golden/Date comparison test 1-0-bde89be08a12361073ff658fef768b7e @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/Date comparison test 2-0-dc1b267f1d79d49e6675afe4fd2a34a5 b/sql/hive/src/test/resources/golden/Date comparison test 2-0-dc1b267f1d79d49e6675afe4fd2a34a5 new file mode 100644 index 000000000000..27ba77ddaf61 --- /dev/null +++ b/sql/hive/src/test/resources/golden/Date comparison test 2-0-dc1b267f1d79d49e6675afe4fd2a34a5 @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/Specify the udtf output-0-d1f244bce64f22b34ad5bf9fd360b632 b/sql/hive/src/test/resources/golden/Specify the udtf output-0-d1f244bce64f22b34ad5bf9fd360b632 new file mode 100644 index 000000000000..d00491fd7e5b --- /dev/null +++ b/sql/hive/src/test/resources/golden/Specify the udtf output-0-d1f244bce64f22b34ad5bf9fd360b632 @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/date_1-0-23edf29bf7376c70d5ecf12720f4b1eb b/sql/hive/src/test/resources/golden/create table as with db name within backticks-0-a253b1ed35dbf503d1b8902dacbe23ac similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-0-23edf29bf7376c70d5ecf12720f4b1eb rename to sql/hive/src/test/resources/golden/create table as with db name within backticks-0-a253b1ed35dbf503d1b8902dacbe23ac diff --git a/sql/hive/src/test/resources/golden/date_1-1-4ebe3571c13a8b0c03096fbd972b7f1b b/sql/hive/src/test/resources/golden/create table as with db name within backticks-1-61dc640dfeaff471f3d2b730f9cbf959 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-1-4ebe3571c13a8b0c03096fbd972b7f1b rename to sql/hive/src/test/resources/golden/create table as with db name within backticks-1-61dc640dfeaff471f3d2b730f9cbf959 diff --git a/sql/hive/src/test/resources/golden/create table as with db name within backticks-2-ce780d068b8d24786e639e361101a0c7 b/sql/hive/src/test/resources/golden/create table as with db name within backticks-2-ce780d068b8d24786e639e361101a0c7 new file mode 100644 index 000000000000..7aae61e5eb82 --- /dev/null +++ b/sql/hive/src/test/resources/golden/create table as with db name within backticks-2-ce780d068b8d24786e639e361101a0c7 @@ -0,0 +1,500 @@ +238 val_238 +86 val_86 +311 val_311 +27 val_27 +165 val_165 +409 val_409 +255 val_255 +278 val_278 +98 val_98 +484 val_484 +265 val_265 +193 val_193 +401 val_401 +150 val_150 +273 val_273 +224 val_224 +369 val_369 +66 val_66 +128 val_128 +213 val_213 +146 val_146 +406 val_406 +429 val_429 +374 val_374 +152 val_152 +469 val_469 +145 val_145 +495 val_495 +37 val_37 +327 val_327 +281 val_281 +277 val_277 +209 val_209 +15 val_15 +82 val_82 +403 val_403 +166 val_166 +417 val_417 +430 val_430 +252 val_252 +292 val_292 +219 val_219 +287 val_287 +153 val_153 +193 val_193 +338 val_338 +446 val_446 +459 val_459 +394 val_394 +237 val_237 +482 val_482 +174 val_174 +413 val_413 +494 val_494 +207 val_207 +199 val_199 +466 val_466 +208 val_208 +174 val_174 +399 val_399 +396 val_396 +247 val_247 +417 val_417 +489 val_489 +162 val_162 +377 val_377 +397 val_397 +309 val_309 +365 val_365 +266 val_266 +439 val_439 +342 val_342 +367 val_367 +325 val_325 +167 val_167 +195 val_195 +475 val_475 +17 val_17 +113 val_113 +155 val_155 +203 val_203 +339 val_339 +0 val_0 +455 val_455 +128 val_128 +311 val_311 +316 val_316 +57 val_57 +302 val_302 +205 val_205 +149 val_149 +438 val_438 +345 val_345 +129 val_129 +170 val_170 +20 val_20 +489 val_489 +157 val_157 +378 val_378 +221 val_221 +92 val_92 +111 val_111 +47 val_47 +72 val_72 +4 val_4 +280 val_280 +35 val_35 +427 val_427 +277 val_277 +208 val_208 +356 val_356 +399 val_399 +169 val_169 +382 val_382 +498 val_498 +125 val_125 +386 val_386 +437 val_437 +469 val_469 +192 val_192 +286 val_286 +187 val_187 +176 val_176 +54 val_54 +459 val_459 +51 val_51 +138 val_138 +103 val_103 +239 val_239 +213 val_213 +216 val_216 +430 val_430 +278 val_278 +176 val_176 +289 val_289 +221 val_221 +65 val_65 +318 val_318 +332 val_332 +311 val_311 +275 val_275 +137 val_137 +241 val_241 +83 val_83 +333 val_333 +180 val_180 +284 val_284 +12 val_12 +230 val_230 +181 val_181 +67 val_67 +260 val_260 +404 val_404 +384 val_384 +489 val_489 +353 val_353 +373 val_373 +272 val_272 +138 val_138 +217 val_217 +84 val_84 +348 val_348 +466 val_466 +58 val_58 +8 val_8 +411 val_411 +230 val_230 +208 val_208 +348 val_348 +24 val_24 +463 val_463 +431 val_431 +179 val_179 +172 val_172 +42 val_42 +129 val_129 +158 val_158 +119 val_119 +496 val_496 +0 val_0 +322 val_322 +197 val_197 +468 val_468 +393 val_393 +454 val_454 +100 val_100 +298 val_298 +199 val_199 +191 val_191 +418 val_418 +96 val_96 +26 val_26 +165 val_165 +327 val_327 +230 val_230 +205 val_205 +120 val_120 +131 val_131 +51 val_51 +404 val_404 +43 val_43 +436 val_436 +156 val_156 +469 val_469 +468 val_468 +308 val_308 +95 val_95 +196 val_196 +288 val_288 +481 val_481 +457 val_457 +98 val_98 +282 val_282 +197 val_197 +187 val_187 +318 val_318 +318 val_318 +409 val_409 +470 val_470 +137 val_137 +369 val_369 +316 val_316 +169 val_169 +413 val_413 +85 val_85 +77 val_77 +0 val_0 +490 val_490 +87 val_87 +364 val_364 +179 val_179 +118 val_118 +134 val_134 +395 val_395 +282 val_282 +138 val_138 +238 val_238 +419 val_419 +15 val_15 +118 val_118 +72 val_72 +90 val_90 +307 val_307 +19 val_19 +435 val_435 +10 val_10 +277 val_277 +273 val_273 +306 val_306 +224 val_224 +309 val_309 +389 val_389 +327 val_327 +242 val_242 +369 val_369 +392 val_392 +272 val_272 +331 val_331 +401 val_401 +242 val_242 +452 val_452 +177 val_177 +226 val_226 +5 val_5 +497 val_497 +402 val_402 +396 val_396 +317 val_317 +395 val_395 +58 val_58 +35 val_35 +336 val_336 +95 val_95 +11 val_11 +168 val_168 +34 val_34 +229 val_229 +233 val_233 +143 val_143 +472 val_472 +322 val_322 +498 val_498 +160 val_160 +195 val_195 +42 val_42 +321 val_321 +430 val_430 +119 val_119 +489 val_489 +458 val_458 +78 val_78 +76 val_76 +41 val_41 +223 val_223 +492 val_492 +149 val_149 +449 val_449 +218 val_218 +228 val_228 +138 val_138 +453 val_453 +30 val_30 +209 val_209 +64 val_64 +468 val_468 +76 val_76 +74 val_74 +342 val_342 +69 val_69 +230 val_230 +33 val_33 +368 val_368 +103 val_103 +296 val_296 +113 val_113 +216 val_216 +367 val_367 +344 val_344 +167 val_167 +274 val_274 +219 val_219 +239 val_239 +485 val_485 +116 val_116 +223 val_223 +256 val_256 +263 val_263 +70 val_70 +487 val_487 +480 val_480 +401 val_401 +288 val_288 +191 val_191 +5 val_5 +244 val_244 +438 val_438 +128 val_128 +467 val_467 +432 val_432 +202 val_202 +316 val_316 +229 val_229 +469 val_469 +463 val_463 +280 val_280 +2 val_2 +35 val_35 +283 val_283 +331 val_331 +235 val_235 +80 val_80 +44 val_44 +193 val_193 +321 val_321 +335 val_335 +104 val_104 +466 val_466 +366 val_366 +175 val_175 +403 val_403 +483 val_483 +53 val_53 +105 val_105 +257 val_257 +406 val_406 +409 val_409 +190 val_190 +406 val_406 +401 val_401 +114 val_114 +258 val_258 +90 val_90 +203 val_203 +262 val_262 +348 val_348 +424 val_424 +12 val_12 +396 val_396 +201 val_201 +217 val_217 +164 val_164 +431 val_431 +454 val_454 +478 val_478 +298 val_298 +125 val_125 +431 val_431 +164 val_164 +424 val_424 +187 val_187 +382 val_382 +5 val_5 +70 val_70 +397 val_397 +480 val_480 +291 val_291 +24 val_24 +351 val_351 +255 val_255 +104 val_104 +70 val_70 +163 val_163 +438 val_438 +119 val_119 +414 val_414 +200 val_200 +491 val_491 +237 val_237 +439 val_439 +360 val_360 +248 val_248 +479 val_479 +305 val_305 +417 val_417 +199 val_199 +444 val_444 +120 val_120 +429 val_429 +169 val_169 +443 val_443 +323 val_323 +325 val_325 +277 val_277 +230 val_230 +478 val_478 +178 val_178 +468 val_468 +310 val_310 +317 val_317 +333 val_333 +493 val_493 +460 val_460 +207 val_207 +249 val_249 +265 val_265 +480 val_480 +83 val_83 +136 val_136 +353 val_353 +172 val_172 +214 val_214 +462 val_462 +233 val_233 +406 val_406 +133 val_133 +175 val_175 +189 val_189 +454 val_454 +375 val_375 +401 val_401 +421 val_421 +407 val_407 +384 val_384 +256 val_256 +26 val_26 +134 val_134 +67 val_67 +384 val_384 +379 val_379 +18 val_18 +462 val_462 +492 val_492 +100 val_100 +298 val_298 +9 val_9 +341 val_341 +498 val_498 +146 val_146 +458 val_458 +362 val_362 +186 val_186 +285 val_285 +348 val_348 +167 val_167 +18 val_18 +273 val_273 +183 val_183 +281 val_281 +344 val_344 +97 val_97 +469 val_469 +315 val_315 +84 val_84 +28 val_28 +37 val_37 +448 val_448 +152 val_152 +348 val_348 +307 val_307 +194 val_194 +414 val_414 +477 val_477 +222 val_222 +126 val_126 +90 val_90 +169 val_169 +403 val_403 +400 val_400 +200 val_200 +97 val_97 diff --git a/sql/hive/src/test/resources/golden/date_1-16-23edf29bf7376c70d5ecf12720f4b1eb b/sql/hive/src/test/resources/golden/create table as with db name within backticks-3-afd6e46b6a289c3c24a8eec75a94043c similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-16-23edf29bf7376c70d5ecf12720f4b1eb rename to sql/hive/src/test/resources/golden/create table as with db name within backticks-3-afd6e46b6a289c3c24a8eec75a94043c diff --git a/sql/hive/src/test/resources/golden/date_1-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/date_1-0-50131c0ba7b7a6b65c789a5a8497bada new file mode 100644 index 000000000000..573541ac9702 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-0-50131c0ba7b7a6b65c789a5a8497bada @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/date_1-2-abdce0c0d14d3fc7441b7c134b02f99a b/sql/hive/src/test/resources/golden/date_1-1-23edf29bf7376c70d5ecf12720f4b1eb similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-2-abdce0c0d14d3fc7441b7c134b02f99a rename to sql/hive/src/test/resources/golden/date_1-1-23edf29bf7376c70d5ecf12720f4b1eb diff --git a/sql/hive/src/test/resources/golden/date_1-3-df16364a220ff96a6ea1cd478cbc1d0b b/sql/hive/src/test/resources/golden/date_1-10-df16364a220ff96a6ea1cd478cbc1d0b similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-3-df16364a220ff96a6ea1cd478cbc1d0b rename to sql/hive/src/test/resources/golden/date_1-10-df16364a220ff96a6ea1cd478cbc1d0b diff --git a/sql/hive/src/test/resources/golden/date_1-10-d964bec7e5632091ab5cb6f6786dbbf9 b/sql/hive/src/test/resources/golden/date_1-11-d964bec7e5632091ab5cb6f6786dbbf9 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-10-d964bec7e5632091ab5cb6f6786dbbf9 rename to sql/hive/src/test/resources/golden/date_1-11-d964bec7e5632091ab5cb6f6786dbbf9 diff --git a/sql/hive/src/test/resources/golden/date_1-11-480c5f024a28232b7857be327c992509 b/sql/hive/src/test/resources/golden/date_1-12-480c5f024a28232b7857be327c992509 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-11-480c5f024a28232b7857be327c992509 rename to sql/hive/src/test/resources/golden/date_1-12-480c5f024a28232b7857be327c992509 diff --git a/sql/hive/src/test/resources/golden/date_1-12-4c0ed7fcb75770d8790575b586bf14f4 b/sql/hive/src/test/resources/golden/date_1-13-4c0ed7fcb75770d8790575b586bf14f4 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-12-4c0ed7fcb75770d8790575b586bf14f4 rename to sql/hive/src/test/resources/golden/date_1-13-4c0ed7fcb75770d8790575b586bf14f4 diff --git a/sql/hive/src/test/resources/golden/date_1-13-44fc74c1993062c0a9522199ff27fea b/sql/hive/src/test/resources/golden/date_1-14-44fc74c1993062c0a9522199ff27fea similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-13-44fc74c1993062c0a9522199ff27fea rename to sql/hive/src/test/resources/golden/date_1-14-44fc74c1993062c0a9522199ff27fea diff --git a/sql/hive/src/test/resources/golden/date_1-14-4855a66124b16d1d0d003235995ac06b b/sql/hive/src/test/resources/golden/date_1-15-4855a66124b16d1d0d003235995ac06b similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-14-4855a66124b16d1d0d003235995ac06b rename to sql/hive/src/test/resources/golden/date_1-15-4855a66124b16d1d0d003235995ac06b diff --git a/sql/hive/src/test/resources/golden/date_1-15-8bc190dba0f641840b5e1e198a14c55b b/sql/hive/src/test/resources/golden/date_1-16-8bc190dba0f641840b5e1e198a14c55b similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-15-8bc190dba0f641840b5e1e198a14c55b rename to sql/hive/src/test/resources/golden/date_1-16-8bc190dba0f641840b5e1e198a14c55b diff --git a/sql/hive/src/test/resources/golden/date_1-5-5e70fc74158fbfca38134174360de12d b/sql/hive/src/test/resources/golden/date_1-17-23edf29bf7376c70d5ecf12720f4b1eb similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-5-5e70fc74158fbfca38134174360de12d rename to sql/hive/src/test/resources/golden/date_1-17-23edf29bf7376c70d5ecf12720f4b1eb diff --git a/sql/hive/src/test/resources/golden/date_1-8-1d5c58095cd52ea539d869f2ab1ab67d b/sql/hive/src/test/resources/golden/date_1-2-4ebe3571c13a8b0c03096fbd972b7f1b similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-8-1d5c58095cd52ea539d869f2ab1ab67d rename to sql/hive/src/test/resources/golden/date_1-2-4ebe3571c13a8b0c03096fbd972b7f1b diff --git a/sql/hive/src/test/resources/golden/date_1-3-26b5c291400dfde455b3c1b878b71d0 b/sql/hive/src/test/resources/golden/date_1-3-26b5c291400dfde455b3c1b878b71d0 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/date_1-6-df16364a220ff96a6ea1cd478cbc1d0b b/sql/hive/src/test/resources/golden/date_1-4-df16364a220ff96a6ea1cd478cbc1d0b similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-6-df16364a220ff96a6ea1cd478cbc1d0b rename to sql/hive/src/test/resources/golden/date_1-4-df16364a220ff96a6ea1cd478cbc1d0b diff --git a/sql/hive/src/test/resources/golden/date_1-4-d964bec7e5632091ab5cb6f6786dbbf9 b/sql/hive/src/test/resources/golden/date_1-5-d964bec7e5632091ab5cb6f6786dbbf9 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-4-d964bec7e5632091ab5cb6f6786dbbf9 rename to sql/hive/src/test/resources/golden/date_1-5-d964bec7e5632091ab5cb6f6786dbbf9 diff --git a/sql/hive/src/test/resources/golden/date_1-6-559d01fb0b42c42f0c4927fa0f9deac4 b/sql/hive/src/test/resources/golden/date_1-6-559d01fb0b42c42f0c4927fa0f9deac4 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/date_1-9-df16364a220ff96a6ea1cd478cbc1d0b b/sql/hive/src/test/resources/golden/date_1-7-df16364a220ff96a6ea1cd478cbc1d0b similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-9-df16364a220ff96a6ea1cd478cbc1d0b rename to sql/hive/src/test/resources/golden/date_1-7-df16364a220ff96a6ea1cd478cbc1d0b diff --git a/sql/hive/src/test/resources/golden/date_1-7-d964bec7e5632091ab5cb6f6786dbbf9 b/sql/hive/src/test/resources/golden/date_1-8-d964bec7e5632091ab5cb6f6786dbbf9 similarity index 100% rename from sql/hive/src/test/resources/golden/date_1-7-d964bec7e5632091ab5cb6f6786dbbf9 rename to sql/hive/src/test/resources/golden/date_1-8-d964bec7e5632091ab5cb6f6786dbbf9 diff --git a/sql/hive/src/test/resources/golden/date_1-9-8306558e0eabe936ac33dabaaa17fea4 b/sql/hive/src/test/resources/golden/date_1-9-8306558e0eabe936ac33dabaaa17fea4 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/inputddl5-0-ebbf2aec5f76af7225c2efaf870b8ba7 b/sql/hive/src/test/resources/golden/inputddl5-0-ebbf2aec5f76af7225c2efaf870b8ba7 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/inputddl5-1-2691407ccdc5c848a4ba2aecb6dbad75 b/sql/hive/src/test/resources/golden/inputddl5-1-2691407ccdc5c848a4ba2aecb6dbad75 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/inputddl5-2-ca2faacf63dc4785f8bfd2ecc397e69b b/sql/hive/src/test/resources/golden/inputddl5-2-ca2faacf63dc4785f8bfd2ecc397e69b new file mode 100644 index 000000000000..518a70918b2c --- /dev/null +++ b/sql/hive/src/test/resources/golden/inputddl5-2-ca2faacf63dc4785f8bfd2ecc397e69b @@ -0,0 +1 @@ +name string diff --git a/sql/hive/src/test/resources/golden/inputddl5-3-4f28c7412a05cff89c0bd86b65aa7ce b/sql/hive/src/test/resources/golden/inputddl5-3-4f28c7412a05cff89c0bd86b65aa7ce new file mode 100644 index 000000000000..33398360345d --- /dev/null +++ b/sql/hive/src/test/resources/golden/inputddl5-3-4f28c7412a05cff89c0bd86b65aa7ce @@ -0,0 +1 @@ +邵铮 diff --git a/sql/hive/src/test/resources/golden/inputddl5-4-bd7e25cff73f470d2e2336876342b783 b/sql/hive/src/test/resources/golden/inputddl5-4-bd7e25cff73f470d2e2336876342b783 new file mode 100644 index 000000000000..d00491fd7e5b --- /dev/null +++ b/sql/hive/src/test/resources/golden/inputddl5-4-bd7e25cff73f470d2e2336876342b783 @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/insert table with generator with column name-0-7ac701cf43e73e9e416888e4df694348 b/sql/hive/src/test/resources/golden/insert table with generator with column name-0-7ac701cf43e73e9e416888e4df694348 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert table with generator with column name-1-5cdf9d51fc0e105e365d82e7611e37f3 b/sql/hive/src/test/resources/golden/insert table with generator with column name-1-5cdf9d51fc0e105e365d82e7611e37f3 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert table with generator with column name-2-f963396461294e06cb7cafe22a1419e4 b/sql/hive/src/test/resources/golden/insert table with generator with column name-2-f963396461294e06cb7cafe22a1419e4 new file mode 100644 index 000000000000..01e79c32a8c9 --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert table with generator with column name-2-f963396461294e06cb7cafe22a1419e4 @@ -0,0 +1,3 @@ +1 +2 +3 diff --git a/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-0-46bdb27b3359dc81d8c246b9f69d4b82 b/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-0-46bdb27b3359dc81d8c246b9f69d4b82 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-1-cdf6989f3b055257f1692c3bbd80dc73 b/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-1-cdf6989f3b055257f1692c3bbd80dc73 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-2-ab3954b69d7a991bc801a509c3166cc5 b/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-2-ab3954b69d7a991bc801a509c3166cc5 new file mode 100644 index 000000000000..0c7520f2090d --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-2-ab3954b69d7a991bc801a509c3166cc5 @@ -0,0 +1,3 @@ +86 val_86 +238 val_238 +311 val_311 diff --git a/sql/hive/src/test/resources/golden/insert table with generator without column name-0-7ac701cf43e73e9e416888e4df694348 b/sql/hive/src/test/resources/golden/insert table with generator without column name-0-7ac701cf43e73e9e416888e4df694348 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert table with generator without column name-1-26599718c322ff4f9740040c066d8292 b/sql/hive/src/test/resources/golden/insert table with generator without column name-1-26599718c322ff4f9740040c066d8292 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert table with generator without column name-2-f963396461294e06cb7cafe22a1419e4 b/sql/hive/src/test/resources/golden/insert table with generator without column name-2-f963396461294e06cb7cafe22a1419e4 new file mode 100644 index 000000000000..01e79c32a8c9 --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert table with generator without column name-2-f963396461294e06cb7cafe22a1419e4 @@ -0,0 +1,3 @@ +1 +2 +3 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-0-d5edc0daa94b33915df794df3b710774 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-0-d5edc0daa94b33915df794df3b710774 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-1-9eb9372f4855928fae16f5fa554b3a62 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-1-9eb9372f4855928fae16f5fa554b3a62 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-10-ec2cef3d37146c450c60202a572f5cab b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-10-ec2cef3d37146c450c60202a572f5cab new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-11-8854d6001200fc11529b2e2da755e5a2 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-11-8854d6001200fc11529b2e2da755e5a2 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-12-71ff68fda0aa7a36cb50d8fab0d70d25 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-12-71ff68fda0aa7a36cb50d8fab0d70d25 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-13-7e4e7d7003fc6ef17bc19c3461ad899 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-13-7e4e7d7003fc6ef17bc19c3461ad899 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-14-ec2cef3d37146c450c60202a572f5cab b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-14-ec2cef3d37146c450c60202a572f5cab new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-15-a3b2e230efde74e970ae8a3b55f383fc b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-15-a3b2e230efde74e970ae8a3b55f383fc new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-2-8396c17a66e3d9a374d4361873b9bfe3 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-2-8396c17a66e3d9a374d4361873b9bfe3 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-3-3876bb356dd8af7e78d061093d555457 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-3-3876bb356dd8af7e78d061093d555457 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-4-528e23afb272c2e69004c86ddaa70ee b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-4-528e23afb272c2e69004c86ddaa70ee new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-5-de5d56456c28d63775554e56355911d2 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-5-de5d56456c28d63775554e56355911d2 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-6-3efdc331b3b4bdac3e60c757600fff53 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-6-3efdc331b3b4bdac3e60c757600fff53 new file mode 100644 index 000000000000..185a91c110d6 --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-6-3efdc331b3b4bdac3e60c757600fff53 @@ -0,0 +1,5 @@ +98 val_98 +98 val_98 +97 val_97 +97 val_97 +96 val_96 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-7-92f6af82704504968de078c133f222f8 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-7-92f6af82704504968de078c133f222f8 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-8-316cad7c63ddd4fb043be2affa5b0a67 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-8-316cad7c63ddd4fb043be2affa5b0a67 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-9-3efdc331b3b4bdac3e60c757600fff53 b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-9-3efdc331b3b4bdac3e60c757600fff53 new file mode 100644 index 000000000000..185a91c110d6 --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert1_overwrite_partitions-9-3efdc331b3b4bdac3e60c757600fff53 @@ -0,0 +1,5 @@ +98 val_98 +98 val_98 +97 val_97 +97 val_97 +96 val_96 diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-10-89737a8857b5b61cc909e0c797f86aea b/sql/hive/src/test/resources/golden/leftsemijoin-10-89737a8857b5b61cc909e0c797f86aea index 25ce912507d5..a1963ba81e0d 100644 --- a/sql/hive/src/test/resources/golden/leftsemijoin-10-89737a8857b5b61cc909e0c797f86aea +++ b/sql/hive/src/test/resources/golden/leftsemijoin-10-89737a8857b5b61cc909e0c797f86aea @@ -1,4 +1,2 @@ Hank 2 -Hank 2 -Joe 2 Joe 2 diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-8-73cad58a10a1483ccb15e94a857013 b/sql/hive/src/test/resources/golden/leftsemijoin-8-73cad58a10a1483ccb15e94a857013 index 25ce912507d5..a1963ba81e0d 100644 --- a/sql/hive/src/test/resources/golden/leftsemijoin-8-73cad58a10a1483ccb15e94a857013 +++ b/sql/hive/src/test/resources/golden/leftsemijoin-8-73cad58a10a1483ccb15e94a857013 @@ -1,4 +1,2 @@ Hank 2 -Hank 2 -Joe 2 Joe 2 diff --git a/sql/hive/src/test/resources/golden/merge4-0-b12e5c70d6d29757471b900b6160fa8a b/sql/hive/src/test/resources/golden/merge4-0-b12e5c70d6d29757471b900b6160fa8a new file mode 100644 index 000000000000..573541ac9702 --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-0-b12e5c70d6d29757471b900b6160fa8a @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/merge4-1-593999fae618b6b38322bc9ae4e0c027 b/sql/hive/src/test/resources/golden/merge4-1-593999fae618b6b38322bc9ae4e0c027 new file mode 100644 index 000000000000..573541ac9702 --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-1-593999fae618b6b38322bc9ae4e0c027 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/merge4-10-692a197bd688b48f762e72978f54aa32 b/sql/hive/src/test/resources/golden/merge4-10-692a197bd688b48f762e72978f54aa32 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/merge4-11-f407e661307b23a5d52a08a3e7af19b b/sql/hive/src/test/resources/golden/merge4-11-f407e661307b23a5d52a08a3e7af19b new file mode 100644 index 000000000000..5d2cddc42f27 --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-11-f407e661307b23a5d52a08a3e7af19b @@ -0,0 +1,1500 @@ +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 12 +0 val_0 2010-08-15 12 +0 val_0 2010-08-15 12 +2 val_2 2010-08-15 11 +2 val_2 2010-08-15 11 +2 val_2 2010-08-15 12 +4 val_4 2010-08-15 11 +4 val_4 2010-08-15 11 +4 val_4 2010-08-15 12 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 12 +5 val_5 2010-08-15 12 +5 val_5 2010-08-15 12 +8 val_8 2010-08-15 11 +8 val_8 2010-08-15 11 +8 val_8 2010-08-15 12 +9 val_9 2010-08-15 11 +9 val_9 2010-08-15 11 +9 val_9 2010-08-15 12 +10 val_10 2010-08-15 11 +10 val_10 2010-08-15 11 +10 val_10 2010-08-15 12 +11 val_11 2010-08-15 11 +11 val_11 2010-08-15 11 +11 val_11 2010-08-15 12 +12 val_12 2010-08-15 11 +12 val_12 2010-08-15 11 +12 val_12 2010-08-15 11 +12 val_12 2010-08-15 11 +12 val_12 2010-08-15 12 +12 val_12 2010-08-15 12 +15 val_15 2010-08-15 11 +15 val_15 2010-08-15 11 +15 val_15 2010-08-15 11 +15 val_15 2010-08-15 11 +15 val_15 2010-08-15 12 +15 val_15 2010-08-15 12 +17 val_17 2010-08-15 11 +17 val_17 2010-08-15 11 +17 val_17 2010-08-15 12 +18 val_18 2010-08-15 11 +18 val_18 2010-08-15 11 +18 val_18 2010-08-15 11 +18 val_18 2010-08-15 11 +18 val_18 2010-08-15 12 +18 val_18 2010-08-15 12 +19 val_19 2010-08-15 11 +19 val_19 2010-08-15 11 +19 val_19 2010-08-15 12 +20 val_20 2010-08-15 11 +20 val_20 2010-08-15 11 +20 val_20 2010-08-15 12 +24 val_24 2010-08-15 11 +24 val_24 2010-08-15 11 +24 val_24 2010-08-15 11 +24 val_24 2010-08-15 11 +24 val_24 2010-08-15 12 +24 val_24 2010-08-15 12 +26 val_26 2010-08-15 11 +26 val_26 2010-08-15 11 +26 val_26 2010-08-15 11 +26 val_26 2010-08-15 11 +26 val_26 2010-08-15 12 +26 val_26 2010-08-15 12 +27 val_27 2010-08-15 11 +27 val_27 2010-08-15 11 +27 val_27 2010-08-15 12 +28 val_28 2010-08-15 11 +28 val_28 2010-08-15 11 +28 val_28 2010-08-15 12 +30 val_30 2010-08-15 11 +30 val_30 2010-08-15 11 +30 val_30 2010-08-15 12 +33 val_33 2010-08-15 11 +33 val_33 2010-08-15 11 +33 val_33 2010-08-15 12 +34 val_34 2010-08-15 11 +34 val_34 2010-08-15 11 +34 val_34 2010-08-15 12 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 12 +35 val_35 2010-08-15 12 +35 val_35 2010-08-15 12 +37 val_37 2010-08-15 11 +37 val_37 2010-08-15 11 +37 val_37 2010-08-15 11 +37 val_37 2010-08-15 11 +37 val_37 2010-08-15 12 +37 val_37 2010-08-15 12 +41 val_41 2010-08-15 11 +41 val_41 2010-08-15 11 +41 val_41 2010-08-15 12 +42 val_42 2010-08-15 11 +42 val_42 2010-08-15 11 +42 val_42 2010-08-15 11 +42 val_42 2010-08-15 11 +42 val_42 2010-08-15 12 +42 val_42 2010-08-15 12 +43 val_43 2010-08-15 11 +43 val_43 2010-08-15 11 +43 val_43 2010-08-15 12 +44 val_44 2010-08-15 11 +44 val_44 2010-08-15 11 +44 val_44 2010-08-15 12 +47 val_47 2010-08-15 11 +47 val_47 2010-08-15 11 +47 val_47 2010-08-15 12 +51 val_51 2010-08-15 11 +51 val_51 2010-08-15 11 +51 val_51 2010-08-15 11 +51 val_51 2010-08-15 11 +51 val_51 2010-08-15 12 +51 val_51 2010-08-15 12 +53 val_53 2010-08-15 11 +53 val_53 2010-08-15 11 +53 val_53 2010-08-15 12 +54 val_54 2010-08-15 11 +54 val_54 2010-08-15 11 +54 val_54 2010-08-15 12 +57 val_57 2010-08-15 11 +57 val_57 2010-08-15 11 +57 val_57 2010-08-15 12 +58 val_58 2010-08-15 11 +58 val_58 2010-08-15 11 +58 val_58 2010-08-15 11 +58 val_58 2010-08-15 11 +58 val_58 2010-08-15 12 +58 val_58 2010-08-15 12 +64 val_64 2010-08-15 11 +64 val_64 2010-08-15 11 +64 val_64 2010-08-15 12 +65 val_65 2010-08-15 11 +65 val_65 2010-08-15 11 +65 val_65 2010-08-15 12 +66 val_66 2010-08-15 11 +66 val_66 2010-08-15 11 +66 val_66 2010-08-15 12 +67 val_67 2010-08-15 11 +67 val_67 2010-08-15 11 +67 val_67 2010-08-15 11 +67 val_67 2010-08-15 11 +67 val_67 2010-08-15 12 +67 val_67 2010-08-15 12 +69 val_69 2010-08-15 11 +69 val_69 2010-08-15 11 +69 val_69 2010-08-15 12 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 12 +70 val_70 2010-08-15 12 +70 val_70 2010-08-15 12 +72 val_72 2010-08-15 11 +72 val_72 2010-08-15 11 +72 val_72 2010-08-15 11 +72 val_72 2010-08-15 11 +72 val_72 2010-08-15 12 +72 val_72 2010-08-15 12 +74 val_74 2010-08-15 11 +74 val_74 2010-08-15 11 +74 val_74 2010-08-15 12 +76 val_76 2010-08-15 11 +76 val_76 2010-08-15 11 +76 val_76 2010-08-15 11 +76 val_76 2010-08-15 11 +76 val_76 2010-08-15 12 +76 val_76 2010-08-15 12 +77 val_77 2010-08-15 11 +77 val_77 2010-08-15 11 +77 val_77 2010-08-15 12 +78 val_78 2010-08-15 11 +78 val_78 2010-08-15 11 +78 val_78 2010-08-15 12 +80 val_80 2010-08-15 11 +80 val_80 2010-08-15 11 +80 val_80 2010-08-15 12 +82 val_82 2010-08-15 11 +82 val_82 2010-08-15 11 +82 val_82 2010-08-15 12 +83 val_83 2010-08-15 11 +83 val_83 2010-08-15 11 +83 val_83 2010-08-15 11 +83 val_83 2010-08-15 11 +83 val_83 2010-08-15 12 +83 val_83 2010-08-15 12 +84 val_84 2010-08-15 11 +84 val_84 2010-08-15 11 +84 val_84 2010-08-15 11 +84 val_84 2010-08-15 11 +84 val_84 2010-08-15 12 +84 val_84 2010-08-15 12 +85 val_85 2010-08-15 11 +85 val_85 2010-08-15 11 +85 val_85 2010-08-15 12 +86 val_86 2010-08-15 11 +86 val_86 2010-08-15 11 +86 val_86 2010-08-15 12 +87 val_87 2010-08-15 11 +87 val_87 2010-08-15 11 +87 val_87 2010-08-15 12 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 12 +90 val_90 2010-08-15 12 +90 val_90 2010-08-15 12 +92 val_92 2010-08-15 11 +92 val_92 2010-08-15 11 +92 val_92 2010-08-15 12 +95 val_95 2010-08-15 11 +95 val_95 2010-08-15 11 +95 val_95 2010-08-15 11 +95 val_95 2010-08-15 11 +95 val_95 2010-08-15 12 +95 val_95 2010-08-15 12 +96 val_96 2010-08-15 11 +96 val_96 2010-08-15 11 +96 val_96 2010-08-15 12 +97 val_97 2010-08-15 11 +97 val_97 2010-08-15 11 +97 val_97 2010-08-15 11 +97 val_97 2010-08-15 11 +97 val_97 2010-08-15 12 +97 val_97 2010-08-15 12 +98 val_98 2010-08-15 11 +98 val_98 2010-08-15 11 +98 val_98 2010-08-15 11 +98 val_98 2010-08-15 11 +98 val_98 2010-08-15 12 +98 val_98 2010-08-15 12 +100 val_100 2010-08-15 11 +100 val_100 2010-08-15 11 +100 val_100 2010-08-15 11 +100 val_100 2010-08-15 11 +100 val_100 2010-08-15 12 +100 val_100 2010-08-15 12 +103 val_103 2010-08-15 11 +103 val_103 2010-08-15 11 +103 val_103 2010-08-15 11 +103 val_103 2010-08-15 11 +103 val_103 2010-08-15 12 +103 val_103 2010-08-15 12 +104 val_104 2010-08-15 11 +104 val_104 2010-08-15 11 +104 val_104 2010-08-15 11 +104 val_104 2010-08-15 11 +104 val_104 2010-08-15 12 +104 val_104 2010-08-15 12 +105 val_105 2010-08-15 11 +105 val_105 2010-08-15 11 +105 val_105 2010-08-15 12 +111 val_111 2010-08-15 11 +111 val_111 2010-08-15 11 +111 val_111 2010-08-15 12 +113 val_113 2010-08-15 11 +113 val_113 2010-08-15 11 +113 val_113 2010-08-15 11 +113 val_113 2010-08-15 11 +113 val_113 2010-08-15 12 +113 val_113 2010-08-15 12 +114 val_114 2010-08-15 11 +114 val_114 2010-08-15 11 +114 val_114 2010-08-15 12 +116 val_116 2010-08-15 11 +116 val_116 2010-08-15 11 +116 val_116 2010-08-15 12 +118 val_118 2010-08-15 11 +118 val_118 2010-08-15 11 +118 val_118 2010-08-15 11 +118 val_118 2010-08-15 11 +118 val_118 2010-08-15 12 +118 val_118 2010-08-15 12 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 12 +119 val_119 2010-08-15 12 +119 val_119 2010-08-15 12 +120 val_120 2010-08-15 11 +120 val_120 2010-08-15 11 +120 val_120 2010-08-15 11 +120 val_120 2010-08-15 11 +120 val_120 2010-08-15 12 +120 val_120 2010-08-15 12 +125 val_125 2010-08-15 11 +125 val_125 2010-08-15 11 +125 val_125 2010-08-15 11 +125 val_125 2010-08-15 11 +125 val_125 2010-08-15 12 +125 val_125 2010-08-15 12 +126 val_126 2010-08-15 11 +126 val_126 2010-08-15 11 +126 val_126 2010-08-15 12 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 12 +128 val_128 2010-08-15 12 +128 val_128 2010-08-15 12 +129 val_129 2010-08-15 11 +129 val_129 2010-08-15 11 +129 val_129 2010-08-15 11 +129 val_129 2010-08-15 11 +129 val_129 2010-08-15 12 +129 val_129 2010-08-15 12 +131 val_131 2010-08-15 11 +131 val_131 2010-08-15 11 +131 val_131 2010-08-15 12 +133 val_133 2010-08-15 11 +133 val_133 2010-08-15 11 +133 val_133 2010-08-15 12 +134 val_134 2010-08-15 11 +134 val_134 2010-08-15 11 +134 val_134 2010-08-15 11 +134 val_134 2010-08-15 11 +134 val_134 2010-08-15 12 +134 val_134 2010-08-15 12 +136 val_136 2010-08-15 11 +136 val_136 2010-08-15 11 +136 val_136 2010-08-15 12 +137 val_137 2010-08-15 11 +137 val_137 2010-08-15 11 +137 val_137 2010-08-15 11 +137 val_137 2010-08-15 11 +137 val_137 2010-08-15 12 +137 val_137 2010-08-15 12 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 12 +138 val_138 2010-08-15 12 +138 val_138 2010-08-15 12 +138 val_138 2010-08-15 12 +143 val_143 2010-08-15 11 +143 val_143 2010-08-15 11 +143 val_143 2010-08-15 12 +145 val_145 2010-08-15 11 +145 val_145 2010-08-15 11 +145 val_145 2010-08-15 12 +146 val_146 2010-08-15 11 +146 val_146 2010-08-15 11 +146 val_146 2010-08-15 11 +146 val_146 2010-08-15 11 +146 val_146 2010-08-15 12 +146 val_146 2010-08-15 12 +149 val_149 2010-08-15 11 +149 val_149 2010-08-15 11 +149 val_149 2010-08-15 11 +149 val_149 2010-08-15 11 +149 val_149 2010-08-15 12 +149 val_149 2010-08-15 12 +150 val_150 2010-08-15 11 +150 val_150 2010-08-15 11 +150 val_150 2010-08-15 12 +152 val_152 2010-08-15 11 +152 val_152 2010-08-15 11 +152 val_152 2010-08-15 11 +152 val_152 2010-08-15 11 +152 val_152 2010-08-15 12 +152 val_152 2010-08-15 12 +153 val_153 2010-08-15 11 +153 val_153 2010-08-15 11 +153 val_153 2010-08-15 12 +155 val_155 2010-08-15 11 +155 val_155 2010-08-15 11 +155 val_155 2010-08-15 12 +156 val_156 2010-08-15 11 +156 val_156 2010-08-15 11 +156 val_156 2010-08-15 12 +157 val_157 2010-08-15 11 +157 val_157 2010-08-15 11 +157 val_157 2010-08-15 12 +158 val_158 2010-08-15 11 +158 val_158 2010-08-15 11 +158 val_158 2010-08-15 12 +160 val_160 2010-08-15 11 +160 val_160 2010-08-15 11 +160 val_160 2010-08-15 12 +162 val_162 2010-08-15 11 +162 val_162 2010-08-15 11 +162 val_162 2010-08-15 12 +163 val_163 2010-08-15 11 +163 val_163 2010-08-15 11 +163 val_163 2010-08-15 12 +164 val_164 2010-08-15 11 +164 val_164 2010-08-15 11 +164 val_164 2010-08-15 11 +164 val_164 2010-08-15 11 +164 val_164 2010-08-15 12 +164 val_164 2010-08-15 12 +165 val_165 2010-08-15 11 +165 val_165 2010-08-15 11 +165 val_165 2010-08-15 11 +165 val_165 2010-08-15 11 +165 val_165 2010-08-15 12 +165 val_165 2010-08-15 12 +166 val_166 2010-08-15 11 +166 val_166 2010-08-15 11 +166 val_166 2010-08-15 12 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 12 +167 val_167 2010-08-15 12 +167 val_167 2010-08-15 12 +168 val_168 2010-08-15 11 +168 val_168 2010-08-15 11 +168 val_168 2010-08-15 12 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 12 +169 val_169 2010-08-15 12 +169 val_169 2010-08-15 12 +169 val_169 2010-08-15 12 +170 val_170 2010-08-15 11 +170 val_170 2010-08-15 11 +170 val_170 2010-08-15 12 +172 val_172 2010-08-15 11 +172 val_172 2010-08-15 11 +172 val_172 2010-08-15 11 +172 val_172 2010-08-15 11 +172 val_172 2010-08-15 12 +172 val_172 2010-08-15 12 +174 val_174 2010-08-15 11 +174 val_174 2010-08-15 11 +174 val_174 2010-08-15 11 +174 val_174 2010-08-15 11 +174 val_174 2010-08-15 12 +174 val_174 2010-08-15 12 +175 val_175 2010-08-15 11 +175 val_175 2010-08-15 11 +175 val_175 2010-08-15 11 +175 val_175 2010-08-15 11 +175 val_175 2010-08-15 12 +175 val_175 2010-08-15 12 +176 val_176 2010-08-15 11 +176 val_176 2010-08-15 11 +176 val_176 2010-08-15 11 +176 val_176 2010-08-15 11 +176 val_176 2010-08-15 12 +176 val_176 2010-08-15 12 +177 val_177 2010-08-15 11 +177 val_177 2010-08-15 11 +177 val_177 2010-08-15 12 +178 val_178 2010-08-15 11 +178 val_178 2010-08-15 11 +178 val_178 2010-08-15 12 +179 val_179 2010-08-15 11 +179 val_179 2010-08-15 11 +179 val_179 2010-08-15 11 +179 val_179 2010-08-15 11 +179 val_179 2010-08-15 12 +179 val_179 2010-08-15 12 +180 val_180 2010-08-15 11 +180 val_180 2010-08-15 11 +180 val_180 2010-08-15 12 +181 val_181 2010-08-15 11 +181 val_181 2010-08-15 11 +181 val_181 2010-08-15 12 +183 val_183 2010-08-15 11 +183 val_183 2010-08-15 11 +183 val_183 2010-08-15 12 +186 val_186 2010-08-15 11 +186 val_186 2010-08-15 11 +186 val_186 2010-08-15 12 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 12 +187 val_187 2010-08-15 12 +187 val_187 2010-08-15 12 +189 val_189 2010-08-15 11 +189 val_189 2010-08-15 11 +189 val_189 2010-08-15 12 +190 val_190 2010-08-15 11 +190 val_190 2010-08-15 11 +190 val_190 2010-08-15 12 +191 val_191 2010-08-15 11 +191 val_191 2010-08-15 11 +191 val_191 2010-08-15 11 +191 val_191 2010-08-15 11 +191 val_191 2010-08-15 12 +191 val_191 2010-08-15 12 +192 val_192 2010-08-15 11 +192 val_192 2010-08-15 11 +192 val_192 2010-08-15 12 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 12 +193 val_193 2010-08-15 12 +193 val_193 2010-08-15 12 +194 val_194 2010-08-15 11 +194 val_194 2010-08-15 11 +194 val_194 2010-08-15 12 +195 val_195 2010-08-15 11 +195 val_195 2010-08-15 11 +195 val_195 2010-08-15 11 +195 val_195 2010-08-15 11 +195 val_195 2010-08-15 12 +195 val_195 2010-08-15 12 +196 val_196 2010-08-15 11 +196 val_196 2010-08-15 11 +196 val_196 2010-08-15 12 +197 val_197 2010-08-15 11 +197 val_197 2010-08-15 11 +197 val_197 2010-08-15 11 +197 val_197 2010-08-15 11 +197 val_197 2010-08-15 12 +197 val_197 2010-08-15 12 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 12 +199 val_199 2010-08-15 12 +199 val_199 2010-08-15 12 +200 val_200 2010-08-15 11 +200 val_200 2010-08-15 11 +200 val_200 2010-08-15 11 +200 val_200 2010-08-15 11 +200 val_200 2010-08-15 12 +200 val_200 2010-08-15 12 +201 val_201 2010-08-15 11 +201 val_201 2010-08-15 11 +201 val_201 2010-08-15 12 +202 val_202 2010-08-15 11 +202 val_202 2010-08-15 11 +202 val_202 2010-08-15 12 +203 val_203 2010-08-15 11 +203 val_203 2010-08-15 11 +203 val_203 2010-08-15 11 +203 val_203 2010-08-15 11 +203 val_203 2010-08-15 12 +203 val_203 2010-08-15 12 +205 val_205 2010-08-15 11 +205 val_205 2010-08-15 11 +205 val_205 2010-08-15 11 +205 val_205 2010-08-15 11 +205 val_205 2010-08-15 12 +205 val_205 2010-08-15 12 +207 val_207 2010-08-15 11 +207 val_207 2010-08-15 11 +207 val_207 2010-08-15 11 +207 val_207 2010-08-15 11 +207 val_207 2010-08-15 12 +207 val_207 2010-08-15 12 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 12 +208 val_208 2010-08-15 12 +208 val_208 2010-08-15 12 +209 val_209 2010-08-15 11 +209 val_209 2010-08-15 11 +209 val_209 2010-08-15 11 +209 val_209 2010-08-15 11 +209 val_209 2010-08-15 12 +209 val_209 2010-08-15 12 +213 val_213 2010-08-15 11 +213 val_213 2010-08-15 11 +213 val_213 2010-08-15 11 +213 val_213 2010-08-15 11 +213 val_213 2010-08-15 12 +213 val_213 2010-08-15 12 +214 val_214 2010-08-15 11 +214 val_214 2010-08-15 11 +214 val_214 2010-08-15 12 +216 val_216 2010-08-15 11 +216 val_216 2010-08-15 11 +216 val_216 2010-08-15 11 +216 val_216 2010-08-15 11 +216 val_216 2010-08-15 12 +216 val_216 2010-08-15 12 +217 val_217 2010-08-15 11 +217 val_217 2010-08-15 11 +217 val_217 2010-08-15 11 +217 val_217 2010-08-15 11 +217 val_217 2010-08-15 12 +217 val_217 2010-08-15 12 +218 val_218 2010-08-15 11 +218 val_218 2010-08-15 11 +218 val_218 2010-08-15 12 +219 val_219 2010-08-15 11 +219 val_219 2010-08-15 11 +219 val_219 2010-08-15 11 +219 val_219 2010-08-15 11 +219 val_219 2010-08-15 12 +219 val_219 2010-08-15 12 +221 val_221 2010-08-15 11 +221 val_221 2010-08-15 11 +221 val_221 2010-08-15 11 +221 val_221 2010-08-15 11 +221 val_221 2010-08-15 12 +221 val_221 2010-08-15 12 +222 val_222 2010-08-15 11 +222 val_222 2010-08-15 11 +222 val_222 2010-08-15 12 +223 val_223 2010-08-15 11 +223 val_223 2010-08-15 11 +223 val_223 2010-08-15 11 +223 val_223 2010-08-15 11 +223 val_223 2010-08-15 12 +223 val_223 2010-08-15 12 +224 val_224 2010-08-15 11 +224 val_224 2010-08-15 11 +224 val_224 2010-08-15 11 +224 val_224 2010-08-15 11 +224 val_224 2010-08-15 12 +224 val_224 2010-08-15 12 +226 val_226 2010-08-15 11 +226 val_226 2010-08-15 11 +226 val_226 2010-08-15 12 +228 val_228 2010-08-15 11 +228 val_228 2010-08-15 11 +228 val_228 2010-08-15 12 +229 val_229 2010-08-15 11 +229 val_229 2010-08-15 11 +229 val_229 2010-08-15 11 +229 val_229 2010-08-15 11 +229 val_229 2010-08-15 12 +229 val_229 2010-08-15 12 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 12 +230 val_230 2010-08-15 12 +230 val_230 2010-08-15 12 +230 val_230 2010-08-15 12 +230 val_230 2010-08-15 12 +233 val_233 2010-08-15 11 +233 val_233 2010-08-15 11 +233 val_233 2010-08-15 11 +233 val_233 2010-08-15 11 +233 val_233 2010-08-15 12 +233 val_233 2010-08-15 12 +235 val_235 2010-08-15 11 +235 val_235 2010-08-15 11 +235 val_235 2010-08-15 12 +237 val_237 2010-08-15 11 +237 val_237 2010-08-15 11 +237 val_237 2010-08-15 11 +237 val_237 2010-08-15 11 +237 val_237 2010-08-15 12 +237 val_237 2010-08-15 12 +238 val_238 2010-08-15 11 +238 val_238 2010-08-15 11 +238 val_238 2010-08-15 11 +238 val_238 2010-08-15 11 +238 val_238 2010-08-15 12 +238 val_238 2010-08-15 12 +239 val_239 2010-08-15 11 +239 val_239 2010-08-15 11 +239 val_239 2010-08-15 11 +239 val_239 2010-08-15 11 +239 val_239 2010-08-15 12 +239 val_239 2010-08-15 12 +241 val_241 2010-08-15 11 +241 val_241 2010-08-15 11 +241 val_241 2010-08-15 12 +242 val_242 2010-08-15 11 +242 val_242 2010-08-15 11 +242 val_242 2010-08-15 11 +242 val_242 2010-08-15 11 +242 val_242 2010-08-15 12 +242 val_242 2010-08-15 12 +244 val_244 2010-08-15 11 +244 val_244 2010-08-15 11 +244 val_244 2010-08-15 12 +247 val_247 2010-08-15 11 +247 val_247 2010-08-15 11 +247 val_247 2010-08-15 12 +248 val_248 2010-08-15 11 +248 val_248 2010-08-15 11 +248 val_248 2010-08-15 12 +249 val_249 2010-08-15 11 +249 val_249 2010-08-15 11 +249 val_249 2010-08-15 12 +252 val_252 2010-08-15 11 +252 val_252 2010-08-15 11 +252 val_252 2010-08-15 12 +255 val_255 2010-08-15 11 +255 val_255 2010-08-15 11 +255 val_255 2010-08-15 11 +255 val_255 2010-08-15 11 +255 val_255 2010-08-15 12 +255 val_255 2010-08-15 12 +256 val_256 2010-08-15 11 +256 val_256 2010-08-15 11 +256 val_256 2010-08-15 11 +256 val_256 2010-08-15 11 +256 val_256 2010-08-15 12 +256 val_256 2010-08-15 12 +257 val_257 2010-08-15 11 +257 val_257 2010-08-15 11 +257 val_257 2010-08-15 12 +258 val_258 2010-08-15 11 +258 val_258 2010-08-15 11 +258 val_258 2010-08-15 12 +260 val_260 2010-08-15 11 +260 val_260 2010-08-15 11 +260 val_260 2010-08-15 12 +262 val_262 2010-08-15 11 +262 val_262 2010-08-15 11 +262 val_262 2010-08-15 12 +263 val_263 2010-08-15 11 +263 val_263 2010-08-15 11 +263 val_263 2010-08-15 12 +265 val_265 2010-08-15 11 +265 val_265 2010-08-15 11 +265 val_265 2010-08-15 11 +265 val_265 2010-08-15 11 +265 val_265 2010-08-15 12 +265 val_265 2010-08-15 12 +266 val_266 2010-08-15 11 +266 val_266 2010-08-15 11 +266 val_266 2010-08-15 12 +272 val_272 2010-08-15 11 +272 val_272 2010-08-15 11 +272 val_272 2010-08-15 11 +272 val_272 2010-08-15 11 +272 val_272 2010-08-15 12 +272 val_272 2010-08-15 12 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 12 +273 val_273 2010-08-15 12 +273 val_273 2010-08-15 12 +274 val_274 2010-08-15 11 +274 val_274 2010-08-15 11 +274 val_274 2010-08-15 12 +275 val_275 2010-08-15 11 +275 val_275 2010-08-15 11 +275 val_275 2010-08-15 12 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 12 +277 val_277 2010-08-15 12 +277 val_277 2010-08-15 12 +277 val_277 2010-08-15 12 +278 val_278 2010-08-15 11 +278 val_278 2010-08-15 11 +278 val_278 2010-08-15 11 +278 val_278 2010-08-15 11 +278 val_278 2010-08-15 12 +278 val_278 2010-08-15 12 +280 val_280 2010-08-15 11 +280 val_280 2010-08-15 11 +280 val_280 2010-08-15 11 +280 val_280 2010-08-15 11 +280 val_280 2010-08-15 12 +280 val_280 2010-08-15 12 +281 val_281 2010-08-15 11 +281 val_281 2010-08-15 11 +281 val_281 2010-08-15 11 +281 val_281 2010-08-15 11 +281 val_281 2010-08-15 12 +281 val_281 2010-08-15 12 +282 val_282 2010-08-15 11 +282 val_282 2010-08-15 11 +282 val_282 2010-08-15 11 +282 val_282 2010-08-15 11 +282 val_282 2010-08-15 12 +282 val_282 2010-08-15 12 +283 val_283 2010-08-15 11 +283 val_283 2010-08-15 11 +283 val_283 2010-08-15 12 +284 val_284 2010-08-15 11 +284 val_284 2010-08-15 11 +284 val_284 2010-08-15 12 +285 val_285 2010-08-15 11 +285 val_285 2010-08-15 11 +285 val_285 2010-08-15 12 +286 val_286 2010-08-15 11 +286 val_286 2010-08-15 11 +286 val_286 2010-08-15 12 +287 val_287 2010-08-15 11 +287 val_287 2010-08-15 11 +287 val_287 2010-08-15 12 +288 val_288 2010-08-15 11 +288 val_288 2010-08-15 11 +288 val_288 2010-08-15 11 +288 val_288 2010-08-15 11 +288 val_288 2010-08-15 12 +288 val_288 2010-08-15 12 +289 val_289 2010-08-15 11 +289 val_289 2010-08-15 11 +289 val_289 2010-08-15 12 +291 val_291 2010-08-15 11 +291 val_291 2010-08-15 11 +291 val_291 2010-08-15 12 +292 val_292 2010-08-15 11 +292 val_292 2010-08-15 11 +292 val_292 2010-08-15 12 +296 val_296 2010-08-15 11 +296 val_296 2010-08-15 11 +296 val_296 2010-08-15 12 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 12 +298 val_298 2010-08-15 12 +298 val_298 2010-08-15 12 +302 val_302 2010-08-15 11 +302 val_302 2010-08-15 11 +302 val_302 2010-08-15 12 +305 val_305 2010-08-15 11 +305 val_305 2010-08-15 11 +305 val_305 2010-08-15 12 +306 val_306 2010-08-15 11 +306 val_306 2010-08-15 11 +306 val_306 2010-08-15 12 +307 val_307 2010-08-15 11 +307 val_307 2010-08-15 11 +307 val_307 2010-08-15 11 +307 val_307 2010-08-15 11 +307 val_307 2010-08-15 12 +307 val_307 2010-08-15 12 +308 val_308 2010-08-15 11 +308 val_308 2010-08-15 11 +308 val_308 2010-08-15 12 +309 val_309 2010-08-15 11 +309 val_309 2010-08-15 11 +309 val_309 2010-08-15 11 +309 val_309 2010-08-15 11 +309 val_309 2010-08-15 12 +309 val_309 2010-08-15 12 +310 val_310 2010-08-15 11 +310 val_310 2010-08-15 11 +310 val_310 2010-08-15 12 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 12 +311 val_311 2010-08-15 12 +311 val_311 2010-08-15 12 +315 val_315 2010-08-15 11 +315 val_315 2010-08-15 11 +315 val_315 2010-08-15 12 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 12 +316 val_316 2010-08-15 12 +316 val_316 2010-08-15 12 +317 val_317 2010-08-15 11 +317 val_317 2010-08-15 11 +317 val_317 2010-08-15 11 +317 val_317 2010-08-15 11 +317 val_317 2010-08-15 12 +317 val_317 2010-08-15 12 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 12 +318 val_318 2010-08-15 12 +318 val_318 2010-08-15 12 +321 val_321 2010-08-15 11 +321 val_321 2010-08-15 11 +321 val_321 2010-08-15 11 +321 val_321 2010-08-15 11 +321 val_321 2010-08-15 12 +321 val_321 2010-08-15 12 +322 val_322 2010-08-15 11 +322 val_322 2010-08-15 11 +322 val_322 2010-08-15 11 +322 val_322 2010-08-15 11 +322 val_322 2010-08-15 12 +322 val_322 2010-08-15 12 +323 val_323 2010-08-15 11 +323 val_323 2010-08-15 11 +323 val_323 2010-08-15 12 +325 val_325 2010-08-15 11 +325 val_325 2010-08-15 11 +325 val_325 2010-08-15 11 +325 val_325 2010-08-15 11 +325 val_325 2010-08-15 12 +325 val_325 2010-08-15 12 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 12 +327 val_327 2010-08-15 12 +327 val_327 2010-08-15 12 +331 val_331 2010-08-15 11 +331 val_331 2010-08-15 11 +331 val_331 2010-08-15 11 +331 val_331 2010-08-15 11 +331 val_331 2010-08-15 12 +331 val_331 2010-08-15 12 +332 val_332 2010-08-15 11 +332 val_332 2010-08-15 11 +332 val_332 2010-08-15 12 +333 val_333 2010-08-15 11 +333 val_333 2010-08-15 11 +333 val_333 2010-08-15 11 +333 val_333 2010-08-15 11 +333 val_333 2010-08-15 12 +333 val_333 2010-08-15 12 +335 val_335 2010-08-15 11 +335 val_335 2010-08-15 11 +335 val_335 2010-08-15 12 +336 val_336 2010-08-15 11 +336 val_336 2010-08-15 11 +336 val_336 2010-08-15 12 +338 val_338 2010-08-15 11 +338 val_338 2010-08-15 11 +338 val_338 2010-08-15 12 +339 val_339 2010-08-15 11 +339 val_339 2010-08-15 11 +339 val_339 2010-08-15 12 +341 val_341 2010-08-15 11 +341 val_341 2010-08-15 11 +341 val_341 2010-08-15 12 +342 val_342 2010-08-15 11 +342 val_342 2010-08-15 11 +342 val_342 2010-08-15 11 +342 val_342 2010-08-15 11 +342 val_342 2010-08-15 12 +342 val_342 2010-08-15 12 +344 val_344 2010-08-15 11 +344 val_344 2010-08-15 11 +344 val_344 2010-08-15 11 +344 val_344 2010-08-15 11 +344 val_344 2010-08-15 12 +344 val_344 2010-08-15 12 +345 val_345 2010-08-15 11 +345 val_345 2010-08-15 11 +345 val_345 2010-08-15 12 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 12 +348 val_348 2010-08-15 12 +348 val_348 2010-08-15 12 +348 val_348 2010-08-15 12 +348 val_348 2010-08-15 12 +351 val_351 2010-08-15 11 +351 val_351 2010-08-15 11 +351 val_351 2010-08-15 12 +353 val_353 2010-08-15 11 +353 val_353 2010-08-15 11 +353 val_353 2010-08-15 11 +353 val_353 2010-08-15 11 +353 val_353 2010-08-15 12 +353 val_353 2010-08-15 12 +356 val_356 2010-08-15 11 +356 val_356 2010-08-15 11 +356 val_356 2010-08-15 12 +360 val_360 2010-08-15 11 +360 val_360 2010-08-15 11 +360 val_360 2010-08-15 12 +362 val_362 2010-08-15 11 +362 val_362 2010-08-15 11 +362 val_362 2010-08-15 12 +364 val_364 2010-08-15 11 +364 val_364 2010-08-15 11 +364 val_364 2010-08-15 12 +365 val_365 2010-08-15 11 +365 val_365 2010-08-15 11 +365 val_365 2010-08-15 12 +366 val_366 2010-08-15 11 +366 val_366 2010-08-15 11 +366 val_366 2010-08-15 12 +367 val_367 2010-08-15 11 +367 val_367 2010-08-15 11 +367 val_367 2010-08-15 11 +367 val_367 2010-08-15 11 +367 val_367 2010-08-15 12 +367 val_367 2010-08-15 12 +368 val_368 2010-08-15 11 +368 val_368 2010-08-15 11 +368 val_368 2010-08-15 12 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 12 +369 val_369 2010-08-15 12 +369 val_369 2010-08-15 12 +373 val_373 2010-08-15 11 +373 val_373 2010-08-15 11 +373 val_373 2010-08-15 12 +374 val_374 2010-08-15 11 +374 val_374 2010-08-15 11 +374 val_374 2010-08-15 12 +375 val_375 2010-08-15 11 +375 val_375 2010-08-15 11 +375 val_375 2010-08-15 12 +377 val_377 2010-08-15 11 +377 val_377 2010-08-15 11 +377 val_377 2010-08-15 12 +378 val_378 2010-08-15 11 +378 val_378 2010-08-15 11 +378 val_378 2010-08-15 12 +379 val_379 2010-08-15 11 +379 val_379 2010-08-15 11 +379 val_379 2010-08-15 12 +382 val_382 2010-08-15 11 +382 val_382 2010-08-15 11 +382 val_382 2010-08-15 11 +382 val_382 2010-08-15 11 +382 val_382 2010-08-15 12 +382 val_382 2010-08-15 12 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 12 +384 val_384 2010-08-15 12 +384 val_384 2010-08-15 12 +386 val_386 2010-08-15 11 +386 val_386 2010-08-15 11 +386 val_386 2010-08-15 12 +389 val_389 2010-08-15 11 +389 val_389 2010-08-15 11 +389 val_389 2010-08-15 12 +392 val_392 2010-08-15 11 +392 val_392 2010-08-15 11 +392 val_392 2010-08-15 12 +393 val_393 2010-08-15 11 +393 val_393 2010-08-15 11 +393 val_393 2010-08-15 12 +394 val_394 2010-08-15 11 +394 val_394 2010-08-15 11 +394 val_394 2010-08-15 12 +395 val_395 2010-08-15 11 +395 val_395 2010-08-15 11 +395 val_395 2010-08-15 11 +395 val_395 2010-08-15 11 +395 val_395 2010-08-15 12 +395 val_395 2010-08-15 12 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 12 +396 val_396 2010-08-15 12 +396 val_396 2010-08-15 12 +397 val_397 2010-08-15 11 +397 val_397 2010-08-15 11 +397 val_397 2010-08-15 11 +397 val_397 2010-08-15 11 +397 val_397 2010-08-15 12 +397 val_397 2010-08-15 12 +399 val_399 2010-08-15 11 +399 val_399 2010-08-15 11 +399 val_399 2010-08-15 11 +399 val_399 2010-08-15 11 +399 val_399 2010-08-15 12 +399 val_399 2010-08-15 12 +400 val_400 2010-08-15 11 +400 val_400 2010-08-15 11 +400 val_400 2010-08-15 12 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 12 +401 val_401 2010-08-15 12 +401 val_401 2010-08-15 12 +401 val_401 2010-08-15 12 +401 val_401 2010-08-15 12 +402 val_402 2010-08-15 11 +402 val_402 2010-08-15 11 +402 val_402 2010-08-15 12 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 12 +403 val_403 2010-08-15 12 +403 val_403 2010-08-15 12 +404 val_404 2010-08-15 11 +404 val_404 2010-08-15 11 +404 val_404 2010-08-15 11 +404 val_404 2010-08-15 11 +404 val_404 2010-08-15 12 +404 val_404 2010-08-15 12 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 12 +406 val_406 2010-08-15 12 +406 val_406 2010-08-15 12 +406 val_406 2010-08-15 12 +407 val_407 2010-08-15 11 +407 val_407 2010-08-15 11 +407 val_407 2010-08-15 12 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 12 +409 val_409 2010-08-15 12 +409 val_409 2010-08-15 12 +411 val_411 2010-08-15 11 +411 val_411 2010-08-15 11 +411 val_411 2010-08-15 12 +413 val_413 2010-08-15 11 +413 val_413 2010-08-15 11 +413 val_413 2010-08-15 11 +413 val_413 2010-08-15 11 +413 val_413 2010-08-15 12 +413 val_413 2010-08-15 12 +414 val_414 2010-08-15 11 +414 val_414 2010-08-15 11 +414 val_414 2010-08-15 11 +414 val_414 2010-08-15 11 +414 val_414 2010-08-15 12 +414 val_414 2010-08-15 12 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 12 +417 val_417 2010-08-15 12 +417 val_417 2010-08-15 12 +418 val_418 2010-08-15 11 +418 val_418 2010-08-15 11 +418 val_418 2010-08-15 12 +419 val_419 2010-08-15 11 +419 val_419 2010-08-15 11 +419 val_419 2010-08-15 12 +421 val_421 2010-08-15 11 +421 val_421 2010-08-15 11 +421 val_421 2010-08-15 12 +424 val_424 2010-08-15 11 +424 val_424 2010-08-15 11 +424 val_424 2010-08-15 11 +424 val_424 2010-08-15 11 +424 val_424 2010-08-15 12 +424 val_424 2010-08-15 12 +427 val_427 2010-08-15 11 +427 val_427 2010-08-15 11 +427 val_427 2010-08-15 12 +429 val_429 2010-08-15 11 +429 val_429 2010-08-15 11 +429 val_429 2010-08-15 11 +429 val_429 2010-08-15 11 +429 val_429 2010-08-15 12 +429 val_429 2010-08-15 12 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 12 +430 val_430 2010-08-15 12 +430 val_430 2010-08-15 12 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 12 +431 val_431 2010-08-15 12 +431 val_431 2010-08-15 12 +432 val_432 2010-08-15 11 +432 val_432 2010-08-15 11 +432 val_432 2010-08-15 12 +435 val_435 2010-08-15 11 +435 val_435 2010-08-15 11 +435 val_435 2010-08-15 12 +436 val_436 2010-08-15 11 +436 val_436 2010-08-15 11 +436 val_436 2010-08-15 12 +437 val_437 2010-08-15 11 +437 val_437 2010-08-15 11 +437 val_437 2010-08-15 12 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 12 +438 val_438 2010-08-15 12 +438 val_438 2010-08-15 12 +439 val_439 2010-08-15 11 +439 val_439 2010-08-15 11 +439 val_439 2010-08-15 11 +439 val_439 2010-08-15 11 +439 val_439 2010-08-15 12 +439 val_439 2010-08-15 12 +443 val_443 2010-08-15 11 +443 val_443 2010-08-15 11 +443 val_443 2010-08-15 12 +444 val_444 2010-08-15 11 +444 val_444 2010-08-15 11 +444 val_444 2010-08-15 12 +446 val_446 2010-08-15 11 +446 val_446 2010-08-15 11 +446 val_446 2010-08-15 12 +448 val_448 2010-08-15 11 +448 val_448 2010-08-15 11 +448 val_448 2010-08-15 12 +449 val_449 2010-08-15 11 +449 val_449 2010-08-15 11 +449 val_449 2010-08-15 12 +452 val_452 2010-08-15 11 +452 val_452 2010-08-15 11 +452 val_452 2010-08-15 12 +453 val_453 2010-08-15 11 +453 val_453 2010-08-15 11 +453 val_453 2010-08-15 12 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 12 +454 val_454 2010-08-15 12 +454 val_454 2010-08-15 12 +455 val_455 2010-08-15 11 +455 val_455 2010-08-15 11 +455 val_455 2010-08-15 12 +457 val_457 2010-08-15 11 +457 val_457 2010-08-15 11 +457 val_457 2010-08-15 12 +458 val_458 2010-08-15 11 +458 val_458 2010-08-15 11 +458 val_458 2010-08-15 11 +458 val_458 2010-08-15 11 +458 val_458 2010-08-15 12 +458 val_458 2010-08-15 12 +459 val_459 2010-08-15 11 +459 val_459 2010-08-15 11 +459 val_459 2010-08-15 11 +459 val_459 2010-08-15 11 +459 val_459 2010-08-15 12 +459 val_459 2010-08-15 12 +460 val_460 2010-08-15 11 +460 val_460 2010-08-15 11 +460 val_460 2010-08-15 12 +462 val_462 2010-08-15 11 +462 val_462 2010-08-15 11 +462 val_462 2010-08-15 11 +462 val_462 2010-08-15 11 +462 val_462 2010-08-15 12 +462 val_462 2010-08-15 12 +463 val_463 2010-08-15 11 +463 val_463 2010-08-15 11 +463 val_463 2010-08-15 11 +463 val_463 2010-08-15 11 +463 val_463 2010-08-15 12 +463 val_463 2010-08-15 12 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 12 +466 val_466 2010-08-15 12 +466 val_466 2010-08-15 12 +467 val_467 2010-08-15 11 +467 val_467 2010-08-15 11 +467 val_467 2010-08-15 12 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 12 +468 val_468 2010-08-15 12 +468 val_468 2010-08-15 12 +468 val_468 2010-08-15 12 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 12 +469 val_469 2010-08-15 12 +469 val_469 2010-08-15 12 +469 val_469 2010-08-15 12 +469 val_469 2010-08-15 12 +470 val_470 2010-08-15 11 +470 val_470 2010-08-15 11 +470 val_470 2010-08-15 12 +472 val_472 2010-08-15 11 +472 val_472 2010-08-15 11 +472 val_472 2010-08-15 12 +475 val_475 2010-08-15 11 +475 val_475 2010-08-15 11 +475 val_475 2010-08-15 12 +477 val_477 2010-08-15 11 +477 val_477 2010-08-15 11 +477 val_477 2010-08-15 12 +478 val_478 2010-08-15 11 +478 val_478 2010-08-15 11 +478 val_478 2010-08-15 11 +478 val_478 2010-08-15 11 +478 val_478 2010-08-15 12 +478 val_478 2010-08-15 12 +479 val_479 2010-08-15 11 +479 val_479 2010-08-15 11 +479 val_479 2010-08-15 12 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 12 +480 val_480 2010-08-15 12 +480 val_480 2010-08-15 12 +481 val_481 2010-08-15 11 +481 val_481 2010-08-15 11 +481 val_481 2010-08-15 12 +482 val_482 2010-08-15 11 +482 val_482 2010-08-15 11 +482 val_482 2010-08-15 12 +483 val_483 2010-08-15 11 +483 val_483 2010-08-15 11 +483 val_483 2010-08-15 12 +484 val_484 2010-08-15 11 +484 val_484 2010-08-15 11 +484 val_484 2010-08-15 12 +485 val_485 2010-08-15 11 +485 val_485 2010-08-15 11 +485 val_485 2010-08-15 12 +487 val_487 2010-08-15 11 +487 val_487 2010-08-15 11 +487 val_487 2010-08-15 12 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 12 +489 val_489 2010-08-15 12 +489 val_489 2010-08-15 12 +489 val_489 2010-08-15 12 +490 val_490 2010-08-15 11 +490 val_490 2010-08-15 11 +490 val_490 2010-08-15 12 +491 val_491 2010-08-15 11 +491 val_491 2010-08-15 11 +491 val_491 2010-08-15 12 +492 val_492 2010-08-15 11 +492 val_492 2010-08-15 11 +492 val_492 2010-08-15 11 +492 val_492 2010-08-15 11 +492 val_492 2010-08-15 12 +492 val_492 2010-08-15 12 +493 val_493 2010-08-15 11 +493 val_493 2010-08-15 11 +493 val_493 2010-08-15 12 +494 val_494 2010-08-15 11 +494 val_494 2010-08-15 11 +494 val_494 2010-08-15 12 +495 val_495 2010-08-15 11 +495 val_495 2010-08-15 11 +495 val_495 2010-08-15 12 +496 val_496 2010-08-15 11 +496 val_496 2010-08-15 11 +496 val_496 2010-08-15 12 +497 val_497 2010-08-15 11 +497 val_497 2010-08-15 11 +497 val_497 2010-08-15 12 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 12 +498 val_498 2010-08-15 12 +498 val_498 2010-08-15 12 diff --git a/sql/hive/src/test/resources/golden/merge4-12-62541540a18d68a3cb8497a741061d11 b/sql/hive/src/test/resources/golden/merge4-12-62541540a18d68a3cb8497a741061d11 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/merge4-13-ed1103f06609365b40e78d13c654cc71 b/sql/hive/src/test/resources/golden/merge4-13-ed1103f06609365b40e78d13c654cc71 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/merge4-14-ba5dbcd0527b8ddab284bc322255bfc7 b/sql/hive/src/test/resources/golden/merge4-14-ba5dbcd0527b8ddab284bc322255bfc7 new file mode 100644 index 000000000000..30becc42d7b5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-14-ba5dbcd0527b8ddab284bc322255bfc7 @@ -0,0 +1,3 @@ +ds=2010-08-15/hr=11 +ds=2010-08-15/hr=12 +ds=2010-08-15/hr=file, diff --git a/sql/hive/src/test/resources/golden/merge4-15-68f50dc2ad6ff803a372bdd88dd8e19a b/sql/hive/src/test/resources/golden/merge4-15-68f50dc2ad6ff803a372bdd88dd8e19a new file mode 100644 index 000000000000..4c867a5deff0 --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-15-68f50dc2ad6ff803a372bdd88dd8e19a @@ -0,0 +1 @@ +1 1 2010-08-15 file, diff --git a/sql/hive/src/test/resources/golden/merge4-2-43d53504df013e6b35f81811138a167a b/sql/hive/src/test/resources/golden/merge4-2-43d53504df013e6b35f81811138a167a new file mode 100644 index 000000000000..573541ac9702 --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-2-43d53504df013e6b35f81811138a167a @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/merge4-3-a4fb8359a2179ec70777aad6366071b7 b/sql/hive/src/test/resources/golden/merge4-3-a4fb8359a2179ec70777aad6366071b7 new file mode 100644 index 000000000000..573541ac9702 --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-3-a4fb8359a2179ec70777aad6366071b7 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/merge4-4-16367c381d4b189b3640c92511244bfe b/sql/hive/src/test/resources/golden/merge4-4-16367c381d4b189b3640c92511244bfe new file mode 100644 index 000000000000..573541ac9702 --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-4-16367c381d4b189b3640c92511244bfe @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/merge4-5-3d24d877366c42030f6d9a596665720d b/sql/hive/src/test/resources/golden/merge4-5-3d24d877366c42030f6d9a596665720d new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/merge4-6-b3a76420183795720ab3a384046e5af b/sql/hive/src/test/resources/golden/merge4-6-b3a76420183795720ab3a384046e5af new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/merge4-7-631a45828eae3f5f562d992efe4cd56d b/sql/hive/src/test/resources/golden/merge4-7-631a45828eae3f5f562d992efe4cd56d new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/merge4-8-f407e661307b23a5d52a08a3e7af19b b/sql/hive/src/test/resources/golden/merge4-8-f407e661307b23a5d52a08a3e7af19b new file mode 100644 index 000000000000..aa972caa5665 --- /dev/null +++ b/sql/hive/src/test/resources/golden/merge4-8-f407e661307b23a5d52a08a3e7af19b @@ -0,0 +1,1000 @@ +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 11 +0 val_0 2010-08-15 12 +0 val_0 2010-08-15 12 +0 val_0 2010-08-15 12 +2 val_2 2010-08-15 11 +2 val_2 2010-08-15 12 +4 val_4 2010-08-15 11 +4 val_4 2010-08-15 12 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 11 +5 val_5 2010-08-15 12 +5 val_5 2010-08-15 12 +5 val_5 2010-08-15 12 +8 val_8 2010-08-15 11 +8 val_8 2010-08-15 12 +9 val_9 2010-08-15 11 +9 val_9 2010-08-15 12 +10 val_10 2010-08-15 11 +10 val_10 2010-08-15 12 +11 val_11 2010-08-15 11 +11 val_11 2010-08-15 12 +12 val_12 2010-08-15 11 +12 val_12 2010-08-15 11 +12 val_12 2010-08-15 12 +12 val_12 2010-08-15 12 +15 val_15 2010-08-15 11 +15 val_15 2010-08-15 11 +15 val_15 2010-08-15 12 +15 val_15 2010-08-15 12 +17 val_17 2010-08-15 11 +17 val_17 2010-08-15 12 +18 val_18 2010-08-15 11 +18 val_18 2010-08-15 11 +18 val_18 2010-08-15 12 +18 val_18 2010-08-15 12 +19 val_19 2010-08-15 11 +19 val_19 2010-08-15 12 +20 val_20 2010-08-15 11 +20 val_20 2010-08-15 12 +24 val_24 2010-08-15 11 +24 val_24 2010-08-15 11 +24 val_24 2010-08-15 12 +24 val_24 2010-08-15 12 +26 val_26 2010-08-15 11 +26 val_26 2010-08-15 11 +26 val_26 2010-08-15 12 +26 val_26 2010-08-15 12 +27 val_27 2010-08-15 11 +27 val_27 2010-08-15 12 +28 val_28 2010-08-15 11 +28 val_28 2010-08-15 12 +30 val_30 2010-08-15 11 +30 val_30 2010-08-15 12 +33 val_33 2010-08-15 11 +33 val_33 2010-08-15 12 +34 val_34 2010-08-15 11 +34 val_34 2010-08-15 12 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 11 +35 val_35 2010-08-15 12 +35 val_35 2010-08-15 12 +35 val_35 2010-08-15 12 +37 val_37 2010-08-15 11 +37 val_37 2010-08-15 11 +37 val_37 2010-08-15 12 +37 val_37 2010-08-15 12 +41 val_41 2010-08-15 11 +41 val_41 2010-08-15 12 +42 val_42 2010-08-15 11 +42 val_42 2010-08-15 11 +42 val_42 2010-08-15 12 +42 val_42 2010-08-15 12 +43 val_43 2010-08-15 11 +43 val_43 2010-08-15 12 +44 val_44 2010-08-15 11 +44 val_44 2010-08-15 12 +47 val_47 2010-08-15 11 +47 val_47 2010-08-15 12 +51 val_51 2010-08-15 11 +51 val_51 2010-08-15 11 +51 val_51 2010-08-15 12 +51 val_51 2010-08-15 12 +53 val_53 2010-08-15 11 +53 val_53 2010-08-15 12 +54 val_54 2010-08-15 11 +54 val_54 2010-08-15 12 +57 val_57 2010-08-15 11 +57 val_57 2010-08-15 12 +58 val_58 2010-08-15 11 +58 val_58 2010-08-15 11 +58 val_58 2010-08-15 12 +58 val_58 2010-08-15 12 +64 val_64 2010-08-15 11 +64 val_64 2010-08-15 12 +65 val_65 2010-08-15 11 +65 val_65 2010-08-15 12 +66 val_66 2010-08-15 11 +66 val_66 2010-08-15 12 +67 val_67 2010-08-15 11 +67 val_67 2010-08-15 11 +67 val_67 2010-08-15 12 +67 val_67 2010-08-15 12 +69 val_69 2010-08-15 11 +69 val_69 2010-08-15 12 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 11 +70 val_70 2010-08-15 12 +70 val_70 2010-08-15 12 +70 val_70 2010-08-15 12 +72 val_72 2010-08-15 11 +72 val_72 2010-08-15 11 +72 val_72 2010-08-15 12 +72 val_72 2010-08-15 12 +74 val_74 2010-08-15 11 +74 val_74 2010-08-15 12 +76 val_76 2010-08-15 11 +76 val_76 2010-08-15 11 +76 val_76 2010-08-15 12 +76 val_76 2010-08-15 12 +77 val_77 2010-08-15 11 +77 val_77 2010-08-15 12 +78 val_78 2010-08-15 11 +78 val_78 2010-08-15 12 +80 val_80 2010-08-15 11 +80 val_80 2010-08-15 12 +82 val_82 2010-08-15 11 +82 val_82 2010-08-15 12 +83 val_83 2010-08-15 11 +83 val_83 2010-08-15 11 +83 val_83 2010-08-15 12 +83 val_83 2010-08-15 12 +84 val_84 2010-08-15 11 +84 val_84 2010-08-15 11 +84 val_84 2010-08-15 12 +84 val_84 2010-08-15 12 +85 val_85 2010-08-15 11 +85 val_85 2010-08-15 12 +86 val_86 2010-08-15 11 +86 val_86 2010-08-15 12 +87 val_87 2010-08-15 11 +87 val_87 2010-08-15 12 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 11 +90 val_90 2010-08-15 12 +90 val_90 2010-08-15 12 +90 val_90 2010-08-15 12 +92 val_92 2010-08-15 11 +92 val_92 2010-08-15 12 +95 val_95 2010-08-15 11 +95 val_95 2010-08-15 11 +95 val_95 2010-08-15 12 +95 val_95 2010-08-15 12 +96 val_96 2010-08-15 11 +96 val_96 2010-08-15 12 +97 val_97 2010-08-15 11 +97 val_97 2010-08-15 11 +97 val_97 2010-08-15 12 +97 val_97 2010-08-15 12 +98 val_98 2010-08-15 11 +98 val_98 2010-08-15 11 +98 val_98 2010-08-15 12 +98 val_98 2010-08-15 12 +100 val_100 2010-08-15 11 +100 val_100 2010-08-15 11 +100 val_100 2010-08-15 12 +100 val_100 2010-08-15 12 +103 val_103 2010-08-15 11 +103 val_103 2010-08-15 11 +103 val_103 2010-08-15 12 +103 val_103 2010-08-15 12 +104 val_104 2010-08-15 11 +104 val_104 2010-08-15 11 +104 val_104 2010-08-15 12 +104 val_104 2010-08-15 12 +105 val_105 2010-08-15 11 +105 val_105 2010-08-15 12 +111 val_111 2010-08-15 11 +111 val_111 2010-08-15 12 +113 val_113 2010-08-15 11 +113 val_113 2010-08-15 11 +113 val_113 2010-08-15 12 +113 val_113 2010-08-15 12 +114 val_114 2010-08-15 11 +114 val_114 2010-08-15 12 +116 val_116 2010-08-15 11 +116 val_116 2010-08-15 12 +118 val_118 2010-08-15 11 +118 val_118 2010-08-15 11 +118 val_118 2010-08-15 12 +118 val_118 2010-08-15 12 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 11 +119 val_119 2010-08-15 12 +119 val_119 2010-08-15 12 +119 val_119 2010-08-15 12 +120 val_120 2010-08-15 11 +120 val_120 2010-08-15 11 +120 val_120 2010-08-15 12 +120 val_120 2010-08-15 12 +125 val_125 2010-08-15 11 +125 val_125 2010-08-15 11 +125 val_125 2010-08-15 12 +125 val_125 2010-08-15 12 +126 val_126 2010-08-15 11 +126 val_126 2010-08-15 12 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 11 +128 val_128 2010-08-15 12 +128 val_128 2010-08-15 12 +128 val_128 2010-08-15 12 +129 val_129 2010-08-15 11 +129 val_129 2010-08-15 11 +129 val_129 2010-08-15 12 +129 val_129 2010-08-15 12 +131 val_131 2010-08-15 11 +131 val_131 2010-08-15 12 +133 val_133 2010-08-15 11 +133 val_133 2010-08-15 12 +134 val_134 2010-08-15 11 +134 val_134 2010-08-15 11 +134 val_134 2010-08-15 12 +134 val_134 2010-08-15 12 +136 val_136 2010-08-15 11 +136 val_136 2010-08-15 12 +137 val_137 2010-08-15 11 +137 val_137 2010-08-15 11 +137 val_137 2010-08-15 12 +137 val_137 2010-08-15 12 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 11 +138 val_138 2010-08-15 12 +138 val_138 2010-08-15 12 +138 val_138 2010-08-15 12 +138 val_138 2010-08-15 12 +143 val_143 2010-08-15 11 +143 val_143 2010-08-15 12 +145 val_145 2010-08-15 11 +145 val_145 2010-08-15 12 +146 val_146 2010-08-15 11 +146 val_146 2010-08-15 11 +146 val_146 2010-08-15 12 +146 val_146 2010-08-15 12 +149 val_149 2010-08-15 11 +149 val_149 2010-08-15 11 +149 val_149 2010-08-15 12 +149 val_149 2010-08-15 12 +150 val_150 2010-08-15 11 +150 val_150 2010-08-15 12 +152 val_152 2010-08-15 11 +152 val_152 2010-08-15 11 +152 val_152 2010-08-15 12 +152 val_152 2010-08-15 12 +153 val_153 2010-08-15 11 +153 val_153 2010-08-15 12 +155 val_155 2010-08-15 11 +155 val_155 2010-08-15 12 +156 val_156 2010-08-15 11 +156 val_156 2010-08-15 12 +157 val_157 2010-08-15 11 +157 val_157 2010-08-15 12 +158 val_158 2010-08-15 11 +158 val_158 2010-08-15 12 +160 val_160 2010-08-15 11 +160 val_160 2010-08-15 12 +162 val_162 2010-08-15 11 +162 val_162 2010-08-15 12 +163 val_163 2010-08-15 11 +163 val_163 2010-08-15 12 +164 val_164 2010-08-15 11 +164 val_164 2010-08-15 11 +164 val_164 2010-08-15 12 +164 val_164 2010-08-15 12 +165 val_165 2010-08-15 11 +165 val_165 2010-08-15 11 +165 val_165 2010-08-15 12 +165 val_165 2010-08-15 12 +166 val_166 2010-08-15 11 +166 val_166 2010-08-15 12 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 11 +167 val_167 2010-08-15 12 +167 val_167 2010-08-15 12 +167 val_167 2010-08-15 12 +168 val_168 2010-08-15 11 +168 val_168 2010-08-15 12 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 11 +169 val_169 2010-08-15 12 +169 val_169 2010-08-15 12 +169 val_169 2010-08-15 12 +169 val_169 2010-08-15 12 +170 val_170 2010-08-15 11 +170 val_170 2010-08-15 12 +172 val_172 2010-08-15 11 +172 val_172 2010-08-15 11 +172 val_172 2010-08-15 12 +172 val_172 2010-08-15 12 +174 val_174 2010-08-15 11 +174 val_174 2010-08-15 11 +174 val_174 2010-08-15 12 +174 val_174 2010-08-15 12 +175 val_175 2010-08-15 11 +175 val_175 2010-08-15 11 +175 val_175 2010-08-15 12 +175 val_175 2010-08-15 12 +176 val_176 2010-08-15 11 +176 val_176 2010-08-15 11 +176 val_176 2010-08-15 12 +176 val_176 2010-08-15 12 +177 val_177 2010-08-15 11 +177 val_177 2010-08-15 12 +178 val_178 2010-08-15 11 +178 val_178 2010-08-15 12 +179 val_179 2010-08-15 11 +179 val_179 2010-08-15 11 +179 val_179 2010-08-15 12 +179 val_179 2010-08-15 12 +180 val_180 2010-08-15 11 +180 val_180 2010-08-15 12 +181 val_181 2010-08-15 11 +181 val_181 2010-08-15 12 +183 val_183 2010-08-15 11 +183 val_183 2010-08-15 12 +186 val_186 2010-08-15 11 +186 val_186 2010-08-15 12 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 11 +187 val_187 2010-08-15 12 +187 val_187 2010-08-15 12 +187 val_187 2010-08-15 12 +189 val_189 2010-08-15 11 +189 val_189 2010-08-15 12 +190 val_190 2010-08-15 11 +190 val_190 2010-08-15 12 +191 val_191 2010-08-15 11 +191 val_191 2010-08-15 11 +191 val_191 2010-08-15 12 +191 val_191 2010-08-15 12 +192 val_192 2010-08-15 11 +192 val_192 2010-08-15 12 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 11 +193 val_193 2010-08-15 12 +193 val_193 2010-08-15 12 +193 val_193 2010-08-15 12 +194 val_194 2010-08-15 11 +194 val_194 2010-08-15 12 +195 val_195 2010-08-15 11 +195 val_195 2010-08-15 11 +195 val_195 2010-08-15 12 +195 val_195 2010-08-15 12 +196 val_196 2010-08-15 11 +196 val_196 2010-08-15 12 +197 val_197 2010-08-15 11 +197 val_197 2010-08-15 11 +197 val_197 2010-08-15 12 +197 val_197 2010-08-15 12 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 11 +199 val_199 2010-08-15 12 +199 val_199 2010-08-15 12 +199 val_199 2010-08-15 12 +200 val_200 2010-08-15 11 +200 val_200 2010-08-15 11 +200 val_200 2010-08-15 12 +200 val_200 2010-08-15 12 +201 val_201 2010-08-15 11 +201 val_201 2010-08-15 12 +202 val_202 2010-08-15 11 +202 val_202 2010-08-15 12 +203 val_203 2010-08-15 11 +203 val_203 2010-08-15 11 +203 val_203 2010-08-15 12 +203 val_203 2010-08-15 12 +205 val_205 2010-08-15 11 +205 val_205 2010-08-15 11 +205 val_205 2010-08-15 12 +205 val_205 2010-08-15 12 +207 val_207 2010-08-15 11 +207 val_207 2010-08-15 11 +207 val_207 2010-08-15 12 +207 val_207 2010-08-15 12 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 11 +208 val_208 2010-08-15 12 +208 val_208 2010-08-15 12 +208 val_208 2010-08-15 12 +209 val_209 2010-08-15 11 +209 val_209 2010-08-15 11 +209 val_209 2010-08-15 12 +209 val_209 2010-08-15 12 +213 val_213 2010-08-15 11 +213 val_213 2010-08-15 11 +213 val_213 2010-08-15 12 +213 val_213 2010-08-15 12 +214 val_214 2010-08-15 11 +214 val_214 2010-08-15 12 +216 val_216 2010-08-15 11 +216 val_216 2010-08-15 11 +216 val_216 2010-08-15 12 +216 val_216 2010-08-15 12 +217 val_217 2010-08-15 11 +217 val_217 2010-08-15 11 +217 val_217 2010-08-15 12 +217 val_217 2010-08-15 12 +218 val_218 2010-08-15 11 +218 val_218 2010-08-15 12 +219 val_219 2010-08-15 11 +219 val_219 2010-08-15 11 +219 val_219 2010-08-15 12 +219 val_219 2010-08-15 12 +221 val_221 2010-08-15 11 +221 val_221 2010-08-15 11 +221 val_221 2010-08-15 12 +221 val_221 2010-08-15 12 +222 val_222 2010-08-15 11 +222 val_222 2010-08-15 12 +223 val_223 2010-08-15 11 +223 val_223 2010-08-15 11 +223 val_223 2010-08-15 12 +223 val_223 2010-08-15 12 +224 val_224 2010-08-15 11 +224 val_224 2010-08-15 11 +224 val_224 2010-08-15 12 +224 val_224 2010-08-15 12 +226 val_226 2010-08-15 11 +226 val_226 2010-08-15 12 +228 val_228 2010-08-15 11 +228 val_228 2010-08-15 12 +229 val_229 2010-08-15 11 +229 val_229 2010-08-15 11 +229 val_229 2010-08-15 12 +229 val_229 2010-08-15 12 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 11 +230 val_230 2010-08-15 12 +230 val_230 2010-08-15 12 +230 val_230 2010-08-15 12 +230 val_230 2010-08-15 12 +230 val_230 2010-08-15 12 +233 val_233 2010-08-15 11 +233 val_233 2010-08-15 11 +233 val_233 2010-08-15 12 +233 val_233 2010-08-15 12 +235 val_235 2010-08-15 11 +235 val_235 2010-08-15 12 +237 val_237 2010-08-15 11 +237 val_237 2010-08-15 11 +237 val_237 2010-08-15 12 +237 val_237 2010-08-15 12 +238 val_238 2010-08-15 11 +238 val_238 2010-08-15 11 +238 val_238 2010-08-15 12 +238 val_238 2010-08-15 12 +239 val_239 2010-08-15 11 +239 val_239 2010-08-15 11 +239 val_239 2010-08-15 12 +239 val_239 2010-08-15 12 +241 val_241 2010-08-15 11 +241 val_241 2010-08-15 12 +242 val_242 2010-08-15 11 +242 val_242 2010-08-15 11 +242 val_242 2010-08-15 12 +242 val_242 2010-08-15 12 +244 val_244 2010-08-15 11 +244 val_244 2010-08-15 12 +247 val_247 2010-08-15 11 +247 val_247 2010-08-15 12 +248 val_248 2010-08-15 11 +248 val_248 2010-08-15 12 +249 val_249 2010-08-15 11 +249 val_249 2010-08-15 12 +252 val_252 2010-08-15 11 +252 val_252 2010-08-15 12 +255 val_255 2010-08-15 11 +255 val_255 2010-08-15 11 +255 val_255 2010-08-15 12 +255 val_255 2010-08-15 12 +256 val_256 2010-08-15 11 +256 val_256 2010-08-15 11 +256 val_256 2010-08-15 12 +256 val_256 2010-08-15 12 +257 val_257 2010-08-15 11 +257 val_257 2010-08-15 12 +258 val_258 2010-08-15 11 +258 val_258 2010-08-15 12 +260 val_260 2010-08-15 11 +260 val_260 2010-08-15 12 +262 val_262 2010-08-15 11 +262 val_262 2010-08-15 12 +263 val_263 2010-08-15 11 +263 val_263 2010-08-15 12 +265 val_265 2010-08-15 11 +265 val_265 2010-08-15 11 +265 val_265 2010-08-15 12 +265 val_265 2010-08-15 12 +266 val_266 2010-08-15 11 +266 val_266 2010-08-15 12 +272 val_272 2010-08-15 11 +272 val_272 2010-08-15 11 +272 val_272 2010-08-15 12 +272 val_272 2010-08-15 12 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 11 +273 val_273 2010-08-15 12 +273 val_273 2010-08-15 12 +273 val_273 2010-08-15 12 +274 val_274 2010-08-15 11 +274 val_274 2010-08-15 12 +275 val_275 2010-08-15 11 +275 val_275 2010-08-15 12 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 11 +277 val_277 2010-08-15 12 +277 val_277 2010-08-15 12 +277 val_277 2010-08-15 12 +277 val_277 2010-08-15 12 +278 val_278 2010-08-15 11 +278 val_278 2010-08-15 11 +278 val_278 2010-08-15 12 +278 val_278 2010-08-15 12 +280 val_280 2010-08-15 11 +280 val_280 2010-08-15 11 +280 val_280 2010-08-15 12 +280 val_280 2010-08-15 12 +281 val_281 2010-08-15 11 +281 val_281 2010-08-15 11 +281 val_281 2010-08-15 12 +281 val_281 2010-08-15 12 +282 val_282 2010-08-15 11 +282 val_282 2010-08-15 11 +282 val_282 2010-08-15 12 +282 val_282 2010-08-15 12 +283 val_283 2010-08-15 11 +283 val_283 2010-08-15 12 +284 val_284 2010-08-15 11 +284 val_284 2010-08-15 12 +285 val_285 2010-08-15 11 +285 val_285 2010-08-15 12 +286 val_286 2010-08-15 11 +286 val_286 2010-08-15 12 +287 val_287 2010-08-15 11 +287 val_287 2010-08-15 12 +288 val_288 2010-08-15 11 +288 val_288 2010-08-15 11 +288 val_288 2010-08-15 12 +288 val_288 2010-08-15 12 +289 val_289 2010-08-15 11 +289 val_289 2010-08-15 12 +291 val_291 2010-08-15 11 +291 val_291 2010-08-15 12 +292 val_292 2010-08-15 11 +292 val_292 2010-08-15 12 +296 val_296 2010-08-15 11 +296 val_296 2010-08-15 12 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 11 +298 val_298 2010-08-15 12 +298 val_298 2010-08-15 12 +298 val_298 2010-08-15 12 +302 val_302 2010-08-15 11 +302 val_302 2010-08-15 12 +305 val_305 2010-08-15 11 +305 val_305 2010-08-15 12 +306 val_306 2010-08-15 11 +306 val_306 2010-08-15 12 +307 val_307 2010-08-15 11 +307 val_307 2010-08-15 11 +307 val_307 2010-08-15 12 +307 val_307 2010-08-15 12 +308 val_308 2010-08-15 11 +308 val_308 2010-08-15 12 +309 val_309 2010-08-15 11 +309 val_309 2010-08-15 11 +309 val_309 2010-08-15 12 +309 val_309 2010-08-15 12 +310 val_310 2010-08-15 11 +310 val_310 2010-08-15 12 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 11 +311 val_311 2010-08-15 12 +311 val_311 2010-08-15 12 +311 val_311 2010-08-15 12 +315 val_315 2010-08-15 11 +315 val_315 2010-08-15 12 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 11 +316 val_316 2010-08-15 12 +316 val_316 2010-08-15 12 +316 val_316 2010-08-15 12 +317 val_317 2010-08-15 11 +317 val_317 2010-08-15 11 +317 val_317 2010-08-15 12 +317 val_317 2010-08-15 12 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 11 +318 val_318 2010-08-15 12 +318 val_318 2010-08-15 12 +318 val_318 2010-08-15 12 +321 val_321 2010-08-15 11 +321 val_321 2010-08-15 11 +321 val_321 2010-08-15 12 +321 val_321 2010-08-15 12 +322 val_322 2010-08-15 11 +322 val_322 2010-08-15 11 +322 val_322 2010-08-15 12 +322 val_322 2010-08-15 12 +323 val_323 2010-08-15 11 +323 val_323 2010-08-15 12 +325 val_325 2010-08-15 11 +325 val_325 2010-08-15 11 +325 val_325 2010-08-15 12 +325 val_325 2010-08-15 12 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 11 +327 val_327 2010-08-15 12 +327 val_327 2010-08-15 12 +327 val_327 2010-08-15 12 +331 val_331 2010-08-15 11 +331 val_331 2010-08-15 11 +331 val_331 2010-08-15 12 +331 val_331 2010-08-15 12 +332 val_332 2010-08-15 11 +332 val_332 2010-08-15 12 +333 val_333 2010-08-15 11 +333 val_333 2010-08-15 11 +333 val_333 2010-08-15 12 +333 val_333 2010-08-15 12 +335 val_335 2010-08-15 11 +335 val_335 2010-08-15 12 +336 val_336 2010-08-15 11 +336 val_336 2010-08-15 12 +338 val_338 2010-08-15 11 +338 val_338 2010-08-15 12 +339 val_339 2010-08-15 11 +339 val_339 2010-08-15 12 +341 val_341 2010-08-15 11 +341 val_341 2010-08-15 12 +342 val_342 2010-08-15 11 +342 val_342 2010-08-15 11 +342 val_342 2010-08-15 12 +342 val_342 2010-08-15 12 +344 val_344 2010-08-15 11 +344 val_344 2010-08-15 11 +344 val_344 2010-08-15 12 +344 val_344 2010-08-15 12 +345 val_345 2010-08-15 11 +345 val_345 2010-08-15 12 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 11 +348 val_348 2010-08-15 12 +348 val_348 2010-08-15 12 +348 val_348 2010-08-15 12 +348 val_348 2010-08-15 12 +348 val_348 2010-08-15 12 +351 val_351 2010-08-15 11 +351 val_351 2010-08-15 12 +353 val_353 2010-08-15 11 +353 val_353 2010-08-15 11 +353 val_353 2010-08-15 12 +353 val_353 2010-08-15 12 +356 val_356 2010-08-15 11 +356 val_356 2010-08-15 12 +360 val_360 2010-08-15 11 +360 val_360 2010-08-15 12 +362 val_362 2010-08-15 11 +362 val_362 2010-08-15 12 +364 val_364 2010-08-15 11 +364 val_364 2010-08-15 12 +365 val_365 2010-08-15 11 +365 val_365 2010-08-15 12 +366 val_366 2010-08-15 11 +366 val_366 2010-08-15 12 +367 val_367 2010-08-15 11 +367 val_367 2010-08-15 11 +367 val_367 2010-08-15 12 +367 val_367 2010-08-15 12 +368 val_368 2010-08-15 11 +368 val_368 2010-08-15 12 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 11 +369 val_369 2010-08-15 12 +369 val_369 2010-08-15 12 +369 val_369 2010-08-15 12 +373 val_373 2010-08-15 11 +373 val_373 2010-08-15 12 +374 val_374 2010-08-15 11 +374 val_374 2010-08-15 12 +375 val_375 2010-08-15 11 +375 val_375 2010-08-15 12 +377 val_377 2010-08-15 11 +377 val_377 2010-08-15 12 +378 val_378 2010-08-15 11 +378 val_378 2010-08-15 12 +379 val_379 2010-08-15 11 +379 val_379 2010-08-15 12 +382 val_382 2010-08-15 11 +382 val_382 2010-08-15 11 +382 val_382 2010-08-15 12 +382 val_382 2010-08-15 12 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 11 +384 val_384 2010-08-15 12 +384 val_384 2010-08-15 12 +384 val_384 2010-08-15 12 +386 val_386 2010-08-15 11 +386 val_386 2010-08-15 12 +389 val_389 2010-08-15 11 +389 val_389 2010-08-15 12 +392 val_392 2010-08-15 11 +392 val_392 2010-08-15 12 +393 val_393 2010-08-15 11 +393 val_393 2010-08-15 12 +394 val_394 2010-08-15 11 +394 val_394 2010-08-15 12 +395 val_395 2010-08-15 11 +395 val_395 2010-08-15 11 +395 val_395 2010-08-15 12 +395 val_395 2010-08-15 12 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 11 +396 val_396 2010-08-15 12 +396 val_396 2010-08-15 12 +396 val_396 2010-08-15 12 +397 val_397 2010-08-15 11 +397 val_397 2010-08-15 11 +397 val_397 2010-08-15 12 +397 val_397 2010-08-15 12 +399 val_399 2010-08-15 11 +399 val_399 2010-08-15 11 +399 val_399 2010-08-15 12 +399 val_399 2010-08-15 12 +400 val_400 2010-08-15 11 +400 val_400 2010-08-15 12 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 11 +401 val_401 2010-08-15 12 +401 val_401 2010-08-15 12 +401 val_401 2010-08-15 12 +401 val_401 2010-08-15 12 +401 val_401 2010-08-15 12 +402 val_402 2010-08-15 11 +402 val_402 2010-08-15 12 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 11 +403 val_403 2010-08-15 12 +403 val_403 2010-08-15 12 +403 val_403 2010-08-15 12 +404 val_404 2010-08-15 11 +404 val_404 2010-08-15 11 +404 val_404 2010-08-15 12 +404 val_404 2010-08-15 12 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 11 +406 val_406 2010-08-15 12 +406 val_406 2010-08-15 12 +406 val_406 2010-08-15 12 +406 val_406 2010-08-15 12 +407 val_407 2010-08-15 11 +407 val_407 2010-08-15 12 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 11 +409 val_409 2010-08-15 12 +409 val_409 2010-08-15 12 +409 val_409 2010-08-15 12 +411 val_411 2010-08-15 11 +411 val_411 2010-08-15 12 +413 val_413 2010-08-15 11 +413 val_413 2010-08-15 11 +413 val_413 2010-08-15 12 +413 val_413 2010-08-15 12 +414 val_414 2010-08-15 11 +414 val_414 2010-08-15 11 +414 val_414 2010-08-15 12 +414 val_414 2010-08-15 12 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 11 +417 val_417 2010-08-15 12 +417 val_417 2010-08-15 12 +417 val_417 2010-08-15 12 +418 val_418 2010-08-15 11 +418 val_418 2010-08-15 12 +419 val_419 2010-08-15 11 +419 val_419 2010-08-15 12 +421 val_421 2010-08-15 11 +421 val_421 2010-08-15 12 +424 val_424 2010-08-15 11 +424 val_424 2010-08-15 11 +424 val_424 2010-08-15 12 +424 val_424 2010-08-15 12 +427 val_427 2010-08-15 11 +427 val_427 2010-08-15 12 +429 val_429 2010-08-15 11 +429 val_429 2010-08-15 11 +429 val_429 2010-08-15 12 +429 val_429 2010-08-15 12 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 11 +430 val_430 2010-08-15 12 +430 val_430 2010-08-15 12 +430 val_430 2010-08-15 12 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 11 +431 val_431 2010-08-15 12 +431 val_431 2010-08-15 12 +431 val_431 2010-08-15 12 +432 val_432 2010-08-15 11 +432 val_432 2010-08-15 12 +435 val_435 2010-08-15 11 +435 val_435 2010-08-15 12 +436 val_436 2010-08-15 11 +436 val_436 2010-08-15 12 +437 val_437 2010-08-15 11 +437 val_437 2010-08-15 12 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 11 +438 val_438 2010-08-15 12 +438 val_438 2010-08-15 12 +438 val_438 2010-08-15 12 +439 val_439 2010-08-15 11 +439 val_439 2010-08-15 11 +439 val_439 2010-08-15 12 +439 val_439 2010-08-15 12 +443 val_443 2010-08-15 11 +443 val_443 2010-08-15 12 +444 val_444 2010-08-15 11 +444 val_444 2010-08-15 12 +446 val_446 2010-08-15 11 +446 val_446 2010-08-15 12 +448 val_448 2010-08-15 11 +448 val_448 2010-08-15 12 +449 val_449 2010-08-15 11 +449 val_449 2010-08-15 12 +452 val_452 2010-08-15 11 +452 val_452 2010-08-15 12 +453 val_453 2010-08-15 11 +453 val_453 2010-08-15 12 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 11 +454 val_454 2010-08-15 12 +454 val_454 2010-08-15 12 +454 val_454 2010-08-15 12 +455 val_455 2010-08-15 11 +455 val_455 2010-08-15 12 +457 val_457 2010-08-15 11 +457 val_457 2010-08-15 12 +458 val_458 2010-08-15 11 +458 val_458 2010-08-15 11 +458 val_458 2010-08-15 12 +458 val_458 2010-08-15 12 +459 val_459 2010-08-15 11 +459 val_459 2010-08-15 11 +459 val_459 2010-08-15 12 +459 val_459 2010-08-15 12 +460 val_460 2010-08-15 11 +460 val_460 2010-08-15 12 +462 val_462 2010-08-15 11 +462 val_462 2010-08-15 11 +462 val_462 2010-08-15 12 +462 val_462 2010-08-15 12 +463 val_463 2010-08-15 11 +463 val_463 2010-08-15 11 +463 val_463 2010-08-15 12 +463 val_463 2010-08-15 12 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 11 +466 val_466 2010-08-15 12 +466 val_466 2010-08-15 12 +466 val_466 2010-08-15 12 +467 val_467 2010-08-15 11 +467 val_467 2010-08-15 12 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 11 +468 val_468 2010-08-15 12 +468 val_468 2010-08-15 12 +468 val_468 2010-08-15 12 +468 val_468 2010-08-15 12 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 11 +469 val_469 2010-08-15 12 +469 val_469 2010-08-15 12 +469 val_469 2010-08-15 12 +469 val_469 2010-08-15 12 +469 val_469 2010-08-15 12 +470 val_470 2010-08-15 11 +470 val_470 2010-08-15 12 +472 val_472 2010-08-15 11 +472 val_472 2010-08-15 12 +475 val_475 2010-08-15 11 +475 val_475 2010-08-15 12 +477 val_477 2010-08-15 11 +477 val_477 2010-08-15 12 +478 val_478 2010-08-15 11 +478 val_478 2010-08-15 11 +478 val_478 2010-08-15 12 +478 val_478 2010-08-15 12 +479 val_479 2010-08-15 11 +479 val_479 2010-08-15 12 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 11 +480 val_480 2010-08-15 12 +480 val_480 2010-08-15 12 +480 val_480 2010-08-15 12 +481 val_481 2010-08-15 11 +481 val_481 2010-08-15 12 +482 val_482 2010-08-15 11 +482 val_482 2010-08-15 12 +483 val_483 2010-08-15 11 +483 val_483 2010-08-15 12 +484 val_484 2010-08-15 11 +484 val_484 2010-08-15 12 +485 val_485 2010-08-15 11 +485 val_485 2010-08-15 12 +487 val_487 2010-08-15 11 +487 val_487 2010-08-15 12 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 11 +489 val_489 2010-08-15 12 +489 val_489 2010-08-15 12 +489 val_489 2010-08-15 12 +489 val_489 2010-08-15 12 +490 val_490 2010-08-15 11 +490 val_490 2010-08-15 12 +491 val_491 2010-08-15 11 +491 val_491 2010-08-15 12 +492 val_492 2010-08-15 11 +492 val_492 2010-08-15 11 +492 val_492 2010-08-15 12 +492 val_492 2010-08-15 12 +493 val_493 2010-08-15 11 +493 val_493 2010-08-15 12 +494 val_494 2010-08-15 11 +494 val_494 2010-08-15 12 +495 val_495 2010-08-15 11 +495 val_495 2010-08-15 12 +496 val_496 2010-08-15 11 +496 val_496 2010-08-15 12 +497 val_497 2010-08-15 11 +497 val_497 2010-08-15 12 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 11 +498 val_498 2010-08-15 12 +498 val_498 2010-08-15 12 +498 val_498 2010-08-15 12 diff --git a/sql/hive/src/test/resources/golden/merge4-9-ad3dc168c8b6f048717e39ab16b0a319 b/sql/hive/src/test/resources/golden/merge4-9-ad3dc168c8b6f048717e39ab16b0a319 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/no from clause-0-b42b408a87b258921240058f880a721a b/sql/hive/src/test/resources/golden/no from clause-0-b42b408a87b258921240058f880a721a new file mode 100644 index 000000000000..390d344ecb9d --- /dev/null +++ b/sql/hive/src/test/resources/golden/no from clause-0-b42b408a87b258921240058f880a721a @@ -0,0 +1 @@ +1 1 -1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-0-36f9196395758cebfed837a1c391a1e b/sql/hive/src/test/resources/golden/nullformatCTAS-0-36f9196395758cebfed837a1c391a1e new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-1-b5a31d4cb34218b8de1ac3fed59fa75b b/sql/hive/src/test/resources/golden/nullformatCTAS-1-b5a31d4cb34218b8de1ac3fed59fa75b new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-10-7f4f04b87c7ef9653b4646949b24cf0b b/sql/hive/src/test/resources/golden/nullformatCTAS-10-7f4f04b87c7ef9653b4646949b24cf0b new file mode 100644 index 000000000000..e74deff51c9b --- /dev/null +++ b/sql/hive/src/test/resources/golden/nullformatCTAS-10-7f4f04b87c7ef9653b4646949b24cf0b @@ -0,0 +1,10 @@ +1.0 1 +1.0 1 +1.0 1 +1.0 1 +1.0 1 +NULL 1 +NULL NULL +1.0 NULL +1.0 1 +1.0 1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-11-4a4c16b53c612d00012d338c97bf5281 b/sql/hive/src/test/resources/golden/nullformatCTAS-11-4a4c16b53c612d00012d338c97bf5281 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-12-7f4f04b87c7ef9653b4646949b24cf0b b/sql/hive/src/test/resources/golden/nullformatCTAS-12-7f4f04b87c7ef9653b4646949b24cf0b new file mode 100644 index 000000000000..00ebb521970d --- /dev/null +++ b/sql/hive/src/test/resources/golden/nullformatCTAS-12-7f4f04b87c7ef9653b4646949b24cf0b @@ -0,0 +1,10 @@ +1.0 1 +1.0 1 +1.0 1 +1.0 1 +1.0 1 +fooNull 1 +fooNull fooNull +1.0 fooNull +1.0 1 +1.0 1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-13-2e59caa113585495d8684fee69d88bc0 b/sql/hive/src/test/resources/golden/nullformatCTAS-13-2e59caa113585495d8684fee69d88bc0 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-14-ad9fe9d68c2cf492259af4f6167c1b12 b/sql/hive/src/test/resources/golden/nullformatCTAS-14-ad9fe9d68c2cf492259af4f6167c1b12 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-2-aa2bdbd93668dceae43d1a02f2ede68d b/sql/hive/src/test/resources/golden/nullformatCTAS-2-aa2bdbd93668dceae43d1a02f2ede68d new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-3-b0057150f237050f38c1efa1f2d6b273 b/sql/hive/src/test/resources/golden/nullformatCTAS-3-b0057150f237050f38c1efa1f2d6b273 new file mode 100644 index 000000000000..b00bcb362453 --- /dev/null +++ b/sql/hive/src/test/resources/golden/nullformatCTAS-3-b0057150f237050f38c1efa1f2d6b273 @@ -0,0 +1,6 @@ +a string +b string +c string +d string + +Detailed Table Information Table(tableName:base_tab, dbName:default, owner:animal, createTime:1423973915, lastAccessTime:0, retention:0, sd:StorageDescriptor(cols:[FieldSchema(name:a, type:string, comment:null), FieldSchema(name:b, type:string, comment:null), FieldSchema(name:c, type:string, comment:null), FieldSchema(name:d, type:string, comment:null)], location:file:/tmp/sparkHiveWarehouse2573474017665704744/base_tab, inputFormat:org.apache.hadoop.mapred.TextInputFormat, outputFormat:org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat, compressed:false, numBuckets:-1, serdeInfo:SerDeInfo(name:null, serializationLib:org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, parameters:{serialization.format=1}), bucketCols:[], sortCols:[], parameters:{}, skewedInfo:SkewedInfo(skewedColNames:[], skewedColValues:[], skewedColValueLocationMaps:{}), storedAsSubDirectories:false), partitionKeys:[], parameters:{numFiles=1, transient_lastDdlTime=1423973915, COLUMN_STATS_ACCURATE=true, totalSize=130, numRows=0, rawDataSize=0}, viewOriginalText:null, viewExpandedText:null, tableType:MANAGED_TABLE) diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-4-16c7086f39d6458b6c5cf2479f0473bd b/sql/hive/src/test/resources/golden/nullformatCTAS-4-16c7086f39d6458b6c5cf2479f0473bd new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-5-183d77b734ce6a373de5b3ebe1cd04c9 b/sql/hive/src/test/resources/golden/nullformatCTAS-5-183d77b734ce6a373de5b3ebe1cd04c9 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-6-159fff36b548e00ee952d1df8ef19833 b/sql/hive/src/test/resources/golden/nullformatCTAS-6-159fff36b548e00ee952d1df8ef19833 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-7-46900b082b02ce3e58087d1f41128f65 b/sql/hive/src/test/resources/golden/nullformatCTAS-7-46900b082b02ce3e58087d1f41128f65 new file mode 100644 index 000000000000..264c973ff7af --- /dev/null +++ b/sql/hive/src/test/resources/golden/nullformatCTAS-7-46900b082b02ce3e58087d1f41128f65 @@ -0,0 +1,4 @@ +a string +b string + +Detailed Table Information Table(tableName:null_tab3, dbName:default, owner:animal, createTime:1423973928, lastAccessTime:0, retention:0, sd:StorageDescriptor(cols:[FieldSchema(name:a, type:string, comment:null), FieldSchema(name:b, type:string, comment:null)], location:file:/tmp/sparkHiveWarehouse2573474017665704744/null_tab3, inputFormat:org.apache.hadoop.mapred.TextInputFormat, outputFormat:org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat, compressed:false, numBuckets:-1, serdeInfo:SerDeInfo(name:null, serializationLib:org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, parameters:{serialization.null.format=fooNull, serialization.format=1}), bucketCols:[], sortCols:[], parameters:{}, skewedInfo:SkewedInfo(skewedColNames:[], skewedColValues:[], skewedColValueLocationMaps:{}), storedAsSubDirectories:false), partitionKeys:[], parameters:{numFiles=1, transient_lastDdlTime=1423973928, COLUMN_STATS_ACCURATE=true, totalSize=80, numRows=10, rawDataSize=70}, viewOriginalText:null, viewExpandedText:null, tableType:MANAGED_TABLE) diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-8-7f26cbd6be5631a3acce26f667d1c5d8 b/sql/hive/src/test/resources/golden/nullformatCTAS-8-7f26cbd6be5631a3acce26f667d1c5d8 new file mode 100644 index 000000000000..881917bcf1c6 --- /dev/null +++ b/sql/hive/src/test/resources/golden/nullformatCTAS-8-7f26cbd6be5631a3acce26f667d1c5d8 @@ -0,0 +1,18 @@ +CREATE TABLE `null_tab3`( + `a` string, + `b` string) +ROW FORMAT DELIMITED + NULL DEFINED AS 'fooNull' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse2573474017665704744/null_tab3' +TBLPROPERTIES ( + 'numFiles'='1', + 'transient_lastDdlTime'='1423973928', + 'COLUMN_STATS_ACCURATE'='true', + 'totalSize'='80', + 'numRows'='10', + 'rawDataSize'='70') diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-9-22e1b3899de7087b39c24d9d8f643b47 b/sql/hive/src/test/resources/golden/nullformatCTAS-9-22e1b3899de7087b39c24d9d8f643b47 new file mode 100644 index 000000000000..3a2e3f4984a0 --- /dev/null +++ b/sql/hive/src/test/resources/golden/nullformatCTAS-9-22e1b3899de7087b39c24d9d8f643b47 @@ -0,0 +1 @@ +-1 diff --git a/sql/hive/src/test/resources/golden/schema-less transform-0-d5738de14dd6e29da712ec3318f4118f b/sql/hive/src/test/resources/golden/schema-less transform-0-d5738de14dd6e29da712ec3318f4118f new file mode 100644 index 000000000000..7aae61e5eb82 --- /dev/null +++ b/sql/hive/src/test/resources/golden/schema-less transform-0-d5738de14dd6e29da712ec3318f4118f @@ -0,0 +1,500 @@ +238 val_238 +86 val_86 +311 val_311 +27 val_27 +165 val_165 +409 val_409 +255 val_255 +278 val_278 +98 val_98 +484 val_484 +265 val_265 +193 val_193 +401 val_401 +150 val_150 +273 val_273 +224 val_224 +369 val_369 +66 val_66 +128 val_128 +213 val_213 +146 val_146 +406 val_406 +429 val_429 +374 val_374 +152 val_152 +469 val_469 +145 val_145 +495 val_495 +37 val_37 +327 val_327 +281 val_281 +277 val_277 +209 val_209 +15 val_15 +82 val_82 +403 val_403 +166 val_166 +417 val_417 +430 val_430 +252 val_252 +292 val_292 +219 val_219 +287 val_287 +153 val_153 +193 val_193 +338 val_338 +446 val_446 +459 val_459 +394 val_394 +237 val_237 +482 val_482 +174 val_174 +413 val_413 +494 val_494 +207 val_207 +199 val_199 +466 val_466 +208 val_208 +174 val_174 +399 val_399 +396 val_396 +247 val_247 +417 val_417 +489 val_489 +162 val_162 +377 val_377 +397 val_397 +309 val_309 +365 val_365 +266 val_266 +439 val_439 +342 val_342 +367 val_367 +325 val_325 +167 val_167 +195 val_195 +475 val_475 +17 val_17 +113 val_113 +155 val_155 +203 val_203 +339 val_339 +0 val_0 +455 val_455 +128 val_128 +311 val_311 +316 val_316 +57 val_57 +302 val_302 +205 val_205 +149 val_149 +438 val_438 +345 val_345 +129 val_129 +170 val_170 +20 val_20 +489 val_489 +157 val_157 +378 val_378 +221 val_221 +92 val_92 +111 val_111 +47 val_47 +72 val_72 +4 val_4 +280 val_280 +35 val_35 +427 val_427 +277 val_277 +208 val_208 +356 val_356 +399 val_399 +169 val_169 +382 val_382 +498 val_498 +125 val_125 +386 val_386 +437 val_437 +469 val_469 +192 val_192 +286 val_286 +187 val_187 +176 val_176 +54 val_54 +459 val_459 +51 val_51 +138 val_138 +103 val_103 +239 val_239 +213 val_213 +216 val_216 +430 val_430 +278 val_278 +176 val_176 +289 val_289 +221 val_221 +65 val_65 +318 val_318 +332 val_332 +311 val_311 +275 val_275 +137 val_137 +241 val_241 +83 val_83 +333 val_333 +180 val_180 +284 val_284 +12 val_12 +230 val_230 +181 val_181 +67 val_67 +260 val_260 +404 val_404 +384 val_384 +489 val_489 +353 val_353 +373 val_373 +272 val_272 +138 val_138 +217 val_217 +84 val_84 +348 val_348 +466 val_466 +58 val_58 +8 val_8 +411 val_411 +230 val_230 +208 val_208 +348 val_348 +24 val_24 +463 val_463 +431 val_431 +179 val_179 +172 val_172 +42 val_42 +129 val_129 +158 val_158 +119 val_119 +496 val_496 +0 val_0 +322 val_322 +197 val_197 +468 val_468 +393 val_393 +454 val_454 +100 val_100 +298 val_298 +199 val_199 +191 val_191 +418 val_418 +96 val_96 +26 val_26 +165 val_165 +327 val_327 +230 val_230 +205 val_205 +120 val_120 +131 val_131 +51 val_51 +404 val_404 +43 val_43 +436 val_436 +156 val_156 +469 val_469 +468 val_468 +308 val_308 +95 val_95 +196 val_196 +288 val_288 +481 val_481 +457 val_457 +98 val_98 +282 val_282 +197 val_197 +187 val_187 +318 val_318 +318 val_318 +409 val_409 +470 val_470 +137 val_137 +369 val_369 +316 val_316 +169 val_169 +413 val_413 +85 val_85 +77 val_77 +0 val_0 +490 val_490 +87 val_87 +364 val_364 +179 val_179 +118 val_118 +134 val_134 +395 val_395 +282 val_282 +138 val_138 +238 val_238 +419 val_419 +15 val_15 +118 val_118 +72 val_72 +90 val_90 +307 val_307 +19 val_19 +435 val_435 +10 val_10 +277 val_277 +273 val_273 +306 val_306 +224 val_224 +309 val_309 +389 val_389 +327 val_327 +242 val_242 +369 val_369 +392 val_392 +272 val_272 +331 val_331 +401 val_401 +242 val_242 +452 val_452 +177 val_177 +226 val_226 +5 val_5 +497 val_497 +402 val_402 +396 val_396 +317 val_317 +395 val_395 +58 val_58 +35 val_35 +336 val_336 +95 val_95 +11 val_11 +168 val_168 +34 val_34 +229 val_229 +233 val_233 +143 val_143 +472 val_472 +322 val_322 +498 val_498 +160 val_160 +195 val_195 +42 val_42 +321 val_321 +430 val_430 +119 val_119 +489 val_489 +458 val_458 +78 val_78 +76 val_76 +41 val_41 +223 val_223 +492 val_492 +149 val_149 +449 val_449 +218 val_218 +228 val_228 +138 val_138 +453 val_453 +30 val_30 +209 val_209 +64 val_64 +468 val_468 +76 val_76 +74 val_74 +342 val_342 +69 val_69 +230 val_230 +33 val_33 +368 val_368 +103 val_103 +296 val_296 +113 val_113 +216 val_216 +367 val_367 +344 val_344 +167 val_167 +274 val_274 +219 val_219 +239 val_239 +485 val_485 +116 val_116 +223 val_223 +256 val_256 +263 val_263 +70 val_70 +487 val_487 +480 val_480 +401 val_401 +288 val_288 +191 val_191 +5 val_5 +244 val_244 +438 val_438 +128 val_128 +467 val_467 +432 val_432 +202 val_202 +316 val_316 +229 val_229 +469 val_469 +463 val_463 +280 val_280 +2 val_2 +35 val_35 +283 val_283 +331 val_331 +235 val_235 +80 val_80 +44 val_44 +193 val_193 +321 val_321 +335 val_335 +104 val_104 +466 val_466 +366 val_366 +175 val_175 +403 val_403 +483 val_483 +53 val_53 +105 val_105 +257 val_257 +406 val_406 +409 val_409 +190 val_190 +406 val_406 +401 val_401 +114 val_114 +258 val_258 +90 val_90 +203 val_203 +262 val_262 +348 val_348 +424 val_424 +12 val_12 +396 val_396 +201 val_201 +217 val_217 +164 val_164 +431 val_431 +454 val_454 +478 val_478 +298 val_298 +125 val_125 +431 val_431 +164 val_164 +424 val_424 +187 val_187 +382 val_382 +5 val_5 +70 val_70 +397 val_397 +480 val_480 +291 val_291 +24 val_24 +351 val_351 +255 val_255 +104 val_104 +70 val_70 +163 val_163 +438 val_438 +119 val_119 +414 val_414 +200 val_200 +491 val_491 +237 val_237 +439 val_439 +360 val_360 +248 val_248 +479 val_479 +305 val_305 +417 val_417 +199 val_199 +444 val_444 +120 val_120 +429 val_429 +169 val_169 +443 val_443 +323 val_323 +325 val_325 +277 val_277 +230 val_230 +478 val_478 +178 val_178 +468 val_468 +310 val_310 +317 val_317 +333 val_333 +493 val_493 +460 val_460 +207 val_207 +249 val_249 +265 val_265 +480 val_480 +83 val_83 +136 val_136 +353 val_353 +172 val_172 +214 val_214 +462 val_462 +233 val_233 +406 val_406 +133 val_133 +175 val_175 +189 val_189 +454 val_454 +375 val_375 +401 val_401 +421 val_421 +407 val_407 +384 val_384 +256 val_256 +26 val_26 +134 val_134 +67 val_67 +384 val_384 +379 val_379 +18 val_18 +462 val_462 +492 val_492 +100 val_100 +298 val_298 +9 val_9 +341 val_341 +498 val_498 +146 val_146 +458 val_458 +362 val_362 +186 val_186 +285 val_285 +348 val_348 +167 val_167 +18 val_18 +273 val_273 +183 val_183 +281 val_281 +344 val_344 +97 val_97 +469 val_469 +315 val_315 +84 val_84 +28 val_28 +37 val_37 +448 val_448 +152 val_152 +348 val_348 +307 val_307 +194 val_194 +414 val_414 +477 val_477 +222 val_222 +126 val_126 +90 val_90 +169 val_169 +403 val_403 +400 val_400 +200 val_200 +97 val_97 diff --git a/sql/hive/src/test/resources/golden/schema-less transform-1-49624ef4e2c3cc2040c06660b926219b b/sql/hive/src/test/resources/golden/schema-less transform-1-49624ef4e2c3cc2040c06660b926219b new file mode 100644 index 000000000000..7aae61e5eb82 --- /dev/null +++ b/sql/hive/src/test/resources/golden/schema-less transform-1-49624ef4e2c3cc2040c06660b926219b @@ -0,0 +1,500 @@ +238 val_238 +86 val_86 +311 val_311 +27 val_27 +165 val_165 +409 val_409 +255 val_255 +278 val_278 +98 val_98 +484 val_484 +265 val_265 +193 val_193 +401 val_401 +150 val_150 +273 val_273 +224 val_224 +369 val_369 +66 val_66 +128 val_128 +213 val_213 +146 val_146 +406 val_406 +429 val_429 +374 val_374 +152 val_152 +469 val_469 +145 val_145 +495 val_495 +37 val_37 +327 val_327 +281 val_281 +277 val_277 +209 val_209 +15 val_15 +82 val_82 +403 val_403 +166 val_166 +417 val_417 +430 val_430 +252 val_252 +292 val_292 +219 val_219 +287 val_287 +153 val_153 +193 val_193 +338 val_338 +446 val_446 +459 val_459 +394 val_394 +237 val_237 +482 val_482 +174 val_174 +413 val_413 +494 val_494 +207 val_207 +199 val_199 +466 val_466 +208 val_208 +174 val_174 +399 val_399 +396 val_396 +247 val_247 +417 val_417 +489 val_489 +162 val_162 +377 val_377 +397 val_397 +309 val_309 +365 val_365 +266 val_266 +439 val_439 +342 val_342 +367 val_367 +325 val_325 +167 val_167 +195 val_195 +475 val_475 +17 val_17 +113 val_113 +155 val_155 +203 val_203 +339 val_339 +0 val_0 +455 val_455 +128 val_128 +311 val_311 +316 val_316 +57 val_57 +302 val_302 +205 val_205 +149 val_149 +438 val_438 +345 val_345 +129 val_129 +170 val_170 +20 val_20 +489 val_489 +157 val_157 +378 val_378 +221 val_221 +92 val_92 +111 val_111 +47 val_47 +72 val_72 +4 val_4 +280 val_280 +35 val_35 +427 val_427 +277 val_277 +208 val_208 +356 val_356 +399 val_399 +169 val_169 +382 val_382 +498 val_498 +125 val_125 +386 val_386 +437 val_437 +469 val_469 +192 val_192 +286 val_286 +187 val_187 +176 val_176 +54 val_54 +459 val_459 +51 val_51 +138 val_138 +103 val_103 +239 val_239 +213 val_213 +216 val_216 +430 val_430 +278 val_278 +176 val_176 +289 val_289 +221 val_221 +65 val_65 +318 val_318 +332 val_332 +311 val_311 +275 val_275 +137 val_137 +241 val_241 +83 val_83 +333 val_333 +180 val_180 +284 val_284 +12 val_12 +230 val_230 +181 val_181 +67 val_67 +260 val_260 +404 val_404 +384 val_384 +489 val_489 +353 val_353 +373 val_373 +272 val_272 +138 val_138 +217 val_217 +84 val_84 +348 val_348 +466 val_466 +58 val_58 +8 val_8 +411 val_411 +230 val_230 +208 val_208 +348 val_348 +24 val_24 +463 val_463 +431 val_431 +179 val_179 +172 val_172 +42 val_42 +129 val_129 +158 val_158 +119 val_119 +496 val_496 +0 val_0 +322 val_322 +197 val_197 +468 val_468 +393 val_393 +454 val_454 +100 val_100 +298 val_298 +199 val_199 +191 val_191 +418 val_418 +96 val_96 +26 val_26 +165 val_165 +327 val_327 +230 val_230 +205 val_205 +120 val_120 +131 val_131 +51 val_51 +404 val_404 +43 val_43 +436 val_436 +156 val_156 +469 val_469 +468 val_468 +308 val_308 +95 val_95 +196 val_196 +288 val_288 +481 val_481 +457 val_457 +98 val_98 +282 val_282 +197 val_197 +187 val_187 +318 val_318 +318 val_318 +409 val_409 +470 val_470 +137 val_137 +369 val_369 +316 val_316 +169 val_169 +413 val_413 +85 val_85 +77 val_77 +0 val_0 +490 val_490 +87 val_87 +364 val_364 +179 val_179 +118 val_118 +134 val_134 +395 val_395 +282 val_282 +138 val_138 +238 val_238 +419 val_419 +15 val_15 +118 val_118 +72 val_72 +90 val_90 +307 val_307 +19 val_19 +435 val_435 +10 val_10 +277 val_277 +273 val_273 +306 val_306 +224 val_224 +309 val_309 +389 val_389 +327 val_327 +242 val_242 +369 val_369 +392 val_392 +272 val_272 +331 val_331 +401 val_401 +242 val_242 +452 val_452 +177 val_177 +226 val_226 +5 val_5 +497 val_497 +402 val_402 +396 val_396 +317 val_317 +395 val_395 +58 val_58 +35 val_35 +336 val_336 +95 val_95 +11 val_11 +168 val_168 +34 val_34 +229 val_229 +233 val_233 +143 val_143 +472 val_472 +322 val_322 +498 val_498 +160 val_160 +195 val_195 +42 val_42 +321 val_321 +430 val_430 +119 val_119 +489 val_489 +458 val_458 +78 val_78 +76 val_76 +41 val_41 +223 val_223 +492 val_492 +149 val_149 +449 val_449 +218 val_218 +228 val_228 +138 val_138 +453 val_453 +30 val_30 +209 val_209 +64 val_64 +468 val_468 +76 val_76 +74 val_74 +342 val_342 +69 val_69 +230 val_230 +33 val_33 +368 val_368 +103 val_103 +296 val_296 +113 val_113 +216 val_216 +367 val_367 +344 val_344 +167 val_167 +274 val_274 +219 val_219 +239 val_239 +485 val_485 +116 val_116 +223 val_223 +256 val_256 +263 val_263 +70 val_70 +487 val_487 +480 val_480 +401 val_401 +288 val_288 +191 val_191 +5 val_5 +244 val_244 +438 val_438 +128 val_128 +467 val_467 +432 val_432 +202 val_202 +316 val_316 +229 val_229 +469 val_469 +463 val_463 +280 val_280 +2 val_2 +35 val_35 +283 val_283 +331 val_331 +235 val_235 +80 val_80 +44 val_44 +193 val_193 +321 val_321 +335 val_335 +104 val_104 +466 val_466 +366 val_366 +175 val_175 +403 val_403 +483 val_483 +53 val_53 +105 val_105 +257 val_257 +406 val_406 +409 val_409 +190 val_190 +406 val_406 +401 val_401 +114 val_114 +258 val_258 +90 val_90 +203 val_203 +262 val_262 +348 val_348 +424 val_424 +12 val_12 +396 val_396 +201 val_201 +217 val_217 +164 val_164 +431 val_431 +454 val_454 +478 val_478 +298 val_298 +125 val_125 +431 val_431 +164 val_164 +424 val_424 +187 val_187 +382 val_382 +5 val_5 +70 val_70 +397 val_397 +480 val_480 +291 val_291 +24 val_24 +351 val_351 +255 val_255 +104 val_104 +70 val_70 +163 val_163 +438 val_438 +119 val_119 +414 val_414 +200 val_200 +491 val_491 +237 val_237 +439 val_439 +360 val_360 +248 val_248 +479 val_479 +305 val_305 +417 val_417 +199 val_199 +444 val_444 +120 val_120 +429 val_429 +169 val_169 +443 val_443 +323 val_323 +325 val_325 +277 val_277 +230 val_230 +478 val_478 +178 val_178 +468 val_468 +310 val_310 +317 val_317 +333 val_333 +493 val_493 +460 val_460 +207 val_207 +249 val_249 +265 val_265 +480 val_480 +83 val_83 +136 val_136 +353 val_353 +172 val_172 +214 val_214 +462 val_462 +233 val_233 +406 val_406 +133 val_133 +175 val_175 +189 val_189 +454 val_454 +375 val_375 +401 val_401 +421 val_421 +407 val_407 +384 val_384 +256 val_256 +26 val_26 +134 val_134 +67 val_67 +384 val_384 +379 val_379 +18 val_18 +462 val_462 +492 val_492 +100 val_100 +298 val_298 +9 val_9 +341 val_341 +498 val_498 +146 val_146 +458 val_458 +362 val_362 +186 val_186 +285 val_285 +348 val_348 +167 val_167 +18 val_18 +273 val_273 +183 val_183 +281 val_281 +344 val_344 +97 val_97 +469 val_469 +315 val_315 +84 val_84 +28 val_28 +37 val_37 +448 val_448 +152 val_152 +348 val_348 +307 val_307 +194 val_194 +414 val_414 +477 val_477 +222 val_222 +126 val_126 +90 val_90 +169 val_169 +403 val_403 +400 val_400 +200 val_200 +97 val_97 diff --git a/sql/hive/src/test/resources/golden/semicolon-0-f104632770dc96b81f00ccdac51fe5a8 b/sql/hive/src/test/resources/golden/semicolon-0-f104632770dc96b81f00ccdac51fe5a8 new file mode 100644 index 000000000000..1b79f38e25b2 --- /dev/null +++ b/sql/hive/src/test/resources/golden/semicolon-0-f104632770dc96b81f00ccdac51fe5a8 @@ -0,0 +1 @@ +500 diff --git a/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-0-c6d02549aec166e16bfc44d5905fa33a b/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-0-c6d02549aec166e16bfc44d5905fa33a new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-1-a8987ff8c7b9ca95bf8b32314694ed1f b/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-1-a8987ff8c7b9ca95bf8b32314694ed1f new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-2-26f54240cf5b909086fc34a34d7fdb56 b/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-2-26f54240cf5b909086fc34a34d7fdb56 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-3-d08d5280027adea681001ad82a5a6974 b/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-3-d08d5280027adea681001ad82a5a6974 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-4-22eb25b5be6daf72a6649adfe5041749 b/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-4-22eb25b5be6daf72a6649adfe5041749 new file mode 100644 index 000000000000..d00491fd7e5b --- /dev/null +++ b/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-4-22eb25b5be6daf72a6649adfe5041749 @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/transform with SerDe-0-cdc393f3914c879787efe523f692b1e0 b/sql/hive/src/test/resources/golden/transform with SerDe-0-cdc393f3914c879787efe523f692b1e0 new file mode 100644 index 000000000000..7aae61e5eb82 --- /dev/null +++ b/sql/hive/src/test/resources/golden/transform with SerDe-0-cdc393f3914c879787efe523f692b1e0 @@ -0,0 +1,500 @@ +238 val_238 +86 val_86 +311 val_311 +27 val_27 +165 val_165 +409 val_409 +255 val_255 +278 val_278 +98 val_98 +484 val_484 +265 val_265 +193 val_193 +401 val_401 +150 val_150 +273 val_273 +224 val_224 +369 val_369 +66 val_66 +128 val_128 +213 val_213 +146 val_146 +406 val_406 +429 val_429 +374 val_374 +152 val_152 +469 val_469 +145 val_145 +495 val_495 +37 val_37 +327 val_327 +281 val_281 +277 val_277 +209 val_209 +15 val_15 +82 val_82 +403 val_403 +166 val_166 +417 val_417 +430 val_430 +252 val_252 +292 val_292 +219 val_219 +287 val_287 +153 val_153 +193 val_193 +338 val_338 +446 val_446 +459 val_459 +394 val_394 +237 val_237 +482 val_482 +174 val_174 +413 val_413 +494 val_494 +207 val_207 +199 val_199 +466 val_466 +208 val_208 +174 val_174 +399 val_399 +396 val_396 +247 val_247 +417 val_417 +489 val_489 +162 val_162 +377 val_377 +397 val_397 +309 val_309 +365 val_365 +266 val_266 +439 val_439 +342 val_342 +367 val_367 +325 val_325 +167 val_167 +195 val_195 +475 val_475 +17 val_17 +113 val_113 +155 val_155 +203 val_203 +339 val_339 +0 val_0 +455 val_455 +128 val_128 +311 val_311 +316 val_316 +57 val_57 +302 val_302 +205 val_205 +149 val_149 +438 val_438 +345 val_345 +129 val_129 +170 val_170 +20 val_20 +489 val_489 +157 val_157 +378 val_378 +221 val_221 +92 val_92 +111 val_111 +47 val_47 +72 val_72 +4 val_4 +280 val_280 +35 val_35 +427 val_427 +277 val_277 +208 val_208 +356 val_356 +399 val_399 +169 val_169 +382 val_382 +498 val_498 +125 val_125 +386 val_386 +437 val_437 +469 val_469 +192 val_192 +286 val_286 +187 val_187 +176 val_176 +54 val_54 +459 val_459 +51 val_51 +138 val_138 +103 val_103 +239 val_239 +213 val_213 +216 val_216 +430 val_430 +278 val_278 +176 val_176 +289 val_289 +221 val_221 +65 val_65 +318 val_318 +332 val_332 +311 val_311 +275 val_275 +137 val_137 +241 val_241 +83 val_83 +333 val_333 +180 val_180 +284 val_284 +12 val_12 +230 val_230 +181 val_181 +67 val_67 +260 val_260 +404 val_404 +384 val_384 +489 val_489 +353 val_353 +373 val_373 +272 val_272 +138 val_138 +217 val_217 +84 val_84 +348 val_348 +466 val_466 +58 val_58 +8 val_8 +411 val_411 +230 val_230 +208 val_208 +348 val_348 +24 val_24 +463 val_463 +431 val_431 +179 val_179 +172 val_172 +42 val_42 +129 val_129 +158 val_158 +119 val_119 +496 val_496 +0 val_0 +322 val_322 +197 val_197 +468 val_468 +393 val_393 +454 val_454 +100 val_100 +298 val_298 +199 val_199 +191 val_191 +418 val_418 +96 val_96 +26 val_26 +165 val_165 +327 val_327 +230 val_230 +205 val_205 +120 val_120 +131 val_131 +51 val_51 +404 val_404 +43 val_43 +436 val_436 +156 val_156 +469 val_469 +468 val_468 +308 val_308 +95 val_95 +196 val_196 +288 val_288 +481 val_481 +457 val_457 +98 val_98 +282 val_282 +197 val_197 +187 val_187 +318 val_318 +318 val_318 +409 val_409 +470 val_470 +137 val_137 +369 val_369 +316 val_316 +169 val_169 +413 val_413 +85 val_85 +77 val_77 +0 val_0 +490 val_490 +87 val_87 +364 val_364 +179 val_179 +118 val_118 +134 val_134 +395 val_395 +282 val_282 +138 val_138 +238 val_238 +419 val_419 +15 val_15 +118 val_118 +72 val_72 +90 val_90 +307 val_307 +19 val_19 +435 val_435 +10 val_10 +277 val_277 +273 val_273 +306 val_306 +224 val_224 +309 val_309 +389 val_389 +327 val_327 +242 val_242 +369 val_369 +392 val_392 +272 val_272 +331 val_331 +401 val_401 +242 val_242 +452 val_452 +177 val_177 +226 val_226 +5 val_5 +497 val_497 +402 val_402 +396 val_396 +317 val_317 +395 val_395 +58 val_58 +35 val_35 +336 val_336 +95 val_95 +11 val_11 +168 val_168 +34 val_34 +229 val_229 +233 val_233 +143 val_143 +472 val_472 +322 val_322 +498 val_498 +160 val_160 +195 val_195 +42 val_42 +321 val_321 +430 val_430 +119 val_119 +489 val_489 +458 val_458 +78 val_78 +76 val_76 +41 val_41 +223 val_223 +492 val_492 +149 val_149 +449 val_449 +218 val_218 +228 val_228 +138 val_138 +453 val_453 +30 val_30 +209 val_209 +64 val_64 +468 val_468 +76 val_76 +74 val_74 +342 val_342 +69 val_69 +230 val_230 +33 val_33 +368 val_368 +103 val_103 +296 val_296 +113 val_113 +216 val_216 +367 val_367 +344 val_344 +167 val_167 +274 val_274 +219 val_219 +239 val_239 +485 val_485 +116 val_116 +223 val_223 +256 val_256 +263 val_263 +70 val_70 +487 val_487 +480 val_480 +401 val_401 +288 val_288 +191 val_191 +5 val_5 +244 val_244 +438 val_438 +128 val_128 +467 val_467 +432 val_432 +202 val_202 +316 val_316 +229 val_229 +469 val_469 +463 val_463 +280 val_280 +2 val_2 +35 val_35 +283 val_283 +331 val_331 +235 val_235 +80 val_80 +44 val_44 +193 val_193 +321 val_321 +335 val_335 +104 val_104 +466 val_466 +366 val_366 +175 val_175 +403 val_403 +483 val_483 +53 val_53 +105 val_105 +257 val_257 +406 val_406 +409 val_409 +190 val_190 +406 val_406 +401 val_401 +114 val_114 +258 val_258 +90 val_90 +203 val_203 +262 val_262 +348 val_348 +424 val_424 +12 val_12 +396 val_396 +201 val_201 +217 val_217 +164 val_164 +431 val_431 +454 val_454 +478 val_478 +298 val_298 +125 val_125 +431 val_431 +164 val_164 +424 val_424 +187 val_187 +382 val_382 +5 val_5 +70 val_70 +397 val_397 +480 val_480 +291 val_291 +24 val_24 +351 val_351 +255 val_255 +104 val_104 +70 val_70 +163 val_163 +438 val_438 +119 val_119 +414 val_414 +200 val_200 +491 val_491 +237 val_237 +439 val_439 +360 val_360 +248 val_248 +479 val_479 +305 val_305 +417 val_417 +199 val_199 +444 val_444 +120 val_120 +429 val_429 +169 val_169 +443 val_443 +323 val_323 +325 val_325 +277 val_277 +230 val_230 +478 val_478 +178 val_178 +468 val_468 +310 val_310 +317 val_317 +333 val_333 +493 val_493 +460 val_460 +207 val_207 +249 val_249 +265 val_265 +480 val_480 +83 val_83 +136 val_136 +353 val_353 +172 val_172 +214 val_214 +462 val_462 +233 val_233 +406 val_406 +133 val_133 +175 val_175 +189 val_189 +454 val_454 +375 val_375 +401 val_401 +421 val_421 +407 val_407 +384 val_384 +256 val_256 +26 val_26 +134 val_134 +67 val_67 +384 val_384 +379 val_379 +18 val_18 +462 val_462 +492 val_492 +100 val_100 +298 val_298 +9 val_9 +341 val_341 +498 val_498 +146 val_146 +458 val_458 +362 val_362 +186 val_186 +285 val_285 +348 val_348 +167 val_167 +18 val_18 +273 val_273 +183 val_183 +281 val_281 +344 val_344 +97 val_97 +469 val_469 +315 val_315 +84 val_84 +28 val_28 +37 val_37 +448 val_448 +152 val_152 +348 val_348 +307 val_307 +194 val_194 +414 val_414 +477 val_477 +222 val_222 +126 val_126 +90 val_90 +169 val_169 +403 val_403 +400 val_400 +200 val_200 +97 val_97 diff --git a/sql/hive/src/test/resources/golden/transform with SerDe3-0-58a8b7eb07a949bc44dccb723222957f b/sql/hive/src/test/resources/golden/transform with SerDe3-0-58a8b7eb07a949bc44dccb723222957f new file mode 100644 index 000000000000..7aae61e5eb82 --- /dev/null +++ b/sql/hive/src/test/resources/golden/transform with SerDe3-0-58a8b7eb07a949bc44dccb723222957f @@ -0,0 +1,500 @@ +238 val_238 +86 val_86 +311 val_311 +27 val_27 +165 val_165 +409 val_409 +255 val_255 +278 val_278 +98 val_98 +484 val_484 +265 val_265 +193 val_193 +401 val_401 +150 val_150 +273 val_273 +224 val_224 +369 val_369 +66 val_66 +128 val_128 +213 val_213 +146 val_146 +406 val_406 +429 val_429 +374 val_374 +152 val_152 +469 val_469 +145 val_145 +495 val_495 +37 val_37 +327 val_327 +281 val_281 +277 val_277 +209 val_209 +15 val_15 +82 val_82 +403 val_403 +166 val_166 +417 val_417 +430 val_430 +252 val_252 +292 val_292 +219 val_219 +287 val_287 +153 val_153 +193 val_193 +338 val_338 +446 val_446 +459 val_459 +394 val_394 +237 val_237 +482 val_482 +174 val_174 +413 val_413 +494 val_494 +207 val_207 +199 val_199 +466 val_466 +208 val_208 +174 val_174 +399 val_399 +396 val_396 +247 val_247 +417 val_417 +489 val_489 +162 val_162 +377 val_377 +397 val_397 +309 val_309 +365 val_365 +266 val_266 +439 val_439 +342 val_342 +367 val_367 +325 val_325 +167 val_167 +195 val_195 +475 val_475 +17 val_17 +113 val_113 +155 val_155 +203 val_203 +339 val_339 +0 val_0 +455 val_455 +128 val_128 +311 val_311 +316 val_316 +57 val_57 +302 val_302 +205 val_205 +149 val_149 +438 val_438 +345 val_345 +129 val_129 +170 val_170 +20 val_20 +489 val_489 +157 val_157 +378 val_378 +221 val_221 +92 val_92 +111 val_111 +47 val_47 +72 val_72 +4 val_4 +280 val_280 +35 val_35 +427 val_427 +277 val_277 +208 val_208 +356 val_356 +399 val_399 +169 val_169 +382 val_382 +498 val_498 +125 val_125 +386 val_386 +437 val_437 +469 val_469 +192 val_192 +286 val_286 +187 val_187 +176 val_176 +54 val_54 +459 val_459 +51 val_51 +138 val_138 +103 val_103 +239 val_239 +213 val_213 +216 val_216 +430 val_430 +278 val_278 +176 val_176 +289 val_289 +221 val_221 +65 val_65 +318 val_318 +332 val_332 +311 val_311 +275 val_275 +137 val_137 +241 val_241 +83 val_83 +333 val_333 +180 val_180 +284 val_284 +12 val_12 +230 val_230 +181 val_181 +67 val_67 +260 val_260 +404 val_404 +384 val_384 +489 val_489 +353 val_353 +373 val_373 +272 val_272 +138 val_138 +217 val_217 +84 val_84 +348 val_348 +466 val_466 +58 val_58 +8 val_8 +411 val_411 +230 val_230 +208 val_208 +348 val_348 +24 val_24 +463 val_463 +431 val_431 +179 val_179 +172 val_172 +42 val_42 +129 val_129 +158 val_158 +119 val_119 +496 val_496 +0 val_0 +322 val_322 +197 val_197 +468 val_468 +393 val_393 +454 val_454 +100 val_100 +298 val_298 +199 val_199 +191 val_191 +418 val_418 +96 val_96 +26 val_26 +165 val_165 +327 val_327 +230 val_230 +205 val_205 +120 val_120 +131 val_131 +51 val_51 +404 val_404 +43 val_43 +436 val_436 +156 val_156 +469 val_469 +468 val_468 +308 val_308 +95 val_95 +196 val_196 +288 val_288 +481 val_481 +457 val_457 +98 val_98 +282 val_282 +197 val_197 +187 val_187 +318 val_318 +318 val_318 +409 val_409 +470 val_470 +137 val_137 +369 val_369 +316 val_316 +169 val_169 +413 val_413 +85 val_85 +77 val_77 +0 val_0 +490 val_490 +87 val_87 +364 val_364 +179 val_179 +118 val_118 +134 val_134 +395 val_395 +282 val_282 +138 val_138 +238 val_238 +419 val_419 +15 val_15 +118 val_118 +72 val_72 +90 val_90 +307 val_307 +19 val_19 +435 val_435 +10 val_10 +277 val_277 +273 val_273 +306 val_306 +224 val_224 +309 val_309 +389 val_389 +327 val_327 +242 val_242 +369 val_369 +392 val_392 +272 val_272 +331 val_331 +401 val_401 +242 val_242 +452 val_452 +177 val_177 +226 val_226 +5 val_5 +497 val_497 +402 val_402 +396 val_396 +317 val_317 +395 val_395 +58 val_58 +35 val_35 +336 val_336 +95 val_95 +11 val_11 +168 val_168 +34 val_34 +229 val_229 +233 val_233 +143 val_143 +472 val_472 +322 val_322 +498 val_498 +160 val_160 +195 val_195 +42 val_42 +321 val_321 +430 val_430 +119 val_119 +489 val_489 +458 val_458 +78 val_78 +76 val_76 +41 val_41 +223 val_223 +492 val_492 +149 val_149 +449 val_449 +218 val_218 +228 val_228 +138 val_138 +453 val_453 +30 val_30 +209 val_209 +64 val_64 +468 val_468 +76 val_76 +74 val_74 +342 val_342 +69 val_69 +230 val_230 +33 val_33 +368 val_368 +103 val_103 +296 val_296 +113 val_113 +216 val_216 +367 val_367 +344 val_344 +167 val_167 +274 val_274 +219 val_219 +239 val_239 +485 val_485 +116 val_116 +223 val_223 +256 val_256 +263 val_263 +70 val_70 +487 val_487 +480 val_480 +401 val_401 +288 val_288 +191 val_191 +5 val_5 +244 val_244 +438 val_438 +128 val_128 +467 val_467 +432 val_432 +202 val_202 +316 val_316 +229 val_229 +469 val_469 +463 val_463 +280 val_280 +2 val_2 +35 val_35 +283 val_283 +331 val_331 +235 val_235 +80 val_80 +44 val_44 +193 val_193 +321 val_321 +335 val_335 +104 val_104 +466 val_466 +366 val_366 +175 val_175 +403 val_403 +483 val_483 +53 val_53 +105 val_105 +257 val_257 +406 val_406 +409 val_409 +190 val_190 +406 val_406 +401 val_401 +114 val_114 +258 val_258 +90 val_90 +203 val_203 +262 val_262 +348 val_348 +424 val_424 +12 val_12 +396 val_396 +201 val_201 +217 val_217 +164 val_164 +431 val_431 +454 val_454 +478 val_478 +298 val_298 +125 val_125 +431 val_431 +164 val_164 +424 val_424 +187 val_187 +382 val_382 +5 val_5 +70 val_70 +397 val_397 +480 val_480 +291 val_291 +24 val_24 +351 val_351 +255 val_255 +104 val_104 +70 val_70 +163 val_163 +438 val_438 +119 val_119 +414 val_414 +200 val_200 +491 val_491 +237 val_237 +439 val_439 +360 val_360 +248 val_248 +479 val_479 +305 val_305 +417 val_417 +199 val_199 +444 val_444 +120 val_120 +429 val_429 +169 val_169 +443 val_443 +323 val_323 +325 val_325 +277 val_277 +230 val_230 +478 val_478 +178 val_178 +468 val_468 +310 val_310 +317 val_317 +333 val_333 +493 val_493 +460 val_460 +207 val_207 +249 val_249 +265 val_265 +480 val_480 +83 val_83 +136 val_136 +353 val_353 +172 val_172 +214 val_214 +462 val_462 +233 val_233 +406 val_406 +133 val_133 +175 val_175 +189 val_189 +454 val_454 +375 val_375 +401 val_401 +421 val_421 +407 val_407 +384 val_384 +256 val_256 +26 val_26 +134 val_134 +67 val_67 +384 val_384 +379 val_379 +18 val_18 +462 val_462 +492 val_492 +100 val_100 +298 val_298 +9 val_9 +341 val_341 +498 val_498 +146 val_146 +458 val_458 +362 val_362 +186 val_186 +285 val_285 +348 val_348 +167 val_167 +18 val_18 +273 val_273 +183 val_183 +281 val_281 +344 val_344 +97 val_97 +469 val_469 +315 val_315 +84 val_84 +28 val_28 +37 val_37 +448 val_448 +152 val_152 +348 val_348 +307 val_307 +194 val_194 +414 val_414 +477 val_477 +222 val_222 +126 val_126 +90 val_90 +169 val_169 +403 val_403 +400 val_400 +200 val_200 +97 val_97 diff --git a/sql/hive/src/test/resources/golden/transform with SerDe4-0-ba9ad2499a7408cb350c7abafaf9ea97 b/sql/hive/src/test/resources/golden/transform with SerDe4-0-ba9ad2499a7408cb350c7abafaf9ea97 new file mode 100644 index 000000000000..7aae61e5eb82 --- /dev/null +++ b/sql/hive/src/test/resources/golden/transform with SerDe4-0-ba9ad2499a7408cb350c7abafaf9ea97 @@ -0,0 +1,500 @@ +238 val_238 +86 val_86 +311 val_311 +27 val_27 +165 val_165 +409 val_409 +255 val_255 +278 val_278 +98 val_98 +484 val_484 +265 val_265 +193 val_193 +401 val_401 +150 val_150 +273 val_273 +224 val_224 +369 val_369 +66 val_66 +128 val_128 +213 val_213 +146 val_146 +406 val_406 +429 val_429 +374 val_374 +152 val_152 +469 val_469 +145 val_145 +495 val_495 +37 val_37 +327 val_327 +281 val_281 +277 val_277 +209 val_209 +15 val_15 +82 val_82 +403 val_403 +166 val_166 +417 val_417 +430 val_430 +252 val_252 +292 val_292 +219 val_219 +287 val_287 +153 val_153 +193 val_193 +338 val_338 +446 val_446 +459 val_459 +394 val_394 +237 val_237 +482 val_482 +174 val_174 +413 val_413 +494 val_494 +207 val_207 +199 val_199 +466 val_466 +208 val_208 +174 val_174 +399 val_399 +396 val_396 +247 val_247 +417 val_417 +489 val_489 +162 val_162 +377 val_377 +397 val_397 +309 val_309 +365 val_365 +266 val_266 +439 val_439 +342 val_342 +367 val_367 +325 val_325 +167 val_167 +195 val_195 +475 val_475 +17 val_17 +113 val_113 +155 val_155 +203 val_203 +339 val_339 +0 val_0 +455 val_455 +128 val_128 +311 val_311 +316 val_316 +57 val_57 +302 val_302 +205 val_205 +149 val_149 +438 val_438 +345 val_345 +129 val_129 +170 val_170 +20 val_20 +489 val_489 +157 val_157 +378 val_378 +221 val_221 +92 val_92 +111 val_111 +47 val_47 +72 val_72 +4 val_4 +280 val_280 +35 val_35 +427 val_427 +277 val_277 +208 val_208 +356 val_356 +399 val_399 +169 val_169 +382 val_382 +498 val_498 +125 val_125 +386 val_386 +437 val_437 +469 val_469 +192 val_192 +286 val_286 +187 val_187 +176 val_176 +54 val_54 +459 val_459 +51 val_51 +138 val_138 +103 val_103 +239 val_239 +213 val_213 +216 val_216 +430 val_430 +278 val_278 +176 val_176 +289 val_289 +221 val_221 +65 val_65 +318 val_318 +332 val_332 +311 val_311 +275 val_275 +137 val_137 +241 val_241 +83 val_83 +333 val_333 +180 val_180 +284 val_284 +12 val_12 +230 val_230 +181 val_181 +67 val_67 +260 val_260 +404 val_404 +384 val_384 +489 val_489 +353 val_353 +373 val_373 +272 val_272 +138 val_138 +217 val_217 +84 val_84 +348 val_348 +466 val_466 +58 val_58 +8 val_8 +411 val_411 +230 val_230 +208 val_208 +348 val_348 +24 val_24 +463 val_463 +431 val_431 +179 val_179 +172 val_172 +42 val_42 +129 val_129 +158 val_158 +119 val_119 +496 val_496 +0 val_0 +322 val_322 +197 val_197 +468 val_468 +393 val_393 +454 val_454 +100 val_100 +298 val_298 +199 val_199 +191 val_191 +418 val_418 +96 val_96 +26 val_26 +165 val_165 +327 val_327 +230 val_230 +205 val_205 +120 val_120 +131 val_131 +51 val_51 +404 val_404 +43 val_43 +436 val_436 +156 val_156 +469 val_469 +468 val_468 +308 val_308 +95 val_95 +196 val_196 +288 val_288 +481 val_481 +457 val_457 +98 val_98 +282 val_282 +197 val_197 +187 val_187 +318 val_318 +318 val_318 +409 val_409 +470 val_470 +137 val_137 +369 val_369 +316 val_316 +169 val_169 +413 val_413 +85 val_85 +77 val_77 +0 val_0 +490 val_490 +87 val_87 +364 val_364 +179 val_179 +118 val_118 +134 val_134 +395 val_395 +282 val_282 +138 val_138 +238 val_238 +419 val_419 +15 val_15 +118 val_118 +72 val_72 +90 val_90 +307 val_307 +19 val_19 +435 val_435 +10 val_10 +277 val_277 +273 val_273 +306 val_306 +224 val_224 +309 val_309 +389 val_389 +327 val_327 +242 val_242 +369 val_369 +392 val_392 +272 val_272 +331 val_331 +401 val_401 +242 val_242 +452 val_452 +177 val_177 +226 val_226 +5 val_5 +497 val_497 +402 val_402 +396 val_396 +317 val_317 +395 val_395 +58 val_58 +35 val_35 +336 val_336 +95 val_95 +11 val_11 +168 val_168 +34 val_34 +229 val_229 +233 val_233 +143 val_143 +472 val_472 +322 val_322 +498 val_498 +160 val_160 +195 val_195 +42 val_42 +321 val_321 +430 val_430 +119 val_119 +489 val_489 +458 val_458 +78 val_78 +76 val_76 +41 val_41 +223 val_223 +492 val_492 +149 val_149 +449 val_449 +218 val_218 +228 val_228 +138 val_138 +453 val_453 +30 val_30 +209 val_209 +64 val_64 +468 val_468 +76 val_76 +74 val_74 +342 val_342 +69 val_69 +230 val_230 +33 val_33 +368 val_368 +103 val_103 +296 val_296 +113 val_113 +216 val_216 +367 val_367 +344 val_344 +167 val_167 +274 val_274 +219 val_219 +239 val_239 +485 val_485 +116 val_116 +223 val_223 +256 val_256 +263 val_263 +70 val_70 +487 val_487 +480 val_480 +401 val_401 +288 val_288 +191 val_191 +5 val_5 +244 val_244 +438 val_438 +128 val_128 +467 val_467 +432 val_432 +202 val_202 +316 val_316 +229 val_229 +469 val_469 +463 val_463 +280 val_280 +2 val_2 +35 val_35 +283 val_283 +331 val_331 +235 val_235 +80 val_80 +44 val_44 +193 val_193 +321 val_321 +335 val_335 +104 val_104 +466 val_466 +366 val_366 +175 val_175 +403 val_403 +483 val_483 +53 val_53 +105 val_105 +257 val_257 +406 val_406 +409 val_409 +190 val_190 +406 val_406 +401 val_401 +114 val_114 +258 val_258 +90 val_90 +203 val_203 +262 val_262 +348 val_348 +424 val_424 +12 val_12 +396 val_396 +201 val_201 +217 val_217 +164 val_164 +431 val_431 +454 val_454 +478 val_478 +298 val_298 +125 val_125 +431 val_431 +164 val_164 +424 val_424 +187 val_187 +382 val_382 +5 val_5 +70 val_70 +397 val_397 +480 val_480 +291 val_291 +24 val_24 +351 val_351 +255 val_255 +104 val_104 +70 val_70 +163 val_163 +438 val_438 +119 val_119 +414 val_414 +200 val_200 +491 val_491 +237 val_237 +439 val_439 +360 val_360 +248 val_248 +479 val_479 +305 val_305 +417 val_417 +199 val_199 +444 val_444 +120 val_120 +429 val_429 +169 val_169 +443 val_443 +323 val_323 +325 val_325 +277 val_277 +230 val_230 +478 val_478 +178 val_178 +468 val_468 +310 val_310 +317 val_317 +333 val_333 +493 val_493 +460 val_460 +207 val_207 +249 val_249 +265 val_265 +480 val_480 +83 val_83 +136 val_136 +353 val_353 +172 val_172 +214 val_214 +462 val_462 +233 val_233 +406 val_406 +133 val_133 +175 val_175 +189 val_189 +454 val_454 +375 val_375 +401 val_401 +421 val_421 +407 val_407 +384 val_384 +256 val_256 +26 val_26 +134 val_134 +67 val_67 +384 val_384 +379 val_379 +18 val_18 +462 val_462 +492 val_492 +100 val_100 +298 val_298 +9 val_9 +341 val_341 +498 val_498 +146 val_146 +458 val_458 +362 val_362 +186 val_186 +285 val_285 +348 val_348 +167 val_167 +18 val_18 +273 val_273 +183 val_183 +281 val_281 +344 val_344 +97 val_97 +469 val_469 +315 val_315 +84 val_84 +28 val_28 +37 val_37 +448 val_448 +152 val_152 +348 val_348 +307 val_307 +194 val_194 +414 val_414 +477 val_477 +222 val_222 +126 val_126 +90 val_90 +169 val_169 +403 val_403 +400 val_400 +200 val_200 +97 val_97 diff --git a/sql/hive/src/test/resources/golden/transform with custom field delimiter-0-703cca3c02ced422feb11dc13b744484 b/sql/hive/src/test/resources/golden/transform with custom field delimiter-0-703cca3c02ced422feb11dc13b744484 new file mode 100644 index 000000000000..e34118512c1d --- /dev/null +++ b/sql/hive/src/test/resources/golden/transform with custom field delimiter-0-703cca3c02ced422feb11dc13b744484 @@ -0,0 +1,500 @@ +238 +86 +311 +27 +165 +409 +255 +278 +98 +484 +265 +193 +401 +150 +273 +224 +369 +66 +128 +213 +146 +406 +429 +374 +152 +469 +145 +495 +37 +327 +281 +277 +209 +15 +82 +403 +166 +417 +430 +252 +292 +219 +287 +153 +193 +338 +446 +459 +394 +237 +482 +174 +413 +494 +207 +199 +466 +208 +174 +399 +396 +247 +417 +489 +162 +377 +397 +309 +365 +266 +439 +342 +367 +325 +167 +195 +475 +17 +113 +155 +203 +339 +0 +455 +128 +311 +316 +57 +302 +205 +149 +438 +345 +129 +170 +20 +489 +157 +378 +221 +92 +111 +47 +72 +4 +280 +35 +427 +277 +208 +356 +399 +169 +382 +498 +125 +386 +437 +469 +192 +286 +187 +176 +54 +459 +51 +138 +103 +239 +213 +216 +430 +278 +176 +289 +221 +65 +318 +332 +311 +275 +137 +241 +83 +333 +180 +284 +12 +230 +181 +67 +260 +404 +384 +489 +353 +373 +272 +138 +217 +84 +348 +466 +58 +8 +411 +230 +208 +348 +24 +463 +431 +179 +172 +42 +129 +158 +119 +496 +0 +322 +197 +468 +393 +454 +100 +298 +199 +191 +418 +96 +26 +165 +327 +230 +205 +120 +131 +51 +404 +43 +436 +156 +469 +468 +308 +95 +196 +288 +481 +457 +98 +282 +197 +187 +318 +318 +409 +470 +137 +369 +316 +169 +413 +85 +77 +0 +490 +87 +364 +179 +118 +134 +395 +282 +138 +238 +419 +15 +118 +72 +90 +307 +19 +435 +10 +277 +273 +306 +224 +309 +389 +327 +242 +369 +392 +272 +331 +401 +242 +452 +177 +226 +5 +497 +402 +396 +317 +395 +58 +35 +336 +95 +11 +168 +34 +229 +233 +143 +472 +322 +498 +160 +195 +42 +321 +430 +119 +489 +458 +78 +76 +41 +223 +492 +149 +449 +218 +228 +138 +453 +30 +209 +64 +468 +76 +74 +342 +69 +230 +33 +368 +103 +296 +113 +216 +367 +344 +167 +274 +219 +239 +485 +116 +223 +256 +263 +70 +487 +480 +401 +288 +191 +5 +244 +438 +128 +467 +432 +202 +316 +229 +469 +463 +280 +2 +35 +283 +331 +235 +80 +44 +193 +321 +335 +104 +466 +366 +175 +403 +483 +53 +105 +257 +406 +409 +190 +406 +401 +114 +258 +90 +203 +262 +348 +424 +12 +396 +201 +217 +164 +431 +454 +478 +298 +125 +431 +164 +424 +187 +382 +5 +70 +397 +480 +291 +24 +351 +255 +104 +70 +163 +438 +119 +414 +200 +491 +237 +439 +360 +248 +479 +305 +417 +199 +444 +120 +429 +169 +443 +323 +325 +277 +230 +478 +178 +468 +310 +317 +333 +493 +460 +207 +249 +265 +480 +83 +136 +353 +172 +214 +462 +233 +406 +133 +175 +189 +454 +375 +401 +421 +407 +384 +256 +26 +134 +67 +384 +379 +18 +462 +492 +100 +298 +9 +341 +498 +146 +458 +362 +186 +285 +348 +167 +18 +273 +183 +281 +344 +97 +469 +315 +84 +28 +37 +448 +152 +348 +307 +194 +414 +477 +222 +126 +90 +169 +403 +400 +200 +97 diff --git a/sql/hive/src/test/resources/golden/transform with custom field delimiter-0-82639dda9ba42df817466dffe2929174 b/sql/hive/src/test/resources/golden/transform with custom field delimiter-0-82639dda9ba42df817466dffe2929174 new file mode 100644 index 000000000000..e34118512c1d --- /dev/null +++ b/sql/hive/src/test/resources/golden/transform with custom field delimiter-0-82639dda9ba42df817466dffe2929174 @@ -0,0 +1,500 @@ +238 +86 +311 +27 +165 +409 +255 +278 +98 +484 +265 +193 +401 +150 +273 +224 +369 +66 +128 +213 +146 +406 +429 +374 +152 +469 +145 +495 +37 +327 +281 +277 +209 +15 +82 +403 +166 +417 +430 +252 +292 +219 +287 +153 +193 +338 +446 +459 +394 +237 +482 +174 +413 +494 +207 +199 +466 +208 +174 +399 +396 +247 +417 +489 +162 +377 +397 +309 +365 +266 +439 +342 +367 +325 +167 +195 +475 +17 +113 +155 +203 +339 +0 +455 +128 +311 +316 +57 +302 +205 +149 +438 +345 +129 +170 +20 +489 +157 +378 +221 +92 +111 +47 +72 +4 +280 +35 +427 +277 +208 +356 +399 +169 +382 +498 +125 +386 +437 +469 +192 +286 +187 +176 +54 +459 +51 +138 +103 +239 +213 +216 +430 +278 +176 +289 +221 +65 +318 +332 +311 +275 +137 +241 +83 +333 +180 +284 +12 +230 +181 +67 +260 +404 +384 +489 +353 +373 +272 +138 +217 +84 +348 +466 +58 +8 +411 +230 +208 +348 +24 +463 +431 +179 +172 +42 +129 +158 +119 +496 +0 +322 +197 +468 +393 +454 +100 +298 +199 +191 +418 +96 +26 +165 +327 +230 +205 +120 +131 +51 +404 +43 +436 +156 +469 +468 +308 +95 +196 +288 +481 +457 +98 +282 +197 +187 +318 +318 +409 +470 +137 +369 +316 +169 +413 +85 +77 +0 +490 +87 +364 +179 +118 +134 +395 +282 +138 +238 +419 +15 +118 +72 +90 +307 +19 +435 +10 +277 +273 +306 +224 +309 +389 +327 +242 +369 +392 +272 +331 +401 +242 +452 +177 +226 +5 +497 +402 +396 +317 +395 +58 +35 +336 +95 +11 +168 +34 +229 +233 +143 +472 +322 +498 +160 +195 +42 +321 +430 +119 +489 +458 +78 +76 +41 +223 +492 +149 +449 +218 +228 +138 +453 +30 +209 +64 +468 +76 +74 +342 +69 +230 +33 +368 +103 +296 +113 +216 +367 +344 +167 +274 +219 +239 +485 +116 +223 +256 +263 +70 +487 +480 +401 +288 +191 +5 +244 +438 +128 +467 +432 +202 +316 +229 +469 +463 +280 +2 +35 +283 +331 +235 +80 +44 +193 +321 +335 +104 +466 +366 +175 +403 +483 +53 +105 +257 +406 +409 +190 +406 +401 +114 +258 +90 +203 +262 +348 +424 +12 +396 +201 +217 +164 +431 +454 +478 +298 +125 +431 +164 +424 +187 +382 +5 +70 +397 +480 +291 +24 +351 +255 +104 +70 +163 +438 +119 +414 +200 +491 +237 +439 +360 +248 +479 +305 +417 +199 +444 +120 +429 +169 +443 +323 +325 +277 +230 +478 +178 +468 +310 +317 +333 +493 +460 +207 +249 +265 +480 +83 +136 +353 +172 +214 +462 +233 +406 +133 +175 +189 +454 +375 +401 +421 +407 +384 +256 +26 +134 +67 +384 +379 +18 +462 +492 +100 +298 +9 +341 +498 +146 +458 +362 +186 +285 +348 +167 +18 +273 +183 +281 +344 +97 +469 +315 +84 +28 +37 +448 +152 +348 +307 +194 +414 +477 +222 +126 +90 +169 +403 +400 +200 +97 diff --git a/sql/hive/src/test/resources/golden/transform with custom field delimiter2-0-e8713b21483e1efb78ee90b61530479b b/sql/hive/src/test/resources/golden/transform with custom field delimiter2-0-e8713b21483e1efb78ee90b61530479b new file mode 100644 index 000000000000..7aae61e5eb82 --- /dev/null +++ b/sql/hive/src/test/resources/golden/transform with custom field delimiter2-0-e8713b21483e1efb78ee90b61530479b @@ -0,0 +1,500 @@ +238 val_238 +86 val_86 +311 val_311 +27 val_27 +165 val_165 +409 val_409 +255 val_255 +278 val_278 +98 val_98 +484 val_484 +265 val_265 +193 val_193 +401 val_401 +150 val_150 +273 val_273 +224 val_224 +369 val_369 +66 val_66 +128 val_128 +213 val_213 +146 val_146 +406 val_406 +429 val_429 +374 val_374 +152 val_152 +469 val_469 +145 val_145 +495 val_495 +37 val_37 +327 val_327 +281 val_281 +277 val_277 +209 val_209 +15 val_15 +82 val_82 +403 val_403 +166 val_166 +417 val_417 +430 val_430 +252 val_252 +292 val_292 +219 val_219 +287 val_287 +153 val_153 +193 val_193 +338 val_338 +446 val_446 +459 val_459 +394 val_394 +237 val_237 +482 val_482 +174 val_174 +413 val_413 +494 val_494 +207 val_207 +199 val_199 +466 val_466 +208 val_208 +174 val_174 +399 val_399 +396 val_396 +247 val_247 +417 val_417 +489 val_489 +162 val_162 +377 val_377 +397 val_397 +309 val_309 +365 val_365 +266 val_266 +439 val_439 +342 val_342 +367 val_367 +325 val_325 +167 val_167 +195 val_195 +475 val_475 +17 val_17 +113 val_113 +155 val_155 +203 val_203 +339 val_339 +0 val_0 +455 val_455 +128 val_128 +311 val_311 +316 val_316 +57 val_57 +302 val_302 +205 val_205 +149 val_149 +438 val_438 +345 val_345 +129 val_129 +170 val_170 +20 val_20 +489 val_489 +157 val_157 +378 val_378 +221 val_221 +92 val_92 +111 val_111 +47 val_47 +72 val_72 +4 val_4 +280 val_280 +35 val_35 +427 val_427 +277 val_277 +208 val_208 +356 val_356 +399 val_399 +169 val_169 +382 val_382 +498 val_498 +125 val_125 +386 val_386 +437 val_437 +469 val_469 +192 val_192 +286 val_286 +187 val_187 +176 val_176 +54 val_54 +459 val_459 +51 val_51 +138 val_138 +103 val_103 +239 val_239 +213 val_213 +216 val_216 +430 val_430 +278 val_278 +176 val_176 +289 val_289 +221 val_221 +65 val_65 +318 val_318 +332 val_332 +311 val_311 +275 val_275 +137 val_137 +241 val_241 +83 val_83 +333 val_333 +180 val_180 +284 val_284 +12 val_12 +230 val_230 +181 val_181 +67 val_67 +260 val_260 +404 val_404 +384 val_384 +489 val_489 +353 val_353 +373 val_373 +272 val_272 +138 val_138 +217 val_217 +84 val_84 +348 val_348 +466 val_466 +58 val_58 +8 val_8 +411 val_411 +230 val_230 +208 val_208 +348 val_348 +24 val_24 +463 val_463 +431 val_431 +179 val_179 +172 val_172 +42 val_42 +129 val_129 +158 val_158 +119 val_119 +496 val_496 +0 val_0 +322 val_322 +197 val_197 +468 val_468 +393 val_393 +454 val_454 +100 val_100 +298 val_298 +199 val_199 +191 val_191 +418 val_418 +96 val_96 +26 val_26 +165 val_165 +327 val_327 +230 val_230 +205 val_205 +120 val_120 +131 val_131 +51 val_51 +404 val_404 +43 val_43 +436 val_436 +156 val_156 +469 val_469 +468 val_468 +308 val_308 +95 val_95 +196 val_196 +288 val_288 +481 val_481 +457 val_457 +98 val_98 +282 val_282 +197 val_197 +187 val_187 +318 val_318 +318 val_318 +409 val_409 +470 val_470 +137 val_137 +369 val_369 +316 val_316 +169 val_169 +413 val_413 +85 val_85 +77 val_77 +0 val_0 +490 val_490 +87 val_87 +364 val_364 +179 val_179 +118 val_118 +134 val_134 +395 val_395 +282 val_282 +138 val_138 +238 val_238 +419 val_419 +15 val_15 +118 val_118 +72 val_72 +90 val_90 +307 val_307 +19 val_19 +435 val_435 +10 val_10 +277 val_277 +273 val_273 +306 val_306 +224 val_224 +309 val_309 +389 val_389 +327 val_327 +242 val_242 +369 val_369 +392 val_392 +272 val_272 +331 val_331 +401 val_401 +242 val_242 +452 val_452 +177 val_177 +226 val_226 +5 val_5 +497 val_497 +402 val_402 +396 val_396 +317 val_317 +395 val_395 +58 val_58 +35 val_35 +336 val_336 +95 val_95 +11 val_11 +168 val_168 +34 val_34 +229 val_229 +233 val_233 +143 val_143 +472 val_472 +322 val_322 +498 val_498 +160 val_160 +195 val_195 +42 val_42 +321 val_321 +430 val_430 +119 val_119 +489 val_489 +458 val_458 +78 val_78 +76 val_76 +41 val_41 +223 val_223 +492 val_492 +149 val_149 +449 val_449 +218 val_218 +228 val_228 +138 val_138 +453 val_453 +30 val_30 +209 val_209 +64 val_64 +468 val_468 +76 val_76 +74 val_74 +342 val_342 +69 val_69 +230 val_230 +33 val_33 +368 val_368 +103 val_103 +296 val_296 +113 val_113 +216 val_216 +367 val_367 +344 val_344 +167 val_167 +274 val_274 +219 val_219 +239 val_239 +485 val_485 +116 val_116 +223 val_223 +256 val_256 +263 val_263 +70 val_70 +487 val_487 +480 val_480 +401 val_401 +288 val_288 +191 val_191 +5 val_5 +244 val_244 +438 val_438 +128 val_128 +467 val_467 +432 val_432 +202 val_202 +316 val_316 +229 val_229 +469 val_469 +463 val_463 +280 val_280 +2 val_2 +35 val_35 +283 val_283 +331 val_331 +235 val_235 +80 val_80 +44 val_44 +193 val_193 +321 val_321 +335 val_335 +104 val_104 +466 val_466 +366 val_366 +175 val_175 +403 val_403 +483 val_483 +53 val_53 +105 val_105 +257 val_257 +406 val_406 +409 val_409 +190 val_190 +406 val_406 +401 val_401 +114 val_114 +258 val_258 +90 val_90 +203 val_203 +262 val_262 +348 val_348 +424 val_424 +12 val_12 +396 val_396 +201 val_201 +217 val_217 +164 val_164 +431 val_431 +454 val_454 +478 val_478 +298 val_298 +125 val_125 +431 val_431 +164 val_164 +424 val_424 +187 val_187 +382 val_382 +5 val_5 +70 val_70 +397 val_397 +480 val_480 +291 val_291 +24 val_24 +351 val_351 +255 val_255 +104 val_104 +70 val_70 +163 val_163 +438 val_438 +119 val_119 +414 val_414 +200 val_200 +491 val_491 +237 val_237 +439 val_439 +360 val_360 +248 val_248 +479 val_479 +305 val_305 +417 val_417 +199 val_199 +444 val_444 +120 val_120 +429 val_429 +169 val_169 +443 val_443 +323 val_323 +325 val_325 +277 val_277 +230 val_230 +478 val_478 +178 val_178 +468 val_468 +310 val_310 +317 val_317 +333 val_333 +493 val_493 +460 val_460 +207 val_207 +249 val_249 +265 val_265 +480 val_480 +83 val_83 +136 val_136 +353 val_353 +172 val_172 +214 val_214 +462 val_462 +233 val_233 +406 val_406 +133 val_133 +175 val_175 +189 val_189 +454 val_454 +375 val_375 +401 val_401 +421 val_421 +407 val_407 +384 val_384 +256 val_256 +26 val_26 +134 val_134 +67 val_67 +384 val_384 +379 val_379 +18 val_18 +462 val_462 +492 val_492 +100 val_100 +298 val_298 +9 val_9 +341 val_341 +498 val_498 +146 val_146 +458 val_458 +362 val_362 +186 val_186 +285 val_285 +348 val_348 +167 val_167 +18 val_18 +273 val_273 +183 val_183 +281 val_281 +344 val_344 +97 val_97 +469 val_469 +315 val_315 +84 val_84 +28 val_28 +37 val_37 +448 val_448 +152 val_152 +348 val_348 +307 val_307 +194 val_194 +414 val_414 +477 val_477 +222 val_222 +126 val_126 +90 val_90 +169 val_169 +403 val_403 +400 val_400 +200 val_200 +97 val_97 diff --git a/sql/hive/src/test/resources/golden/transform with custom field delimiter2-0-e8d2b2e60551f69bfb44e555f5cff064 b/sql/hive/src/test/resources/golden/transform with custom field delimiter2-0-e8d2b2e60551f69bfb44e555f5cff064 new file mode 100644 index 000000000000..7aae61e5eb82 --- /dev/null +++ b/sql/hive/src/test/resources/golden/transform with custom field delimiter2-0-e8d2b2e60551f69bfb44e555f5cff064 @@ -0,0 +1,500 @@ +238 val_238 +86 val_86 +311 val_311 +27 val_27 +165 val_165 +409 val_409 +255 val_255 +278 val_278 +98 val_98 +484 val_484 +265 val_265 +193 val_193 +401 val_401 +150 val_150 +273 val_273 +224 val_224 +369 val_369 +66 val_66 +128 val_128 +213 val_213 +146 val_146 +406 val_406 +429 val_429 +374 val_374 +152 val_152 +469 val_469 +145 val_145 +495 val_495 +37 val_37 +327 val_327 +281 val_281 +277 val_277 +209 val_209 +15 val_15 +82 val_82 +403 val_403 +166 val_166 +417 val_417 +430 val_430 +252 val_252 +292 val_292 +219 val_219 +287 val_287 +153 val_153 +193 val_193 +338 val_338 +446 val_446 +459 val_459 +394 val_394 +237 val_237 +482 val_482 +174 val_174 +413 val_413 +494 val_494 +207 val_207 +199 val_199 +466 val_466 +208 val_208 +174 val_174 +399 val_399 +396 val_396 +247 val_247 +417 val_417 +489 val_489 +162 val_162 +377 val_377 +397 val_397 +309 val_309 +365 val_365 +266 val_266 +439 val_439 +342 val_342 +367 val_367 +325 val_325 +167 val_167 +195 val_195 +475 val_475 +17 val_17 +113 val_113 +155 val_155 +203 val_203 +339 val_339 +0 val_0 +455 val_455 +128 val_128 +311 val_311 +316 val_316 +57 val_57 +302 val_302 +205 val_205 +149 val_149 +438 val_438 +345 val_345 +129 val_129 +170 val_170 +20 val_20 +489 val_489 +157 val_157 +378 val_378 +221 val_221 +92 val_92 +111 val_111 +47 val_47 +72 val_72 +4 val_4 +280 val_280 +35 val_35 +427 val_427 +277 val_277 +208 val_208 +356 val_356 +399 val_399 +169 val_169 +382 val_382 +498 val_498 +125 val_125 +386 val_386 +437 val_437 +469 val_469 +192 val_192 +286 val_286 +187 val_187 +176 val_176 +54 val_54 +459 val_459 +51 val_51 +138 val_138 +103 val_103 +239 val_239 +213 val_213 +216 val_216 +430 val_430 +278 val_278 +176 val_176 +289 val_289 +221 val_221 +65 val_65 +318 val_318 +332 val_332 +311 val_311 +275 val_275 +137 val_137 +241 val_241 +83 val_83 +333 val_333 +180 val_180 +284 val_284 +12 val_12 +230 val_230 +181 val_181 +67 val_67 +260 val_260 +404 val_404 +384 val_384 +489 val_489 +353 val_353 +373 val_373 +272 val_272 +138 val_138 +217 val_217 +84 val_84 +348 val_348 +466 val_466 +58 val_58 +8 val_8 +411 val_411 +230 val_230 +208 val_208 +348 val_348 +24 val_24 +463 val_463 +431 val_431 +179 val_179 +172 val_172 +42 val_42 +129 val_129 +158 val_158 +119 val_119 +496 val_496 +0 val_0 +322 val_322 +197 val_197 +468 val_468 +393 val_393 +454 val_454 +100 val_100 +298 val_298 +199 val_199 +191 val_191 +418 val_418 +96 val_96 +26 val_26 +165 val_165 +327 val_327 +230 val_230 +205 val_205 +120 val_120 +131 val_131 +51 val_51 +404 val_404 +43 val_43 +436 val_436 +156 val_156 +469 val_469 +468 val_468 +308 val_308 +95 val_95 +196 val_196 +288 val_288 +481 val_481 +457 val_457 +98 val_98 +282 val_282 +197 val_197 +187 val_187 +318 val_318 +318 val_318 +409 val_409 +470 val_470 +137 val_137 +369 val_369 +316 val_316 +169 val_169 +413 val_413 +85 val_85 +77 val_77 +0 val_0 +490 val_490 +87 val_87 +364 val_364 +179 val_179 +118 val_118 +134 val_134 +395 val_395 +282 val_282 +138 val_138 +238 val_238 +419 val_419 +15 val_15 +118 val_118 +72 val_72 +90 val_90 +307 val_307 +19 val_19 +435 val_435 +10 val_10 +277 val_277 +273 val_273 +306 val_306 +224 val_224 +309 val_309 +389 val_389 +327 val_327 +242 val_242 +369 val_369 +392 val_392 +272 val_272 +331 val_331 +401 val_401 +242 val_242 +452 val_452 +177 val_177 +226 val_226 +5 val_5 +497 val_497 +402 val_402 +396 val_396 +317 val_317 +395 val_395 +58 val_58 +35 val_35 +336 val_336 +95 val_95 +11 val_11 +168 val_168 +34 val_34 +229 val_229 +233 val_233 +143 val_143 +472 val_472 +322 val_322 +498 val_498 +160 val_160 +195 val_195 +42 val_42 +321 val_321 +430 val_430 +119 val_119 +489 val_489 +458 val_458 +78 val_78 +76 val_76 +41 val_41 +223 val_223 +492 val_492 +149 val_149 +449 val_449 +218 val_218 +228 val_228 +138 val_138 +453 val_453 +30 val_30 +209 val_209 +64 val_64 +468 val_468 +76 val_76 +74 val_74 +342 val_342 +69 val_69 +230 val_230 +33 val_33 +368 val_368 +103 val_103 +296 val_296 +113 val_113 +216 val_216 +367 val_367 +344 val_344 +167 val_167 +274 val_274 +219 val_219 +239 val_239 +485 val_485 +116 val_116 +223 val_223 +256 val_256 +263 val_263 +70 val_70 +487 val_487 +480 val_480 +401 val_401 +288 val_288 +191 val_191 +5 val_5 +244 val_244 +438 val_438 +128 val_128 +467 val_467 +432 val_432 +202 val_202 +316 val_316 +229 val_229 +469 val_469 +463 val_463 +280 val_280 +2 val_2 +35 val_35 +283 val_283 +331 val_331 +235 val_235 +80 val_80 +44 val_44 +193 val_193 +321 val_321 +335 val_335 +104 val_104 +466 val_466 +366 val_366 +175 val_175 +403 val_403 +483 val_483 +53 val_53 +105 val_105 +257 val_257 +406 val_406 +409 val_409 +190 val_190 +406 val_406 +401 val_401 +114 val_114 +258 val_258 +90 val_90 +203 val_203 +262 val_262 +348 val_348 +424 val_424 +12 val_12 +396 val_396 +201 val_201 +217 val_217 +164 val_164 +431 val_431 +454 val_454 +478 val_478 +298 val_298 +125 val_125 +431 val_431 +164 val_164 +424 val_424 +187 val_187 +382 val_382 +5 val_5 +70 val_70 +397 val_397 +480 val_480 +291 val_291 +24 val_24 +351 val_351 +255 val_255 +104 val_104 +70 val_70 +163 val_163 +438 val_438 +119 val_119 +414 val_414 +200 val_200 +491 val_491 +237 val_237 +439 val_439 +360 val_360 +248 val_248 +479 val_479 +305 val_305 +417 val_417 +199 val_199 +444 val_444 +120 val_120 +429 val_429 +169 val_169 +443 val_443 +323 val_323 +325 val_325 +277 val_277 +230 val_230 +478 val_478 +178 val_178 +468 val_468 +310 val_310 +317 val_317 +333 val_333 +493 val_493 +460 val_460 +207 val_207 +249 val_249 +265 val_265 +480 val_480 +83 val_83 +136 val_136 +353 val_353 +172 val_172 +214 val_214 +462 val_462 +233 val_233 +406 val_406 +133 val_133 +175 val_175 +189 val_189 +454 val_454 +375 val_375 +401 val_401 +421 val_421 +407 val_407 +384 val_384 +256 val_256 +26 val_26 +134 val_134 +67 val_67 +384 val_384 +379 val_379 +18 val_18 +462 val_462 +492 val_492 +100 val_100 +298 val_298 +9 val_9 +341 val_341 +498 val_498 +146 val_146 +458 val_458 +362 val_362 +186 val_186 +285 val_285 +348 val_348 +167 val_167 +18 val_18 +273 val_273 +183 val_183 +281 val_281 +344 val_344 +97 val_97 +469 val_469 +315 val_315 +84 val_84 +28 val_28 +37 val_37 +448 val_448 +152 val_152 +348 val_348 +307 val_307 +194 val_194 +414 val_414 +477 val_477 +222 val_222 +126 val_126 +90 val_90 +169 val_169 +403 val_403 +400 val_400 +200 val_200 +97 val_97 diff --git a/sql/hive/src/test/resources/golden/transform with custom field delimiter3-0-d4f4f471819345e9ce1964e281ea5289 b/sql/hive/src/test/resources/golden/transform with custom field delimiter3-0-d4f4f471819345e9ce1964e281ea5289 new file mode 100644 index 000000000000..7aae61e5eb82 --- /dev/null +++ b/sql/hive/src/test/resources/golden/transform with custom field delimiter3-0-d4f4f471819345e9ce1964e281ea5289 @@ -0,0 +1,500 @@ +238 val_238 +86 val_86 +311 val_311 +27 val_27 +165 val_165 +409 val_409 +255 val_255 +278 val_278 +98 val_98 +484 val_484 +265 val_265 +193 val_193 +401 val_401 +150 val_150 +273 val_273 +224 val_224 +369 val_369 +66 val_66 +128 val_128 +213 val_213 +146 val_146 +406 val_406 +429 val_429 +374 val_374 +152 val_152 +469 val_469 +145 val_145 +495 val_495 +37 val_37 +327 val_327 +281 val_281 +277 val_277 +209 val_209 +15 val_15 +82 val_82 +403 val_403 +166 val_166 +417 val_417 +430 val_430 +252 val_252 +292 val_292 +219 val_219 +287 val_287 +153 val_153 +193 val_193 +338 val_338 +446 val_446 +459 val_459 +394 val_394 +237 val_237 +482 val_482 +174 val_174 +413 val_413 +494 val_494 +207 val_207 +199 val_199 +466 val_466 +208 val_208 +174 val_174 +399 val_399 +396 val_396 +247 val_247 +417 val_417 +489 val_489 +162 val_162 +377 val_377 +397 val_397 +309 val_309 +365 val_365 +266 val_266 +439 val_439 +342 val_342 +367 val_367 +325 val_325 +167 val_167 +195 val_195 +475 val_475 +17 val_17 +113 val_113 +155 val_155 +203 val_203 +339 val_339 +0 val_0 +455 val_455 +128 val_128 +311 val_311 +316 val_316 +57 val_57 +302 val_302 +205 val_205 +149 val_149 +438 val_438 +345 val_345 +129 val_129 +170 val_170 +20 val_20 +489 val_489 +157 val_157 +378 val_378 +221 val_221 +92 val_92 +111 val_111 +47 val_47 +72 val_72 +4 val_4 +280 val_280 +35 val_35 +427 val_427 +277 val_277 +208 val_208 +356 val_356 +399 val_399 +169 val_169 +382 val_382 +498 val_498 +125 val_125 +386 val_386 +437 val_437 +469 val_469 +192 val_192 +286 val_286 +187 val_187 +176 val_176 +54 val_54 +459 val_459 +51 val_51 +138 val_138 +103 val_103 +239 val_239 +213 val_213 +216 val_216 +430 val_430 +278 val_278 +176 val_176 +289 val_289 +221 val_221 +65 val_65 +318 val_318 +332 val_332 +311 val_311 +275 val_275 +137 val_137 +241 val_241 +83 val_83 +333 val_333 +180 val_180 +284 val_284 +12 val_12 +230 val_230 +181 val_181 +67 val_67 +260 val_260 +404 val_404 +384 val_384 +489 val_489 +353 val_353 +373 val_373 +272 val_272 +138 val_138 +217 val_217 +84 val_84 +348 val_348 +466 val_466 +58 val_58 +8 val_8 +411 val_411 +230 val_230 +208 val_208 +348 val_348 +24 val_24 +463 val_463 +431 val_431 +179 val_179 +172 val_172 +42 val_42 +129 val_129 +158 val_158 +119 val_119 +496 val_496 +0 val_0 +322 val_322 +197 val_197 +468 val_468 +393 val_393 +454 val_454 +100 val_100 +298 val_298 +199 val_199 +191 val_191 +418 val_418 +96 val_96 +26 val_26 +165 val_165 +327 val_327 +230 val_230 +205 val_205 +120 val_120 +131 val_131 +51 val_51 +404 val_404 +43 val_43 +436 val_436 +156 val_156 +469 val_469 +468 val_468 +308 val_308 +95 val_95 +196 val_196 +288 val_288 +481 val_481 +457 val_457 +98 val_98 +282 val_282 +197 val_197 +187 val_187 +318 val_318 +318 val_318 +409 val_409 +470 val_470 +137 val_137 +369 val_369 +316 val_316 +169 val_169 +413 val_413 +85 val_85 +77 val_77 +0 val_0 +490 val_490 +87 val_87 +364 val_364 +179 val_179 +118 val_118 +134 val_134 +395 val_395 +282 val_282 +138 val_138 +238 val_238 +419 val_419 +15 val_15 +118 val_118 +72 val_72 +90 val_90 +307 val_307 +19 val_19 +435 val_435 +10 val_10 +277 val_277 +273 val_273 +306 val_306 +224 val_224 +309 val_309 +389 val_389 +327 val_327 +242 val_242 +369 val_369 +392 val_392 +272 val_272 +331 val_331 +401 val_401 +242 val_242 +452 val_452 +177 val_177 +226 val_226 +5 val_5 +497 val_497 +402 val_402 +396 val_396 +317 val_317 +395 val_395 +58 val_58 +35 val_35 +336 val_336 +95 val_95 +11 val_11 +168 val_168 +34 val_34 +229 val_229 +233 val_233 +143 val_143 +472 val_472 +322 val_322 +498 val_498 +160 val_160 +195 val_195 +42 val_42 +321 val_321 +430 val_430 +119 val_119 +489 val_489 +458 val_458 +78 val_78 +76 val_76 +41 val_41 +223 val_223 +492 val_492 +149 val_149 +449 val_449 +218 val_218 +228 val_228 +138 val_138 +453 val_453 +30 val_30 +209 val_209 +64 val_64 +468 val_468 +76 val_76 +74 val_74 +342 val_342 +69 val_69 +230 val_230 +33 val_33 +368 val_368 +103 val_103 +296 val_296 +113 val_113 +216 val_216 +367 val_367 +344 val_344 +167 val_167 +274 val_274 +219 val_219 +239 val_239 +485 val_485 +116 val_116 +223 val_223 +256 val_256 +263 val_263 +70 val_70 +487 val_487 +480 val_480 +401 val_401 +288 val_288 +191 val_191 +5 val_5 +244 val_244 +438 val_438 +128 val_128 +467 val_467 +432 val_432 +202 val_202 +316 val_316 +229 val_229 +469 val_469 +463 val_463 +280 val_280 +2 val_2 +35 val_35 +283 val_283 +331 val_331 +235 val_235 +80 val_80 +44 val_44 +193 val_193 +321 val_321 +335 val_335 +104 val_104 +466 val_466 +366 val_366 +175 val_175 +403 val_403 +483 val_483 +53 val_53 +105 val_105 +257 val_257 +406 val_406 +409 val_409 +190 val_190 +406 val_406 +401 val_401 +114 val_114 +258 val_258 +90 val_90 +203 val_203 +262 val_262 +348 val_348 +424 val_424 +12 val_12 +396 val_396 +201 val_201 +217 val_217 +164 val_164 +431 val_431 +454 val_454 +478 val_478 +298 val_298 +125 val_125 +431 val_431 +164 val_164 +424 val_424 +187 val_187 +382 val_382 +5 val_5 +70 val_70 +397 val_397 +480 val_480 +291 val_291 +24 val_24 +351 val_351 +255 val_255 +104 val_104 +70 val_70 +163 val_163 +438 val_438 +119 val_119 +414 val_414 +200 val_200 +491 val_491 +237 val_237 +439 val_439 +360 val_360 +248 val_248 +479 val_479 +305 val_305 +417 val_417 +199 val_199 +444 val_444 +120 val_120 +429 val_429 +169 val_169 +443 val_443 +323 val_323 +325 val_325 +277 val_277 +230 val_230 +478 val_478 +178 val_178 +468 val_468 +310 val_310 +317 val_317 +333 val_333 +493 val_493 +460 val_460 +207 val_207 +249 val_249 +265 val_265 +480 val_480 +83 val_83 +136 val_136 +353 val_353 +172 val_172 +214 val_214 +462 val_462 +233 val_233 +406 val_406 +133 val_133 +175 val_175 +189 val_189 +454 val_454 +375 val_375 +401 val_401 +421 val_421 +407 val_407 +384 val_384 +256 val_256 +26 val_26 +134 val_134 +67 val_67 +384 val_384 +379 val_379 +18 val_18 +462 val_462 +492 val_492 +100 val_100 +298 val_298 +9 val_9 +341 val_341 +498 val_498 +146 val_146 +458 val_458 +362 val_362 +186 val_186 +285 val_285 +348 val_348 +167 val_167 +18 val_18 +273 val_273 +183 val_183 +281 val_281 +344 val_344 +97 val_97 +469 val_469 +315 val_315 +84 val_84 +28 val_28 +37 val_37 +448 val_448 +152 val_152 +348 val_348 +307 val_307 +194 val_194 +414 val_414 +477 val_477 +222 val_222 +126 val_126 +90 val_90 +169 val_169 +403 val_403 +400 val_400 +200 val_200 +97 val_97 diff --git a/sql/hive/src/test/resources/golden/udaf_number_format-0-eff4ef3c207d14d5121368f294697964 b/sql/hive/src/test/resources/golden/udaf_number_format-0-eff4ef3c207d14d5121368f294697964 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16 b/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16 new file mode 100644 index 000000000000..c6f275a0db13 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16 @@ -0,0 +1 @@ +0.0 NULL NULL NULL diff --git a/sql/hive/src/test/resources/golden/udf_reflect2-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/udf_reflect2-0-50131c0ba7b7a6b65c789a5a8497bada new file mode 100644 index 000000000000..573541ac9702 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_reflect2-0-50131c0ba7b7a6b65c789a5a8497bada @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/udf_reflect2-1-7bec330c7bc6f71cbaf9bf1883d1b184 b/sql/hive/src/test/resources/golden/udf_reflect2-1-7bec330c7bc6f71cbaf9bf1883d1b184 new file mode 100644 index 000000000000..cd35e5b290db --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_reflect2-1-7bec330c7bc6f71cbaf9bf1883d1b184 @@ -0,0 +1 @@ +reflect2(arg0,method[,arg1[,arg2..]]) calls method of arg0 with reflection diff --git a/sql/hive/src/test/resources/golden/udf_reflect2-2-c5a05379f482215a5a484bed0299bf19 b/sql/hive/src/test/resources/golden/udf_reflect2-2-c5a05379f482215a5a484bed0299bf19 new file mode 100644 index 000000000000..48ef97292ab6 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_reflect2-2-c5a05379f482215a5a484bed0299bf19 @@ -0,0 +1,3 @@ +reflect2(arg0,method[,arg1[,arg2..]]) calls method of arg0 with reflection +Use this UDF to call Java methods by matching the argument signature + diff --git a/sql/hive/src/test/resources/golden/udf_reflect2-3-effc057c78c00b0af26a4ac0f5f116ca b/sql/hive/src/test/resources/golden/udf_reflect2-3-effc057c78c00b0af26a4ac0f5f116ca new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/udf_reflect2-4-73d466e70e96e9e5f0cd373b37d4e1f4 b/sql/hive/src/test/resources/golden/udf_reflect2-4-73d466e70e96e9e5f0cd373b37d4e1f4 new file mode 100644 index 000000000000..176ea0358d7e --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_reflect2-4-73d466e70e96e9e5f0cd373b37d4e1f4 @@ -0,0 +1,5 @@ +238 -18 238 238 238 238.0 238.0 238 val_238 val_238_concat false true false false false val_238 -1 -1 VALUE_238 al_238 al_2 VAL_238 val_238 2013-02-15 19:41:20 113 1 5 19 41 20 1360986080000 +86 86 86 86 86 86.0 86.0 86 val_86 val_86_concat true true true true true val_86 -1 -1 VALUE_86 al_86 al_8 VAL_86 val_86 2013-02-15 19:41:20 113 1 5 19 41 20 1360986080000 +311 55 311 311 311 311.0 311.0 311 val_311 val_311_concat false true false false false val_311 5 6 VALUE_311 al_311 al_3 VAL_311 val_311 2013-02-15 19:41:20 113 1 5 19 41 20 1360986080000 +27 27 27 27 27 27.0 27.0 27 val_27 val_27_concat false true false false false val_27 -1 -1 VALUE_27 al_27 al_2 VAL_27 val_27 2013-02-15 19:41:20 113 1 5 19 41 20 1360986080000 +165 -91 165 165 165 165.0 165.0 165 val_165 val_165_concat false true false false false val_165 4 4 VALUE_165 al_165 al_1 VAL_165 val_165 2013-02-15 19:41:20 113 1 5 19 41 20 1360986080000 diff --git a/sql/hive/src/test/resources/hive-hcatalog-core-0.13.1.jar b/sql/hive/src/test/resources/hive-hcatalog-core-0.13.1.jar new file mode 100644 index 000000000000..37af9aafad8a Binary files /dev/null and b/sql/hive/src/test/resources/hive-hcatalog-core-0.13.1.jar differ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala deleted file mode 100644 index f320d732fb77..000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.scalatest.FunSuite - -import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.util._ - - -/** - * *** DUPLICATED FROM sql/core. *** - * - * It is hard to have maven allow one subproject depend on another subprojects test code. - * So, we duplicate this code here. - */ -class QueryTest extends PlanTest { - - /** - * Runs the plan and makes sure the answer contains all of the keywords, or the - * none of keywords are listed in the answer - * @param rdd the [[SchemaRDD]] to be executed - * @param exists true for make sure the keywords are listed in the output, otherwise - * to make sure none of the keyword are not listed in the output - * @param keywords keyword in string array - */ - def checkExistence(rdd: SchemaRDD, exists: Boolean, keywords: String*) { - val outputs = rdd.collect().map(_.mkString).mkString - for (key <- keywords) { - if (exists) { - assert(outputs.contains(key), s"Failed for $rdd ($key doens't exist in result)") - } else { - assert(!outputs.contains(key), s"Failed for $rdd ($key existed in the result)") - } - } - } - - /** - * Runs the plan and makes sure the answer matches the expected result. - * @param rdd the [[SchemaRDD]] to be executed - * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ]. - */ - protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = { - val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty - def prepareAnswer(answer: Seq[Row]): Seq[Row] = { - // Converts data to types that we can do equality comparison using Scala collections. - // For BigDecimal type, the Scala type has a better definition of equality test (similar to - // Java's java.math.BigDecimal.compareTo). - val converted: Seq[Row] = answer.map { s => - Row.fromSeq(s.toSeq.map { - case d: java.math.BigDecimal => BigDecimal(d) - case o => o - }) - } - if (!isSorted) converted.sortBy(_.toString) else converted - } - val sparkAnswer = try rdd.collect().toSeq catch { - case e: Exception => - fail( - s""" - |Exception thrown while executing query: - |${rdd.queryExecution} - |== Exception == - |$e - |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} - """.stripMargin) - } - - if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { - fail(s""" - |Results do not match for query: - |${rdd.logicalPlan} - |== Analyzed Plan == - |${rdd.queryExecution.analyzed} - |== Physical Plan == - |${rdd.queryExecution.executedPlan} - |== Results == - |${sideBySide( - s"== Correct Answer - ${expectedAnswer.size} ==" +: - prepareAnswer(expectedAnswer).map(_.toString), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} - """.stripMargin) - } - } - - protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = { - checkAnswer(rdd, Seq(expectedAnswer)) - } - - def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = { - test(sqlString) { - checkAnswer(sqlContext.sql(sqlString), expectedAnswer) - } - } - -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala deleted file mode 100644 index 081d94b6fc02..000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.plans - -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExprId} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.util._ -import org.scalatest.FunSuite - -/** - * *** DUPLICATED FROM sql/catalyst/plans. *** - * - * It is hard to have maven allow one subproject depend on another subprojects test code. - * So, we duplicate this code here. - */ -class PlanTest extends FunSuite { - - /** - * Since attribute references are given globally unique ids during analysis, - * we must normalize them to check if two different queries are identical. - */ - protected def normalizeExprIds(plan: LogicalPlan) = { - val list = plan.flatMap(_.expressions.flatMap(_.references).map(_.exprId.id)) - val minId = if (list.isEmpty) 0 else list.min - plan transformAllExpressions { - case a: AttributeReference => - AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(a.exprId.id - minId)) - } - } - - /** Fails the test if the two plans do not match */ - protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { - val normalized1 = normalizeExprIds(plan1) - val normalized2 = normalizeExprIds(plan2) - if (normalized1 != normalized2) - fail( - s""" - |== FAIL: Plans do not match === - |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} - """.stripMargin) - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index f95a6b43af35..fc6c3c35037b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -17,28 +17,16 @@ package org.apache.spark.sql.hive +import java.io.File + import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{QueryTest, SchemaRDD} +import org.apache.spark.sql.{SaveMode, AnalysisException, DataFrame, QueryTest} import org.apache.spark.storage.RDDBlockId +import org.apache.spark.util.Utils class CachedTableSuite extends QueryTest { - /** - * Throws a test failed exception when the number of cached tables differs from the expected - * number. - */ - def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { - val planWithCaching = query.queryExecution.withCachedData - val cachedData = planWithCaching collect { - case cached: InMemoryRelation => cached - } - - assert( - cachedData.size == numCachedTables, - s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + - planWithCaching) - } def rddIdOf(tableName: String): Int = { val executedPlan = table(tableName).queryExecution.executedPlan @@ -64,6 +52,12 @@ class CachedTableSuite extends QueryTest { sql("SELECT * FROM src"), preCacheResults) + assertCached(sql("SELECT * FROM src s")) + + checkAnswer( + sql("SELECT * FROM src s"), + preCacheResults) + uncacheTable("src") assertCached(sql("SELECT * FROM src"), 0) } @@ -86,12 +80,12 @@ class CachedTableSuite extends QueryTest { } test("Drop cached table") { - sql("CREATE TABLE test(a INT)") - cacheTable("test") - sql("SELECT * FROM test").collect() - sql("DROP TABLE test") - intercept[org.apache.hadoop.hive.ql.metadata.InvalidTableException] { - sql("SELECT * FROM test").collect() + sql("CREATE TABLE cachedTableTest(a INT)") + cacheTable("cachedTableTest") + sql("SELECT * FROM cachedTableTest").collect() + sql("DROP TABLE cachedTableTest") + intercept[AnalysisException] { + sql("SELECT * FROM cachedTableTest").collect() } } @@ -164,4 +158,49 @@ class CachedTableSuite extends QueryTest { assertCached(table("udfTest")) uncacheTable("udfTest") } + + test("REFRESH TABLE also needs to recache the data (data source tables)") { + val tempPath: File = Utils.createTempDir() + tempPath.delete() + table("src").save(tempPath.toString, "parquet", SaveMode.Overwrite) + sql("DROP TABLE IF EXISTS refreshTable") + createExternalTable("refreshTable", tempPath.toString, "parquet") + checkAnswer( + table("refreshTable"), + table("src").collect()) + // Cache the table. + sql("CACHE TABLE refreshTable") + assertCached(table("refreshTable")) + // Append new data. + table("src").save(tempPath.toString, "parquet", SaveMode.Append) + // We are still using the old data. + assertCached(table("refreshTable")) + checkAnswer( + table("refreshTable"), + table("src").collect()) + // Refresh the table. + sql("REFRESH TABLE refreshTable") + // We are using the new data. + assertCached(table("refreshTable")) + checkAnswer( + table("refreshTable"), + table("src").unionAll(table("src")).collect()) + + // Drop the table and create it again. + sql("DROP TABLE refreshTable") + createExternalTable("refreshTable", tempPath.toString, "parquet") + // It is not cached. + assert(!isCached("refreshTable"), "refreshTable should not be cached.") + // Refresh the table. REFRESH TABLE command should not make a uncached + // table cached. + sql("REFRESH TABLE refreshTable") + checkAnswer( + table("refreshTable"), + table("src").unionAll(table("src")).collect()) + // It is not cached. + assert(!isCached("refreshTable"), "refreshTable should not be cached.") + + sql("DROP TABLE refreshTable") + Utils.deleteRecursively(tempPath) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala new file mode 100644 index 000000000000..d960a30e0073 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.{OutputStream, PrintStream} + +import scala.util.Try + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.{AnalysisException, QueryTest} + + +class ErrorPositionSuite extends QueryTest with BeforeAndAfter { + + before { + Seq((1, 1, 1)).toDF("a", "a", "b").registerTempTable("dupAttributes") + } + + positionTest("ambiguous attribute reference 1", + "SELECT a from dupAttributes", "a") + + positionTest("ambiguous attribute reference 2", + "SELECT a, b from dupAttributes", "a") + + positionTest("ambiguous attribute reference 3", + "SELECT b, a from dupAttributes", "a") + + positionTest("unresolved attribute 1", + "SELECT x FROM src", "x") + + positionTest("unresolved attribute 2", + "SELECT x FROM src", "x") + + positionTest("unresolved attribute 3", + "SELECT key, x FROM src", "x") + + positionTest("unresolved attribute 4", + """SELECT key, + |x FROM src + """.stripMargin, "x") + + positionTest("unresolved attribute 5", + """SELECT key, + | x FROM src + """.stripMargin, "x") + + positionTest("unresolved attribute 6", + """SELECT key, + | + | 1 + x FROM src + """.stripMargin, "x") + + positionTest("unresolved attribute 7", + """SELECT key, + | + | 1 + x + 1 FROM src + """.stripMargin, "x") + + positionTest("multi-char unresolved attribute", + """SELECT key, + | + | 1 + abcd + 1 FROM src + """.stripMargin, "abcd") + + positionTest("unresolved attribute group by", + """SELECT key FROM src GROUP BY + |x + """.stripMargin, "x") + + positionTest("unresolved attribute order by", + """SELECT key FROM src ORDER BY + |x + """.stripMargin, "x") + + positionTest("unresolved attribute where", + """SELECT key FROM src + |WHERE x = true + """.stripMargin, "x") + + positionTest("unresolved attribute backticks", + "SELECT `x` FROM src", "`x`") + + positionTest("parse error", + "SELECT WHERE", "WHERE") + + positionTest("bad relation", + "SELECT * FROM badTable", "badTable") + + ignore("other expressions") { + positionTest("bad addition", + "SELECT 1 + array(1)", "1 + array") + } + + /** Hive can be very noisy, messing up the output of our tests. */ + private def quietly[A](f: => A): A = { + val origErr = System.err + val origOut = System.out + try { + System.setErr(new PrintStream(new OutputStream { + def write(b: Int) = {} + })) + System.setOut(new PrintStream(new OutputStream { + def write(b: Int) = {} + })) + + f + } finally { + System.setErr(origErr) + System.setOut(origOut) + } + } + + /** + * Creates a test that checks to see if the error thrown when analyzing a given query includes + * the location of the given token in the query string. + * + * @param name the name of the test + * @param query the query to analyze + * @param token a unique token in the string that should be indicated by the exception + */ + def positionTest(name: String, query: String, token: String): Unit = { + def parseTree = + Try(quietly(HiveQl.dumpTree(HiveQl.getAst(query)))).getOrElse("") + + test(name) { + val error = intercept[AnalysisException] { + quietly(sql(query)) + } + + assert(!error.getMessage.contains("Seq(")) + assert(!error.getMessage.contains("List(")) + + val (line, expectedLineNum) = query.split("\n").zipWithIndex.collect { + case (l, i) if l.contains(token) => (l, i + 1) + }.headOption.getOrElse(sys.error(s"Invalid test. Token $token not in $query")) + val actualLine = error.line.getOrElse { + fail( + s"line not returned for error '${error.getMessage}' on token $token\n$parseTree" + ) + } + assert(actualLine === expectedLineNum, "wrong line") + + val expectedStart = line.indexOf(token) + val actualStart = error.startPosition.getOrElse { + fail( + s"start not returned for error on token $token\n" + + HiveQl.dumpTree(HiveQl.getAst(query)) + ) + } + assert(expectedStart === actualStart, + s"""Incorrect start position. + |== QUERY == + |$query + | + |== AST == + |$parseTree + | + |Actual: $actualStart, Expected: $expectedStart + |$line + |${" " * actualStart}^ + |0123456789 123456789 1234567890 + | 2 3 + """.stripMargin) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 2d3ff680125a..2a7374cc172b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.hive import java.util -import java.sql.Date import java.util.{Locale, TimeZone} import org.apache.hadoop.hive.ql.udf.UDAFPercentile @@ -76,13 +75,13 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { Literal(0.asInstanceOf[Float]) :: Literal(0.asInstanceOf[Double]) :: Literal("0") :: - Literal(new Date(2014, 9, 23)) :: + Literal(java.sql.Date.valueOf("2014-09-23")) :: Literal(Decimal(BigDecimal(123.123))) :: Literal(new java.sql.Timestamp(123123)) :: Literal(Array[Byte](1,2,3)) :: - Literal(Seq[Int](1,2,3), ArrayType(IntegerType)) :: - Literal(Map[Int, Int](1->2, 2->1), MapType(IntegerType, IntegerType)) :: - Literal(Row(1,2.0d,3.0f), + Literal.create(Seq[Int](1,2,3), ArrayType(IntegerType)) :: + Literal.create(Map[Int, Int](1->2, 2->1), MapType(IntegerType, IntegerType)) :: + Literal.create(Row(1,2.0d,3.0f), StructType(StructField("c1", IntegerType) :: StructField("c2", DoubleType) :: StructField("c3", FloatType) :: Nil)) :: @@ -117,21 +116,20 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { } def checkDataType(dt1: Seq[DataType], dt2: Seq[DataType]): Unit = { - dt1.zip(dt2).map { - case (dd1, dd2) => - assert(dd1.getClass === dd2.getClass) // DecimalType doesn't has the default precision info + dt1.zip(dt2).foreach { case (dd1, dd2) => + assert(dd1.getClass === dd2.getClass) // DecimalType doesn't has the default precision info } } def checkValues(row1: Seq[Any], row2: Seq[Any]): Unit = { - row1.zip(row2).map { - case (r1, r2) => checkValue(r1, r2) + row1.zip(row2).foreach { case (r1, r2) => + checkValue(r1, r2) } } def checkValues(row1: Seq[Any], row2: Row): Unit = { - row1.zip(row2.toSeq).map { - case (r1, r2) => checkValue(r1, r2) + row1.zip(row2.toSeq).foreach { case (r1, r2) => + checkValue(r1, r2) } } @@ -142,8 +140,7 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { assert(r1.compare(r2) === 0) case (r1: Array[Byte], r2: Array[Byte]) if r1 != null && r2 != null && r1.length == r2.length => - r1.zip(r2).map { case (b1, b2) => assert(b1 === b2) } - case (r1: Date, r2: Date) => assert(r1.compareTo(r2) === 0) + r1.zip(r2).foreach { case (b1, b2) => assert(b1 === b2) } case (r1, r2) => assert(r1 === r2) } } @@ -168,7 +165,8 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { val constantData = constantExprs.map(_.eval()) val constantNullData = constantData.map(_ => null) val constantWritableOIs = constantExprs.map(e => toWritableInspector(e.dataType)) - val constantNullWritableOIs = constantExprs.map(e => toInspector(Literal(null, e.dataType))) + val constantNullWritableOIs = + constantExprs.map(e => toInspector(Literal.create(null, e.dataType))) checkValues(constantData, constantData.zip(constantWritableOIs).map { case (d, oi) => unwrap(wrap(d, oi), oi) @@ -204,7 +202,8 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { case (t, idx) => StructField(s"c_$idx", t) }) - checkValues(row, unwrap(wrap(Row.fromSeq(row), toInspector(dt)), toInspector(dt)).asInstanceOf[Row]) + checkValues(row, + unwrap(wrap(Row.fromSeq(row), toInspector(dt)), toInspector(dt)).asInstanceOf[Row]) checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) } @@ -214,8 +213,10 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { val d = row(0) :: row(0) :: Nil checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) - checkValue(d, unwrap(wrap(d, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) - checkValue(d, unwrap(wrap(null, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) + checkValue(d, + unwrap(wrap(d, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + checkValue(d, + unwrap(wrap(null, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) } test("wrap / unwrap Map Type") { @@ -224,7 +225,9 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { val d = Map(row(0) -> row(1)) checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) - checkValue(d, unwrap(wrap(d, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) - checkValue(d, unwrap(wrap(null, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) + checkValue(d, + unwrap(wrap(d, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + checkValue(d, + unwrap(wrap(null, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index aad48ada5264..fa8e11ffec2b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.hive +import org.apache.spark.sql.hive.test.TestHive import org.scalatest.FunSuite import org.apache.spark.sql.test.ExamplePointUDT @@ -36,4 +37,11 @@ class HiveMetastoreCatalogSuite extends FunSuite { assert(HiveMetastoreTypes.toMetastoreType(udt) === HiveMetastoreTypes.toMetastoreType(udt.sqlType)) } + + test("duplicated metastore relations") { + import TestHive.implicits._ + val df = TestHive.sql("SELECT * FROM src") + println(df.queryExecution) + df.as('a).join(df.as('b), $"a.key" === $"b.key") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala new file mode 100644 index 000000000000..7ff5719adb3a --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.parquet.ParquetTest +import org.apache.spark.sql.{QueryTest, SQLConf} + +case class Cases(lower: String, UPPER: String) + +class HiveParquetSuite extends QueryTest with ParquetTest { + val sqlContext = TestHive + + import sqlContext._ + + def run(prefix: String): Unit = { + test(s"$prefix: Case insensitive attribute names") { + withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { + val expected = (1 to 4).map(i => Row(i.toString)) + checkAnswer(sql("SELECT upper FROM cases"), expected) + checkAnswer(sql("SELECT LOWER FROM cases"), expected) + } + } + + test(s"$prefix: SELECT on Parquet table") { + val data = (1 to 4).map(i => (i, s"val_$i")) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT * FROM t"), data.map(Row.fromTuple)) + } + } + + test(s"$prefix: Simple column projection + filter on Parquet table") { + withParquetTable((1 to 4).map(i => (i % 2 == 0, i, s"val_$i")), "t") { + checkAnswer( + sql("SELECT `_1`, `_3` FROM t WHERE `_1` = true"), + Seq(Row(true, "val_2"), Row(true, "val_4"))) + } + } + + test(s"$prefix: Converting Hive to Parquet Table via saveAsParquetFile") { + withTempPath { dir => + sql("SELECT * FROM src").saveAsParquetFile(dir.getCanonicalPath) + parquetFile(dir.getCanonicalPath).registerTempTable("p") + withTempTable("p") { + checkAnswer( + sql("SELECT * FROM src ORDER BY key"), + sql("SELECT * from p ORDER BY key").collect().toSeq) + } + } + } + + test(s"$prefix: INSERT OVERWRITE TABLE Parquet table") { + withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") { + withTempPath { file => + sql("SELECT * FROM t LIMIT 1").saveAsParquetFile(file.getCanonicalPath) + parquetFile(file.getCanonicalPath).registerTempTable("p") + withTempTable("p") { + // let's do three overwrites for good measure + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + checkAnswer(sql("SELECT * FROM p"), sql("SELECT * FROM t").collect().toSeq) + } + } + } + } + } + + withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { + run("Parquet data source enabled") + } + + withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "false") { + run("Parquet data source disabled") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 0e6636d38ed3..ecb990e8aac9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -19,23 +19,38 @@ package org.apache.spark.sql.hive import java.io.File -import com.google.common.io.Files +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.{QueryTest, _} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /* Implicits */ import org.apache.spark.sql.hive.test.TestHive._ case class TestData(key: Int, value: String) -class InsertIntoHiveTableSuite extends QueryTest { +case class ThreeCloumntable(key: Int, value: String, key1: String) + +class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { + import org.apache.spark.sql.hive.test.TestHive.implicits._ + + val testData = TestHive.sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString))) - testData.registerTempTable("testData") + (1 to 100).map(i => TestData(i, i.toString))).toDF() + + before { + // Since every we are doing tests for DDL statements, + // it is better to reset before every test. + TestHive.reset() + // Register the testData, which will be used in every test. + testData.registerTempTable("testData") + } test("insertInto() HiveTable") { - createTable[TestData]("createAndInsertTest") + sql("CREATE TABLE createAndInsertTest (key int, value string)") // Add some data. testData.insertInto("createAndInsertTest") @@ -43,7 +58,7 @@ class InsertIntoHiveTableSuite extends QueryTest { // Make sure the table has also been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.collect().toSeq.map(Row.fromTuple) + testData.collect().toSeq ) // Add more data. @@ -52,7 +67,7 @@ class InsertIntoHiveTableSuite extends QueryTest { // Make sure the table has been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.toSchemaRDD.collect().toSeq ++ testData.toSchemaRDD.collect().toSeq + testData.toDF().collect().toSeq ++ testData.toDF().collect().toSeq ) // Now overwrite. @@ -61,29 +76,31 @@ class InsertIntoHiveTableSuite extends QueryTest { // Make sure the registered table has also been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.collect().toSeq.map(Row.fromTuple) + testData.collect().toSeq ) } test("Double create fails when allowExisting = false") { - createTable[TestData]("doubleCreateAndInsertTest") + sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") - intercept[org.apache.hadoop.hive.ql.metadata.HiveException] { - createTable[TestData]("doubleCreateAndInsertTest", allowExisting = false) - } + val message = intercept[QueryExecutionException] { + sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") + }.getMessage + + println("message!!!!" + message) } test("Double create does not fail when allowExisting = true") { - createTable[TestData]("createAndInsertTest") - createTable[TestData]("createAndInsertTest") + sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") + sql("CREATE TABLE IF NOT EXISTS doubleCreateAndInsertTest (key int, value string)") } test("SPARK-4052: scala.collection.Map as value type of MapType") { val schema = StructType(StructField("m", MapType(StringType, StringType), true) :: Nil) val rowRDD = TestHive.sparkContext.parallelize( (1 to 100).map(i => Row(scala.collection.mutable.HashMap(s"key$i" -> s"value$i")))) - val schemaRDD = applySchema(rowRDD, schema) - schemaRDD.registerTempTable("tableWithMapValue") + val df = TestHive.createDataFrame(rowRDD, schema) + df.registerTempTable("tableWithMapValue") sql("CREATE TABLE hiveTableWithMapValue(m MAP )") sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") @@ -96,13 +113,38 @@ class InsertIntoHiveTableSuite extends QueryTest { } test("SPARK-4203:random partition directory order") { - createTable[TestData]("tmp_table") - val tmpDir = Files.createTempDir() - sql(s"CREATE TABLE table_with_partition(c1 string) PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string) location '${tmpDir.toURI.toString}' ") - sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='1') SELECT 'blarr' FROM tmp_table") - sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='2') SELECT 'blarr' FROM tmp_table") - sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='3') SELECT 'blarr' FROM tmp_table") - sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='4') SELECT 'blarr' FROM tmp_table") + sql("CREATE TABLE tmp_table (key int, value string)") + val tmpDir = Utils.createTempDir() + sql( + s""" + |CREATE TABLE table_with_partition(c1 string) + |PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string) + |location '${tmpDir.toURI.toString}' + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table_with_partition + |partition (p1='a',p2='b',p3='c',p4='c',p5='1') + |SELECT 'blarr' FROM tmp_table + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table_with_partition + |partition (p1='a',p2='b',p3='c',p4='c',p5='2') + |SELECT 'blarr' FROM tmp_table + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table_with_partition + |partition (p1='a',p2='b',p3='c',p4='c',p5='3') + |SELECT 'blarr' FROM tmp_table + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table_with_partition + |partition (p1='a',p2='b',p3='c',p4='c',p5='4') + |SELECT 'blarr' FROM tmp_table + """.stripMargin) def listFolders(path: File, acc: List[String]): List[List[String]] = { val dir = path.listFiles() val folders = dir.filter(_.isDirectory).toList @@ -127,8 +169,8 @@ class InsertIntoHiveTableSuite extends QueryTest { val schema = StructType(Seq( StructField("a", ArrayType(StringType, containsNull = false)))) val rowRDD = TestHive.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i")))) - val schemaRDD = applySchema(rowRDD, schema) - schemaRDD.registerTempTable("tableWithArrayValue") + val df = TestHive.createDataFrame(rowRDD, schema) + df.registerTempTable("tableWithArrayValue") sql("CREATE TABLE hiveTableWithArrayValue(a Array )") sql("INSERT OVERWRITE TABLE hiveTableWithArrayValue SELECT a FROM tableWithArrayValue") @@ -144,8 +186,8 @@ class InsertIntoHiveTableSuite extends QueryTest { StructField("m", MapType(StringType, StringType, valueContainsNull = false)))) val rowRDD = TestHive.sparkContext.parallelize( (1 to 100).map(i => Row(Map(s"key$i" -> s"value$i")))) - val schemaRDD = applySchema(rowRDD, schema) - schemaRDD.registerTempTable("tableWithMapValue") + val df = TestHive.createDataFrame(rowRDD, schema) + df.registerTempTable("tableWithMapValue") sql("CREATE TABLE hiveTableWithMapValue(m Map )") sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") @@ -161,8 +203,8 @@ class InsertIntoHiveTableSuite extends QueryTest { StructField("s", StructType(Seq(StructField("f", StringType, nullable = false)))))) val rowRDD = TestHive.sparkContext.parallelize( (1 to 100).map(i => Row(Row(s"value$i")))) - val schemaRDD = applySchema(rowRDD, schema) - schemaRDD.registerTempTable("tableWithStructValue") + val df = TestHive.createDataFrame(rowRDD, schema) + df.registerTempTable("tableWithStructValue") sql("CREATE TABLE hiveTableWithStructValue(s Struct )") sql("INSERT OVERWRITE TABLE hiveTableWithStructValue SELECT s FROM tableWithStructValue") @@ -172,4 +214,51 @@ class InsertIntoHiveTableSuite extends QueryTest { sql("DROP TABLE hiveTableWithStructValue") } + + test("SPARK-5498:partition schema does not match table schema") { + val testData = TestHive.sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString))).toDF() + testData.registerTempTable("testData") + + val testDatawithNull = TestHive.sparkContext.parallelize( + (1 to 10).map(i => ThreeCloumntable(i, i.toString, null))).toDF() + + val tmpDir = Utils.createTempDir() + sql( + s""" + |CREATE TABLE table_with_partition(key int,value string) + |PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table_with_partition + |partition (ds='1') SELECT key,value FROM testData + """.stripMargin) + + // test schema the same between partition and table + sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") + checkAnswer(sql("select key,value from table_with_partition where ds='1' "), + testData.collect().toSeq + ) + + // test difference type of field + sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") + checkAnswer(sql("select key,value from table_with_partition where ds='1' "), + testData.collect().toSeq + ) + + // add column to table + sql("ALTER TABLE table_with_partition ADD COLUMNS(key1 string)") + checkAnswer(sql("select key,value,key1 from table_with_partition where ds='1' "), + testDatawithNull.collect().toSeq + ) + + // change column name to table + sql("ALTER TABLE table_with_partition CHANGE COLUMN key keynew BIGINT") + checkAnswer(sql("select keynew,value from table_with_partition where ds='1' "), + testData.collect().toSeq + ) + + sql("DROP TABLE table_with_partition") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala new file mode 100644 index 000000000000..e12a6c21ccac --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -0,0 +1,81 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.hive + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.Row + +class ListTablesSuite extends QueryTest with BeforeAndAfterAll { + + import org.apache.spark.sql.hive.test.TestHive.implicits._ + + val df = + sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDF("key", "value") + + override def beforeAll(): Unit = { + // The catalog in HiveContext is a case insensitive one. + catalog.registerTable(Seq("ListTablesSuiteTable"), df.logicalPlan) + catalog.registerTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"), df.logicalPlan) + sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)") + sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB") + sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)") + } + + override def afterAll(): Unit = { + catalog.unregisterTable(Seq("ListTablesSuiteTable")) + catalog.unregisterTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable")) + sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable") + sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable") + sql("DROP DATABASE IF EXISTS ListTablesSuiteDB") + } + + test("get all tables of current database") { + Seq(tables(), sql("SHOW TABLes")).foreach { + case allTables => + // We are using default DB. + checkAnswer( + allTables.filter("tableName = 'listtablessuitetable'"), + Row("listtablessuitetable", true)) + assert(allTables.filter("tableName = 'indblisttablessuitetable'").count() === 0) + checkAnswer( + allTables.filter("tableName = 'hivelisttablessuitetable'"), + Row("hivelisttablessuitetable", false)) + assert(allTables.filter("tableName = 'hiveindblisttablessuitetable'").count() === 0) + } + } + + test("getting all tables with a database name") { + Seq(tables("listtablessuiteDb"), sql("SHOW TABLes in listTablesSuitedb")).foreach { + case allTables => + checkAnswer( + allTables.filter("tableName = 'listtablessuitetable'"), + Row("listtablessuitetable", true)) + checkAnswer( + allTables.filter("tableName = 'indblisttablessuitetable'"), + Row("indblisttablessuitetable", true)) + assert(allTables.filter("tableName = 'hivelisttablessuitetable'").count() === 0) + checkAnswer( + allTables.filter("tableName = 'hiveindblisttablessuitetable'"), + Row("hiveindblisttablessuitetable", false)) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 7408c7ffd69e..e09c702c8969 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -19,26 +19,37 @@ package org.apache.spark.sql.hive import java.io.File +import scala.collection.mutable.ArrayBuffer + import org.scalatest.BeforeAndAfterEach import org.apache.commons.io.FileUtils +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.metastore.TableType +import org.apache.hadoop.hive.ql.metadata.Table +import org.apache.hadoop.mapred.InvalidInputException import org.apache.spark.sql._ import org.apache.spark.util.Utils import org.apache.spark.sql.types._ - -/* Implicits */ import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.parquet.ParquetRelation2 +import org.apache.spark.sql.sources.LogicalRelation /** * Tests for persisting tables created though the data sources API into the metastore. */ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { + override def afterEach(): Unit = { reset() + Utils.deleteRecursively(tempPath) } val filePath = Utils.getSparkClassLoader.getResource("sample.json").getFile + var tempPath: File = Utils.createTempDir() + tempPath.delete() test ("persistent JSON table") { sql( @@ -94,7 +105,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { StructField("", innerStruct, true) :: StructField("b", StringType, true) :: Nil) - assert(expectedSchema == table("jsonTable").schema) + assert(expectedSchema === table("jsonTable").schema) jsonFile(filePath).registerTempTable("expectedJsonTable") @@ -137,12 +148,18 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { intercept[Exception] { sql("SELECT * FROM jsonTable").collect() } + + assert( + (new File(filePath)).exists(), + "The table with specified path is considered as an external table, " + + "its data should not deleted after DROP TABLE.") } test("check change without refresh") { - val tempDir = File.createTempFile("sparksql", "json") + val tempDir = File.createTempFile("sparksql", "json", Utils.createTempDir()) tempDir.delete() - sparkContext.parallelize(("a", "b") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath) + sparkContext.parallelize(("a", "b") :: Nil).toDF() + .toJSON.saveAsTextFile(tempDir.getCanonicalPath) sql( s""" @@ -158,7 +175,8 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { Row("a", "b")) FileUtils.deleteDirectory(tempDir) - sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath) + sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toDF() + .toJSON.saveAsTextFile(tempDir.getCanonicalPath) // Schema is cached so the new column does not show. The updated values in existing columns // will show. @@ -166,7 +184,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { sql("SELECT * FROM jsonTable"), Row("a1", "b1")) - refreshTable("jsonTable") + sql("REFRESH TABLE jsonTable") // Check that the refresh worked checkAnswer( @@ -176,9 +194,10 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { } test("drop, change, recreate") { - val tempDir = File.createTempFile("sparksql", "json") + val tempDir = File.createTempFile("sparksql", "json", Utils.createTempDir()) tempDir.delete() - sparkContext.parallelize(("a", "b") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath) + sparkContext.parallelize(("a", "b") :: Nil).toDF() + .toJSON.saveAsTextFile(tempDir.getCanonicalPath) sql( s""" @@ -194,7 +213,8 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { Row("a", "b")) FileUtils.deleteDirectory(tempDir) - sparkContext.parallelize(("a", "b", "c") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath) + sparkContext.parallelize(("a", "b", "c") :: Nil).toDF() + .toJSON.saveAsTextFile(tempDir.getCanonicalPath) sql("DROP TABLE jsonTable") @@ -240,7 +260,139 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { invalidateTable("jsonTable") val expectedSchema = StructType(StructField("c_!@(3)", IntegerType, true) :: Nil) - assert(expectedSchema == table("jsonTable").schema) + assert(expectedSchema === table("jsonTable").schema) + } + + test("CTAS") { + sql( + s""" + |CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${filePath}' + |) + """.stripMargin) + + sql( + s""" + |CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${tempPath}' + |) AS + |SELECT * FROM jsonTable + """.stripMargin) + + assert(table("ctasJsonTable").schema === table("jsonTable").schema) + + checkAnswer( + sql("SELECT * FROM ctasJsonTable"), + sql("SELECT * FROM jsonTable").collect()) + } + + test("CTAS with IF NOT EXISTS") { + sql( + s""" + |CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${filePath}' + |) + """.stripMargin) + + sql( + s""" + |CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${tempPath}' + |) AS + |SELECT * FROM jsonTable + """.stripMargin) + + // Create the table again should trigger a AnalysisException. + val message = intercept[AnalysisException] { + sql( + s""" + |CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${tempPath}' + |) AS + |SELECT * FROM jsonTable + """.stripMargin) + }.getMessage + assert(message.contains("Table ctasJsonTable already exists."), + "We should complain that ctasJsonTable already exists") + + // The following statement should be fine if it has IF NOT EXISTS. + // It tries to create a table ctasJsonTable with a new schema. + // The actual table's schema and data should not be changed. + sql( + s""" + |CREATE TABLE IF NOT EXISTS ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${tempPath}' + |) AS + |SELECT a FROM jsonTable + """.stripMargin) + + // Discard the cached relation. + invalidateTable("ctasJsonTable") + + // Schema should not be changed. + assert(table("ctasJsonTable").schema === table("jsonTable").schema) + // Table data should not be changed. + checkAnswer( + sql("SELECT * FROM ctasJsonTable"), + sql("SELECT * FROM jsonTable").collect()) + } + + test("CTAS a managed table") { + sql( + s""" + |CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${filePath}' + |) + """.stripMargin) + + val expectedPath = catalog.hiveDefaultTableFilePath("ctasJsonTable") + val filesystemPath = new Path(expectedPath) + val fs = filesystemPath.getFileSystem(sparkContext.hadoopConfiguration) + if (fs.exists(filesystemPath)) fs.delete(filesystemPath, true) + + // It is a managed table when we do not specify the location. + sql( + s""" + |CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |AS + |SELECT * FROM jsonTable + """.stripMargin) + + assert(fs.exists(filesystemPath), s"$expectedPath should exist after we create the table.") + + sql( + s""" + |CREATE TABLE loadedTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${expectedPath}' + |) + """.stripMargin) + + assert(table("ctasJsonTable").schema === table("loadedTable").schema) + + checkAnswer( + sql("SELECT * FROM ctasJsonTable"), + sql("SELECT * FROM loadedTable").collect() + ) + + sql("DROP TABLE ctasJsonTable") + assert(!fs.exists(filesystemPath), s"$expectedPath should not exist after we drop the table.") } test("SPARK-5286 Fail to drop an invalid table when using the data source API") { @@ -255,4 +407,353 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { sql("DROP TABLE jsonTable").collect().foreach(println) } + + test("SPARK-5839 HiveMetastoreCatalog does not recognize table aliases of data source tables.") { + val originalDefaultSource = conf.defaultDataSourceName + + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) + val df = jsonRDD(rdd) + + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + // Save the df as a managed table (by not specifiying the path). + df.saveAsTable("savedJsonTable") + + checkAnswer( + sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), + (1 to 4).map(i => Row(i, s"str${i}"))) + + checkAnswer( + sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), + (6 to 10).map(i => Row(i, s"str${i}"))) + + invalidateTable("savedJsonTable") + + checkAnswer( + sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), + (1 to 4).map(i => Row(i, s"str${i}"))) + + checkAnswer( + sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), + (6 to 10).map(i => Row(i, s"str${i}"))) + + // Drop table will also delete the data. + sql("DROP TABLE savedJsonTable") + + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + } + + test("save table") { + val originalDefaultSource = conf.defaultDataSourceName + + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) + val df = jsonRDD(rdd) + + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + // Save the df as a managed table (by not specifiying the path). + df.saveAsTable("savedJsonTable") + + checkAnswer( + sql("SELECT * FROM savedJsonTable"), + df.collect()) + + // Right now, we cannot append to an existing JSON table. + intercept[RuntimeException] { + df.saveAsTable("savedJsonTable", SaveMode.Append) + } + + // We can overwrite it. + df.saveAsTable("savedJsonTable", SaveMode.Overwrite) + checkAnswer( + sql("SELECT * FROM savedJsonTable"), + df.collect()) + + // When the save mode is Ignore, we will do nothing when the table already exists. + df.select("b").saveAsTable("savedJsonTable", SaveMode.Ignore) + assert(df.schema === table("savedJsonTable").schema) + checkAnswer( + sql("SELECT * FROM savedJsonTable"), + df.collect()) + + // Drop table will also delete the data. + sql("DROP TABLE savedJsonTable") + intercept[InvalidInputException] { + jsonFile(catalog.hiveDefaultTableFilePath("savedJsonTable")) + } + + // Create an external table by specifying the path. + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + df.saveAsTable( + "savedJsonTable", + "org.apache.spark.sql.json", + SaveMode.Append, + Map("path" -> tempPath.toString)) + checkAnswer( + sql("SELECT * FROM savedJsonTable"), + df.collect()) + + // Data should not be deleted after we drop the table. + sql("DROP TABLE savedJsonTable") + checkAnswer( + jsonFile(tempPath.toString), + df.collect()) + + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + } + + test("create external table") { + val originalDefaultSource = conf.defaultDataSourceName + + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) + val df = jsonRDD(rdd) + + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + df.saveAsTable( + "savedJsonTable", + "org.apache.spark.sql.json", + SaveMode.Append, + Map("path" -> tempPath.toString)) + + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + createExternalTable("createdJsonTable", tempPath.toString) + assert(table("createdJsonTable").schema === df.schema) + checkAnswer( + sql("SELECT * FROM createdJsonTable"), + df.collect()) + + var message = intercept[AnalysisException] { + createExternalTable("createdJsonTable", filePath.toString) + }.getMessage + assert(message.contains("Table createdJsonTable already exists."), + "We should complain that ctasJsonTable already exists") + + // Data should not be deleted. + sql("DROP TABLE createdJsonTable") + checkAnswer( + jsonFile(tempPath.toString), + df.collect()) + + // Try to specify the schema. + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + val schema = StructType(StructField("b", StringType, true) :: Nil) + createExternalTable( + "createdJsonTable", + "org.apache.spark.sql.json", + schema, + Map("path" -> tempPath.toString)) + checkAnswer( + sql("SELECT * FROM createdJsonTable"), + sql("SELECT b FROM savedJsonTable").collect()) + + sql("DROP TABLE createdJsonTable") + + message = intercept[RuntimeException] { + createExternalTable( + "createdJsonTable", + "org.apache.spark.sql.json", + schema, + Map.empty[String, String]) + }.getMessage + assert( + message.contains("'path' must be specified for json data."), + "We should complain that path is not specified.") + + sql("DROP TABLE savedJsonTable") + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + } + + if (HiveShim.version == "0.13.1") { + test("scan a parquet table created through a CTAS statement") { + val originalConvertMetastore = getConf("spark.sql.hive.convertMetastoreParquet", "true") + val originalUseDataSource = getConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + setConf("spark.sql.hive.convertMetastoreParquet", "true") + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) + jsonRDD(rdd).registerTempTable("jt") + sql( + """ + |create table test_parquet_ctas STORED AS parquET + |AS select tmp.a from jt tmp where tmp.a < 5 + """.stripMargin) + + checkAnswer( + sql(s"SELECT a FROM test_parquet_ctas WHERE a > 2 "), + Row(3) :: Row(4) :: Nil + ) + + table("test_parquet_ctas").queryExecution.optimizedPlan match { + case LogicalRelation(p: ParquetRelation2) => // OK + case _ => + fail( + "test_parquet_ctas should be converted to " + + s"${classOf[ParquetRelation2].getCanonicalName}") + } + + // Clenup and reset confs. + sql("DROP TABLE IF EXISTS jt") + sql("DROP TABLE IF EXISTS test_parquet_ctas") + setConf("spark.sql.hive.convertMetastoreParquet", originalConvertMetastore) + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalUseDataSource) + } + } + + test("Pre insert nullability check (ArrayType)") { + val df1 = + createDataFrame(Tuple1(Seq(Int.box(1), null.asInstanceOf[Integer])) :: Nil).toDF("a") + val expectedSchema1 = + StructType( + StructField("a", ArrayType(IntegerType, containsNull = true), nullable = true) :: Nil) + assert(df1.schema === expectedSchema1) + df1.saveAsTable("arrayInParquet", "parquet", SaveMode.Overwrite) + + val df2 = + createDataFrame(Tuple1(Seq(2, 3)) :: Nil).toDF("a") + val expectedSchema2 = + StructType( + StructField("a", ArrayType(IntegerType, containsNull = false), nullable = true) :: Nil) + assert(df2.schema === expectedSchema2) + df2.insertInto("arrayInParquet", overwrite = false) + createDataFrame(Tuple1(Seq(4, 5)) :: Nil).toDF("a") + .saveAsTable("arrayInParquet", SaveMode.Append) // This one internally calls df2.insertInto. + createDataFrame(Tuple1(Seq(Int.box(6), null.asInstanceOf[Integer])) :: Nil).toDF("a") + .saveAsTable("arrayInParquet", "parquet", SaveMode.Append) + refreshTable("arrayInParquet") + + checkAnswer( + sql("SELECT a FROM arrayInParquet"), + Row(ArrayBuffer(1, null)) :: + Row(ArrayBuffer(2, 3)) :: + Row(ArrayBuffer(4, 5)) :: + Row(ArrayBuffer(6, null)) :: Nil) + + sql("DROP TABLE arrayInParquet") + } + + test("Pre insert nullability check (MapType)") { + val df1 = + createDataFrame(Tuple1(Map(1 -> null.asInstanceOf[Integer])) :: Nil).toDF("a") + val mapType1 = MapType(IntegerType, IntegerType, valueContainsNull = true) + val expectedSchema1 = + StructType( + StructField("a", mapType1, nullable = true) :: Nil) + assert(df1.schema === expectedSchema1) + df1.saveAsTable("mapInParquet", "parquet", SaveMode.Overwrite) + + val df2 = + createDataFrame(Tuple1(Map(2 -> 3)) :: Nil).toDF("a") + val mapType2 = MapType(IntegerType, IntegerType, valueContainsNull = false) + val expectedSchema2 = + StructType( + StructField("a", mapType2, nullable = true) :: Nil) + assert(df2.schema === expectedSchema2) + df2.insertInto("mapInParquet", overwrite = false) + createDataFrame(Tuple1(Map(4 -> 5)) :: Nil).toDF("a") + .saveAsTable("mapInParquet", SaveMode.Append) // This one internally calls df2.insertInto. + createDataFrame(Tuple1(Map(6 -> null.asInstanceOf[Integer])) :: Nil).toDF("a") + .saveAsTable("mapInParquet", "parquet", SaveMode.Append) + refreshTable("mapInParquet") + + checkAnswer( + sql("SELECT a FROM mapInParquet"), + Row(Map(1 -> null)) :: + Row(Map(2 -> 3)) :: + Row(Map(4 -> 5)) :: + Row(Map(6 -> null)) :: Nil) + + sql("DROP TABLE mapInParquet") + } + + test("SPARK-6024 wide schema support") { + // We will need 80 splits for this schema if the threshold is 4000. + val schema = StructType((1 to 5000).map(i => StructField(s"c_${i}", StringType, true))) + assert( + schema.json.size > conf.schemaStringLengthThreshold, + "To correctly test the fix of SPARK-6024, the value of " + + s"spark.sql.sources.schemaStringLengthThreshold needs to be less than ${schema.json.size}") + // Manually create a metastore data source table. + catalog.createDataSourceTable( + tableName = "wide_schema", + userSpecifiedSchema = Some(schema), + provider = "json", + options = Map("path" -> "just a dummy path"), + isExternal = false) + + invalidateTable("wide_schema") + + val actualSchema = table("wide_schema").schema + assert(schema === actualSchema) + } + + test("SPARK-6655 still support a schema stored in spark.sql.sources.schema") { + val tableName = "spark6655" + val schema = StructType(StructField("int", IntegerType, true) :: Nil) + // Manually create the metadata in metastore. + val tbl = new Table("default", tableName) + tbl.setProperty("spark.sql.sources.provider", "json") + tbl.setProperty("spark.sql.sources.schema", schema.json) + tbl.setProperty("EXTERNAL", "FALSE") + tbl.setTableType(TableType.MANAGED_TABLE) + tbl.setSerdeParam("path", catalog.hiveDefaultTableFilePath(tableName)) + catalog.synchronized { + catalog.client.createTable(tbl) + } + + invalidateTable(tableName) + val actualSchema = table(tableName).schema + assert(schema === actualSchema) + sql(s"drop table $tableName") + } + + + test("insert into a table") { + def createDF(from: Int, to: Int): DataFrame = + createDataFrame((from to to).map(i => Tuple2(i, s"str$i"))).toDF("c1", "c2") + + createDF(0, 9).saveAsTable("insertParquet", "parquet") + checkAnswer( + sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + (6 to 9).map(i => Row(i, s"str$i"))) + + intercept[AnalysisException] { + createDF(10, 19).saveAsTable("insertParquet", "parquet") + } + + createDF(10, 19).saveAsTable("insertParquet", "parquet", SaveMode.Append) + checkAnswer( + sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + (6 to 19).map(i => Row(i, s"str$i"))) + + createDF(20, 29).saveAsTable("insertParquet", "parquet", SaveMode.Append) + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 25"), + (6 to 24).map(i => Row(i, s"str$i"))) + + intercept[AnalysisException] { + createDF(30, 39).saveAsTable("insertParquet") + } + + createDF(30, 39).saveAsTable("insertParquet", SaveMode.Append) + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 35"), + (6 to 34).map(i => Row(i, s"str$i"))) + + createDF(40, 49).insertInto("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 45"), + (6 to 44).map(i => Row(i, s"str$i"))) + + createDF(50, 59).saveAsTable("insertParquet", SaveMode.Overwrite) + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 51 AND p.c1 < 55"), + (52 to 54).map(i => Row(i, s"str$i"))) + createDF(60, 69).saveAsTable("insertParquet", SaveMode.Ignore) + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p"), + (50 to 59).map(i => Row(i, s"str$i"))) + + createDF(70, 79).insertInto("insertParquet", overwrite = true) + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p"), + (70 to 79).map(i => Row(i, s"str$i"))) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala new file mode 100644 index 000000000000..a787fa5546e7 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import com.google.common.io.Files + +import org.apache.spark.sql.{QueryTest, _} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.util.Utils + + +class QueryPartitionSuite extends QueryTest { + import org.apache.spark.sql.hive.test.TestHive.implicits._ + + test("SPARK-5068: query data when path doesn't exists"){ + val testData = TestHive.sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString))).toDF() + testData.registerTempTable("testData") + + val tmpDir = Files.createTempDir() + // create the table for test + sql(s"CREATE TABLE table_with_partition(key int,value string) " + + s"PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + + "SELECT key,value FROM testData") + + // test for the exist path + checkAnswer(sql("select key,value from table_with_partition"), + testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect + ++ testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect) + + // delete the path of one partition + val folders = tmpDir.listFiles.filter(_.isDirectory) + Utils.deleteRecursively(folders(0)) + + // test for after delete the path + checkAnswer(sql("select key,value from table_with_partition"), + testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect + ++ testData.toSchemaRDD.collect) + + sql("DROP TABLE table_with_partition") + sql("DROP TABLE createAndInsertTest") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala new file mode 100644 index 000000000000..d6ddd539d159 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.scalatest.FunSuite + +import org.apache.spark.SparkConf +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.sql.hive.test.TestHive + +class SerializationSuite extends FunSuite { + + test("[SPARK-5840] HiveContext should be serializable") { + val hiveContext = new HiveContext(TestHive.sparkContext) + hiveContext.hiveconf + new JavaSerializer(new SparkConf()).newInstance().serialize(hiveContext) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 6f07fd5a879c..00a69de9e426 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -120,18 +120,18 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { // Try to analyze a temp table sql("""SELECT * FROM src""").registerTempTable("tempTable") - intercept[NotImplementedError] { + intercept[UnsupportedOperationException] { analyze("tempTable") } catalog.unregisterTable(Seq("tempTable")) } test("estimates the size of a test MetastoreRelation") { - val rdd = sql("""SELECT * FROM src""") - val sizes = rdd.queryExecution.analyzed.collect { case mr: MetastoreRelation => + val df = sql("""SELECT * FROM src""") + val sizes = df.queryExecution.analyzed.collect { case mr: MetastoreRelation => mr.statistics.sizeInBytes } - assert(sizes.size === 1, s"Size wrong for:\n ${rdd.queryExecution}") + assert(sizes.size === 1, s"Size wrong for:\n ${df.queryExecution}") assert(sizes(0).equals(BigInt(5812)), s"expected exact size 5812 for test table 'src', got: ${sizes(0)}") } @@ -142,13 +142,13 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { after: () => Unit, query: String, expectedAnswer: Seq[Row], - ct: ClassTag[_]) = { + ct: ClassTag[_]): Unit = { before() - var rdd = sql(query) + var df = sql(query) // Assert src has a size smaller than the threshold. - val sizes = rdd.queryExecution.analyzed.collect { + val sizes = df.queryExecution.analyzed.collect { case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } assert(sizes.size === 2 && sizes(0) <= conf.autoBroadcastJoinThreshold @@ -157,21 +157,21 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { // Using `sparkPlan` because for relevant patterns in HashJoin to be // matched, other strategies need to be applied. - var bhj = rdd.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } + var bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } assert(bhj.size === 1, - s"actual query plans do not contain broadcast join: ${rdd.queryExecution}") + s"actual query plans do not contain broadcast join: ${df.queryExecution}") - checkAnswer(rdd, expectedAnswer) // check correctness of output + checkAnswer(df, expectedAnswer) // check correctness of output TestHive.conf.settings.synchronized { val tmp = conf.autoBroadcastJoinThreshold sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1""") - rdd = sql(query) - bhj = rdd.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } + df = sql(query) + bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") - val shj = rdd.queryExecution.sparkPlan.collect { case j: ShuffledHashJoin => j } + val shj = df.queryExecution.sparkPlan.collect { case j: ShuffledHashJoin => j } assert(shj.size === 1, "ShuffledHashJoin should be planned when BroadcastHashJoin is turned off") @@ -199,10 +199,10 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { |left semi JOIN src b ON a.key=86 and a.key = b.key""".stripMargin val answer = Row(86, "val_86") - var rdd = sql(leftSemiJoinQuery) + var df = sql(leftSemiJoinQuery) // Assert src has a size smaller than the threshold. - val sizes = rdd.queryExecution.analyzed.collect { + val sizes = df.queryExecution.analyzed.collect { case r if implicitly[ClassTag[MetastoreRelation]].runtimeClass .isAssignableFrom(r.getClass) => r.statistics.sizeInBytes @@ -213,25 +213,25 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { // Using `sparkPlan` because for relevant patterns in HashJoin to be // matched, other strategies need to be applied. - var bhj = rdd.queryExecution.sparkPlan.collect { + var bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastLeftSemiJoinHash => j } assert(bhj.size === 1, - s"actual query plans do not contain broadcast join: ${rdd.queryExecution}") + s"actual query plans do not contain broadcast join: ${df.queryExecution}") - checkAnswer(rdd, answer) // check correctness of output + checkAnswer(df, answer) // check correctness of output TestHive.conf.settings.synchronized { val tmp = conf.autoBroadcastJoinThreshold sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") - rdd = sql(leftSemiJoinQuery) - bhj = rdd.queryExecution.sparkPlan.collect { + df = sql(leftSemiJoinQuery) + bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastLeftSemiJoinHash => j } assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") - val shj = rdd.queryExecution.sparkPlan.collect { + val shj = df.queryExecution.sparkPlan.collect { case j: LeftSemiJoinHash => j } assert(shj.size === 1, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala new file mode 100644 index 000000000000..85b6bc93d712 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +/* Implicits */ + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.hive.test.TestHive._ + +case class FunctionResult(f1: String, f2: String) + +class UDFSuite extends QueryTest { + test("UDF case insensitive") { + udf.register("random0", () => { Math.random()}) + udf.register("RANDOM1", () => { Math.random()}) + udf.register("strlenScala", (_: String).length + (_:Int)) + assert(sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala index 42a82c1fbf5c..a3f5921a0cb2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.hive.test.TestHive._ class BigDataBenchmarkSuite extends HiveComparisonTest { val testDataDirectory = new File("target" + File.separator + "big-data-benchmark-testdata") + val userVisitPath = new File(testDataDirectory, "uservisits").getCanonicalPath val testTables = Seq( TestTable( "rankings", @@ -63,7 +64,7 @@ class BigDataBenchmarkSuite extends HiveComparisonTest { | searchWord STRING, | duration INT) | ROW FORMAT DELIMITED FIELDS TERMINATED BY "," - | STORED AS TEXTFILE LOCATION "${new File(testDataDirectory, "uservisits").getCanonicalPath}" + | STORED AS TEXTFILE LOCATION "$userVisitPath" """.stripMargin.cmd), TestTable( "documents", @@ -83,7 +84,10 @@ class BigDataBenchmarkSuite extends HiveComparisonTest { "SELECT pageURL, pageRank FROM rankings WHERE pageRank > 1") createQueryTest("query2", - "SELECT SUBSTR(sourceIP, 1, 10), SUM(adRevenue) FROM uservisits GROUP BY SUBSTR(sourceIP, 1, 10)") + """ + |SELECT SUBSTR(sourceIP, 1, 10), SUM(adRevenue) FROM uservisits + |GROUP BY SUBSTR(sourceIP, 1, 10) + """.stripMargin) createQueryTest("query3", """ @@ -113,8 +117,8 @@ class BigDataBenchmarkSuite extends HiveComparisonTest { |CREATE TABLE url_counts_total AS | SELECT SUM(count) AS totalCount, destpage | FROM url_counts_partial GROUP BY destpage - |-- The following queries run, but generate different results in HIVE likely because the UDF is not deterministic - |-- given different input splits. + |-- The following queries run, but generate different results in HIVE + |-- likely because the UDF is not deterministic given different input splits. |-- SELECT CAST(SUM(count) AS INT) FROM url_counts_partial |-- SELECT COUNT(*) FROM url_counts_partial |-- SELECT * FROM url_counts_partial diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index f8a957d55d57..027056d4b865 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -22,8 +22,8 @@ import java.io._ import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen} import org.apache.spark.Logging +import org.apache.spark.sql.sources.DescribeCommand import org.apache.spark.sql.execution.{SetCommand, ExplainCommand} -import org.apache.spark.sql.hive.DescribeCommand import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ @@ -138,7 +138,7 @@ abstract class HiveComparisonTest case _ => plan.children.iterator.exists(isSorted) } - val orderedAnswer = hiveQuery.logical match { + val orderedAnswer = hiveQuery.analyzed match { // Clean out non-deterministic time schema info. // Hack: Hive simply prints the result of a SET command to screen, // and does not return it as a query answer. @@ -241,7 +241,10 @@ abstract class HiveComparisonTest // Clear old output for this testcase. outputDirectories.map(new File(_, testCaseName)).filter(_.exists()).foreach(_.delete()) - val allQueries = sql.split("(?<=[^\\\\]);").map(_.trim).filterNot(q => q == "").toSeq + val sqlWithoutComment = + sql.split("\n").filterNot(l => l.matches("--.*(?<=[^\\\\]);")).mkString("\n") + val allQueries = + sqlWithoutComment.split("(?<=[^\\\\]);").map(_.trim).filterNot(q => q == "").toSeq // TODO: DOCUMENT UNSUPPORTED val queryList = @@ -252,8 +255,9 @@ abstract class HiveComparisonTest .filterNot(_ contains "hive.outerjoin.supports.filters") .filterNot(_ contains "hive.exec.post.hooks") - if (allQueries != queryList) + if (allQueries != queryList) { logWarning(s"Simplifications made on unsupported operations for test $testCaseName") + } lazy val consoleTestCase = { val quotes = "\"\"\"" @@ -296,19 +300,22 @@ abstract class HiveComparisonTest val hiveQueries = queryList.map(new TestHive.HiveQLQueryExecution(_)) // Make sure we can at least parse everything before attempting hive execution. - hiveQueries.foreach(_.logical) + hiveQueries.foreach(_.analyzed) val computedResults = (queryList.zipWithIndex, hiveQueries, hiveCacheFiles).zipped.map { case ((queryString, i), hiveQuery, cachedAnswerFile)=> try { // Hooks often break the harness and don't really affect our test anyway, don't // even try running them. - if (installHooksCommand.findAllMatchIn(queryString).nonEmpty) + if (installHooksCommand.findAllMatchIn(queryString).nonEmpty) { sys.error("hive exec hooks not supported for tests.") + } - logWarning(s"Running query ${i+1}/${queryList.size} with hive.") + logWarning(s"Running query ${i + 1}/${queryList.size} with hive.") // Analyze the query with catalyst to ensure test tables are loaded. val answer = hiveQuery.analyzed match { - case _: ExplainCommand => Nil // No need to execute EXPLAIN queries as we don't check the output. + case _: ExplainCommand => + // No need to execute EXPLAIN queries as we don't check the output. + Nil case _ => TestHive.runSqlHive(queryString) } @@ -391,21 +398,24 @@ abstract class HiveComparisonTest case tf: org.scalatest.exceptions.TestFailedException => throw tf case originalException: Exception => if (System.getProperty("spark.hive.canarytest") != null) { - // When we encounter an error we check to see if the environment is still okay by running a simple query. - // If this fails then we halt testing since something must have gone seriously wrong. + // When we encounter an error we check to see if the environment is still + // okay by running a simple query. If this fails then we halt testing since + // something must have gone seriously wrong. try { new TestHive.HiveQLQueryExecution("SELECT key FROM src").stringResult() TestHive.runSqlHive("SELECT key FROM src") } catch { case e: Exception => - logError(s"FATAL ERROR: Canary query threw $e This implies that the testing environment has likely been corrupted.") - // The testing setup traps exits so wait here for a long time so the developer can see when things started - // to go wrong. + logError(s"FATAL ERROR: Canary query threw $e This implies that the " + + "testing environment has likely been corrupted.") + // The testing setup traps exits so wait here for a long time so the developer + // can see when things started to go wrong. Thread.sleep(1000000) } } - // If the canary query didn't fail then the environment is still okay, so just throw the original exception. + // If the canary query didn't fail then the environment is still okay, + // so just throw the original exception. throw originalException } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala new file mode 100644 index 000000000000..efbef68cd444 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.hive.test.TestHive._ + +/** + * A set of tests that validates commands can also be queried by like a table + */ +class HiveOperatorQueryableSuite extends QueryTest { + test("SPARK-5324 query result of describe command") { + loadTestTable("src") + + // register a describe command to be a temp table + sql("desc src").registerTempTable("mydesc") + checkAnswer( + sql("desc mydesc"), + Seq( + Row("col_name", "string", "name of the column"), + Row("data_type", "string", "data type of the column"), + Row("comment", "string", "comment of the column"))) + + checkAnswer( + sql("select * from mydesc"), + Seq( + Row("key", "int", null), + Row("value", "string", null))) + + checkAnswer( + sql("select col_name, data_type, comment from mydesc"), + Seq( + Row("key", "int", null), + Row("value", "string", null))) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala index c939e6e99d28..bdb53ddf59c1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -22,10 +22,12 @@ import org.apache.spark.sql.hive.test.TestHive class HivePlanTest extends QueryTest { import TestHive._ + import TestHive.implicits._ test("udf constant folding") { - val optimized = sql("SELECT cos(null) FROM src").queryExecution.optimizedPlan - val correctAnswer = sql("SELECT cast(null as double) FROM src").queryExecution.optimizedPlan + Seq.empty[Tuple1[Int]].toDF("a").registerTempTable("t") + val optimized = sql("SELECT cos(null) FROM t").queryExecution.optimizedPlan + val correctAnswer = sql("SELECT cast(null as double) FROM t").queryExecution.optimizedPlan comparePlans(optimized, correctAnswer) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala index 02518d516261..f7b37dae0a5f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala @@ -24,8 +24,9 @@ import org.apache.spark.sql.catalyst.util._ /** * A framework for running the query tests that are listed as a set of text files. * - * TestSuites that derive from this class must provide a map of testCaseName -> testCaseFiles that should be included. - * Additionally, there is support for whitelisting and blacklisting tests as development progresses. + * TestSuites that derive from this class must provide a map of testCaseName -> testCaseFiles + * that should be included. Additionally, there is support for whitelisting and blacklisting + * tests as development progresses. */ abstract class HiveQueryFileTest extends HiveComparisonTest { /** A list of tests deemed out of scope and thus completely disregarded */ @@ -54,15 +55,17 @@ abstract class HiveQueryFileTest extends HiveComparisonTest { case (testCaseName, testCaseFile) => if (blackList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_)) { logDebug(s"Blacklisted test skipped $testCaseName") - } else if (realWhiteList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_) || runAll) { + } else if (realWhiteList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_) || + runAll) { // Build a test case and submit it to scala test framework... val queriesString = fileToString(testCaseFile) createQueryTest(testCaseName, queriesString) } else { // Only output warnings for the built in whitelist as this clutters the output when the user // trying to execute a single test from the commandline. - if(System.getProperty(whiteListProperty) == null && !runAll) + if (System.getProperty(whiteListProperty) == null && !runAll) { ignore(testCaseName) {} + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index df72be7746ac..ac10b173307d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -27,21 +27,25 @@ import scala.util.Try import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.{SparkFiles, SparkException} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{SQLConf, Row, SchemaRDD} case class TestData(a: Int, b: String) /** - * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. + * A set of test cases expressed in Hive QL that are not covered by the tests + * included in the hive distribution. */ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { private val originalTimeZone = TimeZone.getDefault private val originalLocale = Locale.getDefault + import org.apache.spark.sql.hive.test.TestHive.implicits._ + override def beforeAll() { TestHive.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) @@ -56,13 +60,47 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { Locale.setDefault(originalLocale) } - test("SPARK-4908: concurent hive native commands") { + test("SPARK-4908: concurrent hive native commands") { (1 to 100).par.map { _ => sql("USE default") - sql("SHOW TABLES") + sql("SHOW DATABASES") + } + } + + createQueryTest("insert table with generator with column name", + """ + | CREATE TABLE gen_tmp (key Int); + | INSERT OVERWRITE TABLE gen_tmp + | SELECT explode(array(1,2,3)) AS val FROM src LIMIT 3; + | SELECT key FROM gen_tmp ORDER BY key ASC; + """.stripMargin) + + createQueryTest("insert table with generator with multiple column names", + """ + | CREATE TABLE gen_tmp (key Int, value String); + | INSERT OVERWRITE TABLE gen_tmp + | SELECT explode(map(key, value)) as (k1, k2) FROM src LIMIT 3; + | SELECT key, value FROM gen_tmp ORDER BY key, value ASC; + """.stripMargin) + + createQueryTest("insert table with generator without column name", + """ + | CREATE TABLE gen_tmp (key Int); + | INSERT OVERWRITE TABLE gen_tmp + | SELECT explode(array(1,2,3)) FROM src LIMIT 3; + | SELECT key FROM gen_tmp ORDER BY key ASC; + """.stripMargin) + + test("multiple generator in projection") { + intercept[AnalysisException] { + sql("SELECT explode(map(key, value)), key FROM src").collect() + } + + intercept[AnalysisException] { + sql("SELECT explode(map(key, value)) as k1, k2, key FROM src").collect() } } - + createQueryTest("! operator", """ |SELECT a FROM ( @@ -199,6 +237,9 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("having no references", "SELECT key FROM src GROUP BY key HAVING COUNT(*) > 1") + createQueryTest("no from clause", + "SELECT 1, +1, -1") + createQueryTest("boolean = number", """ |SELECT @@ -231,7 +272,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } createQueryTest("modulus", - "SELECT 11 % 10, IF((101.1 % 100.0) BETWEEN 1.01 AND 1.11, \"true\", \"false\"), (101 / 2) % 10 FROM src LIMIT 1") + "SELECT 11 % 10, IF((101.1 % 100.0) BETWEEN 1.01 AND 1.11, \"true\", \"false\"), " + + "(101 / 2) % 10 FROM src LIMIT 1") test("Query expressed in SQL") { setConf("spark.sql.dialect", "sql") @@ -252,8 +294,30 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("Cast Timestamp to Timestamp in UDF", """ - | SELECT DATEDIFF(CAST(value AS timestamp), CAST('2002-03-21 00:00:00' AS timestamp)) - | FROM src LIMIT 1 + | SELECT DATEDIFF(CAST(value AS timestamp), CAST('2002-03-21 00:00:00' AS timestamp)) + | FROM src LIMIT 1 + """.stripMargin) + + createQueryTest("Date comparison test 1", + """ + | SELECT + | CAST(CAST('1970-01-01 22:00:00' AS timestamp) AS date) == + | CAST(CAST('1970-01-01 23:00:00' AS timestamp) AS date) + | FROM src LIMIT 1 + """.stripMargin) + + createQueryTest("Date comparison test 2", + "SELECT CAST(CAST(0 AS timestamp) AS date) > CAST(0 AS timestamp) FROM src LIMIT 1") + + createQueryTest("Date cast", + """ + | SELECT + | CAST(CAST(0 AS timestamp) AS date), + | CAST(CAST(CAST(0 AS timestamp) AS date) AS string), + | CAST(0 AS timestamp), + | CAST(CAST(0 AS timestamp) AS string), + | CAST(CAST(CAST('1970-01-01 23:00:00' AS timestamp) AS date) AS timestamp) + | FROM src LIMIT 1 """.stripMargin) createQueryTest("Simple Average", @@ -281,7 +345,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { "SELECT * FROM src a JOIN src b ON a.key = b.key") createQueryTest("small.cartesian", - "SELECT a.key, b.key FROM (SELECT key FROM src WHERE key < 1) a JOIN (SELECT key FROM src WHERE key = 2) b") + "SELECT a.key, b.key FROM (SELECT key FROM src WHERE key < 1) a JOIN " + + "(SELECT key FROM src WHERE key = 2) b") createQueryTest("length.udf", "SELECT length(\"test\") FROM src LIMIT 1") @@ -306,6 +371,14 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |DROP DATABASE IF EXISTS testdb CASCADE """.stripMargin) + createQueryTest("create table as with db name within backticks", + """ + |CREATE DATABASE IF NOT EXISTS testdb; + |CREATE TABLE `testdb`.`createdtable` AS SELECT * FROM default.src; + |SELECT * FROM testdb.createdtable; + |DROP DATABASE IF EXISTS testdb CASCADE + """.stripMargin) + createQueryTest("insert table with db name", """ |CREATE DATABASE IF NOT EXISTS testdb; @@ -328,6 +401,80 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("transform", "SELECT TRANSFORM (key) USING 'cat' AS (tKey) FROM src") + createQueryTest("schema-less transform", + """ + |SELECT TRANSFORM (key, value) USING 'cat' FROM src; + |SELECT TRANSFORM (*) USING 'cat' FROM src; + """.stripMargin) + + val delimiter = "'\t'" + + createQueryTest("transform with custom field delimiter", + s""" + |SELECT TRANSFORM (key) ROW FORMAT DELIMITED FIELDS TERMINATED BY ${delimiter} + |USING 'cat' AS (tKey) ROW FORMAT DELIMITED FIELDS TERMINATED BY ${delimiter} FROM src; + """.stripMargin.replaceAll("\n", " ")) + + createQueryTest("transform with custom field delimiter2", + s""" + |SELECT TRANSFORM (key, value) ROW FORMAT DELIMITED FIELDS TERMINATED BY ${delimiter} + |USING 'cat' ROW FORMAT DELIMITED FIELDS TERMINATED BY ${delimiter} FROM src; + """.stripMargin.replaceAll("\n", " ")) + + createQueryTest("transform with custom field delimiter3", + s""" + |SELECT TRANSFORM (*) ROW FORMAT DELIMITED FIELDS TERMINATED BY ${delimiter} + |USING 'cat' ROW FORMAT DELIMITED FIELDS TERMINATED BY ${delimiter} FROM src; + """.stripMargin.replaceAll("\n", " ")) + + createQueryTest("transform with SerDe", + """ + |SELECT TRANSFORM (key, value) ROW FORMAT SERDE + |'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + |USING 'cat' AS (tKey, tValue) ROW FORMAT SERDE + |'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' FROM src; + """.stripMargin.replaceAll("\n", " ")) + + test("transform with SerDe2") { + + sql("CREATE TABLE small_src(key INT, value STRING)") + sql("INSERT OVERWRITE TABLE small_src SELECT key, value FROM src LIMIT 10") + + val expected = sql("SELECT key FROM small_src").collect().head + val res = sql( + """ + |SELECT TRANSFORM (key) ROW FORMAT SERDE + |'org.apache.hadoop.hive.serde2.avro.AvroSerDe' + |WITH SERDEPROPERTIES ('avro.schema.literal'='{"namespace": + |"testing.hive.avro.serde","name": "src","type": "record","fields": + |[{"name":"key","type":"int"}]}') USING 'cat' AS (tKey INT) ROW FORMAT SERDE + |'org.apache.hadoop.hive.serde2.avro.AvroSerDe' WITH SERDEPROPERTIES + |('avro.schema.literal'='{"namespace": "testing.hive.avro.serde","name": + |"src","type": "record","fields": [{"name":"key","type":"int"}]}') + |FROM small_src + """.stripMargin.replaceAll("\n", " ")).collect().head + + assert(expected(0) === res(0)) + } + + createQueryTest("transform with SerDe3", + """ + |SELECT TRANSFORM (*) ROW FORMAT SERDE + |'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' WITH SERDEPROPERTIES + |('serialization.last.column.takes.rest'='true') USING 'cat' AS (tKey, tValue) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + |WITH SERDEPROPERTIES ('serialization.last.column.takes.rest'='true') FROM src; + """.stripMargin.replaceAll("\n", " ")) + + createQueryTest("transform with SerDe4", + """ + |SELECT TRANSFORM (*) ROW FORMAT SERDE + |'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' WITH SERDEPROPERTIES + |('serialization.last.column.takes.rest'='true') USING 'cat' ROW FORMAT SERDE + |'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' WITH SERDEPROPERTIES + |('serialization.last.column.takes.rest'='true') FROM src; + """.stripMargin.replaceAll("\n", " ")) + createQueryTest("LIKE", "SELECT * FROM src WHERE value LIKE '%1%'") @@ -343,10 +490,10 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("lateral view2", "SELECT * FROM src LATERAL VIEW explode(array(1,2)) tbl") - createQueryTest("lateral view3", "FROM src SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX") + // scalastyle:off createQueryTest("lateral view4", """ |create table src_lv1 (key string, value string); @@ -356,6 +503,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |insert overwrite table src_lv1 SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX |insert overwrite table src_lv2 SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX """.stripMargin) + // scalastyle:on createQueryTest("lateral view5", "FROM src SELECT explode(array(key+3, key+4))") @@ -363,11 +511,15 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("lateral view6", "SELECT * FROM src LATERAL VIEW explode(map(key+3,key+4)) D as k, v") + createQueryTest("Specify the udtf output", + "SELECT d FROM (SELECT explode(array(1,1)) d FROM src LIMIT 1) t") + test("sampling") { sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s") + sql("SELECT * FROM src TABLESAMPLE(100 PERCENT) s") } - test("SchemaRDD toString") { + test("DataFrame toString") { sql("SHOW TABLES").toString sql("SELECT * FROM src").toString } @@ -426,6 +578,21 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("select null from table", "SELECT null FROM src LIMIT 1") + createQueryTest("CTE feature #1", + "with q1 as (select key from src) select * from q1 where key = 5") + + createQueryTest("CTE feature #2", + """with q1 as (select * from src where key= 5), + |q2 as (select * from src s2 where key = 4) + |select value from q1 union all select value from q2 + """.stripMargin) + + createQueryTest("CTE feature #3", + """with q1 as (select key from src) + |from q1 + |select * where key = 4 + """.stripMargin) + test("predicates contains an empty AttributeSet() references") { sql( """ @@ -465,7 +632,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TestHive.sparkContext.parallelize( TestData(1, "str1") :: TestData(2, "str2") :: Nil) - testData.registerTempTable("REGisteredTABle") + testData.toDF().registerTempTable("REGisteredTABle") assertResult(Array(Row(2, "str2"))) { sql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " + @@ -473,16 +640,16 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } - def isExplanation(result: SchemaRDD) = { + def isExplanation(result: DataFrame): Boolean = { val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } explanation.contains("== Physical Plan ==") } - test("SPARK-1704: Explain commands as a SchemaRDD") { + test("SPARK-1704: Explain commands as a DataFrame") { sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - val rdd = sql("explain select key, count(value) from src group by key") - assert(isExplanation(rdd)) + val df = sql("explain select key, count(value) from src group by key") + assert(isExplanation(df)) TestHive.reset() } @@ -490,7 +657,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("SPARK-2180: HAVING support in GROUP BY clauses (positive)") { val fixture = List(("foo", 2), ("bar", 1), ("foo", 4), ("bar", 3)) .zipWithIndex.map {case Pair(Pair(value, attr), key) => HavingRow(key, value, attr)} - TestHive.sparkContext.parallelize(fixture).registerTempTable("having_test") + TestHive.sparkContext.parallelize(fixture).toDF().registerTempTable("having_test") val results = sql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") .collect() @@ -508,25 +675,44 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assert(sql("select key from src having key > 490").collect().size < 100) } + test("SPARK-5383 alias for udfs with multi output columns") { + assert( + sql("select stack(2, key, value, key, value) as (a, b) from src limit 5") + .collect() + .size == 5) + + assert( + sql("select a, b from (select stack(2, key, value, key, value) as (a, b) from src) t limit 5") + .collect() + .size == 5) + } + + test("SPARK-5367: resolve star expression in udf") { + assert(sql("select concat(*) from src limit 5").collect().size == 5) + assert(sql("select array(*) from src limit 5").collect().size == 5) + assert(sql("select concat(key, *) from src limit 5").collect().size == 5) + assert(sql("select array(key, *) from src limit 5").collect().size == 5) + } + test("Query Hive native command execution result") { - val tableName = "test_native_commands" + val databaseName = "test_native_commands" assertResult(0) { - sql(s"DROP TABLE IF EXISTS $tableName").count() + sql(s"DROP DATABASE IF EXISTS $databaseName").count() } assertResult(0) { - sql(s"CREATE TABLE $tableName(key INT, value STRING)").count() + sql(s"CREATE DATABASE $databaseName").count() } assert( - sql("SHOW TABLES") + sql("SHOW DATABASES") .select('result) .collect() .map(_.getString(0)) - .contains(tableName)) + .contains(databaseName)) - assert(isExplanation(sql(s"EXPLAIN SELECT key, COUNT(*) FROM $tableName GROUP BY key"))) + assert(isExplanation(sql(s"EXPLAIN SELECT key, COUNT(*) FROM src GROUP BY key"))) TestHive.reset() } @@ -619,12 +805,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TestHive.sparkContext.parallelize( TestData(1, "str1") :: TestData(1, "str2") :: Nil) - testData.registerTempTable("test_describe_commands2") + testData.toDF().registerTempTable("test_describe_commands2") assertResult( Array( - Row("a", "IntegerType", null), - Row("b", "StringType", null)) + Row("a", "int", ""), + Row("b", "string", "")) ) { sql("DESCRIBE test_describe_commands2") .select('col_name, 'data_type, 'comment) @@ -663,6 +849,21 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql("DROP TABLE alter1") } + test("ADD JAR command 2") { + // this is a test case from mapjoin_addjar.q + val testJar = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath + val testData = TestHive.getHiveFile("data/files/sample.json").getCanonicalPath + if (HiveShim.version == "0.13.1") { + sql(s"ADD JAR $testJar") + sql( + """CREATE TABLE t1(a string, b string) + |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'""".stripMargin) + sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE t1""") + sql("select * from src join t1 on src.key = t1.a") + sql("DROP TABLE t1") + } + } + test("ADD FILE command") { val testFile = TestHive.getHiveFile("data/files/v1.txt").getCanonicalFile sql(s"ADD FILE $testFile") @@ -738,6 +939,22 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } + test("SPARK-5592: get java.net.URISyntaxException when dynamic partitioning") { + sql(""" + |create table sc as select * + |from (select '2011-01-11', '2011-01-11+14:18:26' from src tablesample (1 rows) + |union all + |select '2011-01-11', '2011-01-11+15:18:26' from src tablesample (1 rows) + |union all + |select '2011-01-11', '2011-01-11+16:18:26' from src tablesample (1 rows) ) s + """.stripMargin) + sql("create table sc_part (key string) partitioned by (ts string) stored as rcfile") + sql("set hive.exec.dynamic.partition=true") + sql("set hive.exec.dynamic.partition.mode=nonstrict") + sql("insert overwrite table sc_part partition(ts) select * from sc") + sql("drop table sc_part") + } + test("Partition spec validation") { sql("DROP TABLE IF EXISTS dp_test") sql("CREATE TABLE dp_test(key INT, value STRING) PARTITIONED BY (dp INT, sp INT)") @@ -763,8 +980,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } test("SPARK-3414 regression: should store analyzed logical plan when registering a temp table") { - sparkContext.makeRDD(Seq.empty[LogEntry]).registerTempTable("rawLogs") - sparkContext.makeRDD(Seq.empty[LogFile]).registerTempTable("logFiles") + sparkContext.makeRDD(Seq.empty[LogEntry]).toDF().registerTempTable("rawLogs") + sparkContext.makeRDD(Seq.empty[LogFile]).toDF().registerTempTable("logFiles") sql( """ @@ -842,8 +1059,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { val testVal = "test.val.0" val nonexistentKey = "nonexistent" val KV = "([^=]+)=([^=]*)".r - def collectResults(rdd: SchemaRDD): Set[(String, String)] = - rdd.collect().map { + def collectResults(df: DataFrame): Set[(String, String)] = + df.collect().map { case Row(key: String, value: String) => key -> value case Row(KV(key, value)) => key -> value }.toSet diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index 422e843d2b0d..8ad362750422 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -17,27 +17,37 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.hive.test.TestHive.{sparkContext, jsonRDD, sql} +import org.apache.spark.sql.hive.test.TestHive.implicits._ case class Nested(a: Int, B: Int) case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested]) /** - * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. + * A set of test cases expressed in Hive QL that are not covered by the tests + * included in the hive distribution. */ class HiveResolutionSuite extends HiveComparisonTest { - case class NestedData(a: Seq[NestedData2], B: NestedData2) - case class NestedData2(a: NestedData3, B: NestedData3) - case class NestedData3(a: Int, B: Int) - test("SPARK-3698: case insensitive test for nested data") { - sparkContext.makeRDD(Seq.empty[NestedData]).registerTempTable("nested") + jsonRDD(sparkContext.makeRDD( + """{"a": [{"a": {"a": 1}}]}""" :: Nil)).registerTempTable("nested") // This should be successfully analyzed sql("SELECT a[0].A.A from nested").queryExecution.analyzed } + test("SPARK-5278: check ambiguous reference to fields") { + jsonRDD(sparkContext.makeRDD( + """{"a": [{"b": 1, "B": 2}]}""" :: Nil)).registerTempTable("nested") + + // there are 2 filed matching field name "b", we should report Ambiguous reference error + val exception = intercept[AnalysisException] { + sql("SELECT a[0].b from nested").queryExecution.analyzed + } + assert(exception.getMessage.contains("Ambiguous reference to fields")) + } + createQueryTest("table.attr", "SELECT src.key FROM src ORDER BY key LIMIT 1") @@ -67,8 +77,8 @@ class HiveResolutionSuite extends HiveComparisonTest { test("case insensitivity with scala reflection") { // Test resolution with Scala Reflection - TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) - .registerTempTable("caseSensitivityTest") + sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) + .toDF().registerTempTable("caseSensitivityTest") val query = sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") assert(query.schema.fields.map(_.name) === Seq("a", "b", "A", "B", "a", "b", "A", "B"), @@ -78,18 +88,27 @@ class HiveResolutionSuite extends HiveComparisonTest { ignore("case insensitivity with scala reflection joins") { // Test resolution with Scala Reflection - TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) - .registerTempTable("caseSensitivityTest") + sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) + .toDF().registerTempTable("caseSensitivityTest") sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a").collect() } test("nested repeated resolution") { - TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) - .registerTempTable("nestedRepeatedTest") + sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) + .toDF().registerTempTable("nestedRepeatedTest") assert(sql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1) } + createQueryTest("test ambiguousReferences resolved as hive", + """ + |CREATE TABLE t1(x INT); + |CREATE TABLE t2(a STRUCT, k INT); + |INSERT OVERWRITE TABLE t1 SELECT 1 FROM src LIMIT 1; + |INSERT OVERWRITE TABLE t2 SELECT named_struct("x",1),1 FROM src LIMIT 1; + |SELECT a.x FROM t1 a JOIN t2 b ON a.x = b.k; + """.stripMargin) + /** * Negative examples. Currently only left here for documentation purposes. * TODO(marmbrus): Test that catalyst fails on these queries. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala index 7486bfa82b00..5586a793618b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -25,17 +25,25 @@ import org.apache.spark.sql.hive.test.TestHive * A set of tests that validates support for Hive SerDe. */ class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll { - - override def beforeAll() = { + override def beforeAll(): Unit = { + import TestHive._ + import org.apache.hadoop.hive.serde2.RegexSerDe + super.beforeAll() TestHive.cacheTables = false + sql(s"""CREATE TABLE IF NOT EXISTS sales (key STRING, value INT) + |ROW FORMAT SERDE '${classOf[RegexSerDe].getCanonicalName}' + |WITH SERDEPROPERTIES ("input.regex" = "([^ ]*)\t([^ ]*)") + """.stripMargin) + sql(s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/sales.txt")}' INTO TABLE sales") } + // table sales is not a cache table, and will be clear after reset + createQueryTest("Read with RegexSerDe", "SELECT * FROM sales", false) + createQueryTest( "Read and write with LazySimpleSerDe (tab separated)", "SELECT * from serdeins") - createQueryTest("Read with RegexSerDe", "SELECT * FROM sales") - createQueryTest("Read with AvroSerDe", "SELECT * FROM episodes") createQueryTest("Read Partitioned with AvroSerDe", "SELECT * FROM episodes_part") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 16f77a438e1a..ab53c6309e08 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.Row +import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.util.Utils @@ -82,10 +84,10 @@ class HiveTableScanSuite extends HiveComparisonTest { sql("create table spark_4959 (col1 string)") sql("""insert into table spark_4959 select "hi" from src limit 1""") table("spark_4959").select( - 'col1.as('CaseSensitiveColName), - 'col1.as('CaseSensitiveColName2)).registerTempTable("spark_4959_2") + 'col1.as("CaseSensitiveColName"), + 'col1.as("CaseSensitiveColName2")).registerTempTable("spark_4959_2") - assert(sql("select CaseSensitiveColName from spark_4959_2").first() === Row("hi")) - assert(sql("select casesensitivecolname from spark_4959_2").first() === Row("hi")) + assert(sql("select CaseSensitiveColName from spark_4959_2").head() === Row("hi")) + assert(sql("select casesensitivecolname from spark_4959_2").head() === Row("hi")) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index 48fffe53cf2f..f0f04f8c73fb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -35,8 +35,10 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { val nullVal = "null" baseTypes.init.foreach { i => - createQueryTest(s"case when then $i else $nullVal end ", s"SELECT case when true then $i else $nullVal end FROM src limit 1") - createQueryTest(s"case when then $nullVal else $i end ", s"SELECT case when true then $nullVal else $i end FROM src limit 1") + createQueryTest(s"case when then $i else $nullVal end ", + s"SELECT case when true then $i else $nullVal end FROM src limit 1") + createQueryTest(s"case when then $nullVal else $i end ", + s"SELECT case when true then $nullVal else $i end FROM src limit 1") } test("[SPARK-2210] boolean cast on boolean value should be removed") { @@ -57,4 +59,10 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { } assert(numEquals === 1) } + + test("COALESCE with different types") { + intercept[RuntimeException] { + TestHive.sql("""SELECT COALESCE(1, true, "abc") FROM src limit 1""").collect() + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index f2374a215291..7f49eac49057 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -22,7 +22,7 @@ import java.util import java.util.Properties import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF +import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFAverage, GenericUDF} import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} @@ -47,7 +47,9 @@ case class ListStringCaseClass(l: Seq[String]) * A test suite for Hive custom UDFs. */ class HiveUdfSuite extends QueryTest { - import TestHive._ + + import TestHive.{udf, sql} + import TestHive.implicits._ test("spark sql udf test that returns a struct") { udf.register("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5)) @@ -58,9 +60,9 @@ class HiveUdfSuite extends QueryTest { | getStruct(1).f3, | getStruct(1).f4, | getStruct(1).f5 FROM src LIMIT 1 - """.stripMargin).first() === Row(1, 2, 3, 4, 5)) + """.stripMargin).head() === Row(1, 2, 3, 4, 5)) } - + test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") { checkAnswer( sql("SELECT PMOD(CAST(key as INT), 10) FROM src LIMIT 1"), @@ -91,10 +93,19 @@ class HiveUdfSuite extends QueryTest { sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") } + test("SPARK-6409 UDAFAverage test") { + sql(s"CREATE TEMPORARY FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'") + checkAnswer( + sql("SELECT test_avg(1), test_avg(substr(value,5)) FROM src"), + Seq(Row(1.0, 260.182))) + sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg") + TestHive.reset() + } + test("SPARK-2693 udaf aggregates test") { checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"), sql("SELECT max(key) FROM src").collect().toSeq) - + checkAnswer(sql("SELECT percentile(key, array(1, 1)) FROM src LIMIT 1"), sql("SELECT array(max(key), max(key)) FROM src").collect().toSeq) } @@ -102,19 +113,20 @@ class HiveUdfSuite extends QueryTest { test("Generic UDAF aggregates") { checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999)) FROM src LIMIT 1"), sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq) - + checkAnswer(sql("SELECT percentile_approx(100.0, array(0.9, 0.9)) FROM src LIMIT 1"), sql("SELECT array(100, 100) FROM src LIMIT 1").collect().toSeq) } - + test("UDFIntegerToString") { val testData = TestHive.sparkContext.parallelize( - IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil) + IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil).toDF() testData.registerTempTable("integerTable") - sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '${classOf[UDFIntegerToString].getName}'") + val udfName = classOf[UDFIntegerToString].getName + sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '$udfName'") checkAnswer( - sql("SELECT testUDFIntegerToString(i) FROM integerTable"), //.collect(), + sql("SELECT testUDFIntegerToString(i) FROM integerTable"), Seq(Row("1"), Row("2"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString") @@ -125,12 +137,12 @@ class HiveUdfSuite extends QueryTest { val testData = TestHive.sparkContext.parallelize( ListListIntCaseClass(Nil) :: ListListIntCaseClass(Seq((1, 2, 3))) :: - ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil) + ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil).toDF() testData.registerTempTable("listListIntTable") sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'") checkAnswer( - sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), //.collect(), + sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), Seq(Row(0), Row(2), Row(13))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt") @@ -140,12 +152,12 @@ class HiveUdfSuite extends QueryTest { test("UDFListString") { val testData = TestHive.sparkContext.parallelize( ListStringCaseClass(Seq("a", "b", "c")) :: - ListStringCaseClass(Seq("d", "e")) :: Nil) + ListStringCaseClass(Seq("d", "e")) :: Nil).toDF() testData.registerTempTable("listStringTable") sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'") checkAnswer( - sql("SELECT testUDFListString(l) FROM listStringTable"), //.collect(), + sql("SELECT testUDFListString(l) FROM listStringTable"), Seq(Row("a,b,c"), Row("d,e"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString") @@ -154,12 +166,12 @@ class HiveUdfSuite extends QueryTest { test("UDFStringString") { val testData = TestHive.sparkContext.parallelize( - StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil) + StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil).toDF() testData.registerTempTable("stringTable") sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'") checkAnswer( - sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), //.collect(), + sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), Seq(Row("hello world"), Row("hello goodbye"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUdf") @@ -171,12 +183,12 @@ class HiveUdfSuite extends QueryTest { ListListIntCaseClass(Nil) :: ListListIntCaseClass(Seq((1, 2, 3))) :: ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: - Nil) + Nil).toDF() testData.registerTempTable("TwoListTable") sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") checkAnswer( - sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), //.collect(), + sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), Seq(Row("0, 0"), Row("2, 2"), Row("13, 13"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") @@ -236,7 +248,8 @@ class PairUdf extends GenericUDF { override def initialize(p1: Array[ObjectInspector]): ObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector( Seq("id", "value"), - Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, PrimitiveObjectInspectorFactory.javaIntObjectInspector) + Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector) ) override def evaluate(args: Array[DeferredObject]): AnyRef = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 8474d850c9c6..067b577f1560 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -143,7 +143,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { sql: String, expectedOutputColumns: Seq[String], expectedScannedColumns: Seq[String], - expectedPartValues: Seq[Seq[String]]) = { + expectedPartValues: Seq[Seq[String]]): Unit = { test(s"$testCaseName - pruning test") { val plan = new TestHive.HiveQLQueryExecution(sql).executedPlan val actualOutputColumns = plan.output.map(_.name) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 7f9f1ac7cd80..4f8d0ac0e765 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -17,22 +17,144 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.QueryTest - -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.hive.{MetastoreRelation, HiveShim} +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.parquet.ParquetRelation2 +import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.types._ +import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} case class Nested1(f1: Nested2) case class Nested2(f2: Nested3) case class Nested3(f3: Int) +case class NestedArray2(b: Seq[Int]) +case class NestedArray1(a: NestedArray2) + +case class Order( + id: Int, + make: String, + `type`: String, + price: Int, + pdate: String, + customer: String, + city: String, + state: String, + month: Int) + /** * A collection of hive query tests where we generate the answers ourselves instead of depending on * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is * valid, but Hive currently cannot execute it. */ class SQLQuerySuite extends QueryTest { + test("SPARK-6835: udtf in lateral view") { + val df = Seq((1, 1)).toDF("c1", "c2") + df.registerTempTable("table1") + val query = sql("SELECT c1, v FROM table1 LATERAL VIEW stack(3, 1, c1 + 1, c1 + 2) d AS v") + checkAnswer(query, Row(1, 1) :: Row(1, 2) :: Row(1, 3) :: Nil) + } + + test("SPARK-6851: Self-joined converted parquet tables") { + val orders = Seq( + Order(1, "Atlas", "MTB", 234, "2015-01-07", "John D", "Pacifica", "CA", 20151), + Order(3, "Swift", "MTB", 285, "2015-01-17", "John S", "Redwood City", "CA", 20151), + Order(4, "Atlas", "Hybrid", 303, "2015-01-23", "Jones S", "San Mateo", "CA", 20151), + Order(7, "Next", "MTB", 356, "2015-01-04", "Jane D", "Daly City", "CA", 20151), + Order(10, "Next", "YFlikr", 187, "2015-01-09", "John D", "Fremont", "CA", 20151), + Order(11, "Swift", "YFlikr", 187, "2015-01-23", "John D", "Hayward", "CA", 20151), + Order(2, "Next", "Hybrid", 324, "2015-02-03", "Jane D", "Daly City", "CA", 20152), + Order(5, "Next", "Street", 187, "2015-02-08", "John D", "Fremont", "CA", 20152), + Order(6, "Atlas", "Street", 154, "2015-02-09", "John D", "Pacifica", "CA", 20152), + Order(8, "Swift", "Hybrid", 485, "2015-02-19", "John S", "Redwood City", "CA", 20152), + Order(9, "Atlas", "Split", 303, "2015-02-28", "Jones S", "San Mateo", "CA", 20152)) + + val orderUpdates = Seq( + Order(1, "Atlas", "MTB", 434, "2015-01-07", "John D", "Pacifica", "CA", 20151), + Order(11, "Swift", "YFlikr", 137, "2015-01-23", "John D", "Hayward", "CA", 20151)) + + orders.toDF.registerTempTable("orders1") + orderUpdates.toDF.registerTempTable("orderupdates1") + + sql( + """CREATE TABLE orders( + | id INT, + | make String, + | type String, + | price INT, + | pdate String, + | customer String, + | city String) + |PARTITIONED BY (state STRING, month INT) + |STORED AS PARQUET + """.stripMargin) + + sql( + """CREATE TABLE orderupdates( + | id INT, + | make String, + | type String, + | price INT, + | pdate String, + | customer String, + | city String) + |PARTITIONED BY (state STRING, month INT) + |STORED AS PARQUET + """.stripMargin) + + sql("set hive.exec.dynamic.partition.mode=nonstrict") + sql("INSERT INTO TABLE orders PARTITION(state, month) SELECT * FROM orders1") + sql("INSERT INTO TABLE orderupdates PARTITION(state, month) SELECT * FROM orderupdates1") + + checkAnswer( + sql( + """ + |select orders.state, orders.month + |from orders + |join ( + | select distinct orders.state,orders.month + | from orders + | join orderupdates + | on orderupdates.id = orders.id) ao + | on ao.state = orders.state and ao.month = orders.month + """.stripMargin), + (1 to 6).map(_ => Row("CA", 20151))) + } + + test("SPARK-5371: union with null and sum") { + val df = Seq((1, 1)).toDF("c1", "c2") + df.registerTempTable("table1") + + val query = sql( + """ + |SELECT + | MIN(c1), + | MIN(c2) + |FROM ( + | SELECT + | SUM(c1) c1, + | NULL c2 + | FROM table1 + | UNION ALL + | SELECT + | NULL c1, + | SUM(c2) c2 + | FROM table1 + |) a + """.stripMargin) + checkAnswer(query, Row(1, 1) :: Nil) + } + + test("explode nested Field") { + Seq(NestedArray1(NestedArray2(Seq(1, 2, 3)))).toDF.registerTempTable("nestedArray") + checkAnswer( + sql("SELECT ints FROM nestedArray LATERAL VIEW explode(a.b) a AS ints"), + Row(1) :: Row(2) :: Row(3) :: Nil) + } + test("SPARK-4512 Fix attribute reference resolution error when using SORT BY") { checkAnswer( sql("SELECT * FROM (SELECT key + key AS a FROM src SORT BY value) t ORDER BY t.a"), @@ -40,6 +162,73 @@ class SQLQuerySuite extends QueryTest { ) } + test("CTAS without serde") { + def checkRelation(tableName: String, isDataSourceParquet: Boolean): Unit = { + val relation = EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) + relation match { + case LogicalRelation(r: ParquetRelation2) => + if (!isDataSourceParquet) { + fail( + s"${classOf[MetastoreRelation].getCanonicalName} is expected, but found " + + s"${ParquetRelation2.getClass.getCanonicalName}.") + } + + case r: MetastoreRelation => + if (isDataSourceParquet) { + fail( + s"${ParquetRelation2.getClass.getCanonicalName} is expected, but found " + + s"${classOf[MetastoreRelation].getCanonicalName}.") + } + } + } + + val originalConf = getConf("spark.sql.hive.convertCTAS", "false") + + setConf("spark.sql.hive.convertCTAS", "true") + + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + var message = intercept[AnalysisException] { + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + }.getMessage + assert(message.contains("Table ctas1 already exists")) + checkRelation("ctas1", true) + sql("DROP TABLE ctas1") + + // Specifying database name for query can be converted to data source write path + // is not allowed right now. + message = intercept[AnalysisException] { + sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + }.getMessage + assert( + message.contains("Cannot specify database name in a CTAS statement"), + "When spark.sql.hive.convertCTAS is true, we should not allow " + + "database name specified.") + + sql("CREATE TABLE ctas1 stored as textfile AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", true) + sql("DROP TABLE ctas1") + + sql( + "CREATE TABLE ctas1 stored as sequencefile AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", true) + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", false) + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", false) + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", false) + sql("DROP TABLE ctas1") + + setConf("spark.sql.hive.convertCTAS", originalConf) + } + test("CTAS with serde") { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect() sql( @@ -102,6 +291,37 @@ class SQLQuerySuite extends QueryTest { "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe", "serde_p1=p1", "serde_p2=p2", "tbl_p1=p11", "tbl_p2=p22","MANAGED_TABLE" ) + + if (HiveShim.version =="0.13.1") { + val origUseParquetDataSource = conf.parquetUseDataSourceApi + try { + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + sql( + """CREATE TABLE ctas5 + | STORED AS parquet AS + | SELECT key, value + | FROM src + | ORDER BY key, value""".stripMargin).collect() + + checkExistence(sql("DESC EXTENDED ctas5"), true, + "name:key", "type:string", "name:value", "ctas5", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", + "MANAGED_TABLE" + ) + + val default = getConf("spark.sql.hive.convertMetastoreParquet", "true") + // use the Hive SerDe for parquet tables + sql("set spark.sql.hive.convertMetastoreParquet = false") + checkAnswer( + sql("SELECT key, value FROM ctas5 ORDER BY key, value"), + sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) + sql(s"set spark.sql.hive.convertMetastoreParquet = $default") + } finally { + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, origUseParquetDataSource.toString) + } + } } test("command substitution") { @@ -141,7 +361,8 @@ class SQLQuerySuite extends QueryTest { } test("double nested data") { - sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil).registerTempTable("nested") + sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil) + .toDF().registerTempTable("nested") checkAnswer( sql("SELECT f1.f2.f3 FROM nested"), Row(1)) @@ -151,7 +372,7 @@ class SQLQuerySuite extends QueryTest { sql("SELECT * FROM test_ctas_1234"), sql("SELECT * FROM nested").collect().toSeq) - intercept[org.apache.hadoop.hive.ql.metadata.InvalidTableException] { + intercept[AnalysisException] { sql("CREATE TABLE test_ctas_12345 AS SELECT * from notexists").collect() } } @@ -159,18 +380,18 @@ class SQLQuerySuite extends QueryTest { test("test CTAS") { checkAnswer(sql("CREATE TABLE test_ctas_123 AS SELECT key, value FROM src"), Seq.empty[Row]) checkAnswer( - sql("SELECT key, value FROM test_ctas_123 ORDER BY key"), + sql("SELECT key, value FROM test_ctas_123 ORDER BY key"), sql("SELECT key, value FROM src ORDER BY key").collect().toSeq) } test("SPARK-4825 save join to table") { - val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) + val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF() sql("CREATE TABLE test1 (key INT, value STRING)") testData.insertInto("test1") sql("CREATE TABLE test2 (key INT, value STRING)") testData.insertInto("test2") testData.insertInto("test2") - sql("SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key").saveAsTable("test") + sql("CREATE TABLE test AS SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key") checkAnswer( table("test"), sql("SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key").collect().toSeq) @@ -222,7 +443,7 @@ class SQLQuerySuite extends QueryTest { sql("SELECT distinct key FROM src order by key").collect().toSeq) } - test("SPARK-4963 SchemaRDD sample on mutable row return wrong result") { + test("SPARK-4963 DataFrame sample on mutable row return wrong result") { sql("SELECT * FROM src WHERE key % 2 = 0") .sample(withReplacement = false, fraction = 0.3) .registerTempTable("sampled") @@ -244,7 +465,7 @@ class SQLQuerySuite extends QueryTest { val rowRdd = sparkContext.parallelize(row :: Nil) - applySchema(rowRdd, schema).registerTempTable("testTable") + TestHive.createDataFrame(rowRdd, schema).registerTempTable("testTable") sql( """CREATE TABLE nullValuesInInnerComplexTypes @@ -267,4 +488,85 @@ class SQLQuerySuite extends QueryTest { sql("DROP TABLE nullValuesInInnerComplexTypes") dropTempTable("testTable") } + + test("SPARK-4296 Grouping field with Hive UDF as sub expression") { + val rdd = sparkContext.makeRDD( """{"a": "str", "b":"1", "c":"1970-01-01 00:00:00"}""" :: Nil) + jsonRDD(rdd).registerTempTable("data") + checkAnswer( + sql("SELECT concat(a, '-', b), year(c) FROM data GROUP BY concat(a, '-', b), year(c)"), + Row("str-1", 1970)) + + dropTempTable("data") + + jsonRDD(rdd).registerTempTable("data") + checkAnswer(sql("SELECT year(c) + 1 FROM data GROUP BY year(c) + 1"), Row(1971)) + + dropTempTable("data") + } + + test("resolve udtf with single alias") { + val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) + jsonRDD(rdd).registerTempTable("data") + val df = sql("SELECT explode(a) AS val FROM data") + val col = df("val") + } + + test("logical.Project should not be resolved if it contains aggregates or generators") { + // This test is used to test the fix of SPARK-5875. + // The original issue was that Project's resolved will be true when it contains + // AggregateExpressions or Generators. However, in this case, the Project + // is not in a valid state (cannot be executed). Because of this bug, the analysis rule of + // PreInsertionCasts will actually start to work before ImplicitGenerate and then + // generates an invalid query plan. + val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) + jsonRDD(rdd).registerTempTable("data") + val originalConf = getConf("spark.sql.hive.convertCTAS", "false") + setConf("spark.sql.hive.convertCTAS", "false") + + sql("CREATE TABLE explodeTest (key bigInt)") + table("explodeTest").queryExecution.analyzed match { + case metastoreRelation: MetastoreRelation => // OK + case _ => + fail("To correctly test the fix of SPARK-5875, explodeTest should be a MetastoreRelation") + } + + sql(s"INSERT OVERWRITE TABLE explodeTest SELECT explode(a) AS val FROM data") + checkAnswer( + sql("SELECT key from explodeTest"), + (1 to 5).flatMap(i => Row(i) :: Row(i + 1) :: Nil) + ) + + sql("DROP TABLE explodeTest") + dropTempTable("data") + setConf("spark.sql.hive.convertCTAS", originalConf) + } + + test("sanity test for SPARK-6618") { + (1 to 100).par.map { i => + val tableName = s"SPARK_6618_table_$i" + sql(s"CREATE TABLE $tableName (col1 string)") + catalog.lookupRelation(Seq(tableName)) + table(tableName) + tables() + sql(s"DROP TABLE $tableName") + } + } + + test("SPARK-5203 union with different decimal precision") { + Seq.empty[(Decimal, Decimal)] + .toDF("d1", "d2") + .select($"d1".cast(DecimalType(10, 15)).as("d")) + .registerTempTable("dn") + + sql("select d from dn union all select d * 2 from dn") + .queryExecution.analyzed + } + + test("test script transform") { + val data = (1 to 100000).map { i => (i, i, i) } + data.toDF("d1", "d2", "d3").registerTempTable("script_trans") + assert(100000 === + sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans") + .queryExecution.toRdd.count()) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala new file mode 100644 index 000000000000..d5dd0bf58e70 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -0,0 +1,911 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.{QueryTest, SQLConf, SaveMode} +import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} +import org.apache.spark.sql.hive.execution.HiveTableScan +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.json.JSONRelation +import org.apache.spark.sql.sources.{InsertIntoDataSource, LogicalRelation} +import org.apache.spark.sql.parquet.{ParquetRelation2, ParquetTableScan} +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +// The data where the partitioning key exists only in the directory structure. +case class ParquetData(intField: Int, stringField: String) +// The data that also includes the partitioning key +case class ParquetDataWithKey(p: Int, intField: Int, stringField: String) + +case class StructContainer(intStructField :Int, stringStructField: String) + +case class ParquetDataWithComplexTypes( + intField: Int, + stringField: String, + structField: StructContainer, + arrayField: Seq[Int]) + +case class ParquetDataWithKeyAndComplexTypes( + p: Int, + intField: Int, + stringField: String, + structField: StructContainer, + arrayField: Seq[Int]) + +/** + * A suite to test the automatic conversion of metastore tables with parquet data to use the + * built in parquet support. + */ +class ParquetMetastoreSuiteBase extends ParquetPartitioningTest { + override def beforeAll(): Unit = { + super.beforeAll() + + sql(s""" + create external table partitioned_parquet + ( + intField INT, + stringField STRING + ) + PARTITIONED BY (p int) + ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + STORED AS + INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + location '${partitionedTableDir.getCanonicalPath}' + """) + + sql(s""" + create external table partitioned_parquet_with_key + ( + intField INT, + stringField STRING + ) + PARTITIONED BY (p int) + ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + STORED AS + INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + location '${partitionedTableDirWithKey.getCanonicalPath}' + """) + + sql(s""" + create external table normal_parquet + ( + intField INT, + stringField STRING + ) + ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + STORED AS + INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + location '${new File(normalTableDir, "normal").getCanonicalPath}' + """) + + sql(s""" + CREATE EXTERNAL TABLE partitioned_parquet_with_complextypes + ( + intField INT, + stringField STRING, + structField STRUCT, + arrayField ARRAY + ) + PARTITIONED BY (p int) + ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + STORED AS + INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + LOCATION '${partitionedTableDirWithComplexTypes.getCanonicalPath}' + """) + + sql(s""" + CREATE EXTERNAL TABLE partitioned_parquet_with_key_and_complextypes + ( + intField INT, + stringField STRING, + structField STRUCT, + arrayField ARRAY + ) + PARTITIONED BY (p int) + ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + STORED AS + INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + LOCATION '${partitionedTableDirWithKeyAndComplexTypes.getCanonicalPath}' + """) + + (1 to 10).foreach { p => + sql(s"ALTER TABLE partitioned_parquet ADD PARTITION (p=$p)") + } + + (1 to 10).foreach { p => + sql(s"ALTER TABLE partitioned_parquet_with_key ADD PARTITION (p=$p)") + } + + (1 to 10).foreach { p => + sql(s"ALTER TABLE partitioned_parquet_with_key_and_complextypes ADD PARTITION (p=$p)") + } + + (1 to 10).foreach { p => + sql(s"ALTER TABLE partitioned_parquet_with_complextypes ADD PARTITION (p=$p)") + } + + val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) + jsonRDD(rdd1).registerTempTable("jt") + val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":[$i, null]}""")) + jsonRDD(rdd2).registerTempTable("jt_array") + + setConf("spark.sql.hive.convertMetastoreParquet", "true") + } + + override def afterAll(): Unit = { + sql("DROP TABLE partitioned_parquet") + sql("DROP TABLE partitioned_parquet_with_key") + sql("DROP TABLE partitioned_parquet_with_complextypes") + sql("DROP TABLE partitioned_parquet_with_key_and_complextypes") + sql("DROP TABLE normal_parquet") + sql("DROP TABLE IF EXISTS jt") + sql("DROP TABLE IF EXISTS jt_array") + setConf("spark.sql.hive.convertMetastoreParquet", "false") + } + + test(s"conversion is working") { + assert( + sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { + case _: HiveTableScan => true + }.isEmpty) + assert( + sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { + case _: ParquetTableScan => true + case _: PhysicalRDD => true + }.nonEmpty) + } +} + +class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { + val originalConf = conf.parquetUseDataSourceApi + + override def beforeAll(): Unit = { + super.beforeAll() + + sql( + """ + |create table test_parquet + |( + | intField INT, + | stringField STRING + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + } + + override def afterAll(): Unit = { + super.afterAll() + sql("DROP TABLE IF EXISTS test_parquet") + + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } + + test("scan an empty parquet table") { + checkAnswer(sql("SELECT count(*) FROM test_parquet"), Row(0)) + } + + test("scan an empty parquet table with upper case") { + checkAnswer(sql("SELECT count(INTFIELD) FROM TEST_parquet"), Row(0)) + } + + test("insert into an empty parquet table") { + sql( + """ + |create table test_insert_parquet + |( + | intField INT, + | stringField STRING + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + // Insert into am empty table. + sql("insert into table test_insert_parquet select a, b from jt where jt.a > 5") + checkAnswer( + sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField < 8"), + Row(6, "str6") :: Row(7, "str7") :: Nil + ) + // Insert overwrite. + sql("insert overwrite table test_insert_parquet select a, b from jt where jt.a < 5") + checkAnswer( + sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField > 2"), + Row(3, "str3") :: Row(4, "str4") :: Nil + ) + sql("DROP TABLE IF EXISTS test_insert_parquet") + + // Create it again. + sql( + """ + |create table test_insert_parquet + |( + | intField INT, + | stringField STRING + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + // Insert overwrite an empty table. + sql("insert overwrite table test_insert_parquet select a, b from jt where jt.a < 5") + checkAnswer( + sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField > 2"), + Row(3, "str3") :: Row(4, "str4") :: Nil + ) + // Insert into the table. + sql("insert into table test_insert_parquet select a, b from jt") + checkAnswer( + sql(s"SELECT intField, stringField FROM test_insert_parquet"), + (1 to 10).map(i => Row(i, s"str$i")) ++ (1 to 4).map(i => Row(i, s"str$i")) + ) + sql("DROP TABLE IF EXISTS test_insert_parquet") + } + + test("scan a parquet table created through a CTAS statement") { + sql( + """ + |create table test_parquet_ctas ROW FORMAT + |SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + |AS select * from jt + """.stripMargin) + + checkAnswer( + sql(s"SELECT a, b FROM test_parquet_ctas WHERE a = 1"), + Seq(Row(1, "str1")) + ) + + table("test_parquet_ctas").queryExecution.optimizedPlan match { + case LogicalRelation(p: ParquetRelation2) => // OK + case _ => + fail( + s"test_parquet_ctas should be converted to ${classOf[ParquetRelation2].getCanonicalName}") + } + + sql("DROP TABLE IF EXISTS test_parquet_ctas") + } + + test("MetastoreRelation in InsertIntoTable will be converted") { + sql( + """ + |create table test_insert_parquet + |( + | intField INT + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") + df.queryExecution.executedPlan match { + case ExecutedCommand( + InsertIntoDataSource( + LogicalRelation(r: ParquetRelation2), query, overwrite)) => // OK + case o => fail("test_insert_parquet should be converted to a " + + s"${classOf[ParquetRelation2].getCanonicalName} and " + + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." + + s"However, found a ${o.toString} ") + } + + checkAnswer( + sql("SELECT intField FROM test_insert_parquet WHERE test_insert_parquet.intField > 5"), + sql("SELECT a FROM jt WHERE jt.a > 5").collect() + ) + + sql("DROP TABLE IF EXISTS test_insert_parquet") + } + + test("MetastoreRelation in InsertIntoHiveTable will be converted") { + sql( + """ + |create table test_insert_parquet + |( + | int_array array + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") + df.queryExecution.executedPlan match { + case ExecutedCommand( + InsertIntoDataSource( + LogicalRelation(r: ParquetRelation2), query, overwrite)) => // OK + case o => fail("test_insert_parquet should be converted to a " + + s"${classOf[ParquetRelation2].getCanonicalName} and " + + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." + + s"However, found a ${o.toString} ") + } + + checkAnswer( + sql("SELECT int_array FROM test_insert_parquet"), + sql("SELECT a FROM jt_array").collect() + ) + + sql("DROP TABLE IF EXISTS test_insert_parquet") + } + + test("SPARK-6450 regression test") { + sql( + """CREATE TABLE IF NOT EXISTS ms_convert (key INT) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + // This shouldn't throw AnalysisException + val analyzed = sql( + """SELECT key FROM ms_convert + |UNION ALL + |SELECT key FROM ms_convert + """.stripMargin).queryExecution.analyzed + + assertResult(2) { + analyzed.collect { + case r @ LogicalRelation(_: ParquetRelation2) => r + }.size + } + + sql("DROP TABLE ms_convert") + } + + test("Caching converted data source Parquet Relations") { + def checkCached(tableIdentifer: catalog.QualifiedTableName): Unit = { + // Converted test_parquet should be cached. + catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) match { + case null => fail("Converted test_parquet should be cached in the cache.") + case logical @ LogicalRelation(parquetRelation: ParquetRelation2) => // OK + case other => + fail( + "The cached test_parquet should be a Parquet Relation. " + + s"However, $other is returned form the cache.") + } + } + + sql("DROP TABLE IF EXISTS test_insert_parquet") + sql("DROP TABLE IF EXISTS test_parquet_partitioned_cache_test") + + sql( + """ + |create table test_insert_parquet + |( + | intField INT, + | stringField STRING + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + var tableIdentifer = catalog.QualifiedTableName("default", "test_insert_parquet") + + // First, make sure the converted test_parquet is not cached. + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + // Table lookup will make the table cached. + table("test_insert_parquet") + checkCached(tableIdentifer) + // For insert into non-partitioned table, we will do the conversion, + // so the converted test_insert_parquet should be cached. + invalidateTable("test_insert_parquet") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + sql( + """ + |INSERT INTO TABLE test_insert_parquet + |select a, b from jt + """.stripMargin) + checkCached(tableIdentifer) + // Make sure we can read the data. + checkAnswer( + sql("select * from test_insert_parquet"), + sql("select a, b from jt").collect()) + // Invalidate the cache. + invalidateTable("test_insert_parquet") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + + // Create a partitioned table. + sql( + """ + |create table test_parquet_partitioned_cache_test + |( + | intField INT, + | stringField STRING + |) + |PARTITIONED BY (date string) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + tableIdentifer = catalog.QualifiedTableName("default", "test_parquet_partitioned_cache_test") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + sql( + """ + |INSERT INTO TABLE test_parquet_partitioned_cache_test + |PARTITION (date='2015-04-01') + |select a, b from jt + """.stripMargin) + // Right now, insert into a partitioned Parquet is not supported in data source Parquet. + // So, we expect it is not cached. + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + sql( + """ + |INSERT INTO TABLE test_parquet_partitioned_cache_test + |PARTITION (date='2015-04-02') + |select a, b from jt + """.stripMargin) + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + + // Make sure we can cache the partitioned table. + table("test_parquet_partitioned_cache_test") + checkCached(tableIdentifer) + // Make sure we can read the data. + checkAnswer( + sql("select STRINGField, date, intField from test_parquet_partitioned_cache_test"), + sql( + """ + |select b, '2015-04-01', a FROM jt + |UNION ALL + |select b, '2015-04-02', a FROM jt + """.stripMargin).collect()) + + invalidateTable("test_parquet_partitioned_cache_test") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + + sql("DROP TABLE test_insert_parquet") + sql("DROP TABLE test_parquet_partitioned_cache_test") + } +} + +class ParquetDataSourceOffMetastoreSuite extends ParquetMetastoreSuiteBase { + val originalConf = conf.parquetUseDataSourceApi + + override def beforeAll(): Unit = { + super.beforeAll() + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + } + + override def afterAll(): Unit = { + super.afterAll() + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } + + test("MetastoreRelation in InsertIntoTable will not be converted") { + sql( + """ + |create table test_insert_parquet + |( + | intField INT + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") + df.queryExecution.executedPlan match { + case insert: execution.InsertIntoHiveTable => // OK + case o => fail(s"The SparkPlan should be ${classOf[InsertIntoHiveTable].getCanonicalName}. " + + s"However, found ${o.toString}.") + } + + checkAnswer( + sql("SELECT intField FROM test_insert_parquet WHERE test_insert_parquet.intField > 5"), + sql("SELECT a FROM jt WHERE jt.a > 5").collect() + ) + + sql("DROP TABLE IF EXISTS test_insert_parquet") + } + + // TODO: enable it after the fix of SPARK-5950. + ignore("MetastoreRelation in InsertIntoHiveTable will not be converted") { + sql( + """ + |create table test_insert_parquet + |( + | int_array array + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") + df.queryExecution.executedPlan match { + case insert: execution.InsertIntoHiveTable => // OK + case o => fail(s"The SparkPlan should be ${classOf[InsertIntoHiveTable].getCanonicalName}. " + + s"However, found ${o.toString}.") + } + + checkAnswer( + sql("SELECT int_array FROM test_insert_parquet"), + sql("SELECT a FROM jt_array").collect() + ) + + sql("DROP TABLE IF EXISTS test_insert_parquet") + } +} + +/** + * A suite of tests for the Parquet support through the data sources API. + */ +class ParquetSourceSuiteBase extends ParquetPartitioningTest { + override def beforeAll(): Unit = { + super.beforeAll() + + sql( s""" + create temporary table partitioned_parquet + USING org.apache.spark.sql.parquet + OPTIONS ( + path '${partitionedTableDir.getCanonicalPath}' + ) + """) + + sql( s""" + create temporary table partitioned_parquet_with_key + USING org.apache.spark.sql.parquet + OPTIONS ( + path '${partitionedTableDirWithKey.getCanonicalPath}' + ) + """) + + sql( s""" + create temporary table normal_parquet + USING org.apache.spark.sql.parquet + OPTIONS ( + path '${new File(partitionedTableDir, "p=1").getCanonicalPath}' + ) + """) + + sql( s""" + CREATE TEMPORARY TABLE partitioned_parquet_with_key_and_complextypes + USING org.apache.spark.sql.parquet + OPTIONS ( + path '${partitionedTableDirWithKeyAndComplexTypes.getCanonicalPath}' + ) + """) + + sql( s""" + CREATE TEMPORARY TABLE partitioned_parquet_with_complextypes + USING org.apache.spark.sql.parquet + OPTIONS ( + path '${partitionedTableDirWithComplexTypes.getCanonicalPath}' + ) + """) + } + + test("SPARK-6016 make sure to use the latest footers") { + sql("drop table if exists spark_6016_fix") + + // Create a DataFrame with two partitions. So, the created table will have two parquet files. + val df1 = jsonRDD(sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i}"""), 2)) + df1.saveAsTable("spark_6016_fix", "parquet", SaveMode.Overwrite) + checkAnswer( + sql("select * from spark_6016_fix"), + (1 to 10).map(i => Row(i)) + ) + + // Create a DataFrame with four partitions. So, the created table will have four parquet files. + val df2 = jsonRDD(sparkContext.parallelize((1 to 10).map(i => s"""{"b":$i}"""), 4)) + df2.saveAsTable("spark_6016_fix", "parquet", SaveMode.Overwrite) + // For the bug of SPARK-6016, we are caching two outdated footers for df1. Then, + // since the new table has four parquet files, we are trying to read new footers from two files + // and then merge metadata in footers of these four (two outdated ones and two latest one), + // which will cause an error. + checkAnswer( + sql("select * from spark_6016_fix"), + (1 to 10).map(i => Row(i)) + ) + + sql("drop table spark_6016_fix") + } +} + +class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { + val originalConf = conf.parquetUseDataSourceApi + + override def beforeAll(): Unit = { + super.beforeAll() + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + } + + override def afterAll(): Unit = { + super.afterAll() + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } + + test("values in arrays and maps stored in parquet are always nullable") { + val df = createDataFrame(Tuple2(Map(2 -> 3), Seq(4, 5, 6)) :: Nil).toDF("m", "a") + val mapType1 = MapType(IntegerType, IntegerType, valueContainsNull = false) + val arrayType1 = ArrayType(IntegerType, containsNull = false) + val expectedSchema1 = + StructType( + StructField("m", mapType1, nullable = true) :: + StructField("a", arrayType1, nullable = true) :: Nil) + assert(df.schema === expectedSchema1) + + df.saveAsTable("alwaysNullable", "parquet") + + val mapType2 = MapType(IntegerType, IntegerType, valueContainsNull = true) + val arrayType2 = ArrayType(IntegerType, containsNull = true) + val expectedSchema2 = + StructType( + StructField("m", mapType2, nullable = true) :: + StructField("a", arrayType2, nullable = true) :: Nil) + + assert(table("alwaysNullable").schema === expectedSchema2) + + checkAnswer( + sql("SELECT m, a FROM alwaysNullable"), + Row(Map(2 -> 3), Seq(4, 5, 6))) + + sql("DROP TABLE alwaysNullable") + } + + test("Aggregation attribute names can't contain special chars \" ,;{}()\\n\\t=\"") { + val tempDir = Utils.createTempDir() + val filePath = new File(tempDir, "testParquet").getCanonicalPath + val filePath2 = new File(tempDir, "testParquet2").getCanonicalPath + + val df = Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str") + val df2 = df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").max("y.int") + intercept[RuntimeException](df2.saveAsParquetFile(filePath)) + + val df3 = df2.toDF("str", "max_int") + df3.saveAsParquetFile(filePath2) + val df4 = parquetFile(filePath2) + checkAnswer(df4, Row("1", 1) :: Row("2", 2) :: Row("3", 3) :: Nil) + assert(df4.columns === Array("str", "max_int")) + } +} + +class ParquetDataSourceOffSourceSuite extends ParquetSourceSuiteBase { + val originalConf = conf.parquetUseDataSourceApi + + override def beforeAll(): Unit = { + super.beforeAll() + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + } + + override def afterAll(): Unit = { + super.afterAll() + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } +} + +/** + * A collection of tests for parquet data with various forms of partitioning. + */ +abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll { + var partitionedTableDir: File = null + var normalTableDir: File = null + var partitionedTableDirWithKey: File = null + var partitionedTableDirWithComplexTypes: File = null + var partitionedTableDirWithKeyAndComplexTypes: File = null + + override def beforeAll(): Unit = { + partitionedTableDir = Utils.createTempDir() + normalTableDir = Utils.createTempDir() + + (1 to 10).foreach { p => + val partDir = new File(partitionedTableDir, s"p=$p") + sparkContext.makeRDD(1 to 10) + .map(i => ParquetData(i, s"part-$p")) + .toDF() + .saveAsParquetFile(partDir.getCanonicalPath) + } + + sparkContext + .makeRDD(1 to 10) + .map(i => ParquetData(i, s"part-1")) + .toDF() + .saveAsParquetFile(new File(normalTableDir, "normal").getCanonicalPath) + + partitionedTableDirWithKey = Utils.createTempDir() + + (1 to 10).foreach { p => + val partDir = new File(partitionedTableDirWithKey, s"p=$p") + sparkContext.makeRDD(1 to 10) + .map(i => ParquetDataWithKey(p, i, s"part-$p")) + .toDF() + .saveAsParquetFile(partDir.getCanonicalPath) + } + + partitionedTableDirWithKeyAndComplexTypes = Utils.createTempDir() + + (1 to 10).foreach { p => + val partDir = new File(partitionedTableDirWithKeyAndComplexTypes, s"p=$p") + sparkContext.makeRDD(1 to 10).map { i => + ParquetDataWithKeyAndComplexTypes( + p, i, s"part-$p", StructContainer(i, f"${i}_string"), 1 to i) + }.toDF().saveAsParquetFile(partDir.getCanonicalPath) + } + + partitionedTableDirWithComplexTypes = Utils.createTempDir() + + (1 to 10).foreach { p => + val partDir = new File(partitionedTableDirWithComplexTypes, s"p=$p") + sparkContext.makeRDD(1 to 10).map { i => + ParquetDataWithComplexTypes(i, s"part-$p", StructContainer(i, f"${i}_string"), 1 to i) + }.toDF().saveAsParquetFile(partDir.getCanonicalPath) + } + } + + override protected def afterAll(): Unit = { + partitionedTableDir.delete() + normalTableDir.delete() + partitionedTableDirWithKey.delete() + partitionedTableDirWithComplexTypes.delete() + partitionedTableDirWithKeyAndComplexTypes.delete() + } + + Seq( + "partitioned_parquet", + "partitioned_parquet_with_key", + "partitioned_parquet_with_complextypes", + "partitioned_parquet_with_key_and_complextypes").foreach { table => + + test(s"ordering of the partitioning columns $table") { + checkAnswer( + sql(s"SELECT p, stringField FROM $table WHERE p = 1"), + Seq.fill(10)(Row(1, "part-1")) + ) + + checkAnswer( + sql(s"SELECT stringField, p FROM $table WHERE p = 1"), + Seq.fill(10)(Row("part-1", 1)) + ) + } + + test(s"project the partitioning column $table") { + checkAnswer( + sql(s"SELECT p, count(*) FROM $table group by p"), + Row(1, 10) :: + Row(2, 10) :: + Row(3, 10) :: + Row(4, 10) :: + Row(5, 10) :: + Row(6, 10) :: + Row(7, 10) :: + Row(8, 10) :: + Row(9, 10) :: + Row(10, 10) :: Nil + ) + } + + test(s"project partitioning and non-partitioning columns $table") { + checkAnswer( + sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"), + Row("part-1", 1, 10) :: + Row("part-2", 2, 10) :: + Row("part-3", 3, 10) :: + Row("part-4", 4, 10) :: + Row("part-5", 5, 10) :: + Row("part-6", 6, 10) :: + Row("part-7", 7, 10) :: + Row("part-8", 8, 10) :: + Row("part-9", 9, 10) :: + Row("part-10", 10, 10) :: Nil + ) + } + + test(s"simple count $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table"), + Row(100)) + } + + test(s"pruned count $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"), + Row(10)) + } + + test(s"non-existent partition $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"), + Row(0)) + } + + test(s"multi-partition pruned count $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"), + Row(30)) + } + + test(s"non-partition predicates $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"), + Row(30)) + } + + test(s"sum $table") { + checkAnswer( + sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"), + Row(1 + 2 + 3)) + } + + test(s"hive udfs $table") { + checkAnswer( + sql(s"SELECT concat(stringField, stringField) FROM $table"), + sql(s"SELECT stringField FROM $table").map { + case Row(s: String) => Row(s + s) + }.collect().toSeq) + } + } + + Seq( + "partitioned_parquet_with_key_and_complextypes", + "partitioned_parquet_with_complextypes").foreach { table => + + test(s"SPARK-5775 read struct from $table") { + checkAnswer( + sql( + s""" + |SELECT p, structField.intStructField, structField.stringStructField + |FROM $table WHERE p = 1 + """.stripMargin), + (1 to 10).map(i => Row(1, i, f"${i}_string"))) + } + + // Re-enable this after SPARK-5508 is fixed + ignore(s"SPARK-5775 read array from $table") { + checkAnswer( + sql(s"SELECT arrayField, p FROM $table WHERE p = 1"), + (1 to 10).map(i => Row(1 to i, 1))) + } + } + + + test("non-part select(*)") { + checkAnswer( + sql("SELECT COUNT(*) FROM normal_parquet"), + Row(10)) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala deleted file mode 100644 index 581f66639949..000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.hive.test.TestHive - -case class Cases(lower: String, UPPER: String) - -class HiveParquetSuite extends QueryTest with ParquetTest { - val sqlContext = TestHive - - import sqlContext._ - - test("Case insensitive attribute names") { - withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { - val expected = (1 to 4).map(i => Row(i.toString)) - checkAnswer(sql("SELECT upper FROM cases"), expected) - checkAnswer(sql("SELECT LOWER FROM cases"), expected) - } - } - - test("SELECT on Parquet table") { - val data = (1 to 4).map(i => (i, s"val_$i")) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT * FROM t"), data.map(Row.fromTuple)) - } - } - - test("Simple column projection + filter on Parquet table") { - withParquetTable((1 to 4).map(i => (i % 2 == 0, i, s"val_$i")), "t") { - checkAnswer( - sql("SELECT `_1`, `_3` FROM t WHERE `_1` = true"), - Seq(Row(true, "val_2"), Row(true, "val_4"))) - } - } - - test("Converting Hive to Parquet Table via saveAsParquetFile") { - withTempPath { dir => - sql("SELECT * FROM src").saveAsParquetFile(dir.getCanonicalPath) - parquetFile(dir.getCanonicalPath).registerTempTable("p") - withTempTable("p") { - checkAnswer( - sql("SELECT * FROM src ORDER BY key"), - sql("SELECT * from p ORDER BY key").collect().toSeq) - } - } - } - - - test("INSERT OVERWRITE TABLE Parquet table") { - withParquetTable((1 to 4).map(i => (i, s"val_$i")), "t") { - withTempPath { file => - sql("SELECT * FROM t LIMIT 1").saveAsParquetFile(file.getCanonicalPath) - parquetFile(file.getCanonicalPath).registerTempTable("p") - withTempTable("p") { - // let's do three overwrites for good measure - sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - checkAnswer(sql("SELECT * FROM p"), sql("SELECT * FROM t").collect().toSeq) - } - } - } - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala deleted file mode 100644 index 79fd99d9f89f..000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala +++ /dev/null @@ -1,271 +0,0 @@ - -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import java.io.File - -import org.apache.spark.sql.catalyst.expressions.Row -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.hive.execution.HiveTableScan -import org.apache.spark.sql.hive.test.TestHive._ - -// The data where the partitioning key exists only in the directory structure. -case class ParquetData(intField: Int, stringField: String) -// The data that also includes the partitioning key -case class ParquetDataWithKey(p: Int, intField: Int, stringField: String) - - -/** - * A suite to test the automatic conversion of metastore tables with parquet data to use the - * built in parquet support. - */ -class ParquetMetastoreSuite extends ParquetPartitioningTest { - override def beforeAll(): Unit = { - super.beforeAll() - - sql(s""" - create external table partitioned_parquet - ( - intField INT, - stringField STRING - ) - PARTITIONED BY (p int) - ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - STORED AS - INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - location '${partitionedTableDir.getCanonicalPath}' - """) - - sql(s""" - create external table partitioned_parquet_with_key - ( - intField INT, - stringField STRING - ) - PARTITIONED BY (p int) - ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - STORED AS - INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - location '${partitionedTableDirWithKey.getCanonicalPath}' - """) - - sql(s""" - create external table normal_parquet - ( - intField INT, - stringField STRING - ) - ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - STORED AS - INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - location '${new File(partitionedTableDir, "p=1").getCanonicalPath}' - """) - - (1 to 10).foreach { p => - sql(s"ALTER TABLE partitioned_parquet ADD PARTITION (p=$p)") - } - - (1 to 10).foreach { p => - sql(s"ALTER TABLE partitioned_parquet_with_key ADD PARTITION (p=$p)") - } - - setConf("spark.sql.hive.convertMetastoreParquet", "true") - } - - override def afterAll(): Unit = { - setConf("spark.sql.hive.convertMetastoreParquet", "false") - } - - test("conversion is working") { - assert( - sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { - case _: HiveTableScan => true - }.isEmpty) - assert( - sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { - case _: ParquetTableScan => true - }.nonEmpty) - } -} - -/** - * A suite of tests for the Parquet support through the data sources API. - */ -class ParquetSourceSuite extends ParquetPartitioningTest { - override def beforeAll(): Unit = { - super.beforeAll() - - sql( s""" - create temporary table partitioned_parquet - USING org.apache.spark.sql.parquet - OPTIONS ( - path '${partitionedTableDir.getCanonicalPath}' - ) - """) - - sql( s""" - create temporary table partitioned_parquet_with_key - USING org.apache.spark.sql.parquet - OPTIONS ( - path '${partitionedTableDirWithKey.getCanonicalPath}' - ) - """) - - sql( s""" - create temporary table normal_parquet - USING org.apache.spark.sql.parquet - OPTIONS ( - path '${new File(partitionedTableDir, "p=1").getCanonicalPath}' - ) - """) - } -} - -/** - * A collection of tests for parquet data with various forms of partitioning. - */ -abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll { - var partitionedTableDir: File = null - var partitionedTableDirWithKey: File = null - - override def beforeAll(): Unit = { - partitionedTableDir = File.createTempFile("parquettests", "sparksql") - partitionedTableDir.delete() - partitionedTableDir.mkdir() - - (1 to 10).foreach { p => - val partDir = new File(partitionedTableDir, s"p=$p") - sparkContext.makeRDD(1 to 10) - .map(i => ParquetData(i, s"part-$p")) - .saveAsParquetFile(partDir.getCanonicalPath) - } - - partitionedTableDirWithKey = File.createTempFile("parquettests", "sparksql") - partitionedTableDirWithKey.delete() - partitionedTableDirWithKey.mkdir() - - (1 to 10).foreach { p => - val partDir = new File(partitionedTableDirWithKey, s"p=$p") - sparkContext.makeRDD(1 to 10) - .map(i => ParquetDataWithKey(p, i, s"part-$p")) - .saveAsParquetFile(partDir.getCanonicalPath) - } - } - - Seq("partitioned_parquet", "partitioned_parquet_with_key").foreach { table => - test(s"ordering of the partitioning columns $table") { - checkAnswer( - sql(s"SELECT p, stringField FROM $table WHERE p = 1"), - Seq.fill(10)(Row(1, "part-1")) - ) - - checkAnswer( - sql(s"SELECT stringField, p FROM $table WHERE p = 1"), - Seq.fill(10)(Row("part-1", 1)) - ) - } - - test(s"project the partitioning column $table") { - checkAnswer( - sql(s"SELECT p, count(*) FROM $table group by p"), - Row(1, 10) :: - Row(2, 10) :: - Row(3, 10) :: - Row(4, 10) :: - Row(5, 10) :: - Row(6, 10) :: - Row(7, 10) :: - Row(8, 10) :: - Row(9, 10) :: - Row(10, 10) :: Nil - ) - } - - test(s"project partitioning and non-partitioning columns $table") { - checkAnswer( - sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"), - Row("part-1", 1, 10) :: - Row("part-2", 2, 10) :: - Row("part-3", 3, 10) :: - Row("part-4", 4, 10) :: - Row("part-5", 5, 10) :: - Row("part-6", 6, 10) :: - Row("part-7", 7, 10) :: - Row("part-8", 8, 10) :: - Row("part-9", 9, 10) :: - Row("part-10", 10, 10) :: Nil - ) - } - - test(s"simple count $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table"), - Row(100)) - } - - test(s"pruned count $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"), - Row(10)) - } - - test(s"non-existant partition $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"), - Row(0)) - } - - test(s"multi-partition pruned count $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"), - Row(30)) - } - - test(s"non-partition predicates $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"), - Row(30)) - } - - test(s"sum $table") { - checkAnswer( - sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"), - Row(1 + 2 + 3)) - } - - test(s"hive udfs $table") { - checkAnswer( - sql(s"SELECT concat(stringField, stringField) FROM $table"), - sql(s"SELECT stringField FROM $table").map { - case Row(s: String) => Row(s + s) - }.collect().toSeq) - } - } - - test("non-part select(*)") { - checkAnswer( - sql("SELECT COUNT(*) FROM normal_parquet"), - Row(10)) - } -} diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala index c0b7741bc3e5..33e96eaabfbf 100644 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala @@ -34,16 +34,18 @@ import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.stats.StatsSetupConst import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer, io => hiveIo} -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, PrimitiveObjectInspector} +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, ObjectInspector, PrimitiveObjectInspector} import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory} -import org.apache.hadoop.io.NullWritable +import org.apache.hadoop.io.{NullWritable, Writable} import org.apache.hadoop.mapred.InputFormat -import org.apache.spark.sql.types.{Decimal, DecimalType} +import org.apache.spark.sql.types.{UTF8String, Decimal, DecimalType} + +private[hive] case class HiveFunctionWrapper(functionClassName: String) + extends java.io.Serializable { -case class HiveFunctionWrapper(functionClassName: String) extends java.io.Serializable { // for Serialization def this() = this(null) @@ -133,7 +135,7 @@ private[hive] object HiveShim { PrimitiveCategory.VOID, null) def getStringWritable(value: Any): hadoopIo.Text = - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String]) + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) def getIntWritable(value: Any): hadoopIo.IntWritable = if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) @@ -160,7 +162,7 @@ private[hive] object HiveShim { if (value == null) null else new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) def getDateWritable(value: Any): hiveIo.DateWritable = - if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date]) + if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[Int]) def getTimestampWritable(value: Any): hiveIo.TimestampWritable = if (value == null) { @@ -208,7 +210,7 @@ private[hive] object HiveShim { def getDataLocationPath(p: Partition) = p.getPartitionPath - def getAllPartitionsOf(client: Hive, tbl: Table) = client.getAllPartitionsForPruner(tbl) + def getAllPartitionsOf(client: Hive, tbl: Table) = client.getAllPartitionsForPruner(tbl) def compatibilityBlackList = Seq( "decimal_.*", @@ -241,8 +243,23 @@ private[hive] object HiveShim { Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) } } + + def getConvertedOI( + inputOI: ObjectInspector, + outputOI: ObjectInspector): ObjectInspector = { + ObjectInspectorConverters.getConvertedOI(inputOI, outputOI, true) + } + + def prepareWritable(w: Writable): Writable = { + w + } + + def setTblNullFormat(crtTbl: CreateTableDesc, tbl: Table) = {} } -class ShimFileSinkDesc(var dir: String, var tableInfo: TableDesc, var compressed: Boolean) +private[hive] class ShimFileSinkDesc( + var dir: String, + var tableInfo: TableDesc, + var compressed: Boolean) extends FileSinkDesc(dir, tableInfo, compressed) { } diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index c04cda7bf153..d331c210e893 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -17,33 +17,35 @@ package org.apache.spark.sql.hive -import java.util.{ArrayList => JArrayList} -import java.util.Properties +import java.rmi.server.UID +import java.util.{Properties, ArrayList => JArrayList} import scala.collection.JavaConversions._ import scala.language.implicitConversions +import com.esotericsoftware.kryo.Kryo import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.hive.common.StatsSetupConst -import org.apache.hadoop.hive.common.`type`.{HiveDecimal} +import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Context -import org.apache.hadoop.hive.ql.metadata.{Table, Hive, Partition} +import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} +import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory -import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, DecimalTypeInfo, TypeInfoFactory} +import org.apache.hadoop.hive.serde.serdeConstants +import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} -import org.apache.hadoop.hive.serde2.objectinspector.{PrimitiveObjectInspector, ObjectInspector} -import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils} -import org.apache.hadoop.hive.serde2.{io => hiveIo} +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorConverters, PrimitiveObjectInspector} +import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfo, TypeInfoFactory} +import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer, io => hiveIo} +import org.apache.hadoop.io.{NullWritable, Writable} +import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.Logging -import org.apache.spark.sql.types.{Decimal, DecimalType} - +import org.apache.spark.sql.types.{Decimal, DecimalType, UTF8String} /** * This class provides the UDF creation and also the UDF instance serialization and @@ -53,22 +55,20 @@ import org.apache.spark.sql.types.{Decimal, DecimalType} * * @param functionClassName UDF class name */ -case class HiveFunctionWrapper(var functionClassName: String) extends java.io.Externalizable { +private[hive] case class HiveFunctionWrapper(var functionClassName: String) + extends java.io.Externalizable { + // for Serialization def this() = this(null) - import java.io.{OutputStream, InputStream} - import com.esotericsoftware.kryo.Kryo import org.apache.spark.util.Utils._ - import org.apache.hadoop.hive.ql.exec.Utilities - import org.apache.hadoop.hive.ql.exec.UDF @transient private val methodDeSerialize = { val method = classOf[Utilities].getDeclaredMethod( "deserializeObjectByKryo", classOf[Kryo], - classOf[InputStream], + classOf[java.io.InputStream], classOf[Class[_]]) method.setAccessible(true) @@ -81,7 +81,7 @@ case class HiveFunctionWrapper(var functionClassName: String) extends java.io.Ex "serializeObjectByKryo", classOf[Kryo], classOf[Object], - classOf[OutputStream]) + classOf[java.io.OutputStream]) method.setAccessible(true) method @@ -218,7 +218,7 @@ private[hive] object HiveShim { TypeInfoFactory.voidTypeInfo, null) def getStringWritable(value: Any): hadoopIo.Text = - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String]) + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) def getIntWritable(value: Any): hadoopIo.IntWritable = if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) @@ -261,7 +261,7 @@ private[hive] object HiveShim { } def getDateWritable(value: Any): hiveIo.DateWritable = - if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date]) + if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[Int]) def getTimestampWritable(value: Any): hiveIo.TimestampWritable = if (value == null) { @@ -395,13 +395,39 @@ private[hive] object HiveShim { Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) } } + + def getConvertedOI(inputOI: ObjectInspector, outputOI: ObjectInspector): ObjectInspector = { + ObjectInspectorConverters.getConvertedOI(inputOI, outputOI) + } + + /* + * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that + * is needed to initialize before serialization. + */ + def prepareWritable(w: Writable): Writable = { + w match { + case w: AvroGenericRecordWritable => + w.setRecordReaderID(new UID()) + case _ => + } + w + } + + def setTblNullFormat(crtTbl: CreateTableDesc, tbl: Table) = { + if (crtTbl != null && crtTbl.getNullFormat() != null) { + tbl.setSerdeParam(serdeConstants.SERIALIZATION_NULL_FORMAT, crtTbl.getNullFormat()) + } + } } /* - * Bug introdiced in hive-0.13. FileSinkDesc is serilizable, but its member path is not. + * Bug introduced in hive-0.13. FileSinkDesc is serilizable, but its member path is not. * Fix it through wrapper. */ -class ShimFileSinkDesc(var dir: String, var tableInfo: TableDesc, var compressed: Boolean) +private[hive] class ShimFileSinkDesc( + var dir: String, + var tableInfo: TableDesc, + var compressed: Boolean) extends Serializable with Logging { var compressCodec: String = _ var compressType: String = _ diff --git a/streaming/pom.xml b/streaming/pom.xml index 22b0d714b57f..5ca55a4f680b 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -20,8 +20,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT ../pom.xml @@ -40,10 +40,34 @@ spark-core_${scala.binary.version} ${project.version} + + + + com.google.guava + guava + org.eclipse.jetty jetty-server + + org.eclipse.jetty + jetty-plus + + + org.eclipse.jetty + jetty-util + + + org.eclipse.jetty + jetty-http + + + org.eclipse.jetty + jetty-servlet + + + org.scala-lang scala-library @@ -58,6 +82,11 @@ junit test + + org.seleniumhq.selenium + selenium-java + test + com.novocode junit-interface @@ -68,32 +97,12 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes - org.apache.maven.plugins - maven-jar-plugin - - - - test-jar - - - - test-jar-on-test-compile - test-compile - - test-jar - - - + maven-shade-plugin + + true + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index b780282bdac3..7bfae253c3a0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.{SparkException, SparkConf, Logging} import org.apache.spark.io.CompressionCodec -import org.apache.spark.util.MetadataCleaner +import org.apache.spark.util.{MetadataCleaner, Utils} import org.apache.spark.streaming.scheduler.JobGenerator @@ -43,10 +43,13 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) val delaySeconds = MetadataCleaner.getDelaySeconds(ssc.conf) val sparkConfPairs = ssc.conf.getAll - def sparkConf = { - new SparkConf(false).setAll(sparkConfPairs) + def createSparkConf(): SparkConf = { + val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs) .remove("spark.driver.host") .remove("spark.driver.port") + val newMasterOption = new SparkConf(loadDefaults = true).getOption("spark.master") + newMasterOption.foreach { newMaster => newSparkConf.setMaster(newMaster) } + newSparkConf } def validate() { @@ -64,17 +67,18 @@ object Checkpoint extends Logging { val REGEX = (PREFIX + """([\d]+)([\w\.]*)""").r /** Get the checkpoint file for the given checkpoint time */ - def checkpointFile(checkpointDir: String, checkpointTime: Time) = { + def checkpointFile(checkpointDir: String, checkpointTime: Time): Path = { new Path(checkpointDir, PREFIX + checkpointTime.milliseconds) } /** Get the checkpoint backup file for the given checkpoint time */ - def checkpointBackupFile(checkpointDir: String, checkpointTime: Time) = { + def checkpointBackupFile(checkpointDir: String, checkpointTime: Time): Path = { new Path(checkpointDir, PREFIX + checkpointTime.milliseconds + ".bk") } /** Get checkpoint files present in the give directory, ordered by oldest-first */ - def getCheckpointFiles(checkpointDir: String, fs: FileSystem): Seq[Path] = { + def getCheckpointFiles(checkpointDir: String, fsOption: Option[FileSystem] = None): Seq[Path] = { + def sortFunc(path1: Path, path2: Path): Boolean = { val (time1, bk1) = path1.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } val (time2, bk2) = path2.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } @@ -82,6 +86,7 @@ object Checkpoint extends Logging { } val path = new Path(checkpointDir) + val fs = fsOption.getOrElse(path.getFileSystem(new Configuration())) if (fs.exists(path)) { val statuses = fs.listStatus(path) if (statuses != null) { @@ -116,7 +121,10 @@ class CheckpointWriter( private var stopped = false private var fs_ : FileSystem = _ - class CheckpointWriteHandler(checkpointTime: Time, bytes: Array[Byte]) extends Runnable { + class CheckpointWriteHandler( + checkpointTime: Time, + bytes: Array[Byte], + clearCheckpointDataLater: Boolean) extends Runnable { def run() { var attempts = 0 val startTime = System.currentTimeMillis() @@ -133,8 +141,11 @@ class CheckpointWriter( // Write checkpoint to temp file fs.delete(tempFile, true) // just in case it exists val fos = fs.create(tempFile) - fos.write(bytes) - fos.close() + Utils.tryWithSafeFinally { + fos.write(bytes) + } { + fos.close() + } // If the checkpoint file exists, back it up // If the backup exists as well, just delete it, otherwise rename will fail @@ -151,8 +162,8 @@ class CheckpointWriter( } // Delete old checkpoint files - val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs) - if (allCheckpointFiles.size > 4) { + val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)) + if (allCheckpointFiles.size > 10) { allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach(file => { logInfo("Deleting " + file) fs.delete(file, true) @@ -163,7 +174,7 @@ class CheckpointWriter( val finishTime = System.currentTimeMillis() logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + checkpointFile + "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " ms") - jobGenerator.onCheckpointCompletion(checkpointTime) + jobGenerator.onCheckpointCompletion(checkpointTime, clearCheckpointDataLater) return } catch { case ioe: IOException => @@ -177,15 +188,18 @@ class CheckpointWriter( } } - def write(checkpoint: Checkpoint) { + def write(checkpoint: Checkpoint, clearCheckpointDataLater: Boolean) { val bos = new ByteArrayOutputStream() val zos = compressionCodec.compressedOutputStream(bos) val oos = new ObjectOutputStream(zos) - oos.writeObject(checkpoint) - oos.close() - bos.close() + Utils.tryWithSafeFinally { + oos.writeObject(checkpoint) + } { + oos.close() + } try { - executor.execute(new CheckpointWriteHandler(checkpoint.checkpointTime, bos.toByteArray)) + executor.execute(new CheckpointWriteHandler( + checkpoint.checkpointTime, bos.toByteArray, clearCheckpointDataLater)) logDebug("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue") } catch { case rej: RejectedExecutionException => @@ -222,13 +236,24 @@ class CheckpointWriter( private[streaming] object CheckpointReader extends Logging { - def read(checkpointDir: String, conf: SparkConf, hadoopConf: Configuration): Option[Checkpoint] = - { + /** + * Read checkpoint files present in the given checkpoint directory. If there are no checkpoint + * files, then return None, else try to return the latest valid checkpoint object. If no + * checkpoint files could be read correctly, then return None (if ignoreReadError = true), + * or throw exception (if ignoreReadError = false). + */ + def read( + checkpointDir: String, + conf: SparkConf, + hadoopConf: Configuration, + ignoreReadError: Boolean = false): Option[Checkpoint] = { val checkpointPath = new Path(checkpointDir) - def fs = checkpointPath.getFileSystem(hadoopConf) + + // TODO(rxin): Why is this a def?! + def fs: FileSystem = checkpointPath.getFileSystem(hadoopConf) // Try to find the checkpoint files - val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs).reverse + val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)).reverse if (checkpointFiles.isEmpty) { return None } @@ -239,18 +264,24 @@ object CheckpointReader extends Logging { checkpointFiles.foreach(file => { logInfo("Attempting to load checkpoint from file " + file) try { - val fis = fs.open(file) - // ObjectInputStream uses the last defined user-defined class loader in the stack - // to find classes, which maybe the wrong class loader. Hence, a inherited version - // of ObjectInputStream is used to explicitly use the current thread's default class - // loader to find and load classes. This is a well know Java issue and has popped up - // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) - val zis = compressionCodec.compressedInputStream(fis) - val ois = new ObjectInputStreamWithLoader(zis, - Thread.currentThread().getContextClassLoader) - val cp = ois.readObject.asInstanceOf[Checkpoint] - ois.close() - fs.close() + var ois: ObjectInputStreamWithLoader = null + var cp: Checkpoint = null + Utils.tryWithSafeFinally { + val fis = fs.open(file) + // ObjectInputStream uses the last defined user-defined class loader in the stack + // to find classes, which maybe the wrong class loader. Hence, a inherited version + // of ObjectInputStream is used to explicitly use the current thread's default class + // loader to find and load classes. This is a well know Java issue and has popped up + // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) + val zis = compressionCodec.compressedInputStream(fis) + ois = new ObjectInputStreamWithLoader(zis, + Thread.currentThread().getContextClassLoader) + cp = ois.readObject.asInstanceOf[Checkpoint] + } { + if (ois != null) { + ois.close() + } + } cp.validate() logInfo("Checkpoint successfully loaded from file " + file) logInfo("Checkpoint was generated at time " + cp.checkpointTime) @@ -262,7 +293,10 @@ object CheckpointReader extends Logging { }) // If none of checkpoint files could be read, then throw exception - throw new SparkException("Failed to read checkpoint from directory " + checkpointPath) + if (!ignoreReadError) { + throw new SparkException(s"Failed to read checkpoint from directory $checkpointPath") + } + None } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index 0e285d6088ec..175140481e5a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -100,11 +100,11 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { } } - def getInputStreams() = this.synchronized { inputStreams.toArray } + def getInputStreams(): Array[InputDStream[_]] = this.synchronized { inputStreams.toArray } - def getOutputStreams() = this.synchronized { outputStreams.toArray } + def getOutputStreams(): Array[DStream[_]] = this.synchronized { outputStreams.toArray } - def getReceiverInputStreams() = this.synchronized { + def getReceiverInputStreams(): Array[ReceiverInputDStream[_]] = this.synchronized { inputStreams.filter(_.isInstanceOf[ReceiverInputDStream[_]]) .map(_.asInstanceOf[ReceiverInputDStream[_]]) .toArray diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala b/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala index a0d8fb5ab93e..3249bb348981 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala @@ -55,7 +55,6 @@ case class Duration (private val millis: Long) { def div(that: Duration): Double = this / that - def isMultipleOf(that: Duration): Boolean = (this.millis % that.millis == 0) @@ -71,7 +70,7 @@ case class Duration (private val millis: Long) { def milliseconds: Long = millis - def prettyPrint = Utils.msDurationToString(millis) + def prettyPrint: String = Utils.msDurationToString(millis) } @@ -80,7 +79,7 @@ case class Duration (private val millis: Long) { * a given number of milliseconds. */ object Milliseconds { - def apply(milliseconds: Long) = new Duration(milliseconds) + def apply(milliseconds: Long): Duration = new Duration(milliseconds) } /** @@ -88,7 +87,7 @@ object Milliseconds { * a given number of seconds. */ object Seconds { - def apply(seconds: Long) = new Duration(seconds * 1000) + def apply(seconds: Long): Duration = new Duration(seconds * 1000) } /** @@ -96,7 +95,7 @@ object Seconds { * a given number of minutes. */ object Minutes { - def apply(minutes: Long) = new Duration(minutes * 60000) + def apply(minutes: Long): Duration = new Duration(minutes * 60000) } // Java-friendlier versions of the objects above. @@ -107,16 +106,16 @@ object Durations { /** * @return [[org.apache.spark.streaming.Duration]] representing given number of milliseconds. */ - def milliseconds(milliseconds: Long) = Milliseconds(milliseconds) + def milliseconds(milliseconds: Long): Duration = Milliseconds(milliseconds) /** * @return [[org.apache.spark.streaming.Duration]] representing given number of seconds. */ - def seconds(seconds: Long) = Seconds(seconds) + def seconds(seconds: Long): Duration = Seconds(seconds) /** * @return [[org.apache.spark.streaming.Duration]] representing given number of minutes. */ - def minutes(minutes: Long) = Minutes(minutes) + def minutes(minutes: Long): Duration = Minutes(minutes) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala b/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala index ad4f3fdd14ad..3f5be785e1b1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala @@ -39,18 +39,18 @@ class Interval(val beginTime: Time, val endTime: Time) { this.endTime < that.endTime } - def <= (that: Interval) = (this < that || this == that) + def <= (that: Interval): Boolean = (this < that || this == that) - def > (that: Interval) = !(this <= that) + def > (that: Interval): Boolean = !(this <= that) - def >= (that: Interval) = !(this < that) + def >= (that: Interval): Boolean = !(this < that) - override def toString = "[" + beginTime + ", " + endTime + "]" + override def toString: String = "[" + beginTime + ", " + endTime + "]" } private[streaming] object Interval { - def currentInterval(duration: Duration): Interval = { + def currentInterval(duration: Duration): Interval = { val time = new Time(System.currentTimeMillis) val intervalBegin = time.floor(duration) new Interval(intervalBegin, intervalBegin + duration) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 8ef078713784..90c8b47aebce 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -27,10 +27,12 @@ import scala.reflect.ClassTag import akka.actor.{Props, SupervisorStrategy} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.io.{BytesWritable, LongWritable, Text} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.spark._ +import org.apache.spark.annotation.Experimental +import org.apache.spark.input.FixedLengthBinaryInputFormat import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream._ @@ -105,6 +107,19 @@ class StreamingContext private[streaming] ( */ def this(path: String) = this(path, new Configuration) + /** + * Recreate a StreamingContext from a checkpoint file using an existing SparkContext. + * @param path Path to the directory that was specified as the checkpoint directory + * @param sparkContext Existing SparkContext + */ + def this(path: String, sparkContext: SparkContext) = { + this( + sparkContext, + CheckpointReader.read(path, sparkContext.conf, sparkContext.hadoopConfiguration).get, + null) + } + + if (sc_ == null && cp_ == null) { throw new Exception("Spark Streaming cannot be initialized with " + "both SparkContext and checkpoint as null") @@ -113,10 +128,12 @@ class StreamingContext private[streaming] ( private[streaming] val isCheckpointPresent = (cp_ != null) private[streaming] val sc: SparkContext = { - if (isCheckpointPresent) { - new SparkContext(cp_.sparkConf) - } else { + if (sc_ != null) { sc_ + } else if (isCheckpointPresent) { + new SparkContext(cp_.createSparkConf()) + } else { + throw new SparkException("Cannot create StreamingContext without a SparkContext") } } @@ -127,7 +144,7 @@ class StreamingContext private[streaming] ( private[streaming] val conf = sc.conf - private[streaming] val env = SparkEnv.get + private[streaming] val env = sc.env private[streaming] val graph: DStreamGraph = { if (isCheckpointPresent) { @@ -172,7 +189,9 @@ class StreamingContext private[streaming] ( /** Register streaming source to metrics system */ private val streamingSource = new StreamingSource(this) - SparkEnv.get.metricsSystem.registerSource(streamingSource) + assert(env != null) + assert(env.metricsSystem != null) + env.metricsSystem.registerSource(streamingSource) /** Enumeration to identify current state of the StreamingContext */ private[streaming] object StreamingContextState extends Enumeration { @@ -186,7 +205,7 @@ class StreamingContext private[streaming] ( /** * Return the associated Spark context */ - def sparkContext = sc + def sparkContext: SparkContext = sc /** * Set each DStreams in this context to remember RDDs it generated in the last given duration. @@ -359,6 +378,30 @@ class StreamingContext private[streaming] ( new FileInputDStream[K, V, F](this, directory, filter, newFilesOnly) } + /** + * Create a input stream that monitors a Hadoop-compatible filesystem + * for new files and reads them using the given key-value types and input format. + * Files must be written to the monitored directory by "moving" them from another + * location within the same file system. File names starting with . are ignored. + * @param directory HDFS directory to monitor for new file + * @param filter Function to filter paths to process + * @param newFilesOnly Should process only new files and ignore existing files in the directory + * @param conf Hadoop configuration + * @tparam K Key type for reading HDFS file + * @tparam V Value type for reading HDFS file + * @tparam F Input format for reading HDFS file + */ + def fileStream[ + K: ClassTag, + V: ClassTag, + F <: NewInputFormat[K, V]: ClassTag + ] (directory: String, + filter: Path => Boolean, + newFilesOnly: Boolean, + conf: Configuration): InputDStream[(K, V)] = { + new FileInputDStream[K, V, F](this, directory, filter, newFilesOnly, Option(conf)) + } + /** * Create a input stream that monitors a Hadoop-compatible filesystem * for new files and reads them as text files (using key as LongWritable, value @@ -371,6 +414,37 @@ class StreamingContext private[streaming] ( fileStream[LongWritable, Text, TextInputFormat](directory).map(_._2.toString) } + /** + * :: Experimental :: + * + * Create an input stream that monitors a Hadoop-compatible filesystem + * for new files and reads them as flat binary files, assuming a fixed length per record, + * generating one byte array per record. Files must be written to the monitored directory + * by "moving" them from another location within the same file system. File names + * starting with . are ignored. + * + * '''Note:''' We ensure that the byte array for each record in the + * resulting RDDs of the DStream has the provided record length. + * + * @param directory HDFS directory to monitor for new file + * @param recordLength length of each record in bytes + */ + @Experimental + def binaryRecordsStream( + directory: String, + recordLength: Int): DStream[Array[Byte]] = { + val conf = sc_.hadoopConfiguration + conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength) + val br = fileStream[LongWritable, BytesWritable, FixedLengthBinaryInputFormat]( + directory, FileInputDStream.defaultFilter : Path => Boolean, newFilesOnly=true, conf) + val data = br.map { case (k, v) => + val bytes = v.getBytes + assert(bytes.length == recordLength, "Byte array does not have correct length") + bytes + } + data + } + /** * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. @@ -469,10 +543,23 @@ class StreamingContext private[streaming] ( * will be thrown in this thread. * @param timeout time to wait in milliseconds */ + @deprecated("Use awaitTerminationOrTimeout(Long) instead", "1.3.0") def awaitTermination(timeout: Long) { waiter.waitForStopOrError(timeout) } + /** + * Wait for the execution to stop. Any exceptions that occurs during the execution + * will be thrown in this thread. + * + * @param timeout time to wait in milliseconds + * @return `true` if it's stopped; or throw the reported error during the execution; or `false` + * if the waiting time elapsed before returning from the method. + */ + def awaitTerminationOrTimeout(timeout: Long): Boolean = { + waiter.waitForStopOrError(timeout) + } + /** * Stop the execution of the streams immediately (does not wait for all received data * to be processed). @@ -508,6 +595,7 @@ class StreamingContext private[streaming] ( // Even if we have already stopped, we still need to attempt to stop the SparkContext because // a user might stop(stopSparkContext = false) and then call stop(stopSparkContext = true). if (stopSparkContext) sc.stop() + uiTab.foreach(_.detach()) // The state should always be Stopped after calling `stop()`, even if we haven't started yet: state = Stopped } @@ -525,7 +613,8 @@ object StreamingContext extends Logging { @deprecated("Replaced by implicit functions in the DStream companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def toPairDStreamFunctions[K, V](stream: DStream[(K, V)]) - (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = { + (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) + : PairDStreamFunctions[K, V] = { DStream.toPairDStreamFunctions(stream)(kt, vt, ord) } @@ -549,19 +638,59 @@ object StreamingContext extends Logging { hadoopConf: Configuration = new Configuration(), createOnError: Boolean = false ): StreamingContext = { - val checkpointOption = try { - CheckpointReader.read(checkpointPath, new SparkConf(), hadoopConf) - } catch { - case e: Exception => - if (createOnError) { - None - } else { - throw e - } - } + val checkpointOption = CheckpointReader.read( + checkpointPath, new SparkConf(), hadoopConf, createOnError) checkpointOption.map(new StreamingContext(null, _, null)).getOrElse(creatingFunc()) } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the StreamingContext + * will be created by called the provided `creatingFunc` on the provided `sparkContext`. Note + * that the SparkConf configuration in the checkpoint data will not be restored as the + * SparkContext has already been created. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new StreamingContext using the given SparkContext + * @param sparkContext SparkContext using which the StreamingContext will be created + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: SparkContext => StreamingContext, + sparkContext: SparkContext + ): StreamingContext = { + getOrCreate(checkpointPath, creatingFunc, sparkContext, createOnError = false) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the StreamingContext + * will be created by called the provided `creatingFunc` on the provided `sparkContext`. Note + * that the SparkConf configuration in the checkpoint data will not be restored as the + * SparkContext has already been created. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new StreamingContext using the given SparkContext + * @param sparkContext SparkContext using which the StreamingContext will be created + * @param createOnError Whether to create a new StreamingContext if there is an + * error in reading checkpoint data. By default, an exception will be + * thrown on error. + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: SparkContext => StreamingContext, + sparkContext: SparkContext, + createOnError: Boolean + ): StreamingContext = { + val checkpointOption = CheckpointReader.read( + checkpointPath, sparkContext.conf, sparkContext.hadoopConfiguration, createOnError) + checkpointOption.map(new StreamingContext(sparkContext, _, null)) + .getOrElse(creatingFunc(sparkContext)) + } + /** * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to StreamingContext. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala index 505e4431e435..01cdcb057404 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala @@ -36,7 +36,7 @@ import org.apache.spark.streaming.dstream.DStream * [[org.apache.spark.streaming.api.java.JavaPairDStream]]. */ class JavaDStream[T](val dstream: DStream[T])(implicit val classTag: ClassTag[T]) - extends JavaDStreamLike[T, JavaDStream[T], JavaRDD[T]] { + extends AbstractJavaDStreamLike[T, JavaDStream[T], JavaRDD[T]] { override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index c382a12f4d09..808dcc174cf9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -34,6 +34,15 @@ import org.apache.spark.streaming._ import org.apache.spark.streaming.api.java.JavaDStream._ import org.apache.spark.streaming.dstream.DStream +/** + * As a workaround for https://issues.scala-lang.org/browse/SI-8905, implementations + * of JavaDStreamLike should extend this dummy abstract class instead of directly inheriting + * from the trait. See SPARK-3266 for additional details. + */ +private[streaming] +abstract class AbstractJavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], + R <: JavaRDDLike[T, R]] extends JavaDStreamLike[T, This, R] + trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T, R]] extends Serializable { implicit val classTag: ClassTag[T] @@ -160,7 +169,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def flatMap[U](f: FlatMapFunction[T, U]): JavaDStream[U] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala + def fn: (T) => Iterable[U] = (x: T) => f.call(x).asScala new JavaDStream(dstream.flatMap(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -170,7 +179,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def flatMapToPair[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairDStream[K2, V2] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala + def fn: (T) => Iterable[(K2, V2)] = (x: T) => f.call(x).asScala def cm: ClassTag[(K2, V2)] = fakeClassTag new JavaPairDStream(dstream.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -181,7 +190,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * of the RDD. */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaDStream[U] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[U] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } new JavaDStream(dstream.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -192,7 +203,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]) : JavaPairDStream[K2, V2] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[(K2, V2)] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } new JavaPairDStream(dstream.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -406,8 +419,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T implicit val cmv2: ClassTag[V2] = fakeClassTag implicit val cmw: ClassTag[W] = fakeClassTag - def scalaTransform (inThis: RDD[T], inThat: RDD[(K2, V2)], time: Time): RDD[W] = + def scalaTransform (inThis: RDD[T], inThat: RDD[(K2, V2)], time: Time): RDD[W] = { transformFunc.call(wrapRDD(inThis), other.wrapRDD(inThat), time).rdd + } dstream.transformWith[(K2, V2), W](other.dstream, scalaTransform(_, _, _)) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index de124cf40eff..93baad19e3ee 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -45,7 +45,7 @@ import org.apache.spark.streaming.dstream.DStream class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( implicit val kManifest: ClassTag[K], implicit val vManifest: ClassTag[V]) - extends JavaDStreamLike[(K, V), JavaPairDStream[K, V], JavaPairRDD[K, V]] { + extends AbstractJavaDStreamLike[(K, V), JavaPairDStream[K, V], JavaPairRDD[K, V]] { override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd) @@ -526,7 +526,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( */ def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairDStream[K, U] = { import scala.collection.JavaConverters._ - def fn = (x: V) => f.apply(x).asScala + def fn: (V) => Iterable[U] = (x: V) => f.apply(x).asScala implicit val cm: ClassTag[U] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]] dstream.flatMapValues(fn) @@ -726,7 +726,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ - def saveAsHadoopFiles[F <: OutputFormat[K, V]](prefix: String, suffix: String) { + def saveAsHadoopFiles(prefix: String, suffix: String) { dstream.saveAsHadoopFiles(prefix, suffix) } @@ -734,12 +734,12 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ - def saveAsHadoopFiles( + def saveAsHadoopFiles[F <: OutputFormat[_, _]]( prefix: String, suffix: String, keyClass: Class[_], valueClass: Class[_], - outputFormatClass: Class[_ <: OutputFormat[_, _]]) { + outputFormatClass: Class[F]) { dstream.saveAsHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass) } @@ -747,12 +747,12 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ - def saveAsHadoopFiles( + def saveAsHadoopFiles[F <: OutputFormat[_, _]]( prefix: String, suffix: String, keyClass: Class[_], valueClass: Class[_], - outputFormatClass: Class[_ <: OutputFormat[_, _]], + outputFormatClass: Class[F], conf: JobConf) { dstream.saveAsHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass, conf) } @@ -761,7 +761,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ - def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[K, V]](prefix: String, suffix: String) { + def saveAsNewAPIHadoopFiles(prefix: String, suffix: String) { dstream.saveAsNewAPIHadoopFiles(prefix, suffix) } @@ -769,12 +769,12 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ - def saveAsNewAPIHadoopFiles( + def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[_, _]]( prefix: String, suffix: String, keyClass: Class[_], valueClass: Class[_], - outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) { + outputFormatClass: Class[F]) { dstream.saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass) } @@ -782,12 +782,12 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ - def saveAsNewAPIHadoopFiles( + def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[_, _]]( prefix: String, suffix: String, keyClass: Class[_], valueClass: Class[_], - outputFormatClass: Class[_ <: NewOutputFormat[_, _]], + outputFormatClass: Class[F], conf: Configuration = new Configuration) { dstream.saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass, conf) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 9a2254bcdc1f..572d7d8e8753 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -29,15 +29,17 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} +import org.apache.spark.api.java.function.{Function0 => JFunction0} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.scheduler.StreamingListener -import org.apache.hadoop.conf.Configuration -import org.apache.spark.streaming.dstream.{PluggableInputDStream, ReceiverInputDStream, DStream} +import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver +import org.apache.hadoop.conf.Configuration /** * A Java-friendly version of [[org.apache.spark.streaming.StreamingContext]] which is the main @@ -177,7 +179,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { /** * Create an input stream from network source hostname:port. Data is received using - * a TCP socket and the receive bytes it interepreted as object using the given + * a TCP socket and the receive bytes it interpreted as object using the given * converter. * @param hostname Hostname to connect to for receiving data * @param port Port to connect to for receiving data @@ -191,7 +193,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { converter: JFunction[InputStream, java.lang.Iterable[T]], storageLevel: StorageLevel) : JavaReceiverInputDStream[T] = { - def fn = (x: InputStream) => converter.call(x).toIterator + def fn: (InputStream) => Iterator[T] = (x: InputStream) => converter.call(x).toIterator implicit val cmt: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] ssc.socketStream(hostname, port, fn, storageLevel) @@ -209,6 +211,24 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { ssc.textFileStream(directory) } + /** + * :: Experimental :: + * + * Create an input stream that monitors a Hadoop-compatible filesystem + * for new files and reads them as flat binary files with fixed record lengths, + * yielding byte arrays + * + * '''Note:''' We ensure that the byte array for each record in the + * resulting RDDs of the DStream has the provided record length. + * + * @param directory HDFS directory to monitor for new files + * @param recordLength The length at which to split the records + */ + @Experimental + def binaryRecordsStream(directory: String, recordLength: Int): JavaDStream[Array[Byte]] = { + ssc.binaryRecordsStream(directory, recordLength) + } + /** * Create an input stream from network source hostname:port, where data is received * as serialized blocks (serialized using the Spark's serializer) that can be directly @@ -294,10 +314,41 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cmk: ClassTag[K] = ClassTag(kClass) implicit val cmv: ClassTag[V] = ClassTag(vClass) implicit val cmf: ClassTag[F] = ClassTag(fClass) - def fn = (x: Path) => filter.call(x).booleanValue() + def fn: (Path) => Boolean = (x: Path) => filter.call(x).booleanValue() ssc.fileStream[K, V, F](directory, fn, newFilesOnly) } + /** + * Create an input stream that monitors a Hadoop-compatible filesystem + * for new files and reads them using the given key-value types and input format. + * Files must be written to the monitored directory by "moving" them from another + * location within the same file system. File names starting with . are ignored. + * @param directory HDFS directory to monitor for new file + * @param kClass class of key for reading HDFS file + * @param vClass class of value for reading HDFS file + * @param fClass class of input format for reading HDFS file + * @param filter Function to filter paths to process + * @param newFilesOnly Should process only new files and ignore existing files in the directory + * @param conf Hadoop configuration + * @tparam K Key type for reading HDFS file + * @tparam V Value type for reading HDFS file + * @tparam F Input format for reading HDFS file + */ + def fileStream[K, V, F <: NewInputFormat[K, V]]( + directory: String, + kClass: Class[K], + vClass: Class[V], + fClass: Class[F], + filter: JFunction[Path, JBoolean], + newFilesOnly: Boolean, + conf: Configuration): JavaPairInputDStream[K, V] = { + implicit val cmk: ClassTag[K] = ClassTag(kClass) + implicit val cmv: ClassTag[V] = ClassTag(vClass) + implicit val cmf: ClassTag[F] = ClassTag(fClass) + def fn: (Path) => Boolean = (x: Path) => filter.call(x).booleanValue() + ssc.fileStream[K, V, F](directory, fn, newFilesOnly, conf) + } + /** * Create an input stream with any arbitrary user implemented actor receiver. * @param props Props object defining creation of the actor @@ -547,10 +598,23 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * will be thrown in this thread. * @param timeout time to wait in milliseconds */ + @deprecated("Use awaitTerminationOrTimeout(Long) instead", "1.3.0") def awaitTermination(timeout: Long): Unit = { ssc.awaitTermination(timeout) } + /** + * Wait for the execution to stop. Any exceptions that occurs during the execution + * will be thrown in this thread. + * + * @param timeout time to wait in milliseconds + * @return `true` if it's stopped; or throw the reported error during the execution; or `false` + * if the waiting time elapsed before returning from the method. + */ + def awaitTerminationOrTimeout(timeout: Long): Boolean = { + ssc.awaitTerminationOrTimeout(timeout) + } + /** * Stop the execution of the streams. Will stop the associated JavaSparkContext as well. */ @@ -562,7 +626,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Stop the execution of the streams. * @param stopSparkContext Stop the associated SparkContext or not */ - def stop(stopSparkContext: Boolean) = ssc.stop(stopSparkContext) + def stop(stopSparkContext: Boolean): Unit = ssc.stop(stopSparkContext) /** * Stop the execution of the streams. @@ -570,7 +634,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * @param stopGracefully Stop gracefully by waiting for the processing of all * received data to be completed */ - def stop(stopSparkContext: Boolean, stopGracefully: Boolean) = { + def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = { ssc.stop(stopSparkContext, stopGracefully) } @@ -592,6 +656,7 @@ object JavaStreamingContext { * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext */ + @deprecated("use getOrCreate without JavaStreamingContextFactor", "1.4.0") def getOrCreate( checkpointPath: String, factory: JavaStreamingContextFactory @@ -613,6 +678,7 @@ object JavaStreamingContext { * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible * file system */ + @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( checkpointPath: String, hadoopConf: Configuration, @@ -637,6 +703,7 @@ object JavaStreamingContext { * @param createOnError Whether to create a new JavaStreamingContext if there is an * error in reading checkpoint data. */ + @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( checkpointPath: String, hadoopConf: Configuration, @@ -649,6 +716,117 @@ object JavaStreamingContext { new JavaStreamingContext(ssc) } + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction0[JavaStreamingContext] + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + creatingFunc.call().ssc + }) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible + * file system + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction0[JavaStreamingContext], + hadoopConf: Configuration + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + creatingFunc.call().ssc + }, hadoopConf) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible + * file system + * @param createOnError Whether to create a new JavaStreamingContext if there is an + * error in reading checkpoint data. + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction0[JavaStreamingContext], + hadoopConf: Configuration, + createOnError: Boolean + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + creatingFunc.call().ssc + }, hadoopConf, createOnError) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + * @param sparkContext SparkContext using which the StreamingContext will be created + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction[JavaSparkContext, JavaStreamingContext], + sparkContext: JavaSparkContext + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, (sparkContext: SparkContext) => { + creatingFunc.call(new JavaSparkContext(sparkContext)).ssc + }, sparkContext.sc) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + * @param sparkContext SparkContext using which the StreamingContext will be created + * @param createOnError Whether to create a new JavaStreamingContext if there is an + * error in reading checkpoint data. + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction[JavaSparkContext, JavaStreamingContext], + sparkContext: JavaSparkContext, + createOnError: Boolean + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, (sparkContext: SparkContext) => { + creatingFunc.call(new JavaSparkContext(sparkContext)).ssc + }, sparkContext.sc, createOnError) + new JavaStreamingContext(ssc) + } + /** * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to StreamingContext. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 7053f47ec69a..4c28654ef641 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -176,11 +176,11 @@ private[python] abstract class PythonDStream( val func = new TransformFunction(pfunc) - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration - val asJavaDStream = JavaDStream.fromDStream(this) + val asJavaDStream: JavaDStream[Array[Byte]] = JavaDStream.fromDStream(this) } /** @@ -212,7 +212,7 @@ private[python] class PythonTransformed2DStream( val func = new TransformFunction(pfunc) - override def dependencies = List(parent, parent2) + override def dependencies: List[DStream[_]] = List(parent, parent2) override def slideDuration: Duration = parent.slideDuration @@ -223,7 +223,7 @@ private[python] class PythonTransformed2DStream( func(Some(rdd1), Some(rdd2), validTime) } - val asJavaDStream = JavaDStream.fromDStream(this) + val asJavaDStream: JavaDStream[Array[Byte]] = JavaDStream.fromDStream(this) } /** @@ -260,12 +260,15 @@ private[python] class PythonReducedWindowedDStream( extends PythonDStream(parent, preduceFunc) { super.persist(StorageLevel.MEMORY_ONLY) - override val mustCheckpoint = true - val invReduceFunc = new TransformFunction(pinvReduceFunc) + override val mustCheckpoint: Boolean = true + + val invReduceFunc: TransformFunction = new TransformFunction(pinvReduceFunc) def windowDuration: Duration = _windowDuration + override def slideDuration: Duration = _slideDuration + override def parentRememberDuration: Duration = rememberDuration + windowDuration override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index b874f561c12e..24f99a2b929f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -104,7 +104,7 @@ abstract class DStream[T: ClassTag] ( private[streaming] def parentRememberDuration = rememberDuration /** Return the StreamingContext associated with this DStream */ - def context = ssc + def context: StreamingContext = ssc /* Set the creation call site */ private[streaming] val creationSite = DStream.getCreationSite() @@ -619,14 +619,16 @@ abstract class DStream[T: ClassTag] ( * operator, so this DStream will be registered as an output stream and there materialized. */ def print(num: Int) { - def foreachFunc = (rdd: RDD[T], time: Time) => { - val firstNum = rdd.take(num + 1) - println ("-------------------------------------------") - println ("Time: " + time) - println ("-------------------------------------------") - firstNum.take(num).foreach(println) - if (firstNum.size > num) println("...") - println() + def foreachFunc: (RDD[T], Time) => Unit = { + (rdd: RDD[T], time: Time) => { + val firstNum = rdd.take(num + 1) + println("-------------------------------------------") + println("Time: " + time) + println("-------------------------------------------") + firstNum.take(num).foreach(println) + if (firstNum.size > num) println("...") + println() + } } new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register() } @@ -837,7 +839,7 @@ object DStream { /** Filtering function that excludes non-user classes for a streaming application */ def streamingExclustionFunction(className: String): Boolean = { - def doesMatch(r: Regex) = r.findFirstIn(className).isDefined + def doesMatch(r: Regex): Boolean = r.findFirstIn(className).isDefined val isSparkClass = doesMatch(SPARK_CLASS_REGEX) val isSparkExampleClass = doesMatch(SPARK_EXAMPLES_CLASS_REGEX) val isSparkStreamingTestClass = doesMatch(SPARK_STREAMING_TESTCLASS_REGEX) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala index 0dc72790fbdb..39fd21342813 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala @@ -114,7 +114,7 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T]) } } - override def toString() = { + override def toString: String = { "[\n" + currentCheckpointFiles.size + " checkpoint files \n" + currentCheckpointFiles.mkString("\n") + "\n]" } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index e7c5639a6349..eca69f00188e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -18,14 +18,15 @@ package org.apache.spark.streaming.dstream import java.io.{IOException, ObjectInputStream} -import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable import scala.reflect.ClassTag +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import org.apache.spark.{SparkConf, SerializableWritable} import org.apache.spark.rdd.{RDD, UnionRDD} import org.apache.spark.streaming._ import org.apache.spark.util.{TimeStampedHashMap, Utils} @@ -62,19 +63,32 @@ import org.apache.spark.util.{TimeStampedHashMap, Utils} * the streaming app. * - If a file is to be visible in the directory listings, it must be visible within a certain * duration of the mod time of the file. This duration is the "remember window", which is set to - * 1 minute (see `FileInputDStream.MIN_REMEMBER_DURATION`). Otherwise, the file will never be + * 1 minute (see `FileInputDStream.minRememberDuration`). Otherwise, the file will never be * selected as the mod time will be less than the ignore threshold when it becomes visible. * - Once a file is visible, the mod time cannot change. If it does due to appends, then the * processing semantics are undefined. */ private[streaming] -class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : ClassTag]( +class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( @transient ssc_ : StreamingContext, directory: String, filter: Path => Boolean = FileInputDStream.defaultFilter, - newFilesOnly: Boolean = true) + newFilesOnly: Boolean = true, + conf: Option[Configuration] = None) + (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]) extends InputDStream[(K, V)](ssc_) { + private val serializableConfOpt = conf.map(new SerializableWritable(_)) + + /** + * Minimum duration of remembering the information of selected files. Defaults to 60 seconds. + * + * Files with mod times older than this "window" of remembering will be ignored. So if new + * files are visible within this window, then the file will get selected in the next batch. + */ + private val minRememberDurationS = + Seconds(ssc.conf.getTimeAsSeconds("spark.streaming.minRememberDuration", "60s")) + // This is a def so that it works during checkpoint recovery: private def clock = ssc.scheduler.clock @@ -83,14 +97,15 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas // Initial ignore threshold based on which old, existing files in the directory (at the time of // starting the streaming application) will be ignored or considered - private val initialModTimeIgnoreThreshold = if (newFilesOnly) clock.currentTime() else 0L + private val initialModTimeIgnoreThreshold = if (newFilesOnly) clock.getTimeMillis() else 0L /* * Make sure that the information of files selected in the last few batches are remembered. * This would allow us to filter away not-too-old files which have already been recently * selected and processed. */ - private val numBatchesToRemember = FileInputDStream.calculateNumBatchesToRemember(slideDuration) + private val numBatchesToRemember = FileInputDStream + .calculateNumBatchesToRemember(slideDuration, minRememberDurationS) private val durationToRemember = slideDuration * numBatchesToRemember remember(durationToRemember) @@ -156,7 +171,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas */ private def findNewFiles(currentTime: Long): Array[String] = { try { - lastNewFileFindingTime = clock.currentTime() + lastNewFileFindingTime = clock.getTimeMillis() // Calculate ignore threshold val modTimeIgnoreThreshold = math.max( @@ -169,7 +184,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas def accept(path: Path): Boolean = isNewFile(path, currentTime, modTimeIgnoreThreshold) } val newFiles = fs.listStatus(directoryPath, filter).map(_.getPath.toString) - val timeTaken = clock.currentTime() - lastNewFileFindingTime + val timeTaken = clock.getTimeMillis() - lastNewFileFindingTime logInfo("Finding new files took " + timeTaken + " ms") logDebug("# cached file times = " + fileToModTime.size) if (timeTaken > slideDuration.milliseconds) { @@ -237,7 +252,15 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas /** Generate one RDD from an array of files */ private def filesToRDD(files: Seq[String]): RDD[(K, V)] = { val fileRDDs = files.map(file =>{ - val rdd = context.sparkContext.newAPIHadoopFile[K, V, F](file) + val rdd = serializableConfOpt.map(_.value) match { + case Some(config) => context.sparkContext.newAPIHadoopFile( + file, + fm.runtimeClass.asInstanceOf[Class[F]], + km.runtimeClass.asInstanceOf[Class[K]], + vm.runtimeClass.asInstanceOf[Class[V]], + config) + case None => context.sparkContext.newAPIHadoopFile[K, V, F](file) + } if (rdd.partitions.size == 0) { logError("File " + file + " has no data in it. Spark Streaming can only ingest " + "files that have been \"moved\" to the directory assigned to the file stream. " + @@ -285,7 +308,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas private[streaming] class FileInputDStreamCheckpointData extends DStreamCheckpointData(this) { - def hadoopFiles = data.asInstanceOf[mutable.HashMap[Time, Array[String]]] + private def hadoopFiles = data.asInstanceOf[mutable.HashMap[Time, Array[String]]] override def update(time: Time) { hadoopFiles.clear() @@ -307,7 +330,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas } } - override def toString() = { + override def toString: String = { "[\n" + hadoopFiles.size + " file sets\n" + hadoopFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n") + "\n]" } @@ -317,20 +340,14 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas private[streaming] object FileInputDStream { - /** - * Minimum duration of remembering the information of selected files. Files with mod times - * older than this "window" of remembering will be ignored. So if new files are visible - * within this window, then the file will get selected in the next batch. - */ - private val MIN_REMEMBER_DURATION = Minutes(1) - def defaultFilter(path: Path): Boolean = !path.getName().startsWith(".") /** * Calculate the number of last batches to remember, such that all the files selected in - * at least last MIN_REMEMBER_DURATION duration can be remembered. + * at least last minRememberDurationS duration can be remembered. */ - def calculateNumBatchesToRemember(batchDuration: Duration): Int = { - math.ceil(MIN_REMEMBER_DURATION.milliseconds.toDouble / batchDuration.milliseconds).toInt + def calculateNumBatchesToRemember(batchDuration: Duration, + minRememberDurationS: Duration): Int = { + math.ceil(minRememberDurationS.milliseconds.toDouble / batchDuration.milliseconds).toInt } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala index c81534ae584e..fcd5216f101a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala @@ -27,7 +27,7 @@ class FilteredDStream[T: ClassTag]( filterFunc: T => Boolean ) extends DStream[T](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala index 658623455498..9d09a3baf37c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala @@ -28,7 +28,7 @@ class FlatMapValuedDStream[K: ClassTag, V: ClassTag, U: ClassTag]( flatMapValueFunc: V => TraversableOnce[U] ) extends DStream[(K, U)](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala index c7bb2833eabb..475ea2d2d4f3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala @@ -27,7 +27,7 @@ class FlatMappedDStream[T: ClassTag, U: ClassTag]( flatMapFunc: T => Traversable[U] ) extends DStream[U](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala index 1361c30395b5..685a32e1d280 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala @@ -28,7 +28,7 @@ class ForEachDStream[T: ClassTag] ( foreachFunc: (RDD[T], Time) => Unit ) extends DStream[Unit](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala index a9bb51f05404..dbb295fe54f7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala @@ -25,7 +25,7 @@ private[streaming] class GlommedDStream[T: ClassTag](parent: DStream[T]) extends DStream[Array[T]](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index aa1993f0580a..e652702e213e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -61,7 +61,7 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext) } } - override def dependencies = List() + override def dependencies: List[DStream[_]] = List() override def slideDuration: Duration = { if (ssc == null) throw new Exception("ssc is null") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala index 3d8ee29df1e8..5994bc1e23f2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala @@ -28,7 +28,7 @@ class MapPartitionedDStream[T: ClassTag, U: ClassTag]( preservePartitioning: Boolean ) extends DStream[U](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala index 7aea1f945d9d..954d2eb4a7b0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala @@ -28,7 +28,7 @@ class MapValuedDStream[K: ClassTag, V: ClassTag, U: ClassTag]( mapValueFunc: V => U ) extends DStream[(K, U)](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala index 02704a8d1c2e..fa14b2e897c3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala @@ -27,7 +27,7 @@ class MappedDStream[T: ClassTag, U: ClassTag] ( mapFunc: T => U ) extends DStream[U](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala index c0a5af0b65cc..1385ccbf56ee 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -52,7 +52,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( // Reduce each batch of data using reduceByKey which will be further reduced by window // by ReducedWindowedDStream - val reducedStream = parent.reduceByKey(reduceFunc, partitioner) + private val reducedStream = parent.reduceByKey(reduceFunc, partitioner) // Persist RDDs to memory by default as these RDDs are going to be reused. super.persist(StorageLevel.MEMORY_ONLY_SER) @@ -60,7 +60,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( def windowDuration: Duration = _windowDuration - override def dependencies = List(reducedStream) + override def dependencies: List[DStream[_]] = List(reducedStream) override def slideDuration: Duration = _slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala index 880a89bc3689..7757ccac09a5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala @@ -33,7 +33,7 @@ class ShuffledDStream[K: ClassTag, V: ClassTag, C: ClassTag]( mapSideCombine: Boolean = true ) extends DStream[(K,C)] (parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index ebb04dd35b9a..de8718d0a80f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -36,7 +36,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( super.persist(StorageLevel.MEMORY_ONLY_SER) - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala index 71b61856e23c..5d46ca0715ff 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala @@ -32,7 +32,7 @@ class TransformedDStream[U: ClassTag] ( require(parents.map(_.slideDuration).distinct.size == 1, "Some of the DStreams have different slide durations") - override def dependencies = parents.toList + override def dependencies: List[DStream[_]] = parents.toList override def slideDuration: Duration = parents.head.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala index abbc40befa95..9405dbaa1232 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala @@ -33,17 +33,17 @@ class UnionDStream[T: ClassTag](parents: Array[DStream[T]]) require(parents.map(_.slideDuration).distinct.size == 1, "Some of the DStreams have different slide durations") - override def dependencies = parents.toList + override def dependencies: List[DStream[_]] = parents.toList override def slideDuration: Duration = parents.head.slideDuration override def compute(validTime: Time): Option[RDD[T]] = { val rdds = new ArrayBuffer[RDD[T]]() - parents.map(_.getOrCompute(validTime)).foreach(_ match { + parents.map(_.getOrCompute(validTime)).foreach { case Some(rdd) => rdds += rdd case None => throw new Exception("Could not generate RDD from a parent for unifying at time " + validTime) - }) + } if (rdds.size > 0) { Some(new UnionRDD(ssc.sc, rdds)) } else { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala index 775b6bfd065c..899865a906c2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala @@ -46,7 +46,7 @@ class WindowedDStream[T: ClassTag]( def windowDuration: Duration = _windowDuration - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = _slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index dd1e96334952..93caa4ba35c7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -117,8 +117,8 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( override def getPreferredLocations(split: Partition): Seq[String] = { val partition = split.asInstanceOf[WriteAheadLogBackedBlockRDDPartition] val blockLocations = getBlockIdLocations().get(partition.blockId) - def segmentLocations = HdfsUtils.getFileSegmentLocations( - partition.segment.path, partition.segment.offset, partition.segment.length, hadoopConfig) - blockLocations.getOrElse(segmentLocations) + blockLocations.getOrElse( + HdfsUtils.getFileSegmentLocations( + partition.segment.path, partition.segment.offset, partition.segment.length, hadoopConfig)) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala index a7d63bd4f2db..cd309788a771 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala @@ -17,6 +17,7 @@ package org.apache.spark.streaming.receiver +import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.duration._ @@ -25,10 +26,10 @@ import scala.reflect.ClassTag import akka.actor._ import akka.actor.SupervisorStrategy.{Escalate, Restart} + import org.apache.spark.{Logging, SparkEnv} -import org.apache.spark.storage.StorageLevel -import java.nio.ByteBuffer import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.storage.StorageLevel /** * :: DeveloperApi :: @@ -149,13 +150,13 @@ private[streaming] class ActorReceiver[T: ClassTag]( class Supervisor extends Actor { override val supervisorStrategy = receiverSupervisorStrategy - val worker = context.actorOf(props, name) + private val worker = context.actorOf(props, name) logInfo("Started receiver worker at:" + worker.path) - val n: AtomicInteger = new AtomicInteger(0) - val hiccups: AtomicInteger = new AtomicInteger(0) + private val n: AtomicInteger = new AtomicInteger(0) + private val hiccups: AtomicInteger = new AtomicInteger(0) - def receive = { + override def receive: PartialFunction[Any, Unit] = { case IteratorData(iterator) => logDebug("received iterator") @@ -189,13 +190,12 @@ private[streaming] class ActorReceiver[T: ClassTag]( } } - def onStart() = { + def onStart(): Unit = { supervisor logInfo("Supervision tree for receivers initialized at:" + supervisor.path) - } - def onStop() = { + def onStop(): Unit = { supervisor ! PoisonPill } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 79263a718397..f4963a78e1d1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -23,7 +23,8 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StreamBlockId -import org.apache.spark.streaming.util.{RecurringTimer, SystemClock} +import org.apache.spark.streaming.util.RecurringTimer +import org.apache.spark.util.{SystemClock, Utils} /** Listener object for BlockGenerator events */ private[streaming] trait BlockGeneratorListener { @@ -78,9 +79,9 @@ private[streaming] class BlockGenerator( private case class Block(id: StreamBlockId, buffer: ArrayBuffer[Any]) private val clock = new SystemClock() - private val blockInterval = conf.getLong("spark.streaming.blockInterval", 200) + private val blockIntervalMs = conf.getTimeAsMs("spark.streaming.blockInterval", "200ms") private val blockIntervalTimer = - new RecurringTimer(clock, blockInterval, updateCurrentBuffer, "BlockGenerator") + new RecurringTimer(clock, blockIntervalMs, updateCurrentBuffer, "BlockGenerator") private val blockQueueSize = conf.getInt("spark.streaming.blockQueueSize", 10) private val blocksForPushing = new ArrayBlockingQueue[Block](blockQueueSize) private val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } @@ -119,7 +120,7 @@ private[streaming] class BlockGenerator( * `BlockGeneratorListener.onAddData` callback will be called. All received data items * will be periodically pushed into BlockManager. */ - def addDataWithCallback(data: Any, metadata: Any) = synchronized { + def addDataWithCallback(data: Any, metadata: Any): Unit = synchronized { waitToPush() currentBuffer += data listener.onAddData(data, metadata) @@ -131,7 +132,7 @@ private[streaming] class BlockGenerator( val newBlockBuffer = currentBuffer currentBuffer = new ArrayBuffer[Any] if (newBlockBuffer.size > 0) { - val blockId = StreamBlockId(receiverId, time - blockInterval) + val blockId = StreamBlockId(receiverId, time - blockIntervalMs) val newBlock = new Block(blockId, newBlockBuffer) listener.onGenerateBlock(blockId) blocksForPushing.put(newBlock) // put is blocking when queue is full diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala index e4f6ba626ebb..97db9ded8336 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.receiver import org.apache.spark.{Logging, SparkConf} -import java.util.concurrent.TimeUnit._ +import com.google.common.util.concurrent.{RateLimiter=>GuavaRateLimiter} /** Provides waitToPush() method to limit the rate at which receivers consume data. * @@ -33,37 +33,12 @@ import java.util.concurrent.TimeUnit._ */ private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging { - private var lastSyncTime = System.nanoTime - private var messagesWrittenSinceSync = 0L private val desiredRate = conf.getInt("spark.streaming.receiver.maxRate", 0) - private val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS) + private lazy val rateLimiter = GuavaRateLimiter.create(desiredRate) def waitToPush() { - if( desiredRate <= 0 ) { - return - } - val now = System.nanoTime - val elapsedNanosecs = math.max(now - lastSyncTime, 1) - val rate = messagesWrittenSinceSync.toDouble * 1000000000 / elapsedNanosecs - if (rate < desiredRate) { - // It's okay to write; just update some variables and return - messagesWrittenSinceSync += 1 - if (now > lastSyncTime + SYNC_INTERVAL) { - // Sync interval has passed; let's resync - lastSyncTime = now - messagesWrittenSinceSync = 1 - } - } else { - // Calculate how much time we should sleep to bring ourselves to the desired rate. - val targetTimeInMillis = messagesWrittenSinceSync * 1000 / desiredRate - val elapsedTimeInMillis = elapsedNanosecs / 1000000 - val sleepTimeInMillis = targetTimeInMillis - elapsedTimeInMillis - if (sleepTimeInMillis > 0) { - logTrace("Natural rate is " + rate + " per second but desired rate is " + - desiredRate + ", sleeping for " + sleepTimeInMillis + " ms to compensate.") - Thread.sleep(sleepTimeInMillis) - } - waitToPush() + if (desiredRate > 0) { + rateLimiter.acquire() } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index f7a8ebee8a54..297bf04c0c25 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -27,8 +27,8 @@ import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage._ -import org.apache.spark.streaming.util.{Clock, SystemClock, WriteAheadLogFileSegment, WriteAheadLogManager} -import org.apache.spark.util.Utils +import org.apache.spark.streaming.util.{WriteAheadLogFileSegment, WriteAheadLogManager} +import org.apache.spark.util.{ThreadUtils, Clock, SystemClock} /** Trait that represents the metadata related to storage of blocks */ private[streaming] trait ReceivedBlockStoreResult { @@ -150,7 +150,7 @@ private[streaming] class WriteAheadLogBasedBlockHandler( // For processing futures used in parallel block storing into block manager and write ahead log // # threads = 2, so that both writing to BM and WAL can proceed in parallel implicit private val executionContext = ExecutionContext.fromExecutorService( - Utils.newDaemonFixedThreadPool(2, this.getClass.getSimpleName)) + ThreadUtils.newDaemonFixedThreadPool(2, this.getClass.getSimpleName)) /** * This implementation stores the block into the block manager as well as a write ahead log. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala index 5acf8a9a811e..5b5a3fe64860 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala @@ -245,7 +245,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * Get the unique identifier the receiver input stream that this * receiver is associated with. */ - def streamId = id + def streamId: Int = id /* * ================= diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index 1f0244c251eb..4943f29395d1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -162,13 +162,13 @@ private[streaming] abstract class ReceiverSupervisor( } /** Check if receiver has been marked for stopping */ - def isReceiverStarted() = { + def isReceiverStarted(): Boolean = { logDebug("state = " + receiverState) receiverState == Started } /** Check if receiver has been marked for stopping */ - def isReceiverStopped() = { + def isReceiverStopped(): Boolean = { logDebug("state = " + receiverState) receiverState == Stopped } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 716cf2c7f32f..89af40330b9d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -21,18 +21,16 @@ import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable.ArrayBuffer -import scala.concurrent.Await -import akka.actor.{Actor, Props} -import akka.pattern.ask import com.google.common.base.Throwables import org.apache.hadoop.conf.Configuration import org.apache.spark.{Logging, SparkEnv, SparkException} +import org.apache.spark.rpc.{RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.Time import org.apache.spark.streaming.scheduler._ -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{RpcUtils, Utils} /** * Concrete implementation of [[org.apache.spark.streaming.receiver.ReceiverSupervisor]] @@ -63,33 +61,23 @@ private[streaming] class ReceiverSupervisorImpl( } - /** Remote Akka actor for the ReceiverTracker */ - private val trackerActor = { - val ip = env.conf.get("spark.driver.host", "localhost") - val port = env.conf.getInt("spark.driver.port", 7077) - val url = "akka.tcp://%s@%s:%s/user/ReceiverTracker".format( - SparkEnv.driverActorSystemName, ip, port) - env.actorSystem.actorSelection(url) - } - - /** Timeout for Akka actor messages */ - private val askTimeout = AkkaUtils.askTimeout(env.conf) + /** Remote RpcEndpointRef for the ReceiverTracker */ + private val trackerEndpoint = RpcUtils.makeDriverRef("ReceiverTracker", env.conf, env.rpcEnv) - /** Akka actor for receiving messages from the ReceiverTracker in the driver */ - private val actor = env.actorSystem.actorOf( - Props(new Actor { + /** RpcEndpointRef for receiving messages from the ReceiverTracker in the driver */ + private val endpoint = env.rpcEnv.setupEndpoint( + "Receiver-" + streamId + "-" + System.currentTimeMillis(), new ThreadSafeRpcEndpoint { + override val rpcEnv: RpcEnv = env.rpcEnv - override def receive() = { + override def receive: PartialFunction[Any, Unit] = { case StopReceiver => logInfo("Received stop signal") - stop("Stopped by driver", None) + ReceiverSupervisorImpl.this.stop("Stopped by driver", None) case CleanupOldBlocks(threshTime) => logDebug("Received delete old batch signal") cleanupOldBlocks(threshTime) } - - def ref = self - }), "Receiver-" + streamId + "-" + System.currentTimeMillis()) + }) /** Unique block ids if one wants to add blocks directly */ private val newBlockId = new AtomicLong(System.currentTimeMillis()) @@ -158,15 +146,14 @@ private[streaming] class ReceiverSupervisorImpl( logDebug(s"Pushed block $blockId in ${(System.currentTimeMillis - time)} ms") val blockInfo = ReceivedBlockInfo(streamId, numRecords, blockStoreResult) - val future = trackerActor.ask(AddBlock(blockInfo))(askTimeout) - Await.result(future, askTimeout) + trackerEndpoint.askWithReply[Boolean](AddBlock(blockInfo)) logDebug(s"Reported block $blockId") } /** Report error to the receiver tracker */ def reportError(message: String, error: Throwable) { val errorString = Option(error).map(Throwables.getStackTraceAsString).getOrElse("") - trackerActor ! ReportError(streamId, message, errorString) + trackerEndpoint.send(ReportError(streamId, message, errorString)) logWarning("Reported error " + message + " - " + error) } @@ -176,22 +163,19 @@ private[streaming] class ReceiverSupervisorImpl( override protected def onStop(message: String, error: Option[Throwable]) { blockGenerator.stop() - env.actorSystem.stop(actor) + env.rpcEnv.stop(endpoint) } override protected def onReceiverStart() { val msg = RegisterReceiver( - streamId, receiver.getClass.getSimpleName, Utils.localHostName(), actor) - val future = trackerActor.ask(msg)(askTimeout) - Await.result(future, askTimeout) + streamId, receiver.getClass.getSimpleName, Utils.localHostName(), endpoint) + trackerEndpoint.askWithReply[Boolean](msg) } override protected def onReceiverStop(message: String, error: Option[Throwable]) { logInfo("Deregistering receiver " + streamId) val errorString = error.map(Throwables.getStackTraceAsString).getOrElse("") - val future = trackerActor.ask( - DeregisterReceiver(streamId, message, errorString))(askTimeout) - Await.result(future, askTimeout) + trackerEndpoint.askWithReply[Boolean](DeregisterReceiver(streamId, message, errorString)) logInfo("Stopped receiver " + streamId) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala index 7e0f6b2cdfc0..30cf87f5b7dd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala @@ -36,5 +36,5 @@ class Job(val time: Time, func: () => _) { id = "streaming job " + time + "." + number } - override def toString = id + override def toString: String = id } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 8632c94349bf..2467d50839ad 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -19,17 +19,17 @@ package org.apache.spark.streaming.scheduler import scala.util.{Failure, Success, Try} -import akka.actor.{ActorRef, Props, Actor} - import org.apache.spark.{SparkEnv, Logging} import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time} -import org.apache.spark.streaming.util.{Clock, ManualClock, RecurringTimer} +import org.apache.spark.streaming.util.RecurringTimer +import org.apache.spark.util.{Clock, EventLoop, ManualClock} /** Event classes for JobGenerator */ private[scheduler] sealed trait JobGeneratorEvent private[scheduler] case class GenerateJobs(time: Time) extends JobGeneratorEvent private[scheduler] case class ClearMetadata(time: Time) extends JobGeneratorEvent -private[scheduler] case class DoCheckpoint(time: Time) extends JobGeneratorEvent +private[scheduler] case class DoCheckpoint( + time: Time, clearCheckpointDataLater: Boolean) extends JobGeneratorEvent private[scheduler] case class ClearCheckpointData(time: Time) extends JobGeneratorEvent /** @@ -45,12 +45,18 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { val clock = { val clockClass = ssc.sc.conf.get( - "spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock") - Class.forName(clockClass).newInstance().asInstanceOf[Clock] + "spark.streaming.clock", "org.apache.spark.util.SystemClock") + try { + Class.forName(clockClass).newInstance().asInstanceOf[Clock] + } catch { + case e: ClassNotFoundException if clockClass.startsWith("org.apache.spark.streaming") => + val newClockClass = clockClass.replace("org.apache.spark.streaming", "org.apache.spark") + Class.forName(newClockClass).newInstance().asInstanceOf[Clock] + } } private val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, - longTime => eventActor ! GenerateJobs(new Time(longTime)), "JobGenerator") + longTime => eventLoop.post(GenerateJobs(new Time(longTime))), "JobGenerator") // This is marked lazy so that this is initialized after checkpoint duration has been set // in the context and the generator has been started. @@ -62,22 +68,26 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { null } - // eventActor is created when generator starts. + // eventLoop is created when generator starts. // This not being null means the scheduler has been started and not stopped - private var eventActor: ActorRef = null + private var eventLoop: EventLoop[JobGeneratorEvent] = null // last batch whose completion,checkpointing and metadata cleanup has been completed private var lastProcessedBatch: Time = null /** Start generation of jobs */ def start(): Unit = synchronized { - if (eventActor != null) return // generator has already been started + if (eventLoop != null) return // generator has already been started + + eventLoop = new EventLoop[JobGeneratorEvent]("JobGenerator") { + override protected def onReceive(event: JobGeneratorEvent): Unit = processEvent(event) - eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { - def receive = { - case event: JobGeneratorEvent => processEvent(event) + override protected def onError(e: Throwable): Unit = { + jobScheduler.reportError("Error in job generator", e) } - }), "JobGenerator") + } + eventLoop.start() + if (ssc.isCheckpointPresent) { restart() } else { @@ -91,22 +101,20 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { * checkpoints written. */ def stop(processReceivedData: Boolean): Unit = synchronized { - if (eventActor == null) return // generator has already been stopped + if (eventLoop == null) return // generator has already been stopped if (processReceivedData) { logInfo("Stopping JobGenerator gracefully") val timeWhenStopStarted = System.currentTimeMillis() - val stopTimeout = conf.getLong( - "spark.streaming.gracefulStopTimeout", - 10 * ssc.graph.batchDuration.milliseconds - ) + val stopTimeoutMs = conf.getTimeAsMs( + "spark.streaming.gracefulStopTimeout", s"${10 * ssc.graph.batchDuration.milliseconds}ms") val pollTime = 100 // To prevent graceful stop to get stuck permanently - def hasTimedOut = { - val timedOut = System.currentTimeMillis() - timeWhenStopStarted > stopTimeout + def hasTimedOut: Boolean = { + val timedOut = (System.currentTimeMillis() - timeWhenStopStarted) > stopTimeoutMs if (timedOut) { - logWarning("Timed out while stopping the job generator (timeout = " + stopTimeout + ")") + logWarning("Timed out while stopping the job generator (timeout = " + stopTimeoutMs + ")") } timedOut } @@ -125,7 +133,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { logInfo("Stopped generation timer") // Wait for the jobs to complete and checkpoints to be written - def haveAllBatchesBeenProcessed = { + def haveAllBatchesBeenProcessed: Boolean = { lastProcessedBatch != null && lastProcessedBatch.milliseconds == stopTime } logInfo("Waiting for jobs to be processed and checkpoints to be written") @@ -140,9 +148,9 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { graph.stop() } - // Stop the actor and checkpoint writer + // Stop the event loop and checkpoint writer if (shouldCheckpoint) checkpointWriter.stop() - ssc.env.actorSystem.stop(eventActor) + eventLoop.stop() logInfo("Stopped JobGenerator") } @@ -150,14 +158,16 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { * Callback called when a batch has been completely processed. */ def onBatchCompletion(time: Time) { - eventActor ! ClearMetadata(time) + eventLoop.post(ClearMetadata(time)) } /** * Callback called when the checkpoint of a batch has been written. */ - def onCheckpointCompletion(time: Time) { - eventActor ! ClearCheckpointData(time) + def onCheckpointCompletion(time: Time, clearCheckpointDataLater: Boolean) { + if (clearCheckpointDataLater) { + eventLoop.post(ClearCheckpointData(time)) + } } /** Processes all events */ @@ -166,7 +176,8 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { event match { case GenerateJobs(time) => generateJobs(time) case ClearMetadata(time) => clearMetadata(time) - case DoCheckpoint(time) => doCheckpoint(time) + case DoCheckpoint(time, clearCheckpointDataLater) => + doCheckpoint(time, clearCheckpointDataLater) case ClearCheckpointData(time) => clearCheckpointData(time) } } @@ -238,7 +249,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) } - eventActor ! DoCheckpoint(time) + eventLoop.post(DoCheckpoint(time, clearCheckpointDataLater = false)) } /** Clear DStream metadata for the given `time`. */ @@ -248,7 +259,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // If checkpointing is enabled, then checkpoint, // else mark batch to be fully processed if (shouldCheckpoint) { - eventActor ! DoCheckpoint(time) + eventLoop.post(DoCheckpoint(time, clearCheckpointDataLater = true)) } else { // If checkpointing is not enabled, then delete metadata information about // received blocks (block data not saved in any case). Otherwise, wait for @@ -271,11 +282,11 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { } /** Perform checkpoint for the give `time`. */ - private def doCheckpoint(time: Time) { + private def doCheckpoint(time: Time, clearCheckpointDataLater: Boolean) { if (shouldCheckpoint && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) { logInfo("Checkpointing graph for time " + time) ssc.graph.updateCheckpointData(time) - checkpointWriter.write(new Checkpoint(ssc, time)) + checkpointWriter.write(new Checkpoint(ssc, time), clearCheckpointDataLater) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 0e0f5bd3b9db..508b89278dcb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -17,13 +17,15 @@ package org.apache.spark.streaming.scheduler -import scala.util.{Failure, Success, Try} -import scala.collection.JavaConversions._ import java.util.concurrent.{TimeUnit, ConcurrentHashMap, Executors} -import akka.actor.{ActorRef, Actor, Props} -import org.apache.spark.{SparkException, Logging, SparkEnv} + +import scala.collection.JavaConversions._ +import scala.util.{Failure, Success} + +import org.apache.spark.Logging import org.apache.spark.rdd.PairRDDFunctions import org.apache.spark.streaming._ +import org.apache.spark.util.EventLoop private[scheduler] sealed trait JobSchedulerEvent @@ -46,22 +48,22 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { val listenerBus = new StreamingListenerBus() // These two are created only when scheduler starts. - // eventActor not being null means the scheduler has been started and not stopped + // eventLoop not being null means the scheduler has been started and not stopped var receiverTracker: ReceiverTracker = null - private var eventActor: ActorRef = null - + private var eventLoop: EventLoop[JobSchedulerEvent] = null def start(): Unit = synchronized { - if (eventActor != null) return // scheduler has already been started + if (eventLoop != null) return // scheduler has already been started logDebug("Starting JobScheduler") - eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { - def receive = { - case event: JobSchedulerEvent => processEvent(event) - } - }), "JobScheduler") + eventLoop = new EventLoop[JobSchedulerEvent]("JobScheduler") { + override protected def onReceive(event: JobSchedulerEvent): Unit = processEvent(event) + + override protected def onError(e: Throwable): Unit = reportError("Error in job scheduler", e) + } + eventLoop.start() - listenerBus.start() + listenerBus.start(ssc.sparkContext) receiverTracker = new ReceiverTracker(ssc) receiverTracker.start() jobGenerator.start() @@ -69,11 +71,11 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } def stop(processAllReceivedData: Boolean): Unit = synchronized { - if (eventActor == null) return // scheduler has already been stopped + if (eventLoop == null) return // scheduler has already been stopped logDebug("Stopping JobScheduler") // First, stop receiving - receiverTracker.stop() + receiverTracker.stop(processAllReceivedData) // Second, stop generating jobs. If it has to process all received data, // then this will wait for all the processing through JobScheduler to be over. @@ -96,8 +98,8 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { // Stop everything else listenerBus.stop() - ssc.env.actorSystem.stop(eventActor) - eventActor = null + eventLoop.stop() + eventLoop = null logInfo("Stopped JobScheduler") } @@ -105,6 +107,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { if (jobSet.jobs.isEmpty) { logInfo("No jobs added for time " + jobSet.time) } else { + listenerBus.post(StreamingListenerBatchSubmitted(jobSet.toBatchInfo)) jobSets.put(jobSet.time, jobSet) jobSet.jobs.foreach(job => jobExecutor.execute(new JobHandler(job))) logInfo("Added jobs for time " + jobSet.time) @@ -116,7 +119,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } def reportError(msg: String, e: Throwable) { - eventActor ! ErrorReported(msg, e) + eventLoop.post(ErrorReported(msg, e)) } private def processEvent(event: JobSchedulerEvent) { @@ -134,10 +137,13 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { private def handleJobStart(job: Job) { val jobSet = jobSets.get(job.time) - if (!jobSet.hasStarted) { + val isFirstJobOfJobSet = !jobSet.hasStarted + jobSet.handleJobStart(job) + if (isFirstJobOfJobSet) { + // "StreamingListenerBatchStarted" should be posted after calling "handleJobStart" to get the + // correct "jobSet.processingStartTime". listenerBus.post(StreamingListenerBatchStarted(jobSet.toBatchInfo)) } - jobSet.handleJobStart(job) logInfo("Starting job " + job.id + " from job set of time " + jobSet.time) } @@ -168,14 +174,14 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { private class JobHandler(job: Job) extends Runnable { def run() { - eventActor ! JobStarted(job) + eventLoop.post(JobStarted(job)) // Disable checks for existing output directories in jobs launched by the streaming scheduler, // since we may need to write output to an existing directory during checkpoint recovery; // see SPARK-4835 for more details. PairRDDFunctions.disableOutputSpecValidation.withValue(true) { job.run() } - eventActor ! JobCompleted(job) + eventLoop.post(JobCompleted(job)) } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index 8c15a75b1b0e..5b134877d0b2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -28,8 +28,7 @@ private[streaming] case class JobSet( time: Time, jobs: Seq[Job], - receivedBlockInfo: Map[Int, Array[ReceivedBlockInfo]] = Map.empty - ) { + receivedBlockInfo: Map[Int, Array[ReceivedBlockInfo]] = Map.empty) { private val incompleteJobs = new HashSet[Job]() private val submissionTime = System.currentTimeMillis() // when this jobset was submitted @@ -48,17 +47,17 @@ case class JobSet( if (hasCompleted) processingEndTime = System.currentTimeMillis() } - def hasStarted = processingStartTime > 0 + def hasStarted: Boolean = processingStartTime > 0 - def hasCompleted = incompleteJobs.isEmpty + def hasCompleted: Boolean = incompleteJobs.isEmpty // Time taken to process all the jobs from the time they started processing // (i.e. not including the time they wait in the streaming scheduler queue) - def processingDelay = processingEndTime - processingStartTime + def processingDelay: Long = processingEndTime - processingStartTime // Time taken to process all the jobs from the time they were submitted // (i.e. including the time they wait in the streaming scheduler queue) - def totalDelay = { + def totalDelay: Long = { processingEndTime - time.milliseconds } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index e19ac939f9ac..200cf4ef4b0f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -27,8 +27,8 @@ import org.apache.hadoop.fs.Path import org.apache.spark.{SparkException, Logging, SparkConf} import org.apache.spark.streaming.Time -import org.apache.spark.streaming.util.{Clock, WriteAheadLogManager} -import org.apache.spark.util.Utils +import org.apache.spark.streaming.util.WriteAheadLogManager +import org.apache.spark.util.{Clock, Utils} /** Trait representing any event in the ReceivedBlockTracker that updates its state. */ private[streaming] sealed trait ReceivedBlockTrackerLogEvent @@ -150,7 +150,7 @@ private[streaming] class ReceivedBlockTracker( * returns only after the files are cleaned up. */ def cleanupOldBatches(cleanupThreshTime: Time, waitForCompletion: Boolean): Unit = synchronized { - assert(cleanupThreshTime.milliseconds < clock.currentTime()) + assert(cleanupThreshTime.milliseconds < clock.getTimeMillis()) val timesToCleanup = timeToAllocatedBlocks.keys.filter { _ < cleanupThreshTime }.toSeq logInfo("Deleting batches " + timesToCleanup) writeToLog(BatchCleanupEvent(timesToCleanup)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala index d7e39c528c51..52f08b9c9de6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala @@ -17,8 +17,8 @@ package org.apache.spark.streaming.scheduler -import akka.actor.ActorRef import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rpc.RpcEndpointRef /** * :: DeveloperApi :: @@ -28,7 +28,7 @@ import org.apache.spark.annotation.DeveloperApi case class ReceiverInfo( streamId: Int, name: String, - private[streaming] val actor: ActorRef, + private[streaming] val endpoint: RpcEndpointRef, active: Boolean, location: String, lastErrorMessage: String = "", diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 4f998869731e..c4ead6f30a63 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -17,13 +17,11 @@ package org.apache.spark.streaming.scheduler - import scala.collection.mutable.{HashMap, SynchronizedMap} import scala.language.existentials -import akka.actor._ - import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException} +import org.apache.spark.rpc._ import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, StopReceiver} @@ -36,7 +34,7 @@ private[streaming] case class RegisterReceiver( streamId: Int, typ: String, host: String, - receiverActor: ActorRef + receiverEndpoint: RpcEndpointRef ) extends ReceiverTrackerMessage private[streaming] case class AddBlock(receivedBlockInfo: ReceivedBlockInfo) extends ReceiverTrackerMessage @@ -67,33 +65,33 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false ) private val listenerBus = ssc.scheduler.listenerBus - // actor is created when generator starts. + // endpoint is created when generator starts. // This not being null means the tracker has been started and not stopped - private var actor: ActorRef = null + private var endpoint: RpcEndpointRef = null - /** Start the actor and receiver execution thread. */ - def start() = synchronized { - if (actor != null) { + /** Start the endpoint and receiver execution thread. */ + def start(): Unit = synchronized { + if (endpoint != null) { throw new SparkException("ReceiverTracker already started") } if (!receiverInputStreams.isEmpty) { - actor = ssc.env.actorSystem.actorOf(Props(new ReceiverTrackerActor), - "ReceiverTracker") + endpoint = ssc.env.rpcEnv.setupEndpoint( + "ReceiverTracker", new ReceiverTrackerEndpoint(ssc.env.rpcEnv)) if (!skipReceiverLaunch) receiverExecutor.start() logInfo("ReceiverTracker started") } } /** Stop the receiver execution thread. */ - def stop() = synchronized { - if (!receiverInputStreams.isEmpty && actor != null) { + def stop(graceful: Boolean): Unit = synchronized { + if (!receiverInputStreams.isEmpty && endpoint != null) { // First, stop the receivers - if (!skipReceiverLaunch) receiverExecutor.stop() + if (!skipReceiverLaunch) receiverExecutor.stop(graceful) - // Finally, stop the actor - ssc.env.actorSystem.stop(actor) - actor = null + // Finally, stop the endpoint + ssc.env.rpcEnv.stop(endpoint) + endpoint = null receivedBlockTracker.stop() logInfo("ReceiverTracker stopped") } @@ -129,8 +127,8 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Signal the receivers to delete old block data if (ssc.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) { logInfo(s"Cleanup old received batch data: $cleanupThreshTime") - receiverInfo.values.flatMap { info => Option(info.actor) } - .foreach { _ ! CleanupOldBlocks(cleanupThreshTime) } + receiverInfo.values.flatMap { info => Option(info.endpoint) } + .foreach { _.send(CleanupOldBlocks(cleanupThreshTime)) } } } @@ -139,23 +137,23 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false streamId: Int, typ: String, host: String, - receiverActor: ActorRef, - sender: ActorRef + receiverEndpoint: RpcEndpointRef, + senderAddress: RpcAddress ) { if (!receiverInputStreamIds.contains(streamId)) { throw new SparkException("Register received for unexpected id " + streamId) } receiverInfo(streamId) = ReceiverInfo( - streamId, s"${typ}-${streamId}", receiverActor, true, host) + streamId, s"${typ}-${streamId}", receiverEndpoint, true, host) listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) - logInfo("Registered receiver for stream " + streamId + " from " + sender.path.address) + logInfo("Registered receiver for stream " + streamId + " from " + senderAddress) } /** Deregister a receiver */ private def deregisterReceiver(streamId: Int, message: String, error: String) { val newReceiverInfo = receiverInfo.get(streamId) match { case Some(oldInfo) => - oldInfo.copy(actor = null, active = false, lastErrorMessage = message, lastError = error) + oldInfo.copy(endpoint = null, active = false, lastErrorMessage = message, lastError = error) case None => logWarning("No prior receiver info") ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, lastError = error) @@ -199,25 +197,30 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false receivedBlockTracker.hasUnallocatedReceivedBlocks } - /** Actor to receive messages from the receivers. */ - private class ReceiverTrackerActor extends Actor { - def receive = { - case RegisterReceiver(streamId, typ, host, receiverActor) => - registerReceiver(streamId, typ, host, receiverActor, sender) - sender ! true - case AddBlock(receivedBlockInfo) => - sender ! addBlock(receivedBlockInfo) + /** RpcEndpoint to receive messages from the receivers. */ + private class ReceiverTrackerEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint { + + override def receive: PartialFunction[Any, Unit] = { case ReportError(streamId, message, error) => reportError(streamId, message, error) + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterReceiver(streamId, typ, host, receiverEndpoint) => + registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address) + context.reply(true) + case AddBlock(receivedBlockInfo) => + context.reply(addBlock(receivedBlockInfo)) case DeregisterReceiver(streamId, message, error) => deregisterReceiver(streamId, message, error) - sender ! true + context.reply(true) } } /** This thread class runs all the receivers on the cluster. */ class ReceiverLauncher { @transient val env = ssc.env + @volatile @transient private var running = false @transient val thread = new Thread() { override def run() { try { @@ -233,7 +236,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false thread.start() } - def stop() { + def stop(graceful: Boolean) { // Send the stop signal to all the receivers stopReceivers() @@ -241,9 +244,18 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // That is, for the receivers to quit gracefully. thread.join(10000) + if (graceful) { + val pollTime = 100 + logInfo("Waiting for receiver job to terminate gracefully") + while (receiverInfo.nonEmpty || running) { + Thread.sleep(pollTime) + } + logInfo("Waited for receiver job to terminate gracefully") + } + // Check if all the receivers have been deregistered or not - if (!receiverInfo.isEmpty) { - logWarning("All of the receivers have not deregistered, " + receiverInfo) + if (receiverInfo.nonEmpty) { + logWarning("Not all of the receivers have deregistered, " + receiverInfo) } else { logInfo("All of the receivers have deregistered successfully") } @@ -295,15 +307,17 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Distribute the receivers and start them logInfo("Starting " + receivers.length + " receivers") + running = true ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver)) + running = false logInfo("All of the receivers have been terminated") } /** Stops the receivers. */ private def stopReceivers() { // Signal the receivers to stop - receiverInfo.values.flatMap { info => Option(info.actor)} - .foreach { _ ! StopReceiver } + receiverInfo.values.flatMap { info => Option(info.endpoint)} + .foreach { _.send(StopReceiver) } logInfo("Sent stop signal to all " + receiverInfo.size + " receivers") } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala index ed1aa114e19d..74dbba453f02 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala @@ -50,9 +50,6 @@ case class StreamingListenerReceiverError(receiverInfo: ReceiverInfo) case class StreamingListenerReceiverStopped(receiverInfo: ReceiverInfo) extends StreamingListenerEvent -/** An event used in the listener to shutdown the listener daemon thread. */ -private[scheduler] case object StreamingListenerShutdown extends StreamingListenerEvent - /** * :: DeveloperApi :: * A listener interface for receiving information about an ongoing streaming diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala index 398724d9e813..b07d6cf347ca 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala @@ -17,83 +17,42 @@ package org.apache.spark.streaming.scheduler +import java.util.concurrent.atomic.AtomicBoolean + import org.apache.spark.Logging -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} -import java.util.concurrent.LinkedBlockingQueue +import org.apache.spark.util.AsynchronousListenerBus /** Asynchronously passes StreamingListenerEvents to registered StreamingListeners. */ -private[spark] class StreamingListenerBus() extends Logging { - private val listeners = new ArrayBuffer[StreamingListener]() - with SynchronizedBuffer[StreamingListener] - - /* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than - * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */ - private val EVENT_QUEUE_CAPACITY = 10000 - private val eventQueue = new LinkedBlockingQueue[StreamingListenerEvent](EVENT_QUEUE_CAPACITY) - private var queueFullErrorMessageLogged = false - - val listenerThread = new Thread("StreamingListenerBus") { - setDaemon(true) - override def run() { - while (true) { - val event = eventQueue.take - event match { - case receiverStarted: StreamingListenerReceiverStarted => - listeners.foreach(_.onReceiverStarted(receiverStarted)) - case receiverError: StreamingListenerReceiverError => - listeners.foreach(_.onReceiverError(receiverError)) - case receiverStopped: StreamingListenerReceiverStopped => - listeners.foreach(_.onReceiverStopped(receiverStopped)) - case batchSubmitted: StreamingListenerBatchSubmitted => - listeners.foreach(_.onBatchSubmitted(batchSubmitted)) - case batchStarted: StreamingListenerBatchStarted => - listeners.foreach(_.onBatchStarted(batchStarted)) - case batchCompleted: StreamingListenerBatchCompleted => - listeners.foreach(_.onBatchCompleted(batchCompleted)) - case StreamingListenerShutdown => - // Get out of the while loop and shutdown the daemon thread - return - case _ => - } - } +private[spark] class StreamingListenerBus + extends AsynchronousListenerBus[StreamingListener, StreamingListenerEvent]("StreamingListenerBus") + with Logging { + + private val logDroppedEvent = new AtomicBoolean(false) + + override def onPostEvent(listener: StreamingListener, event: StreamingListenerEvent): Unit = { + event match { + case receiverStarted: StreamingListenerReceiverStarted => + listener.onReceiverStarted(receiverStarted) + case receiverError: StreamingListenerReceiverError => + listener.onReceiverError(receiverError) + case receiverStopped: StreamingListenerReceiverStopped => + listener.onReceiverStopped(receiverStopped) + case batchSubmitted: StreamingListenerBatchSubmitted => + listener.onBatchSubmitted(batchSubmitted) + case batchStarted: StreamingListenerBatchStarted => + listener.onBatchStarted(batchStarted) + case batchCompleted: StreamingListenerBatchCompleted => + listener.onBatchCompleted(batchCompleted) + case _ => } } - def start() { - listenerThread.start() - } - - def addListener(listener: StreamingListener) { - listeners += listener - } - - def post(event: StreamingListenerEvent) { - val eventAdded = eventQueue.offer(event) - if (!eventAdded && !queueFullErrorMessageLogged) { + override def onDropEvent(event: StreamingListenerEvent): Unit = { + if (logDroppedEvent.compareAndSet(false, true)) { + // Only log the following message once to avoid duplicated annoying logs. logError("Dropping StreamingListenerEvent because no remaining room in event queue. " + "This likely means one of the StreamingListeners is too slow and cannot keep up with the " + "rate at which events are being started by the scheduler.") - queueFullErrorMessageLogged = true } } - - /** - * Waits until there are no more events in the queue, or until the specified time has elapsed. - * Used for testing only. Returns true if the queue has emptied and false is the specified time - * elapsed before the queue emptied. - */ - def waitUntilEmpty(timeoutMillis: Int): Boolean = { - val finishTime = System.currentTimeMillis + timeoutMillis - while (!eventQueue.isEmpty) { - if (System.currentTimeMillis > finishTime) { - return false - } - /* Sleep rather than using wait/notify, because this is used only for testing and wait/notify - * add overhead in the general case. */ - Thread.sleep(10) - } - true - } - - def stop(): Unit = post(StreamingListenerShutdown) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala new file mode 100644 index 000000000000..df1c0a10704c --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.ui + +import scala.xml.Node + +import org.apache.spark.streaming.scheduler.BatchInfo +import org.apache.spark.ui.UIUtils + +private[ui] abstract class BatchTableBase(tableId: String) { + + protected def columns: Seq[Node] = { +
    + + + + } + + protected def baseRow(batch: BatchInfo): Seq[Node] = { + val batchTime = batch.batchTime.milliseconds + val formattedBatchTime = UIUtils.formatDate(batch.batchTime.milliseconds) + val eventCount = batch.receivedBlockInfo.values.map { + receivers => receivers.map(_.numRecords).sum + }.sum + val schedulingDelay = batch.schedulingDelay + val formattedSchedulingDelay = schedulingDelay.map(UIUtils.formatDuration).getOrElse("-") + val processingTime = batch.processingDelay + val formattedProcessingTime = processingTime.map(UIUtils.formatDuration).getOrElse("-") + + + + + + } + + private def batchTable: Seq[Node] = { +
    Batch TimeInput SizeScheduling DelayProcessing Time{formattedBatchTime}{eventCount.toString} events + {formattedSchedulingDelay} + + {formattedProcessingTime} +
    + + {columns} + + + {renderRows} + +
    + } + + def toNodeSeq: Seq[Node] = { + batchTable + } + + /** + * Return HTML for all rows of this table. + */ + protected def renderRows: Seq[Node] +} + +private[ui] class ActiveBatchTable(runningBatches: Seq[BatchInfo], waitingBatches: Seq[BatchInfo]) + extends BatchTableBase("active-batches-table") { + + override protected def columns: Seq[Node] = super.columns ++ Status + + override protected def renderRows: Seq[Node] = { + // The "batchTime"s of "waitingBatches" must be greater than "runningBatches"'s, so display + // waiting batches before running batches + waitingBatches.flatMap(batch => {waitingBatchRow(batch)}) ++ + runningBatches.flatMap(batch => {runningBatchRow(batch)}) + } + + private def runningBatchRow(batch: BatchInfo): Seq[Node] = { + baseRow(batch) ++ processing + } + + private def waitingBatchRow(batch: BatchInfo): Seq[Node] = { + baseRow(batch) ++ queued + } +} + +private[ui] class CompletedBatchTable(batches: Seq[BatchInfo]) + extends BatchTableBase("completed-batches-table") { + + override protected def columns: Seq[Node] = super.columns ++ Total Delay + + override protected def renderRows: Seq[Node] = { + batches.flatMap(batch => {completedBatchRow(batch)}) + } + + private def completedBatchRow(batch: BatchInfo): Seq[Node] = { + val totalDelay = batch.totalDelay + val formattedTotalDelay = totalDelay.map(UIUtils.formatDuration).getOrElse("-") + baseRow(batch) ++ + + {formattedTotalDelay} + + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index 5ee53a5c5f56..be1e8686cf9f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -17,9 +17,10 @@ package org.apache.spark.streaming.ui +import scala.collection.mutable.{Queue, HashMap} + import org.apache.spark.streaming.{Time, StreamingContext} import org.apache.spark.streaming.scheduler._ -import scala.collection.mutable.{Queue, HashMap} import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted import org.apache.spark.streaming.scheduler.StreamingListenerBatchStarted import org.apache.spark.streaming.scheduler.BatchInfo @@ -32,7 +33,7 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) private val waitingBatchInfos = new HashMap[Time, BatchInfo] private val runningBatchInfos = new HashMap[Time, BatchInfo] - private val completedaBatchInfos = new Queue[BatchInfo] + private val completedBatchInfos = new Queue[BatchInfo] private val batchInfoLimit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 100) private var totalCompletedBatches = 0L private var totalReceivedRecords = 0L @@ -59,11 +60,13 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } } - override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted) = synchronized { - runningBatchInfos(batchSubmitted.batchInfo.batchTime) = batchSubmitted.batchInfo + override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted): Unit = { + synchronized { + waitingBatchInfos(batchSubmitted.batchInfo.batchTime) = batchSubmitted.batchInfo + } } - override def onBatchStarted(batchStarted: StreamingListenerBatchStarted) = synchronized { + override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = synchronized { runningBatchInfos(batchStarted.batchInfo.batchTime) = batchStarted.batchInfo waitingBatchInfos.remove(batchStarted.batchInfo.batchTime) @@ -72,19 +75,21 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } } - override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) = synchronized { - waitingBatchInfos.remove(batchCompleted.batchInfo.batchTime) - runningBatchInfos.remove(batchCompleted.batchInfo.batchTime) - completedaBatchInfos.enqueue(batchCompleted.batchInfo) - if (completedaBatchInfos.size > batchInfoLimit) completedaBatchInfos.dequeue() - totalCompletedBatches += 1L - - batchCompleted.batchInfo.receivedBlockInfo.foreach { case (_, infos) => - totalProcessedRecords += infos.map(_.numRecords).sum + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = { + synchronized { + waitingBatchInfos.remove(batchCompleted.batchInfo.batchTime) + runningBatchInfos.remove(batchCompleted.batchInfo.batchTime) + completedBatchInfos.enqueue(batchCompleted.batchInfo) + if (completedBatchInfos.size > batchInfoLimit) completedBatchInfos.dequeue() + totalCompletedBatches += 1L + + batchCompleted.batchInfo.receivedBlockInfo.foreach { case (_, infos) => + totalProcessedRecords += infos.map(_.numRecords).sum + } } } - def numReceivers = synchronized { + def numReceivers: Int = synchronized { ssc.graph.getReceiverInputStreams().size } @@ -113,7 +118,7 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } def retainedCompletedBatches: Seq[BatchInfo] = synchronized { - completedaBatchInfos.toSeq + completedBatchInfos.toSeq } def processingDelayDistribution: Option[Distribution] = synchronized { @@ -144,7 +149,7 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) }.toMap } - def lastReceivedBatchRecords: Map[Int, Long] = { + def lastReceivedBatchRecords: Map[Int, Long] = synchronized { val lastReceivedBlockInfoOption = lastReceivedBatch.map(_.receivedBlockInfo) lastReceivedBlockInfoOption.map { lastReceivedBlockInfo => (0 until numReceivers).map { receiverId => @@ -155,24 +160,24 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } } - def receiverInfo(receiverId: Int): Option[ReceiverInfo] = { + def receiverInfo(receiverId: Int): Option[ReceiverInfo] = synchronized { receiverInfos.get(receiverId) } - def lastCompletedBatch: Option[BatchInfo] = { - completedaBatchInfos.sortBy(_.batchTime)(Time.ordering).lastOption + def lastCompletedBatch: Option[BatchInfo] = synchronized { + completedBatchInfos.sortBy(_.batchTime)(Time.ordering).lastOption } - def lastReceivedBatch: Option[BatchInfo] = { + def lastReceivedBatch: Option[BatchInfo] = synchronized { retainedBatches.lastOption } - private def retainedBatches: Seq[BatchInfo] = synchronized { + private def retainedBatches: Seq[BatchInfo] = { (waitingBatchInfos.values.toSeq ++ - runningBatchInfos.values.toSeq ++ completedaBatchInfos).sortBy(_.batchTime)(Time.ordering) + runningBatchInfos.values.toSeq ++ completedBatchInfos).sortBy(_.batchTime)(Time.ordering) } private def extractDistribution(getMetric: BatchInfo => Option[Long]): Option[Distribution] = { - Distribution(completedaBatchInfos.flatMap(getMetric(_)).map(_.toDouble)) + Distribution(completedBatchInfos.flatMap(getMetric(_)).map(_.toDouble)) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 98e9a2e639e2..07fa285642ee 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -32,25 +32,28 @@ private[ui] class StreamingPage(parent: StreamingTab) extends WebUIPage("") with Logging { private val listener = parent.listener - private val startTime = Calendar.getInstance().getTime() + private val startTime = System.currentTimeMillis() private val emptyCell = "-" /** Render the page */ def render(request: HttpServletRequest): Seq[Node] = { - val content = + val content = listener.synchronized { generateBasicStats() ++

    ++

    Statistics over last {listener.retainedCompletedBatches.size} processed batches

    ++ generateReceiverStats() ++ - generateBatchStatsTable() + generateBatchStatsTable() ++ + generateBatchListTables() + } UIUtils.headerSparkPage("Streaming", content, parent, Some(5000)) } /** Generate basic stats of the streaming program */ private def generateBasicStats(): Seq[Node] = { - val timeSinceStart = System.currentTimeMillis() - startTime.getTime + val timeSinceStart = System.currentTimeMillis() - startTime + // scalastyle:off
    • - Started at: {startTime.toString} + Started at: {UIUtils.formatDate(startTime)}
    • Time since start: {formatDurationVerbose(timeSinceStart)} @@ -62,18 +65,19 @@ private[ui] class StreamingPage(parent: StreamingTab) Batch interval: {formatDurationVerbose(listener.batchDuration)}
    • - Processed batches: {listener.numTotalCompletedBatches} + Completed batches: {listener.numTotalCompletedBatches}
    • - Waiting batches: {listener.numUnprocessedBatches} + Active batches: {listener.numUnprocessedBatches}
    • - Received records: {listener.numTotalReceivedRecords} + Received events: {listener.numTotalReceivedRecords}
    • - Processed records: {listener.numTotalProcessedRecords} + Processed events: {listener.numTotalProcessedRecords}
    + // scalastyle:on } /** Generate stats of data received by the receivers in the streaming program */ @@ -85,10 +89,10 @@ private[ui] class StreamingPage(parent: StreamingTab) "Receiver", "Status", "Location", - "Records in last batch\n[" + formatDate(Calendar.getInstance().getTime()) + "]", - "Minimum rate\n[records/sec]", - "Median rate\n[records/sec]", - "Maximum rate\n[records/sec]", + "Events in last batch\n[" + formatDate(Calendar.getInstance().getTime()) + "]", + "Minimum rate\n[events/sec]", + "Median rate\n[events/sec]", + "Maximum rate\n[events/sec]", "Last Error" ) val dataRows = (0 until listener.numReceivers).map { receiverId => @@ -189,5 +193,26 @@ private[ui] class StreamingPage(parent: StreamingTab) } UIUtils.listingTable(headers, generateDataRow, data, fixedWidth = true) } + + private def generateBatchListTables(): Seq[Node] = { + val runningBatches = listener.runningBatches.sortBy(_.batchTime.milliseconds).reverse + val waitingBatches = listener.waitingBatches.sortBy(_.batchTime.milliseconds).reverse + val completedBatches = listener.retainedCompletedBatches. + sortBy(_.batchTime.milliseconds).reverse + + val activeBatchesContent = { +

    Active Batches ({runningBatches.size + waitingBatches.size})

    ++ + new ActiveBatchTable(runningBatches, waitingBatches).toNodeSeq + } + + val completedBatchesContent = { +

    + Completed Batches (last {completedBatches.size} out of {listener.numTotalCompletedBatches}) +

    ++ + new CompletedBatchTable(completedBatches).toNodeSeq + } + + activeBatchesContent ++ completedBatchesContent + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index d9d04cd706a0..9a860ea4a6c6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -36,6 +36,10 @@ private[spark] class StreamingTab(ssc: StreamingContext) ssc.addStreamingListener(listener) attachPage(new StreamingPage(this)) parent.attachTab(this) + + def detach() { + getSparkUI(ssc).detachTab(this) + } } private object StreamingTab { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala deleted file mode 100644 index d6d96d7ba00f..000000000000 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.util - -private[streaming] -trait Clock { - def currentTime(): Long - def waitTillTime(targetTime: Long): Long -} - -private[streaming] -class SystemClock() extends Clock { - - val minPollTime = 25L - - def currentTime(): Long = { - System.currentTimeMillis() - } - - def waitTillTime(targetTime: Long): Long = { - var currentTime = 0L - currentTime = System.currentTimeMillis() - - var waitTime = targetTime - currentTime - if (waitTime <= 0) { - return currentTime - } - - val pollTime = math.max(waitTime / 10.0, minPollTime).toLong - - while (true) { - currentTime = System.currentTimeMillis() - waitTime = targetTime - currentTime - if (waitTime <= 0) { - return currentTime - } - val sleepTime = math.min(waitTime, pollTime) - Thread.sleep(sleepTime) - } - -1 - } -} - -private[streaming] -class ManualClock() extends Clock { - - private var time = 0L - - def currentTime() = this.synchronized { - time - } - - def setTime(timeToSet: Long) = { - this.synchronized { - time = timeToSet - this.notifyAll() - } - } - - def addToTime(timeToAdd: Long) = { - this.synchronized { - time += timeToAdd - this.notifyAll() - } - } - def waitTillTime(targetTime: Long): Long = { - this.synchronized { - while (time < targetTime) { - this.wait(100) - } - } - currentTime() - } -} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala index a73d6f3bf066..4d968f8bfa7a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala @@ -18,9 +18,7 @@ package org.apache.spark.streaming.util import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ import org.apache.spark.util.collection.OpenHashMap -import scala.collection.JavaConversions.mapAsScalaMap private[streaming] object RawTextHelper { @@ -71,7 +69,7 @@ object RawTextHelper { var count = 0 while(data.hasNext) { - value = data.next + value = data.next() if (value != null) { count += 1 if (len == 0) { @@ -108,9 +106,13 @@ object RawTextHelper { } } - def add(v1: Long, v2: Long) = (v1 + v2) + def add(v1: Long, v2: Long): Long = { + v1 + v2 + } - def subtract(v1: Long, v2: Long) = (v1 - v2) + def subtract(v1: Long, v2: Long): Long = { + v1 - v2 + } - def max(v1: Long, v2: Long) = math.max(v1, v2) + def max(v1: Long, v2: Long): Long = math.max(v1, v2) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala index a7850812bd61..ca2f319f174a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala @@ -72,7 +72,8 @@ object RawTextSender extends Logging { } catch { case e: IOException => logError("Client disconnected") - socket.close() + } finally { + socket.close() } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala index 1a616a0434f2..c8eef833eb43 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming.util import org.apache.spark.Logging +import org.apache.spark.util.{Clock, SystemClock} private[streaming] class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: String) @@ -38,7 +39,7 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: * current system time. */ def getStartTime(): Long = { - (math.floor(clock.currentTime.toDouble / period) + 1).toLong * period + (math.floor(clock.getTimeMillis().toDouble / period) + 1).toLong * period } /** @@ -48,7 +49,7 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: * more than current time. */ def getRestartTime(originalStartTime: Long): Long = { - val gap = clock.currentTime - originalStartTime + val gap = clock.getTimeMillis() - originalStartTime (math.floor(gap.toDouble / period).toLong + 1) * period + originalStartTime } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala index 166661b7496d..38a93cc3c9a1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala @@ -19,13 +19,13 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer -import scala.concurrent.duration.Duration import scala.concurrent.{Await, ExecutionContext, Future} +import scala.language.postfixOps import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.Logging -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Clock, SystemClock} import WriteAheadLogManager._ /** @@ -60,7 +60,7 @@ private[streaming] class WriteAheadLogManager( if (callerName.nonEmpty) s" for $callerName" else "" private val threadpoolName = s"WriteAheadLogManager $callerNameTag" implicit private val executionContext = ExecutionContext.fromExecutorService( - Utils.newDaemonFixedThreadPool(1, threadpoolName)) + ThreadUtils.newDaemonSingleThreadExecutor(threadpoolName)) override protected val logName = s"WriteAheadLogManager $callerNameTag" private var currentLogPath: Option[String] = None @@ -82,7 +82,7 @@ private[streaming] class WriteAheadLogManager( var succeeded = false while (!succeeded && failures < maxFailures) { try { - fileSegment = getLogWriter(clock.currentTime).write(byteBuffer) + fileSegment = getLogWriter(clock.getTimeMillis()).write(byteBuffer) succeeded = true } catch { case ex: Exception => diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index d4c40745658c..cb2e8380b493 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -22,10 +22,12 @@ import java.nio.charset.Charset; import java.util.*; +import org.apache.commons.lang.mutable.MutableBoolean; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; + import scala.Tuple2; import org.junit.Assert; @@ -45,6 +47,7 @@ import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.api.java.*; import org.apache.spark.util.Utils; +import org.apache.spark.SparkConf; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; @@ -316,6 +319,7 @@ public void testReduceByWindowWithoutInverse() { testReduceByWindow(false); } + @SuppressWarnings("unchecked") private void testReduceByWindow(boolean withInverse) { List> inputData = Arrays.asList( Arrays.asList(1,2,3), @@ -684,6 +688,7 @@ public void testStreamingContextTransform(){ JavaDStream transformed1 = ssc.transform( listOfDStreams1, new Function2>, Time, JavaRDD>() { + @Override public JavaRDD call(List> listOfRDDs, Time time) { Assert.assertEquals(2, listOfRDDs.size()); return null; @@ -697,6 +702,7 @@ public JavaRDD call(List> listOfRDDs, Time time) { JavaPairDStream> transformed2 = ssc.transformToPair( listOfDStreams2, new Function2>, Time, JavaPairRDD>>() { + @Override public JavaPairRDD> call(List> listOfRDDs, Time time) { Assert.assertEquals(3, listOfRDDs.size()); JavaRDD rdd1 = (JavaRDD)listOfRDDs.get(0); @@ -926,7 +932,7 @@ public void testPairMap() { // Maps pair -> pair of different type public Tuple2 call(Tuple2 in) throws Exception { return in.swap(); } - }); + }); JavaTestUtils.attachTestOutputStream(reversed); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -984,12 +990,12 @@ public void testPairMap2() { // Maps pair -> single JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaDStream reversed = pairStream.map( - new Function, Integer>() { - @Override - public Integer call(Tuple2 in) throws Exception { - return in._2(); - } - }); + new Function, Integer>() { + @Override + public Integer call(Tuple2 in) throws Exception { + return in._2(); + } + }); JavaTestUtils.attachTestOutputStream(reversed); List> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -1120,7 +1126,7 @@ public void testCombineByKey() { JavaPairDStream combined = pairStream.combineByKey( new Function() { - @Override + @Override public Integer call(Integer i) throws Exception { return i; } @@ -1141,14 +1147,14 @@ public void testCountByValue() { Arrays.asList("hello")); List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("world", 1L)), - Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("moon", 1L)), - Arrays.asList( - new Tuple2("hello", 1L))); + Arrays.asList( + new Tuple2("hello", 1L), + new Tuple2("world", 1L)), + Arrays.asList( + new Tuple2("hello", 1L), + new Tuple2("moon", 1L)), + Arrays.asList( + new Tuple2("hello", 1L))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream counted = stream.countByValue(); @@ -1246,17 +1252,17 @@ public void testUpdateStateByKey() { JavaPairDStream updated = pairStream.updateStateByKey( new Function2, Optional, Optional>() { - @Override - public Optional call(List values, Optional state) { - int out = 0; - if (state.isPresent()) { - out = out + state.get(); - } - for (Integer v: values) { - out = out + v; + @Override + public Optional call(List values, Optional state) { + int out = 0; + if (state.isPresent()) { + out = out + state.get(); + } + for (Integer v : values) { + out = out + v; + } + return Optional.of(out); } - return Optional.of(out); - } }); JavaTestUtils.attachTestOutputStream(updated); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1289,17 +1295,17 @@ public void testUpdateStateByKeyWithInitial() { JavaPairDStream updated = pairStream.updateStateByKey( new Function2, Optional, Optional>() { - @Override - public Optional call(List values, Optional state) { - int out = 0; - if (state.isPresent()) { - out = out + state.get(); - } - for (Integer v: values) { - out = out + v; + @Override + public Optional call(List values, Optional state) { + int out = 0; + if (state.isPresent()) { + out = out + state.get(); + } + for (Integer v : values) { + out = out + v; + } + return Optional.of(out); } - return Optional.of(out); - } }, new HashPartitioner(1), initialRDD); JavaTestUtils.attachTestOutputStream(updated); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1325,7 +1331,7 @@ public void testReduceByKeyAndWindowWithInverse() { JavaPairDStream reduceWindowed = pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), - new Duration(2000), new Duration(1000)); + new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(reduceWindowed); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1704,6 +1710,74 @@ public Integer call(String s) throws Exception { Utils.deleteRecursively(tempDir); } + @SuppressWarnings("unchecked") + @Test + public void testContextGetOrCreate() throws InterruptedException { + + final SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("newContext", "true"); + + File emptyDir = Files.createTempDir(); + emptyDir.deleteOnExit(); + StreamingContextSuite contextSuite = new StreamingContextSuite(); + String corruptedCheckpointDir = contextSuite.createCorruptedCheckpoint(); + String checkpointDir = contextSuite.createValidCheckpoint(); + + // Function to create JavaStreamingContext without any output operations + // (used to detect the new context) + final MutableBoolean newContextCreated = new MutableBoolean(false); + Function0 creatingFunc = new Function0() { + public JavaStreamingContext call() { + newContextCreated.setValue(true); + return new JavaStreamingContext(conf, Seconds.apply(1)); + } + }; + + newContextCreated.setValue(false); + ssc = JavaStreamingContext.getOrCreate(emptyDir.getAbsolutePath(), creatingFunc); + Assert.assertTrue("new context not created", newContextCreated.isTrue()); + ssc.stop(); + + newContextCreated.setValue(false); + ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc, + new org.apache.hadoop.conf.Configuration(), true); + Assert.assertTrue("new context not created", newContextCreated.isTrue()); + ssc.stop(); + + newContextCreated.setValue(false); + ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, + new org.apache.hadoop.conf.Configuration()); + Assert.assertTrue("old context not recovered", newContextCreated.isFalse()); + ssc.stop(); + + // Function to create JavaStreamingContext using existing JavaSparkContext + // without any output operations (used to detect the new context) + Function creatingFunc2 = + new Function() { + public JavaStreamingContext call(JavaSparkContext context) { + newContextCreated.setValue(true); + return new JavaStreamingContext(context, Seconds.apply(1)); + } + }; + + JavaSparkContext sc = new JavaSparkContext(conf); + newContextCreated.setValue(false); + ssc = JavaStreamingContext.getOrCreate(emptyDir.getAbsolutePath(), creatingFunc2, sc); + Assert.assertTrue("new context not created", newContextCreated.isTrue()); + ssc.stop(false); + + newContextCreated.setValue(false); + ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc2, sc, true); + Assert.assertTrue("new context not created", newContextCreated.isTrue()); + ssc.stop(false); + + newContextCreated.setValue(false); + ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc2, sc); + Assert.assertTrue("old context not recovered", newContextCreated.isFalse()); + ssc.stop(); + } /* TEST DISABLED: Pending a discussion about checkpoint() semantics with TD @SuppressWarnings("unchecked") @@ -1769,7 +1843,7 @@ public Iterable call(InputStream in) throws IOException { @SuppressWarnings("unchecked") @Test public void testTextFileStream() throws IOException { - File testDir = Utils.createTempDir(System.getProperty("java.io.tmpdir")); + File testDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); List> expected = fileTestPrepare(testDir); JavaDStream input = ssc.textFileStream(testDir.toString()); @@ -1782,7 +1856,7 @@ public void testTextFileStream() throws IOException { @SuppressWarnings("unchecked") @Test public void testFileStream() throws IOException { - File testDir = Utils.createTempDir(System.getProperty("java.io.tmpdir")); + File testDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); List> expected = fileTestPrepare(testDir); JavaPairInputDStream inputStream = ssc.fileStream( @@ -1828,4 +1902,23 @@ private List> fileTestPrepare(File testDir) throws IOException { return expected; } + + @SuppressWarnings("unchecked") + // SPARK-5795: no logic assertions, just testing that intended API invocations compile + private void compileSaveAsJavaAPI(JavaPairDStream pds) { + pds.saveAsNewAPIHadoopFiles( + "", "", LongWritable.class, Text.class, + org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); + pds.saveAsHadoopFiles( + "", "", LongWritable.class, Text.class, + org.apache.hadoop.mapred.SequenceFileOutputFormat.class); + // Checks that a previous common workaround for this API still compiles + pds.saveAsNewAPIHadoopFiles( + "", "", LongWritable.class, Text.class, + (Class) org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); + pds.saveAsHadoopFiles( + "", "", LongWritable.class, Text.class, + (Class) org.apache.hadoop.mapred.SequenceFileOutputFormat.class); + } + } diff --git a/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 1e24da7f5f60..cfedb5a042a3 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -31,7 +31,7 @@ public void setUp() { SparkConf conf = new SparkConf() .setMaster("local[2]") .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties index 9697237bfa1a..75e3b53a093f 100644 --- a/streaming/src/test/resources/log4j.properties +++ b/streaming/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index e8f4a7779ec2..87bc20f79c3c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -22,13 +22,12 @@ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.language.existentials import scala.reflect.ClassTag -import util.ManualClock - import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.SparkContext._ import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.{DStream, WindowedDStream} +import org.apache.spark.util.{Clock, ManualClock} import org.apache.spark.HashPartitioner class BasicOperationsSuite extends TestSuiteBase { @@ -172,7 +171,9 @@ class BasicOperationsSuite extends TestSuiteBase { test("flatMapValues") { testOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), - (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _).flatMapValues(x => Seq(x, x + 10)), + (s: DStream[String]) => { + s.map(x => (x, 1)).reduceByKey(_ + _).flatMapValues(x => Seq(x, x + 10)) + }, Seq( Seq(("a", 2), ("a", 12), ("b", 1), ("b", 11)), Seq(("", 2), ("", 12)), Seq() ), true ) @@ -475,7 +476,7 @@ class BasicOperationsSuite extends TestSuiteBase { stream.foreachRDD(_ => {}) // Dummy output stream ssc.start() Thread.sleep(2000) - def getInputFromSlice(fromMillis: Long, toMillis: Long) = { + def getInputFromSlice(fromMillis: Long, toMillis: Long): Set[Int] = { stream.slice(new Time(fromMillis), new Time(toMillis)).flatMap(_.collect()).toSet } @@ -586,7 +587,7 @@ class BasicOperationsSuite extends TestSuiteBase { for (i <- 0 until input.size) { testServer.send(input(i).toString + "\n") Thread.sleep(200) - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) collectRddInfo() } @@ -637,8 +638,8 @@ class BasicOperationsSuite extends TestSuiteBase { ssc.graph.getOutputStreams().head.dependencies.head.asInstanceOf[DStream[T]] if (rememberDuration != null) ssc.remember(rememberDuration) val output = runStreams[(Int, Int)](ssc, cleanupTestInput.size, numExpectedOutput) - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - assert(clock.currentTime() === Seconds(10).milliseconds) + val clock = ssc.scheduler.clock.asInstanceOf[Clock] + assert(clock.getTimeMillis() === Seconds(10).milliseconds) assert(output.size === numExpectedOutput) operatedStream } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 8f8bc61437ba..6b0a3f91d4d0 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -32,8 +32,7 @@ import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutput import org.scalatest.concurrent.Eventually._ import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} -import org.apache.spark.streaming.util.ManualClock -import org.apache.spark.util.Utils +import org.apache.spark.util.{Clock, ManualClock, Utils} /** * This test suites tests the checkpointing functionality of DStreams - @@ -44,7 +43,7 @@ class CheckpointSuite extends TestSuiteBase { var ssc: StreamingContext = null - override def batchDuration = Milliseconds(500) + override def batchDuration: Duration = Milliseconds(500) override def beforeFunction() { super.beforeFunction() @@ -61,7 +60,7 @@ class CheckpointSuite extends TestSuiteBase { assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second") - conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + conf.set("spark.streaming.clock", "org.apache.spark.util.ManualClock") val stateStreamCheckpointInterval = Seconds(1) val fs = FileSystem.getLocal(new Configuration()) @@ -73,7 +72,7 @@ class CheckpointSuite extends TestSuiteBase { val input = (1 to 10).map(_ => Seq("a")).toSeq val operation = (st: DStream[String]) => { val updateFunc = (values: Seq[Int], state: Option[Int]) => { - Some((values.sum + state.getOrElse(0))) + Some(values.sum + state.getOrElse(0)) } st.map(x => (x, 1)) .updateStateByKey(updateFunc) @@ -147,7 +146,7 @@ class CheckpointSuite extends TestSuiteBase { // This tests whether spark conf persists through checkpoints, and certain // configs gets scrubbed - test("persistence of conf through checkpoints") { + test("recovery of conf through checkpoints") { val key = "spark.mykey" val value = "myvalue" System.setProperty(key, value) @@ -155,7 +154,7 @@ class CheckpointSuite extends TestSuiteBase { val originalConf = ssc.conf val cp = new Checkpoint(ssc, Time(1000)) - val cpConf = cp.sparkConf + val cpConf = cp.createSparkConf() assert(cpConf.get("spark.master") === originalConf.get("spark.master")) assert(cpConf.get("spark.app.name") === originalConf.get("spark.app.name")) assert(cpConf.get(key) === value) @@ -164,7 +163,8 @@ class CheckpointSuite extends TestSuiteBase { // Serialize/deserialize to simulate write to storage and reading it back val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp)) - val newCpConf = newCp.sparkConf + // Verify new SparkConf has all the previous properties + val newCpConf = newCp.createSparkConf() assert(newCpConf.get("spark.master") === originalConf.get("spark.master")) assert(newCpConf.get("spark.app.name") === originalConf.get("spark.app.name")) assert(newCpConf.get(key) === value) @@ -175,6 +175,20 @@ class CheckpointSuite extends TestSuiteBase { ssc = new StreamingContext(null, newCp, null) val restoredConf = ssc.conf assert(restoredConf.get(key) === value) + ssc.stop() + + // Verify new SparkConf picks up new master url if it is set in the properties. See SPARK-6331. + try { + val newMaster = "local[100]" + System.setProperty("spark.master", newMaster) + val newCpConf = newCp.createSparkConf() + assert(newCpConf.get("spark.master") === newMaster) + assert(newCpConf.get("spark.app.name") === originalConf.get("spark.app.name")) + ssc = new StreamingContext(null, newCp, null) + assert(ssc.sparkContext.master === newMaster) + } finally { + System.clearProperty("spark.master") + } } @@ -185,7 +199,12 @@ class CheckpointSuite extends TestSuiteBase { testCheckpointedOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ), (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), - Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), Seq() ), 3 ) } @@ -198,7 +217,8 @@ class CheckpointSuite extends TestSuiteBase { val n = 10 val w = 4 val input = (1 to n).map(_ => Seq("a")).toSeq - val output = Seq(Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4))) + val output = Seq( + Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4))) val operation = (st: DStream[String]) => { st.map(x => (x, 1)) .reduceByKeyAndWindow(_ + _, _ - _, batchDuration * w, batchDuration) @@ -208,7 +228,7 @@ class CheckpointSuite extends TestSuiteBase { } test("recovery with saveAsHadoopFiles operation") { - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() try { testCheckpointedOperation( Seq(Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq()), @@ -222,7 +242,13 @@ class CheckpointSuite extends TestSuiteBase { classOf[TextOutputFormat[Text, IntWritable]]) output }, - Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq()), 3 ) } finally { @@ -231,7 +257,7 @@ class CheckpointSuite extends TestSuiteBase { } test("recovery with saveAsNewAPIHadoopFiles operation") { - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() try { testCheckpointedOperation( Seq(Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq()), @@ -245,7 +271,13 @@ class CheckpointSuite extends TestSuiteBase { classOf[NewTextOutputFormat[Text, IntWritable]]) output }, - Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq()), 3 ) } finally { @@ -269,7 +301,7 @@ class CheckpointSuite extends TestSuiteBase { // // After SPARK-5079 is addressed, should be able to remove this test since a strengthened // version of the other saveAsHadoopFile* tests would prevent regressions for this issue. - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() try { testCheckpointedOperation( Seq(Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq()), @@ -284,7 +316,13 @@ class CheckpointSuite extends TestSuiteBase { output } }, - Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq()), 3 ) } finally { @@ -324,13 +362,13 @@ class CheckpointSuite extends TestSuiteBase { * Writes a file named `i` (which contains the number `i`) to the test directory and sets its * modification time to `clock`'s current time. */ - def writeFile(i: Int, clock: ManualClock): Unit = { + def writeFile(i: Int, clock: Clock): Unit = { val file = new File(testDir, i.toString) Files.write(i + "\n", file, Charsets.UTF_8) - assert(file.setLastModified(clock.currentTime())) + assert(file.setLastModified(clock.getTimeMillis())) // Check that the file's modification date is actually the value we wrote, since rounding or // truncation will break the test: - assert(file.lastModified() === clock.currentTime()) + assert(file.lastModified() === clock.getTimeMillis()) } /** @@ -372,13 +410,13 @@ class CheckpointSuite extends TestSuiteBase { ssc.start() // Advance half a batch so that the first file is created after the StreamingContext starts - clock.addToTime(batchDuration.milliseconds / 2) + clock.advance(batchDuration.milliseconds / 2) // Create files and advance manual clock to process them for (i <- Seq(1, 2, 3)) { writeFile(i, clock) // Advance the clock after creating the file to avoid a race when // setting its modification time - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) if (i != 3) { // Since we want to shut down while the 3rd batch is processing eventually(eventuallyTimeout) { @@ -386,15 +424,14 @@ class CheckpointSuite extends TestSuiteBase { } } } - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) eventually(eventuallyTimeout) { // Wait until all files have been recorded and all batches have started assert(recordedFiles(ssc) === Seq(1, 2, 3) && batchCounter.getNumStartedBatches === 3) } // Wait for a checkpoint to be written - val fs = new Path(checkpointDir).getFileSystem(ssc.sc.hadoopConfiguration) eventually(eventuallyTimeout) { - assert(Checkpoint.getCheckpointFiles(checkpointDir, fs).size === 6) + assert(Checkpoint.getCheckpointFiles(checkpointDir).size === 6) } ssc.stop() // Check that we shut down while the third batch was being processed @@ -410,7 +447,7 @@ class CheckpointSuite extends TestSuiteBase { writeFile(i, clock) // Advance the clock after creating the file to avoid a race when // setting its modification time - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) } // Recover context from checkpoint file and verify whether the files that were @@ -419,7 +456,7 @@ class CheckpointSuite extends TestSuiteBase { withStreamingContext(new StreamingContext(checkpointDir)) { ssc => // So that the restarted StreamingContext's clock has gone forward in time since failure ssc.conf.set("spark.streaming.manualClock.jump", (batchDuration * 3).milliseconds.toString) - val oldClockTime = clock.currentTime() + val oldClockTime = clock.getTimeMillis() clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val batchCounter = new BatchCounter(ssc) val outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[Int]] @@ -430,7 +467,7 @@ class CheckpointSuite extends TestSuiteBase { ssc.start() // Verify that the clock has traveled forward to the expected time eventually(eventuallyTimeout) { - clock.currentTime() === oldClockTime + clock.getTimeMillis() === oldClockTime } // Wait for pre-failure batch to be recomputed (3 while SSC was down plus last batch) val numBatchesAfterRestart = 4 @@ -441,12 +478,12 @@ class CheckpointSuite extends TestSuiteBase { writeFile(i, clock) // Advance the clock after creating the file to avoid a race when // setting its modification time - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) eventually(eventuallyTimeout) { assert(batchCounter.getNumCompletedBatches === index + numBatchesAfterRestart + 1) } } - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) logInfo("Output after restart = " + outputStream.output.mkString("[", ", ", "]")) assert(outputStream.output.size > 0, "No files processed after restart") ssc.stop() @@ -519,17 +556,18 @@ class CheckpointSuite extends TestSuiteBase { * Advances the manual clock on the streaming scheduler by given number of batches. * It also waits for the expected amount of time for each batch. */ - def advanceTimeWithRealDelay[V: ClassTag](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = { + def advanceTimeWithRealDelay[V: ClassTag](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = + { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - logInfo("Manual clock before advancing = " + clock.currentTime()) + logInfo("Manual clock before advancing = " + clock.getTimeMillis()) for (i <- 1 to numBatches.toInt) { - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) Thread.sleep(batchDuration.milliseconds) } - logInfo("Manual clock after advancing = " + clock.currentTime()) + logInfo("Manual clock after advancing = " + clock.getTimeMillis()) Thread.sleep(batchDuration.milliseconds) - val outputStream = ssc.graph.getOutputStreams.filter { dstream => + val outputStream = ssc.graph.getOutputStreams().filter { dstream => dstream.isInstanceOf[TestOutputStreamWithPartitions[V]] }.head.asInstanceOf[TestOutputStreamWithPartitions[V]] outputStream.output.map(_.flatten) @@ -538,4 +576,4 @@ class CheckpointSuite extends TestSuiteBase { private object CheckpointSuite extends Serializable { var batchThreeShouldBlockIndefinitely: Boolean = true -} \ No newline at end of file +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala index 40434b1f9b70..0c4c06534a69 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala @@ -20,37 +20,30 @@ package org.apache.spark.streaming import org.apache.spark.Logging import org.apache.spark.util.Utils -import java.io.File - /** * This testsuite tests master failures at random times while the stream is running using * the real clock. */ class FailureSuite extends TestSuiteBase with Logging { - var directory = "FailureSuite" + val directory = Utils.createTempDir() val numBatches = 30 - override def batchDuration = Milliseconds(1000) - - override def useManualClock = false + override def batchDuration: Duration = Milliseconds(1000) - override def beforeFunction() { - super.beforeFunction() - Utils.deleteRecursively(new File(directory)) - } + override def useManualClock: Boolean = false override def afterFunction() { + Utils.deleteRecursively(directory) super.afterFunction() - Utils.deleteRecursively(new File(directory)) } test("multiple failures with map") { - MasterFailureTest.testMap(directory, numBatches, batchDuration) + MasterFailureTest.testMap(directory.getAbsolutePath, numBatches, batchDuration) } test("multiple failures with updateStateByKey") { - MasterFailureTest.testUpdateStateByKey(directory, numBatches, batchDuration) + MasterFailureTest.testUpdateStateByKey(directory.getAbsolutePath, numBatches, batchDuration) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index bddf51e13042..e6ac4975c5e6 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -17,12 +17,8 @@ package org.apache.spark.streaming -import akka.actor.Actor -import akka.actor.Props -import akka.util.ByteString - import java.io.{File, BufferedWriter, OutputStreamWriter} -import java.net.{InetSocketAddress, SocketException, ServerSocket} +import java.net.{SocketException, ServerSocket} import java.nio.charset.Charset import java.util.concurrent.{Executors, TimeUnit, ArrayBlockingQueue} import java.util.concurrent.atomic.AtomicInteger @@ -36,9 +32,8 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.util.ManualClock -import org.apache.spark.util.Utils -import org.apache.spark.streaming.receiver.{ActorHelper, Receiver} +import org.apache.spark.util.{ManualClock, Utils} +import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.rdd.RDD import org.apache.hadoop.io.{Text, LongWritable} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat @@ -57,7 +52,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { "localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] val outputStream = new TestOutputStream(networkStream, outputBuffer) - def output = outputBuffer.flatMap(x => x) + def output: ArrayBuffer[String] = outputBuffer.flatMap(x => x) outputStream.register() ssc.start() @@ -69,7 +64,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { for (i <- 0 until input.size) { testServer.send(input(i).toString + "\n") Thread.sleep(500) - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) } Thread.sleep(1000) logInfo("Stopping server") @@ -95,6 +90,57 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } } + test("binary records stream") { + val testDir: File = null + try { + val batchDuration = Seconds(2) + val testDir = Utils.createTempDir() + // Create a file that exists before the StreamingContext is created: + val existingFile = new File(testDir, "0") + Files.write("0\n", existingFile, Charset.forName("UTF-8")) + assert(existingFile.setLastModified(10000) && existingFile.lastModified === 10000) + + // Set up the streaming context and input streams + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + // This `setTime` call ensures that the clock is past the creation time of `existingFile` + clock.setTime(existingFile.lastModified + batchDuration.milliseconds) + val batchCounter = new BatchCounter(ssc) + val fileStream = ssc.binaryRecordsStream(testDir.toString, 1) + val outputBuffer = new ArrayBuffer[Seq[Array[Byte]]] + with SynchronizedBuffer[Seq[Array[Byte]]] + val outputStream = new TestOutputStream(fileStream, outputBuffer) + outputStream.register() + ssc.start() + + // Advance the clock so that the files are created after StreamingContext starts, but + // not enough to trigger a batch + clock.advance(batchDuration.milliseconds / 2) + + val input = Seq(1, 2, 3, 4, 5) + input.foreach { i => + Thread.sleep(batchDuration.milliseconds) + val file = new File(testDir, i.toString) + Files.write(Array[Byte](i.toByte), file) + assert(file.setLastModified(clock.getTimeMillis())) + assert(file.lastModified === clock.getTimeMillis()) + logInfo("Created file " + file) + // Advance the clock after creating the file to avoid a race when + // setting its modification time + clock.advance(batchDuration.milliseconds) + eventually(eventuallyTimeout) { + assert(batchCounter.getNumCompletedBatches === i) + } + } + + val expectedOutput = input.map(i => i.toByte) + val obtainedOutput = outputBuffer.flatten.toList.map(i => i(0).toByte) + assert(obtainedOutput === expectedOutput) + } + } finally { + if (testDir != null) Utils.deleteRecursively(testDir) + } + } test("file input stream - newFilesOnly = true") { testFileStream(newFilesOnly = true) @@ -118,7 +164,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val countStream = networkStream.count val outputBuffer = new ArrayBuffer[Seq[Long]] with SynchronizedBuffer[Seq[Long]] val outputStream = new TestOutputStream(countStream, outputBuffer) - def output = outputBuffer.flatMap(x => x) + def output: ArrayBuffer[Long] = outputBuffer.flatMap(x => x) outputStream.register() ssc.start() @@ -128,7 +174,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { while((!MultiThreadTestReceiver.haveAllThreadsFinished || output.sum < numTotalRecords) && System.currentTimeMillis() - startTime < 5000) { Thread.sleep(100) - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) } Thread.sleep(1000) logInfo("Stopping context") @@ -150,7 +196,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val queueStream = ssc.queueStream(queue, oneAtATime = true) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] val outputStream = new TestOutputStream(queueStream, outputBuffer) - def output = outputBuffer.filter(_.size > 0) + def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) outputStream.register() ssc.start() @@ -158,12 +204,12 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val input = Seq("1", "2", "3", "4", "5") val expectedOutput = input.map(Seq(_)) - //Thread.sleep(1000) + val inputIterator = input.toIterator for (i <- 0 until input.size) { // Enqueue more than 1 item per tick but they should dequeue one at a time inputIterator.take(2).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) } Thread.sleep(1000) logInfo("Stopping context") @@ -193,7 +239,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val queueStream = ssc.queueStream(queue, oneAtATime = false) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] val outputStream = new TestOutputStream(queueStream, outputBuffer) - def output = outputBuffer.filter(_.size > 0) + def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) outputStream.register() ssc.start() @@ -205,12 +251,12 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Enqueue the first 3 items (one by one), they should be merged in the next batch val inputIterator = input.toIterator inputIterator.take(3).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) Thread.sleep(1000) // Enqueue the remaining items (again one by one), merged in the final batch inputIterator.foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) Thread.sleep(1000) logInfo("Stopping context") ssc.stop() @@ -257,19 +303,19 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Advance the clock so that the files are created after StreamingContext starts, but // not enough to trigger a batch - clock.addToTime(batchDuration.milliseconds / 2) + clock.advance(batchDuration.milliseconds / 2) // Over time, create files in the directory val input = Seq(1, 2, 3, 4, 5) input.foreach { i => val file = new File(testDir, i.toString) Files.write(i + "\n", file, Charset.forName("UTF-8")) - assert(file.setLastModified(clock.currentTime())) - assert(file.lastModified === clock.currentTime) + assert(file.setLastModified(clock.getTimeMillis())) + assert(file.lastModified === clock.getTimeMillis()) logInfo("Created file " + file) // Advance the clock after creating the file to avoid a race when // setting its modification time - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) eventually(eventuallyTimeout) { assert(batchCounter.getNumCompletedBatches === i) } @@ -306,7 +352,8 @@ class TestServer(portToBind: Int = 0) extends Logging { logInfo("New connection") try { clientSocket.setTcpNoDelay(true) - val outputStream = new BufferedWriter(new OutputStreamWriter(clientSocket.getOutputStream)) + val outputStream = new BufferedWriter( + new OutputStreamWriter(clientSocket.getOutputStream)) while(clientSocket.isConnected) { val msg = queue.poll(100, TimeUnit.MILLISECONDS) @@ -338,7 +385,7 @@ class TestServer(portToBind: Int = 0) extends Logging { def stop() { servingThread.interrupt() } - def port = serverSocket.getLocalPort + def port: Int = serverSocket.getLocalPort } /** This is a receiver to test multiple threads inserting data using block generator */ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 132ff2443fc0..c090eaec2928 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -24,22 +24,20 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.postfixOps -import akka.actor.{ActorSystem, Props} -import com.google.common.io.Files -import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage._ import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.util._ -import org.apache.spark.util.AkkaUtils +import org.apache.spark.util.{ManualClock, Utils} import WriteAheadLogBasedBlockHandler._ import WriteAheadLogSuite._ @@ -56,27 +54,24 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche val manualClock = new ManualClock val blockManagerSize = 10000000 - var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var blockManagerMaster: BlockManagerMaster = null var blockManager: BlockManager = null var tempDirectory: File = null before { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem - conf.set("spark.driver.port", boundPort.toString) + rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) + conf.set("spark.driver.port", rpcEnv.address.port.toString) - blockManagerMaster = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf, true) + blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", + new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) - blockManager = new BlockManager("bm", actorSystem, blockManagerMaster, serializer, + blockManager = new BlockManager("bm", rpcEnv, blockManagerMaster, serializer, blockManagerSize, conf, mapOutputTracker, shuffleManager, new NioBlockTransferService(conf, securityMgr), securityMgr, 0) blockManager.initialize("app-id") - tempDirectory = Files.createTempDir() + tempDirectory = Utils.createTempDir() manualClock.setTime(0) } @@ -89,14 +84,11 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche blockManagerMaster.stop() blockManagerMaster = null } - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null + rpcEnv.shutdown() + rpcEnv.awaitTermination() + rpcEnv = null - if (tempDirectory != null && tempDirectory.exists()) { - FileUtils.deleteDirectory(tempDirectory) - tempDirectory = null - } + Utils.deleteRecursively(tempDirectory) } test("BlockManagerBasedBlockHandler - store blocks") { @@ -104,7 +96,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche testBlockStoring(handler) { case (data, blockIds, storeResults) => // Verify the data in block manager is correct val storedData = blockIds.flatMap { blockId => - blockManager.getLocal(blockId).map { _.data.map {_.toString}.toList }.getOrElse(List.empty) + blockManager.getLocal(blockId).map(_.data.map(_.toString).toList).getOrElse(List.empty) }.toList storedData shouldEqual data @@ -128,7 +120,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche testBlockStoring(handler) { case (data, blockIds, storeResults) => // Verify the data in block manager is correct val storedData = blockIds.flatMap { blockId => - blockManager.getLocal(blockId).map { _.data.map {_.toString}.toList }.getOrElse(List.empty) + blockManager.getLocal(blockId).map(_.data.map(_.toString).toList).getOrElse(List.empty) }.toList storedData shouldEqual data @@ -165,7 +157,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche preCleanupLogFiles.size should be > 1 // this depends on the number of blocks inserted using generateAndStoreData() - manualClock.currentTime() shouldEqual 5000L + manualClock.getTimeMillis() shouldEqual 5000L val cleanupThreshTime = 3000L handler.cleanupOldBlocks(cleanupThreshTime) @@ -243,7 +235,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche val blockIds = Seq.fill(blocks.size)(generateBlockId()) val storeResults = blocks.zip(blockIds).map { case (block, id) => - manualClock.addToTime(500) // log rolling interval set to 1000 ms through SparkConf + manualClock.advance(500) // log rolling interval set to 1000 ms through SparkConf logDebug("Inserting block " + id) receivedBlockHandler.storeBlock(id, block) }.toList diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index fbb7b0bfebaf..b63b37d9f9ce 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -24,8 +24,6 @@ import scala.concurrent.duration._ import scala.language.{implicitConversions, postfixOps} import scala.util.Random -import com.google.common.io.Files -import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} import org.scalatest.concurrent.Eventually._ @@ -34,9 +32,9 @@ import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult import org.apache.spark.streaming.scheduler._ -import org.apache.spark.streaming.util.{Clock, ManualClock, SystemClock, WriteAheadLogReader} +import org.apache.spark.streaming.util.WriteAheadLogReader import org.apache.spark.streaming.util.WriteAheadLogSuite._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} class ReceivedBlockTrackerSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { @@ -51,15 +49,12 @@ class ReceivedBlockTrackerSuite before { conf = new SparkConf().setMaster("local[2]").setAppName("ReceivedBlockTrackerSuite") - checkpointDirectory = Files.createTempDir() + checkpointDirectory = Utils.createTempDir() } after { allReceivedBlockTrackers.foreach { _.stop() } - if (checkpointDirectory != null && checkpointDirectory.exists()) { - FileUtils.deleteDirectory(checkpointDirectory) - checkpointDirectory = null - } + Utils.deleteRecursively(checkpointDirectory) } test("block addition, and block to batch allocation") { @@ -100,7 +95,7 @@ class ReceivedBlockTrackerSuite def incrementTime() { val timeIncrementMillis = 2000L - manualClock.addToTime(timeIncrementMillis) + manualClock.advance(timeIncrementMillis) } // Generate and add blocks to the given tracker @@ -138,13 +133,13 @@ class ReceivedBlockTrackerSuite tracker2.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1 // Allocate blocks to batch and verify whether the unallocated blocks got allocated - val batchTime1 = manualClock.currentTime + val batchTime1 = manualClock.getTimeMillis() tracker2.allocateBlocksToBatch(batchTime1) tracker2.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual blockInfos1 // Add more blocks and allocate to another batch incrementTime() - val batchTime2 = manualClock.currentTime + val batchTime2 = manualClock.getTimeMillis() val blockInfos2 = addBlockInfos(tracker2) tracker2.allocateBlocksToBatch(batchTime2) tracker2.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2 @@ -233,7 +228,8 @@ class ReceivedBlockTrackerSuite * Get all the data written in the given write ahead log files. By default, it will read all * files in the test log directory. */ - def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles): Seq[ReceivedBlockTrackerLogEvent] = { + def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles) + : Seq[ReceivedBlockTrackerLogEvent] = { logFiles.flatMap { file => new WriteAheadLogReader(file, hadoopConf).toSeq }.map { byteBuffer => @@ -249,7 +245,8 @@ class ReceivedBlockTrackerSuite } /** Create batch allocation object from the given info */ - def createBatchAllocation(time: Long, blockInfos: Seq[ReceivedBlockInfo]): BatchAllocationEvent = { + def createBatchAllocation(time: Long, blockInfos: Seq[ReceivedBlockInfo]) + : BatchAllocationEvent = { BatchAllocationEvent(time, AllocatedBlocks(Map((streamId -> blockInfos)))) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index e8c34a9ee40b..b84129fd70dd 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -24,7 +24,6 @@ import java.util.concurrent.Semaphore import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import com.google.common.io.Files import org.scalatest.concurrent.Timeouts import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ @@ -34,6 +33,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._ +import org.apache.spark.util.Utils /** Testsuite for testing the network receiver behavior */ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { @@ -131,11 +131,11 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { test("block generator") { val blockGeneratorListener = new FakeBlockGeneratorListener - val blockInterval = 200 - val conf = new SparkConf().set("spark.streaming.blockInterval", blockInterval.toString) + val blockIntervalMs = 200 + val conf = new SparkConf().set("spark.streaming.blockInterval", s"${blockIntervalMs}ms") val blockGenerator = new BlockGenerator(blockGeneratorListener, 1, conf) val expectedBlocks = 5 - val waitTime = expectedBlocks * blockInterval + (blockInterval / 2) + val waitTime = expectedBlocks * blockIntervalMs + (blockIntervalMs / 2) val generatedData = new ArrayBuffer[Int] // Generate blocks @@ -155,17 +155,17 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { assert(recordedData.toSet === generatedData.toSet) } - test("block generator throttling") { + ignore("block generator throttling") { val blockGeneratorListener = new FakeBlockGeneratorListener - val blockInterval = 100 - val maxRate = 100 - val conf = new SparkConf().set("spark.streaming.blockInterval", blockInterval.toString). + val blockIntervalMs = 100 + val maxRate = 1001 + val conf = new SparkConf().set("spark.streaming.blockInterval", s"${blockIntervalMs}ms"). set("spark.streaming.receiver.maxRate", maxRate.toString) val blockGenerator = new BlockGenerator(blockGeneratorListener, 1, conf) val expectedBlocks = 20 - val waitTime = expectedBlocks * blockInterval + val waitTime = expectedBlocks * blockIntervalMs val expectedMessages = maxRate * waitTime / 1000 - val expectedMessagesPerBlock = maxRate * blockInterval / 1000 + val expectedMessagesPerBlock = maxRate * blockIntervalMs / 1000 val generatedData = new ArrayBuffer[Int] // Generate blocks @@ -176,7 +176,6 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { blockGenerator.addData(count) generatedData += count count += 1 - Thread.sleep(1) } blockGenerator.stop() @@ -185,25 +184,31 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { assert(blockGeneratorListener.arrayBuffers.size > 0, "No blocks received") assert(recordedData.toSet === generatedData.toSet, "Received data not same") - // recordedData size should be close to the expected rate - val minExpectedMessages = expectedMessages - 3 - val maxExpectedMessages = expectedMessages + 1 + // recordedData size should be close to the expected rate; use an error margin proportional to + // the value, so that rate changes don't cause a brittle test + val minExpectedMessages = expectedMessages - 0.05 * expectedMessages + val maxExpectedMessages = expectedMessages + 0.05 * expectedMessages val numMessages = recordedData.size assert( numMessages >= minExpectedMessages && numMessages <= maxExpectedMessages, s"#records received = $numMessages, not between $minExpectedMessages and $maxExpectedMessages" ) - val minExpectedMessagesPerBlock = expectedMessagesPerBlock - 3 - val maxExpectedMessagesPerBlock = expectedMessagesPerBlock + 1 + // XXX Checking every block would require an even distribution of messages across blocks, + // which throttling code does not control. Therefore, test against the average. + val minExpectedMessagesPerBlock = expectedMessagesPerBlock - 0.05 * expectedMessagesPerBlock + val maxExpectedMessagesPerBlock = expectedMessagesPerBlock + 0.05 * expectedMessagesPerBlock val receivedBlockSizes = recordedBlocks.map { _.size }.mkString(",") + + // the first and last block may be incomplete, so we slice them out + val validBlocks = recordedBlocks.drop(1).dropRight(1) + val averageBlockSize = validBlocks.map(block => block.size).sum / validBlocks.size + assert( - // the first and last block may be incomplete, so we slice them out - recordedBlocks.drop(1).dropRight(1).forall { block => - block.size >= minExpectedMessagesPerBlock && block.size <= maxExpectedMessagesPerBlock - }, + averageBlockSize >= minExpectedMessagesPerBlock && + averageBlockSize <= maxExpectedMessagesPerBlock, s"# records in received blocks = [$receivedBlockSizes], not between " + - s"$minExpectedMessagesPerBlock and $maxExpectedMessagesPerBlock" + s"$minExpectedMessagesPerBlock and $maxExpectedMessagesPerBlock, on average" ) } @@ -222,7 +227,7 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { .set("spark.streaming.receiver.writeAheadLog.enable", "true") .set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1") val batchDuration = Milliseconds(500) - val tempDirectory = Files.createTempDir() + val tempDirectory = Utils.createTempDir() val logDirectory1 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 0)) val logDirectory2 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 1)) val allLogFiles1 = new mutable.HashSet[String]() @@ -251,7 +256,6 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { } withStreamingContext(new StreamingContext(sparkConf, batchDuration)) { ssc => - tempDirectory.deleteOnExit() val receiver1 = ssc.sparkContext.clean(new FakeReceiver(sendData = true)) val receiver2 = ssc.sparkContext.clean(new FakeReceiver(sendData = true)) val receiverStream1 = ssc.receiverStream(receiver1) @@ -309,7 +313,7 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { val errors = new ArrayBuffer[Throwable] /** Check if all data structures are clean */ - def isAllEmpty = { + def isAllEmpty: Boolean = { singles.isEmpty && byteBuffers.isEmpty && iterators.isEmpty && arrayBuffers.isEmpty && errors.isEmpty } @@ -321,24 +325,21 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { def pushBytes( bytes: ByteBuffer, optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] - ) { + optionalBlockId: Option[StreamBlockId]) { byteBuffers += bytes } def pushIterator( iterator: Iterator[_], optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] - ) { + optionalBlockId: Option[StreamBlockId]) { iterators += iterator } def pushArrayBuffer( arrayBuffer: ArrayBuffer[_], optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] - ) { + optionalBlockId: Option[StreamBlockId]) { arrayBuffers += arrayBuffer } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 9f352bdcb089..4f193322ad33 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -17,8 +17,10 @@ package org.apache.spark.streaming +import java.io.File import java.util.concurrent.atomic.AtomicInteger +import org.apache.commons.io.FileUtils import org.scalatest.{Assertions, BeforeAndAfter, FunSuite} import org.scalatest.concurrent.Timeouts import org.scalatest.concurrent.Eventually._ @@ -73,9 +75,9 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("from conf with settings") { val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) - myConf.set("spark.cleaner.ttl", "10") + myConf.set("spark.cleaner.ttl", "10s") ssc = new StreamingContext(myConf, batchDuration) - assert(ssc.conf.getInt("spark.cleaner.ttl", -1) === 10) + assert(ssc.conf.getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) } test("from existing SparkContext") { @@ -85,24 +87,26 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("from existing SparkContext with settings") { val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) - myConf.set("spark.cleaner.ttl", "10") + myConf.set("spark.cleaner.ttl", "10s") ssc = new StreamingContext(myConf, batchDuration) - assert(ssc.conf.getInt("spark.cleaner.ttl", -1) === 10) + assert(ssc.conf.getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) } test("from checkpoint") { val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) - myConf.set("spark.cleaner.ttl", "10") + myConf.set("spark.cleaner.ttl", "10s") val ssc1 = new StreamingContext(myConf, batchDuration) addInputStream(ssc1).register() ssc1.start() val cp = new Checkpoint(ssc1, Time(1000)) - assert(cp.sparkConfPairs.toMap.getOrElse("spark.cleaner.ttl", "-1") === "10") + assert( + Utils.timeStringAsSeconds(cp.sparkConfPairs + .toMap.getOrElse("spark.cleaner.ttl", "-1")) === 10) ssc1.stop() val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp)) - assert(newCp.sparkConf.getInt("spark.cleaner.ttl", -1) === 10) + assert(newCp.createSparkConf().getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) ssc = new StreamingContext(null, newCp, null) - assert(ssc.conf.getInt("spark.cleaner.ttl", -1) === 10) + assert(ssc.conf.getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) } test("start and stop state check") { @@ -176,7 +180,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("stop gracefully") { val conf = new SparkConf().setMaster(master).setAppName(appName) - conf.set("spark.cleaner.ttl", "3600") + conf.set("spark.cleaner.ttl", "3600s") sc = new SparkContext(conf) for (i <- 1 to 4) { logInfo("==================================\n\n\n") @@ -190,7 +194,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w logInfo("Count = " + count + ", Running count = " + runningCount) } ssc.start() - ssc.awaitTermination(500) + ssc.awaitTerminationOrTimeout(500) ssc.stop(stopSparkContext = false, stopGracefully = true) logInfo("Running count = " + runningCount) logInfo("TestReceiver.counter = " + TestReceiver.counter.get()) @@ -205,6 +209,32 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w } } + test("stop slow receiver gracefully") { + val conf = new SparkConf().setMaster(master).setAppName(appName) + conf.set("spark.streaming.gracefulStopTimeout", "20000s") + sc = new SparkContext(conf) + logInfo("==================================\n\n\n") + ssc = new StreamingContext(sc, Milliseconds(100)) + var runningCount = 0 + SlowTestReceiver.receivedAllRecords = false + // Create test receiver that sleeps in onStop() + val totalNumRecords = 15 + val recordsPerSecond = 1 + val input = ssc.receiverStream(new SlowTestReceiver(totalNumRecords, recordsPerSecond)) + input.count().foreachRDD { rdd => + val count = rdd.first() + runningCount += count.toInt + logInfo("Count = " + count + ", Running count = " + runningCount) + } + ssc.start() + ssc.awaitTerminationOrTimeout(500) + ssc.stop(stopSparkContext = false, stopGracefully = true) + logInfo("Running count = " + runningCount) + assert(runningCount > 0) + assert(runningCount == totalNumRecords) + Thread.sleep(100) + } + test("awaitTermination") { ssc = new StreamingContext(master, appName, batchDuration) val inputStream = addInputStream(ssc) @@ -217,7 +247,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w // test whether awaitTermination() exits after give amount of time failAfter(1000 millis) { - ssc.awaitTermination(500) + ssc.awaitTerminationOrTimeout(500) } // test whether awaitTermination() does not exit if not time is given @@ -262,7 +292,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w val exception = intercept[Exception] { ssc.start() - ssc.awaitTermination(5000) + ssc.awaitTerminationOrTimeout(5000) } assert(exception.getMessage.contains("map task"), "Expected exception not thrown") } @@ -273,11 +303,168 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w inputStream.transform { rdd => throw new TestException("error in transform"); rdd }.register() val exception = intercept[TestException] { ssc.start() - ssc.awaitTermination(5000) + ssc.awaitTerminationOrTimeout(5000) } assert(exception.getMessage.contains("transform"), "Expected exception not thrown") } + test("awaitTerminationOrTimeout") { + ssc = new StreamingContext(master, appName, batchDuration) + val inputStream = addInputStream(ssc) + inputStream.map(x => x).register() + + ssc.start() + + // test whether awaitTerminationOrTimeout() return false after give amount of time + failAfter(1000 millis) { + assert(ssc.awaitTerminationOrTimeout(500) === false) + } + + // test whether awaitTerminationOrTimeout() return true if context is stopped + failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown + new Thread() { + override def run() { + Thread.sleep(500) + ssc.stop() + } + }.start() + assert(ssc.awaitTerminationOrTimeout(10000) === true) + } + } + + test("getOrCreate") { + val conf = new SparkConf().setMaster(master).setAppName(appName) + + // Function to create StreamingContext that has a config to identify it to be new context + var newContextCreated = false + def creatingFunction(): StreamingContext = { + newContextCreated = true + new StreamingContext(conf, batchDuration) + } + + // Call ssc.stop after a body of code + def testGetOrCreate(body: => Unit): Unit = { + newContextCreated = false + try { + body + } finally { + if (ssc != null) { + ssc.stop() + } + ssc = null + } + } + + val emptyPath = Utils.createTempDir().getAbsolutePath() + + // getOrCreate should create new context with empty path + testGetOrCreate { + ssc = StreamingContext.getOrCreate(emptyPath, creatingFunction _) + assert(ssc != null, "no context created") + assert(newContextCreated, "new context not created") + } + + val corrutedCheckpointPath = createCorruptedCheckpoint() + + // getOrCreate should throw exception with fake checkpoint file and createOnError = false + intercept[Exception] { + ssc = StreamingContext.getOrCreate(corrutedCheckpointPath, creatingFunction _) + } + + // getOrCreate should throw exception with fake checkpoint file + intercept[Exception] { + ssc = StreamingContext.getOrCreate( + corrutedCheckpointPath, creatingFunction _, createOnError = false) + } + + // getOrCreate should create new context with fake checkpoint file and createOnError = true + testGetOrCreate { + ssc = StreamingContext.getOrCreate( + corrutedCheckpointPath, creatingFunction _, createOnError = true) + assert(ssc != null, "no context created") + assert(newContextCreated, "new context not created") + } + + val checkpointPath = createValidCheckpoint() + + // getOrCreate should recover context with checkpoint path, and recover old configuration + testGetOrCreate { + ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _) + assert(ssc != null, "no context created") + assert(!newContextCreated, "old context not recovered") + assert(ssc.conf.get("someKey") === "someValue") + } + } + + test("getOrCreate with existing SparkContext") { + val conf = new SparkConf().setMaster(master).setAppName(appName) + sc = new SparkContext(conf) + + // Function to create StreamingContext that has a config to identify it to be new context + var newContextCreated = false + def creatingFunction(sparkContext: SparkContext): StreamingContext = { + newContextCreated = true + new StreamingContext(sparkContext, batchDuration) + } + + // Call ssc.stop(stopSparkContext = false) after a body of cody + def testGetOrCreate(body: => Unit): Unit = { + newContextCreated = false + try { + body + } finally { + if (ssc != null) { + ssc.stop(stopSparkContext = false) + } + ssc = null + } + } + + val emptyPath = Utils.createTempDir().getAbsolutePath() + + // getOrCreate should create new context with empty path + testGetOrCreate { + ssc = StreamingContext.getOrCreate(emptyPath, creatingFunction _, sc, createOnError = true) + assert(ssc != null, "no context created") + assert(newContextCreated, "new context not created") + assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") + } + + val corrutedCheckpointPath = createCorruptedCheckpoint() + + // getOrCreate should throw exception with fake checkpoint file and createOnError = false + intercept[Exception] { + ssc = StreamingContext.getOrCreate(corrutedCheckpointPath, creatingFunction _, sc) + } + + // getOrCreate should throw exception with fake checkpoint file + intercept[Exception] { + ssc = StreamingContext.getOrCreate( + corrutedCheckpointPath, creatingFunction _, sc, createOnError = false) + } + + // getOrCreate should create new context with fake checkpoint file and createOnError = true + testGetOrCreate { + ssc = StreamingContext.getOrCreate( + corrutedCheckpointPath, creatingFunction _, sc, createOnError = true) + assert(ssc != null, "no context created") + assert(newContextCreated, "new context not created") + assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") + } + + val checkpointPath = createValidCheckpoint() + + // StreamingContext.getOrCreate should recover context with checkpoint path + testGetOrCreate { + ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _, sc) + assert(ssc != null, "no context created") + assert(!newContextCreated, "old context not recovered") + assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") + assert(!ssc.conf.contains("someKey"), + "recovered StreamingContext unexpectedly has old config") + } + } + test("DStream and generated RDD creation sites") { testPackage.test() } @@ -287,6 +474,30 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w val inputStream = new TestInputStream(s, input, 1) inputStream } + + def createValidCheckpoint(): String = { + val testDirectory = Utils.createTempDir().getAbsolutePath() + val checkpointDirectory = Utils.createTempDir().getAbsolutePath() + val conf = new SparkConf().setMaster(master).setAppName(appName) + conf.set("someKey", "someValue") + ssc = new StreamingContext(conf, batchDuration) + ssc.checkpoint(checkpointDirectory) + ssc.textFileStream(testDirectory).foreachRDD { rdd => rdd.count() } + ssc.start() + eventually(timeout(10000 millis)) { + assert(Checkpoint.getCheckpointFiles(checkpointDirectory).size > 1) + } + ssc.stop() + checkpointDirectory + } + + def createCorruptedCheckpoint(): String = { + val checkpointDirectory = Utils.createTempDir().getAbsolutePath() + val fakeCheckpointFile = Checkpoint.checkpointFile(checkpointDirectory, Time(1000)) + FileUtils.write(new File(fakeCheckpointFile.toString()), "blablabla") + assert(Checkpoint.getCheckpointFiles(checkpointDirectory).nonEmpty) + checkpointDirectory + } } class TestException(msg: String) extends Exception(msg) @@ -319,6 +530,39 @@ object TestReceiver { val counter = new AtomicInteger(1) } +/** Custom receiver for testing whether a slow receiver can be shutdown gracefully or not */ +class SlowTestReceiver(totalRecords: Int, recordsPerSecond: Int) + extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging { + + var receivingThreadOption: Option[Thread] = None + + def onStart() { + val thread = new Thread() { + override def run() { + logInfo("Receiving started") + for(i <- 1 to totalRecords) { + Thread.sleep(1000 / recordsPerSecond) + store(i) + } + SlowTestReceiver.receivedAllRecords = true + logInfo(s"Received all $totalRecords records") + } + } + receivingThreadOption = Some(thread) + thread.start() + } + + def onStop() { + // Simulate slow receiver by waiting for all records to be produced + while(!SlowTestReceiver.receivedAllRecords) Thread.sleep(100) + // no cleanup to be done, the receiving thread should stop on it own + } +} + +object SlowTestReceiver { + var receivedAllRecords = false +} + /** Streaming application for testing DStream and RDD creation sites */ package object testPackage extends Assertions { def test() { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index f52562b0a0f7..721043950954 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -38,18 +38,46 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { // To make sure that the processing start and end times in collected // information are different for successive batches - override def batchDuration = Milliseconds(100) - override def actuallyWait = true + override def batchDuration: Duration = Milliseconds(100) + override def actuallyWait: Boolean = true test("batch info reporting") { val ssc = setupStreams(input, operation) val collector = new BatchInfoCollector ssc.addStreamingListener(collector) runStreams(ssc, input.size, input.size) - val batchInfos = collector.batchInfos - batchInfos should have size 4 - batchInfos.foreach(info => { + // SPARK-6766: batch info should be submitted + val batchInfosSubmitted = collector.batchInfosSubmitted + batchInfosSubmitted should have size 4 + + batchInfosSubmitted.foreach(info => { + info.schedulingDelay should be (None) + info.processingDelay should be (None) + info.totalDelay should be (None) + }) + + isInIncreasingOrder(batchInfosSubmitted.map(_.submissionTime)) should be (true) + + // SPARK-6766: processingStartTime of batch info should not be None when starting + val batchInfosStarted = collector.batchInfosStarted + batchInfosStarted should have size 4 + + batchInfosStarted.foreach(info => { + info.schedulingDelay should not be None + info.schedulingDelay.get should be >= 0L + info.processingDelay should be (None) + info.totalDelay should be (None) + }) + + isInIncreasingOrder(batchInfosStarted.map(_.submissionTime)) should be (true) + isInIncreasingOrder(batchInfosStarted.map(_.processingStartTime.get)) should be (true) + + // test onBatchCompleted + val batchInfosCompleted = collector.batchInfosCompleted + batchInfosCompleted should have size 4 + + batchInfosCompleted.foreach(info => { info.schedulingDelay should not be None info.processingDelay should not be None info.totalDelay should not be None @@ -58,9 +86,9 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { info.totalDelay.get should be >= 0L }) - isInIncreasingOrder(batchInfos.map(_.submissionTime)) should be (true) - isInIncreasingOrder(batchInfos.map(_.processingStartTime.get)) should be (true) - isInIncreasingOrder(batchInfos.map(_.processingEndTime.get)) should be (true) + isInIncreasingOrder(batchInfosCompleted.map(_.submissionTime)) should be (true) + isInIncreasingOrder(batchInfosCompleted.map(_.processingStartTime.get)) should be (true) + isInIncreasingOrder(batchInfosCompleted.map(_.processingEndTime.get)) should be (true) } test("receiver info reporting") { @@ -99,9 +127,20 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { /** Listener that collects information on processed batches */ class BatchInfoCollector extends StreamingListener { - val batchInfos = new ArrayBuffer[BatchInfo] + val batchInfosCompleted = new ArrayBuffer[BatchInfo] + val batchInfosStarted = new ArrayBuffer[BatchInfo] + val batchInfosSubmitted = new ArrayBuffer[BatchInfo] + + override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted) { + batchInfosSubmitted += batchSubmitted.batchInfo + } + + override def onBatchStarted(batchStarted: StreamingListenerBatchStarted) { + batchInfosStarted += batchStarted.batchInfo + } + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { - batchInfos += batchCompleted.batchInfo + batchInfosCompleted += batchCompleted.batchInfo } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 7d82c3e4aadc..c3cae8aeb6d1 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -31,10 +31,9 @@ import org.scalatest.concurrent.PatienceConfiguration import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} import org.apache.spark.streaming.scheduler.{StreamingListenerBatchStarted, StreamingListenerBatchCompleted, StreamingListener} -import org.apache.spark.streaming.util.ManualClock import org.apache.spark.{SparkConf, Logging} import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils +import org.apache.spark.util.{ManualClock, Utils} /** * This is a input stream just for the testsuites. This is equivalent to a checkpointable, @@ -54,8 +53,9 @@ class TestInputStream[T: ClassTag](ssc_ : StreamingContext, input: Seq[Seq[T]], val selectedInput = if (index < input.size) input(index) else Seq[T]() // lets us test cases where RDDs are not created - if (selectedInput == null) + if (selectedInput == null) { return None + } val rdd = ssc.sc.makeRDD(selectedInput, numPartitions) logInfo("Created RDD " + rdd.id + " with " + selectedInput) @@ -105,7 +105,9 @@ class TestOutputStreamWithPartitions[T: ClassTag](parent: DStream[T], output.clear() } - def toTestOutputStream = new TestOutputStream[T](this.parent, this.output.map(_.flatten)) + def toTestOutputStream: TestOutputStream[T] = { + new TestOutputStream[T](this.parent, this.output.map(_.flatten)) + } } /** @@ -149,34 +151,34 @@ class BatchCounter(ssc: StreamingContext) { trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Name of the framework for Spark context - def framework = this.getClass.getSimpleName + def framework: String = this.getClass.getSimpleName // Master for Spark context - def master = "local[2]" + def master: String = "local[2]" // Batch duration - def batchDuration = Seconds(1) + def batchDuration: Duration = Seconds(1) // Directory where the checkpoint data will be saved - lazy val checkpointDir = { + lazy val checkpointDir: String = { val dir = Utils.createTempDir() logDebug(s"checkpointDir: $dir") dir.toString } // Number of partitions of the input parallel collections created for testing - def numInputPartitions = 2 + def numInputPartitions: Int = 2 // Maximum time to wait before the test times out - def maxWaitTimeMillis = 10000 + def maxWaitTimeMillis: Int = 10000 // Whether to use manual clock or not - def useManualClock = true + def useManualClock: Boolean = true // Whether to actually wait in real time before changing manual clock - def actuallyWait = false + def actuallyWait: Boolean = false - //// A SparkConf to use in tests. Can be modified before calling setupStreams to configure things. + // A SparkConf to use in tests. Can be modified before calling setupStreams to configure things. val conf = new SparkConf() .setMaster(master) .setAppName(framework) @@ -189,10 +191,10 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { def beforeFunction() { if (useManualClock) { logInfo("Using manual clock") - conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + conf.set("spark.streaming.clock", "org.apache.spark.util.ManualClock") } else { logInfo("Using real clock") - conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock") + conf.set("spark.streaming.clock", "org.apache.spark.util.SystemClock") } } @@ -333,23 +335,24 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Advance manual clock val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - logInfo("Manual clock before advancing = " + clock.currentTime()) + logInfo("Manual clock before advancing = " + clock.getTimeMillis()) if (actuallyWait) { for (i <- 1 to numBatches) { logInfo("Actually waiting for " + batchDuration) - clock.addToTime(batchDuration.milliseconds) + clock.advance(batchDuration.milliseconds) Thread.sleep(batchDuration.milliseconds) } } else { - clock.addToTime(numBatches * batchDuration.milliseconds) + clock.advance(numBatches * batchDuration.milliseconds) } - logInfo("Manual clock after advancing = " + clock.currentTime()) + logInfo("Manual clock after advancing = " + clock.getTimeMillis()) // Wait until expected number of output items have been generated val startTime = System.currentTimeMillis() - while (output.size < numExpectedOutput && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { + while (output.size < numExpectedOutput && + System.currentTimeMillis() - startTime < maxWaitTimeMillis) { logInfo("output.size = " + output.size + ", numExpectedOutput = " + numExpectedOutput) - ssc.awaitTermination(50) + ssc.awaitTerminationOrTimeout(50) } val timeTaken = System.currentTimeMillis() - startTime logInfo("Output generated in " + timeTaken + " milliseconds") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala new file mode 100644 index 000000000000..205ddf6dbe9b --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import org.openqa.selenium.WebDriver +import org.openqa.selenium.htmlunit.HtmlUnitDriver +import org.scalatest._ +import org.scalatest.concurrent.Eventually._ +import org.scalatest.selenium.WebBrowser +import org.scalatest.time.SpanSugar._ + +import org.apache.spark._ + + + + +/** + * Selenium tests for the Spark Web UI. + */ +class UISeleniumSuite + extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase { + + implicit var webDriver: WebDriver = _ + + override def beforeAll(): Unit = { + webDriver = new HtmlUnitDriver + } + + override def afterAll(): Unit = { + if (webDriver != null) { + webDriver.quit() + } + } + + /** + * Create a test SparkStreamingContext with the SparkUI enabled. + */ + private def newSparkStreamingContext(): StreamingContext = { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set("spark.ui.enabled", "true") + val ssc = new StreamingContext(conf, Seconds(1)) + assert(ssc.sc.ui.isDefined, "Spark UI is not started!") + ssc + } + + test("attaching and detaching a Streaming tab") { + withStreamingContext(newSparkStreamingContext()) { ssc => + val sparkUI = ssc.sparkContext.ui.get + + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (sparkUI.appUIAddress.stripSuffix("/")) + find(cssSelector( """ul li a[href*="streaming"]""")) should not be (None) + } + + eventually(timeout(10 seconds), interval(50 milliseconds)) { + // check whether streaming page exists + go to (sparkUI.appUIAddress.stripSuffix("/") + "/streaming") + val statisticText = findAll(cssSelector("li strong")).map(_.text).toSeq + statisticText should contain("Network receivers:") + statisticText should contain("Batch interval:") + + val h4Text = findAll(cssSelector("h4")).map(_.text).toSeq + h4Text should contain("Active Batches (0)") + h4Text should contain("Completed Batches (last 0 out of 0)") + + findAll(cssSelector("""#active-batches-table th""")).map(_.text).toSeq should be { + List("Batch Time", "Input Size", "Scheduling Delay", "Processing Time", "Status") + } + findAll(cssSelector("""#completed-batches-table th""")).map(_.text).toSeq should be { + List("Batch Time", "Input Size", "Scheduling Delay", "Processing Time", "Total Delay") + } + } + + ssc.stop(false) + + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (sparkUI.appUIAddress.stripSuffix("/")) + find(cssSelector( """ul li a[href*="streaming"]""")) should be(None) + } + + eventually(timeout(10 seconds), interval(50 milliseconds)) { + go to (sparkUI.appUIAddress.stripSuffix("/") + "/streaming") + val statisticText = findAll(cssSelector("li strong")).map(_.text).toSeq + statisticText should not contain ("Network receivers:") + statisticText should not contain ("Batch interval:") + } + } + } +} + diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISuite.scala deleted file mode 100644 index 8e3011826685..000000000000 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISuite.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming - -import scala.io.Source - -import org.scalatest.FunSuite -import org.scalatest.concurrent.Eventually._ -import org.scalatest.time.SpanSugar._ - -import org.apache.spark.SparkConf - -class UISuite extends FunSuite { - - // Ignored: See SPARK-1530 - ignore("streaming tab in spark UI") { - val conf = new SparkConf() - .setMaster("local") - .setAppName("test") - .set("spark.ui.enabled", "true") - val ssc = new StreamingContext(conf, Seconds(1)) - assert(ssc.sc.ui.isDefined, "Spark UI is not started!") - val ui = ssc.sc.ui.get - - eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(ui.appUIAddress).mkString - assert(!html.contains("random data that should not be present")) - // test if streaming tab exist - assert(html.toLowerCase.contains("streaming")) - // test if other Spark tabs still exist - assert(html.toLowerCase.contains("stages")) - } - - eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(ui.appUIAddress.stripSuffix("/") + "/streaming").mkString - assert(html.toLowerCase.contains("batch")) - assert(html.toLowerCase.contains("network")) - } - } -} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala index a5d2bb2fde16..c39ad05f4152 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala @@ -22,9 +22,9 @@ import org.apache.spark.storage.StorageLevel class WindowOperationsSuite extends TestSuiteBase { - override def maxWaitTimeMillis = 20000 // large window tests can sometimes take longer + override def maxWaitTimeMillis: Int = 20000 // large window tests can sometimes take longer - override def batchDuration = Seconds(1) // making sure its visible in this class + override def batchDuration: Duration = Seconds(1) // making sure its visible in this class val largerSlideInput = Seq( Seq(("a", 1)), diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index 7a6a2f3e577d..c3602a5b7373 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -28,10 +28,13 @@ import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBloc import org.apache.spark.streaming.util.{WriteAheadLogFileSegment, WriteAheadLogWriter} import org.apache.spark.util.Utils -class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { +class WriteAheadLogBackedBlockRDDSuite + extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { + val conf = new SparkConf() .setMaster("local[2]") .setAppName(this.getClass.getSimpleName) + val hadoopConf = new Configuration() var sparkContext: SparkContext = null @@ -86,7 +89,8 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll w * @param numPartitionsInWAL Number of partitions to write to the Write Ahead Log * @param testStoreInBM Test whether blocks read from log are stored back into block manager */ - private def testRDD(numPartitionsInBM: Int, numPartitionsInWAL: Int, testStoreInBM: Boolean = false) { + private def testRDD( + numPartitionsInBM: Int, numPartitionsInWAL: Int, testStoreInBM: Boolean = false) { val numBlocks = numPartitionsInBM + numPartitionsInWAL val data = Seq.fill(numBlocks, 10)(scala.util.Random.nextString(50)) @@ -110,7 +114,7 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll w "Unexpected blocks in BlockManager" ) - // Make sure that the right `numPartitionsInWAL` blocks are in write ahead logs, and other are not + // Make sure that the right `numPartitionsInWAL` blocks are in WALs, and other are not require( segments.takeRight(numPartitionsInWAL).forall(s => new File(s.path.stripPrefix("file://")).exists()), @@ -152,6 +156,6 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll w } private def generateFakeSegments(count: Int): Seq[WriteAheadLogFileSegment] = { - Array.fill(count)(new WriteAheadLogFileSegment("random", 0l, 0)) + Array.fill(count)(new WriteAheadLogFileSegment("random", 0L, 0)) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala new file mode 100644 index 000000000000..7865b06c2e3c --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import java.util.concurrent.CountDownLatch + +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming._ +import org.apache.spark.util.{ManualClock, Utils} + +class JobGeneratorSuite extends TestSuiteBase { + + // SPARK-6222 is a tricky regression bug which causes received block metadata + // to be deleted before the corresponding batch has completed. This occurs when + // the following conditions are met. + // 1. streaming checkpointing is enabled by setting streamingContext.checkpoint(dir) + // 2. input data is received through a receiver as blocks + // 3. a batch processing a set of blocks takes a long time, such that a few subsequent + // batches have been generated and submitted for processing. + // + // The JobGenerator (as of Mar 16, 2015) checkpoints twice per batch, once after generation + // of a batch, and another time after the completion of a batch. The cleanup of + // checkpoint data (including block metadata, etc.) from DStream must be done only after the + // 2nd checkpoint has completed, that is, after the batch has been completely processed. + // However, the issue is that the checkpoint data and along with it received block data is + // cleaned even in the case of the 1st checkpoint, causing pre-mature deletion of received block + // data. For example, if the 3rd batch is still being process, the 7th batch may get generated, + // and the corresponding "1st checkpoint" will delete received block metadata of batch older + // than 6th batch. That, is 3rd batch's block metadata gets deleted even before 3rd batch has + // been completely processed. + // + // This test tries to create that scenario by the following. + // 1. enable checkpointing + // 2. generate batches with received blocks + // 3. make the 3rd batch never complete + // 4. allow subsequent batches to be generated (to allow premature deletion of 3rd batch metadata) + // 5. verify whether 3rd batch's block metadata still exists + // + test("SPARK-6222: Do not clear received block data too soon") { + import JobGeneratorSuite._ + val checkpointDir = Utils.createTempDir() + val testConf = conf + testConf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + testConf.set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1") + + withStreamingContext(new StreamingContext(testConf, batchDuration)) { ssc => + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val numBatches = 10 + val longBatchNumber = 3 // 3rd batch will take a long time + val longBatchTime = longBatchNumber * batchDuration.milliseconds + + val testTimeout = timeout(10 seconds) + val inputStream = ssc.receiverStream(new TestReceiver) + + inputStream.foreachRDD((rdd: RDD[Int], time: Time) => { + if (time.milliseconds == longBatchTime) { + while (waitLatch.getCount() > 0) { + waitLatch.await() + println("Await over") + } + } + }) + + val batchCounter = new BatchCounter(ssc) + ssc.checkpoint(checkpointDir.getAbsolutePath) + ssc.start() + + // Make sure the only 1 batch of information is to be remembered + assert(inputStream.rememberDuration === batchDuration) + val receiverTracker = ssc.scheduler.receiverTracker + + // Get the blocks belonging to a batch + def getBlocksOfBatch(batchTime: Long): Seq[ReceivedBlockInfo] = { + receiverTracker.getBlocksOfBatchAndStream(Time(batchTime), inputStream.id) + } + + // Wait for new blocks to be received + def waitForNewReceivedBlocks() { + eventually(testTimeout) { + assert(receiverTracker.hasUnallocatedBlocks) + } + } + + // Wait for received blocks to be allocated to a batch + def waitForBlocksToBeAllocatedToBatch(batchTime: Long) { + eventually(testTimeout) { + assert(getBlocksOfBatch(batchTime).nonEmpty) + } + } + + // Generate a large number of batches with blocks in them + for (batchNum <- 1 to numBatches) { + waitForNewReceivedBlocks() + clock.advance(batchDuration.milliseconds) + waitForBlocksToBeAllocatedToBatch(clock.getTimeMillis()) + } + + // Wait for 3rd batch to start + eventually(testTimeout) { + ssc.scheduler.getPendingTimes().contains(Time(numBatches * batchDuration.milliseconds)) + } + + // Verify that the 3rd batch's block data is still present while the 3rd batch is incomplete + assert(getBlocksOfBatch(longBatchTime).nonEmpty, "blocks of incomplete batch already deleted") + assert(batchCounter.getNumCompletedBatches < longBatchNumber) + waitLatch.countDown() + } + } +} + +object JobGeneratorSuite { + val waitLatch = new CountDownLatch(1) +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala new file mode 100644 index 000000000000..94b1985116fe --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.ui + +import org.scalatest.Matchers + +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.scheduler._ +import org.apache.spark.streaming.{Duration, Time, Milliseconds, TestSuiteBase} + +class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { + + val input = (1 to 4).map(Seq(_)).toSeq + val operation = (d: DStream[Int]) => d.map(x => x) + + override def batchDuration: Duration = Milliseconds(100) + + test("onBatchSubmitted, onBatchStarted, onBatchCompleted, " + + "onReceiverStarted, onReceiverError, onReceiverStopped") { + val ssc = setupStreams(input, operation) + val listener = new StreamingJobProgressListener(ssc) + + val receivedBlockInfo = Map( + 0 -> Array(ReceivedBlockInfo(0, 100, null), ReceivedBlockInfo(0, 200, null)), + 1 -> Array(ReceivedBlockInfo(1, 300, null)) + ) + + // onBatchSubmitted + val batchInfoSubmitted = BatchInfo(Time(1000), receivedBlockInfo, 1000, None, None) + listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) + listener.waitingBatches should be (List(batchInfoSubmitted)) + listener.runningBatches should be (Nil) + listener.retainedCompletedBatches should be (Nil) + listener.lastCompletedBatch should be (None) + listener.numUnprocessedBatches should be (1) + listener.numTotalCompletedBatches should be (0) + listener.numTotalProcessedRecords should be (0) + listener.numTotalReceivedRecords should be (0) + + // onBatchStarted + val batchInfoStarted = BatchInfo(Time(1000), receivedBlockInfo, 1000, Some(2000), None) + listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) + listener.waitingBatches should be (Nil) + listener.runningBatches should be (List(batchInfoStarted)) + listener.retainedCompletedBatches should be (Nil) + listener.lastCompletedBatch should be (None) + listener.numUnprocessedBatches should be (1) + listener.numTotalCompletedBatches should be (0) + listener.numTotalProcessedRecords should be (0) + listener.numTotalReceivedRecords should be (600) + + // onBatchCompleted + val batchInfoCompleted = BatchInfo(Time(1000), receivedBlockInfo, 1000, Some(2000), None) + listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) + listener.waitingBatches should be (Nil) + listener.runningBatches should be (Nil) + listener.retainedCompletedBatches should be (List(batchInfoCompleted)) + listener.lastCompletedBatch should be (Some(batchInfoCompleted)) + listener.numUnprocessedBatches should be (0) + listener.numTotalCompletedBatches should be (1) + listener.numTotalProcessedRecords should be (600) + listener.numTotalReceivedRecords should be (600) + + // onReceiverStarted + val receiverInfoStarted = ReceiverInfo(0, "test", null, true, "localhost") + listener.onReceiverStarted(StreamingListenerReceiverStarted(receiverInfoStarted)) + listener.receiverInfo(0) should be (Some(receiverInfoStarted)) + listener.receiverInfo(1) should be (None) + + // onReceiverError + val receiverInfoError = ReceiverInfo(1, "test", null, true, "localhost") + listener.onReceiverError(StreamingListenerReceiverError(receiverInfoError)) + listener.receiverInfo(0) should be (Some(receiverInfoStarted)) + listener.receiverInfo(1) should be (Some(receiverInfoError)) + listener.receiverInfo(2) should be (None) + + // onReceiverStopped + val receiverInfoStopped = ReceiverInfo(2, "test", null, true, "localhost") + listener.onReceiverStopped(StreamingListenerReceiverStopped(receiverInfoStopped)) + listener.receiverInfo(0) should be (Some(receiverInfoStarted)) + listener.receiverInfo(1) should be (Some(receiverInfoError)) + listener.receiverInfo(2) should be (Some(receiverInfoStopped)) + listener.receiverInfo(3) should be (None) + } + + test("Remove the old completed batches when exceeding the limit") { + val ssc = setupStreams(input, operation) + val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 100) + val listener = new StreamingJobProgressListener(ssc) + + val receivedBlockInfo = Map( + 0 -> Array(ReceivedBlockInfo(0, 100, null), ReceivedBlockInfo(0, 200, null)), + 1 -> Array(ReceivedBlockInfo(1, 300, null)) + ) + val batchInfoCompleted = BatchInfo(Time(1000), receivedBlockInfo, 1000, Some(2000), None) + + for(_ <- 0 until (limit + 10)) { + listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) + } + + listener.retainedCompletedBatches.size should be (limit) + listener.numTotalCompletedBatches should be(limit + 10) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 7ce9499dc614..a3919c43b95b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -26,7 +26,7 @@ import scala.language.{implicitConversions, postfixOps} import WriteAheadLogSuite._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.util.Utils +import org.apache.spark.util.{ManualClock, Utils} import org.scalatest.{BeforeAndAfter, FunSuite} import org.scalatest.concurrent.Eventually._ @@ -197,7 +197,7 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { val logFiles = getLogFilesInDirectory(testDir) assert(logFiles.size > 1) - manager.cleanupOldLogs(manualClock.currentTime() / 2, waitForCompletion) + manager.cleanupOldLogs(manualClock.getTimeMillis() / 2, waitForCompletion) if (waitForCompletion) { assert(getLogFilesInDirectory(testDir).size < logFiles.size) @@ -219,7 +219,7 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { // Recover old files and generate a second set of log files val dataToWrite2 = generateRandomData() - manualClock.addToTime(100000) + manualClock.advance(100000) writeDataUsingManager(testDir, dataToWrite2, manualClock) val logFiles2 = getLogFilesInDirectory(testDir) assert(logFiles2.size > logFiles1.size) @@ -279,19 +279,19 @@ object WriteAheadLogSuite { manualClock: ManualClock = new ManualClock, stopManager: Boolean = true ): WriteAheadLogManager = { - if (manualClock.currentTime < 100000) manualClock.setTime(10000) + if (manualClock.getTimeMillis() < 100000) manualClock.setTime(10000) val manager = new WriteAheadLogManager(logDirectory, hadoopConf, rollingIntervalSecs = 1, callerName = "WriteAheadLogSuite", clock = manualClock) // Ensure that 500 does not get sorted after 2000, so put a high base value. data.foreach { item => - manualClock.addToTime(500) + manualClock.advance(500) manager.writeToLog(item) } if (stopManager) manager.stop() manager } - /** Read data from a segments of a log file directly and return the list of byte buffers.*/ + /** Read data from a segments of a log file directly and return the list of byte buffers. */ def readDataManually(segments: Seq[WriteAheadLogFileSegment]): Seq[String] = { segments.map { segment => val reader = HdfsUtils.getInputStream(segment.path, hadoopConf) diff --git a/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala b/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala index d0bf328f2b74..d66750463033 100644 --- a/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala @@ -25,7 +25,8 @@ package org.apache.spark.streamingtest */ class ImplicitSuite { - // We only want to test if `implict` works well with the compiler, so we don't need a real DStream. + // We only want to test if `implicit` works well with the compiler, + // so we don't need a real DStream. def mockDStream[T]: org.apache.spark.streaming.dstream.DStream[T] = null def testToPairDStreamFunctions(): Unit = { diff --git a/tools/pom.xml b/tools/pom.xml index e7419ed2c607..1c6f3e83a181 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -19,8 +19,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT ../pom.xml diff --git a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala index 8d0f09933c8d..583823c90c5c 100644 --- a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala +++ b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala @@ -17,7 +17,7 @@ package org.apache.spark.tools -import java.lang.reflect.Method +import java.lang.reflect.{Type, Method} import scala.collection.mutable.ArrayBuffer import scala.language.existentials @@ -302,7 +302,7 @@ object JavaAPICompletenessChecker { private def isExcludedByInterface(method: Method): Boolean = { val excludedInterfaces = Set("org.apache.spark.Logging", "org.apache.hadoop.mapreduce.HadoopMapReduceUtil") - def toComparisionKey(method: Method) = + def toComparisionKey(method: Method): (Class[_], String, Type) = (method.getReturnType, method.getName, method.getGenericReturnType) val interfaces = method.getDeclaringClass.getInterfaces.filter { i => excludedInterfaces.contains(i.getName) diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala index 15ee95070a3d..f2d135397ce2 100644 --- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala +++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.Utils * Writes simulated shuffle output from several threads and records the observed throughput. */ object StoragePerfTester { - def main(args: Array[String]) = { + def main(args: Array[String]): Unit = { /** Total amount of data to generate. Distributed evenly amongst maps and reduce splits. */ val dataSizeMb = Utils.memoryStringToMb(sys.env.getOrElse("OUTPUT_DATA", "1g")) @@ -58,8 +58,8 @@ object StoragePerfTester { val sc = new SparkContext("local[4]", "Write Tester", conf) val hashShuffleManager = sc.env.shuffleManager.asInstanceOf[HashShuffleManager] - def writeOutputBytes(mapId: Int, total: AtomicLong) = { - val shuffle = hashShuffleManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits, + def writeOutputBytes(mapId: Int, total: AtomicLong): Unit = { + val shuffle = hashShuffleManager.shuffleBlockResolver.forMapTask(1, mapId, numOutputSplits, new KryoSerializer(sc.conf), new ShuffleWriteMetrics()) val writers = shuffle.writers for (i <- 1 to recordsPerMap) { @@ -78,7 +78,7 @@ object StoragePerfTester { val totalBytes = new AtomicLong() for (task <- 1 to numMaps) { executor.submit(new Runnable() { - override def run() = { + override def run(): Unit = { try { writeOutputBytes(task, totalBytes) latch.countDown() diff --git a/yarn/pom.xml b/yarn/pom.xml index 7595549e4b6d..7c8c3613e7a0 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -19,8 +19,8 @@ 4.0.0 org.apache.spark - spark-parent - 1.3.0-SNAPSHOT + spark-parent_2.10 + 1.4.0-SNAPSHOT ../pom.xml @@ -58,6 +58,34 @@ org.apache.hadoop hadoop-client + + + + com.google.guava + guava + + + org.eclipse.jetty + jetty-server + + + org.eclipse.jetty + jetty-plus + + + org.eclipse.jetty + jetty-util + + + org.eclipse.jetty + jetty-http + + + org.eclipse.jetty + jetty-servlet + + + org.apache.hadoop hadoop-yarn-server-tests diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 902bdda59860..93ae45133ce2 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -19,39 +19,40 @@ package org.apache.spark.deploy.yarn import scala.util.control.NonFatal -import java.io.IOException +import java.io.{File, IOException} import java.lang.reflect.InvocationTargetException -import java.net.Socket +import java.net.{Socket, URL} import java.util.concurrent.atomic.AtomicReference -import akka.actor._ -import akka.remote._ import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv} import org.apache.spark.SparkException -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil} import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.scheduler.cluster.YarnSchedulerBackend import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util._ /** * Common application master functionality for Spark on Yarn. */ -private[spark] class ApplicationMaster(args: ApplicationMasterArguments, - client: YarnRMClient) extends Logging { +private[spark] class ApplicationMaster( + args: ApplicationMasterArguments, + client: YarnRMClient) + extends Logging { + // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be // optimal as more containers are available. Might need to handle this better. private val sparkConf = new SparkConf() private val yarnConf: YarnConfiguration = SparkHadoopUtil.get.newConfiguration(sparkConf) .asInstanceOf[YarnConfiguration] - private val isDriver = args.userClass != null + private val isClusterMode = args.userClass != null // Default to numExecutors * 2, with minimum of 3 private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", @@ -64,12 +65,12 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, @volatile private var finalMsg: String = "" @volatile private var userClassThread: Thread = _ - private var reporterThread: Thread = _ - private var allocator: YarnAllocator = _ + @volatile private var reporterThread: Thread = _ + @volatile private var allocator: YarnAllocator = _ // Fields used in client mode. - private var actorSystem: ActorSystem = null - private var actor: ActorRef = _ + private var rpcEnv: RpcEnv = null + private var amEndpoint: RpcEndpointRef = _ // Fields used in cluster mode. private val sparkContextRef = new AtomicReference[SparkContext](null) @@ -78,7 +79,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, try { val appAttemptId = client.getAttemptId() - if (isDriver) { + if (isClusterMode) { // Set the web ui port to be ephemeral for yarn so we don't conflict with // other spark processes running on the same box System.setProperty("spark.ui.port", "0") @@ -93,50 +94,44 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, logInfo("ApplicationAttemptId: " + appAttemptId) val fs = FileSystem.get(yarnConf) - val cleanupHook = new Runnable { - override def run() { - // If the SparkContext is still registered, shut it down as a best case effort in case - // users do not call sc.stop or do System.exit(). - val sc = sparkContextRef.get() - if (sc != null) { - logInfo("Invoking sc stop from shutdown hook") - sc.stop() - } - val maxAppAttempts = client.getMaxRegAttempts(sparkConf, yarnConf) - val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts - - if (!finished) { - // This happens when the user application calls System.exit(). We have the choice - // of either failing or succeeding at this point. We report success to avoid - // retrying applications that have succeeded (System.exit(0)), which means that - // applications that explicitly exit with a non-zero status will also show up as - // succeeded in the RM UI. - finish(finalStatus, - ApplicationMaster.EXIT_SUCCESS, - "Shutdown hook called before final status was reported.") - } - if (!unregistered) { - // we only want to unregister if we don't want the RM to retry - if (finalStatus == FinalApplicationStatus.SUCCEEDED || isLastAttempt) { - unregister(finalStatus, finalMsg) - cleanupStagingDir(fs) - } + Utils.addShutdownHook { () => + // If the SparkContext is still registered, shut it down as a best case effort in case + // users do not call sc.stop or do System.exit(). + val sc = sparkContextRef.get() + if (sc != null) { + logInfo("Invoking sc stop from shutdown hook") + sc.stop() + } + val maxAppAttempts = client.getMaxRegAttempts(sparkConf, yarnConf) + val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts + + if (!finished) { + // This happens when the user application calls System.exit(). We have the choice + // of either failing or succeeding at this point. We report success to avoid + // retrying applications that have succeeded (System.exit(0)), which means that + // applications that explicitly exit with a non-zero status will also show up as + // succeeded in the RM UI. + finish(finalStatus, + ApplicationMaster.EXIT_SUCCESS, + "Shutdown hook called before final status was reported.") + } + + if (!unregistered) { + // we only want to unregister if we don't want the RM to retry + if (finalStatus == FinalApplicationStatus.SUCCEEDED || isLastAttempt) { + unregister(finalStatus, finalMsg) + cleanupStagingDir(fs) } } } - // Use higher priority than FileSystem. - assert(ApplicationMaster.SHUTDOWN_HOOK_PRIORITY > FileSystem.SHUTDOWN_HOOK_PRIORITY) - ShutdownHookManager - .get().addShutdownHook(cleanupHook, ApplicationMaster.SHUTDOWN_HOOK_PRIORITY) - // Call this to force generation of secret so it gets populated into the - // Hadoop UGI. This has to happen before the startUserClass which does a + // Hadoop UGI. This has to happen before the startUserApplication which does a // doAs in order for the credentials to be passed on to the executor containers. val securityMgr = new SecurityManager(sparkConf) - if (isDriver) { + if (isClusterMode) { runDriver(securityMgr) } else { runExecutorLauncher(securityMgr) @@ -147,7 +142,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, logError("Uncaught exception: ", e) finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_UNCAUGHT_EXCEPTION, - "Uncaught exception: " + e.getMessage()) + "Uncaught exception: " + e) } exitCode } @@ -158,8 +153,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, * status to SUCCEEDED in cluster mode to handle if the user calls System.exit * from the application code. */ - final def getDefaultFinalStatus() = { - if (isDriver) { + final def getDefaultFinalStatus(): FinalApplicationStatus = { + if (isClusterMode) { FinalApplicationStatus.SUCCEEDED } else { FinalApplicationStatus.UNDEFINED @@ -171,31 +166,35 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, * This means the ResourceManager will not retry the application attempt on your behalf if * a failure occurred. */ - final def unregister(status: FinalApplicationStatus, diagnostics: String = null) = synchronized { - if (!unregistered) { - logInfo(s"Unregistering ApplicationMaster with $status" + - Option(diagnostics).map(msg => s" (diag message: $msg)").getOrElse("")) - unregistered = true - client.unregister(status, Option(diagnostics).getOrElse("")) + final def unregister(status: FinalApplicationStatus, diagnostics: String = null): Unit = { + synchronized { + if (!unregistered) { + logInfo(s"Unregistering ApplicationMaster with $status" + + Option(diagnostics).map(msg => s" (diag message: $msg)").getOrElse("")) + unregistered = true + client.unregister(status, Option(diagnostics).getOrElse("")) + } } } - final def finish(status: FinalApplicationStatus, code: Int, msg: String = null) = synchronized { - if (!finished) { - val inShutdown = Utils.inShutdown() - logInfo(s"Final app status: ${status}, exitCode: ${code}" + - Option(msg).map(msg => s", (reason: $msg)").getOrElse("")) - exitCode = code - finalStatus = status - finalMsg = msg - finished = true - if (!inShutdown && Thread.currentThread() != reporterThread && reporterThread != null) { - logDebug("shutting down reporter thread") - reporterThread.interrupt() - } - if (!inShutdown && Thread.currentThread() != userClassThread && userClassThread != null) { - logDebug("shutting down user thread") - userClassThread.interrupt() + final def finish(status: FinalApplicationStatus, code: Int, msg: String = null): Unit = { + synchronized { + if (!finished) { + val inShutdown = Utils.inShutdown() + logInfo(s"Final app status: $status, exitCode: $code" + + Option(msg).map(msg => s", (reason: $msg)").getOrElse("")) + exitCode = code + finalStatus = status + finalMsg = msg + finished = true + if (!inShutdown && Thread.currentThread() != reporterThread && reporterThread != null) { + logDebug("shutting down reporter thread") + reporterThread.interrupt() + } + if (!inShutdown && Thread.currentThread() != userClassThread && userClassThread != null) { + logDebug("shutting down user thread") + userClassThread.interrupt() + } } } } @@ -217,6 +216,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, val appId = client.getAttemptId().getApplicationId().toString() val historyAddress = sparkConf.getOption("spark.yarn.historyServer.address") + .map { text => SparkHadoopUtil.get.substituteHadoopVariables(text, yarnConf) } .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}" } .getOrElse("") @@ -231,9 +231,27 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, reporterThread = launchReporterThread() } + /** + * Create an [[RpcEndpoint]] that communicates with the driver. + * + * In cluster mode, the AM and the driver belong to same process + * so the AMEndpoint need not monitor lifecycle of the driver. + */ + private def runAMEndpoint( + host: String, + port: String, + isClusterMode: Boolean): Unit = { + val driverEndpont = rpcEnv.setupEndpointRef( + SparkEnv.driverActorSystemName, + RpcAddress(host, port.toInt), + YarnSchedulerBackend.ENDPOINT_NAME) + amEndpoint = + rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverEndpont, isClusterMode)) + } + private def runDriver(securityMgr: SecurityManager): Unit = { addAmIpFilter() - userClassThread = startUserClass() + userClassThread = startUserApplication() // This a bit hacky, but we need to wait until the spark.driver.port property has // been set by the Thread executing the user class. @@ -245,15 +263,19 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, ApplicationMaster.EXIT_SC_NOT_INITED, "Timed out waiting for SparkContext.") } else { + rpcEnv = sc.env.rpcEnv + runAMEndpoint( + sc.getConf.get("spark.driver.host"), + sc.getConf.get("spark.driver.port"), + isClusterMode = true) registerAM(sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) userClassThread.join() } } private def runExecutorLauncher(securityMgr: SecurityManager): Unit = { - actorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, - conf = sparkConf, securityManager = securityMgr)._1 - actor = waitForSparkDriver() + rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, 0, sparkConf, securityMgr) + waitForSparkDriver() addAmIpFilter() registerAM(sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) @@ -267,7 +289,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, // we want to be reasonably responsive without causing too many requests to RM. val schedulerInterval = - sparkConf.getLong("spark.yarn.scheduler.heartbeat.interval-ms", 5000) + sparkConf.getTimeAsMs("spark.yarn.scheduler.heartbeat.interval-ms", "5s") // must be <= expiryInterval / 2. val interval = math.max(0, math.min(expiryInterval / 2, schedulerInterval)) @@ -344,13 +366,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, private def waitForSparkContextInitialized(): SparkContext = { logInfo("Waiting for spark context initialization") sparkContextRef.synchronized { - val waitTries = sparkConf.getOption("spark.yarn.applicationMaster.waitTries") - .map(_.toLong * 10000L) - if (waitTries.isDefined) { - logWarning( - "spark.yarn.applicationMaster.waitTries is deprecated, use spark.yarn.am.waitTime") - } - val totalWaitTime = sparkConf.getLong("spark.yarn.am.waitTime", waitTries.getOrElse(100000L)) + val totalWaitTime = sparkConf.getTimeAsMs("spark.yarn.am.waitTime", "100s") val deadline = System.currentTimeMillis() + totalWaitTime while (sparkContextRef.get() == null && System.currentTimeMillis < deadline && !finished) { @@ -367,7 +383,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, } } - private def waitForSparkDriver(): ActorRef = { + private def waitForSparkDriver(): Unit = { logInfo("Waiting for Spark driver to be reachable.") var driverUp = false val hostport = args.userArgs(0) @@ -375,8 +391,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, // Spark driver should already be up since it launched us, but we don't want to // wait forever, so wait 100 seconds max to match the cluster mode setting. - val totalWaitTime = sparkConf.getLong("spark.yarn.am.waitTime", 100000L) - val deadline = System.currentTimeMillis + totalWaitTime + val totalWaitTimeMs = sparkConf.getTimeAsMs("spark.yarn.am.waitTime", "100s") + val deadline = System.currentTimeMillis + totalWaitTimeMs while (!driverUp && !finished && System.currentTimeMillis < deadline) { try { @@ -399,12 +415,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, sparkConf.set("spark.driver.host", driverHost) sparkConf.set("spark.driver.port", driverPort.toString) - val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( - SparkEnv.driverActorSystemName, - driverHost, - driverPort.toString, - YarnSchedulerBackend.ACTOR_NAME) - actorSystem.actorOf(Props(new AMActor(driverUrl)), name = "YarnAM") + runAMEndpoint(driverHost, driverPort.toString, isClusterMode = false) } /** Add the Yarn IP filter that is required for properly securing the UI. */ @@ -412,11 +423,11 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, val proxyBase = System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) val amFilter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" val params = client.getAmIpFilterParams(yarnConf, proxyBase) - if (isDriver) { + if (isClusterMode) { System.setProperty("spark.ui.filters", amFilter) params.foreach { case (k, v) => System.setProperty(s"spark.$amFilter.param.$k", v) } } else { - actor ! AddWebUIFilter(amFilter, params.toMap, proxyBase) + amEndpoint.send(AddWebUIFilter(amFilter, params.toMap, proxyBase)) } } @@ -427,11 +438,30 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, * * Returns the user thread that was started. */ - private def startUserClass(): Thread = { - logInfo("Starting the user JAR in a separate Thread") + private def startUserApplication(): Thread = { + logInfo("Starting the user application in a separate Thread") System.setProperty("spark.executor.instances", args.numExecutors.toString) - val mainMethod = Class.forName(args.userClass, false, - Thread.currentThread.getContextClassLoader).getMethod("main", classOf[Array[String]]) + + val classpath = Client.getUserClasspath(sparkConf) + val urls = classpath.map { entry => + new URL("file:" + new File(entry.getPath()).getAbsolutePath()) + } + val userClassLoader = + if (Client.isUserClassPathFirst(sparkConf, isDriver = true)) { + new ChildFirstURLClassLoader(urls, Utils.getContextOrSparkClassLoader) + } else { + new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) + } + + if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) { + System.setProperty("spark.submit.pyFiles", + PythonRunner.formatPaths(args.pyFiles).mkString(",")) + } + if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { + // TODO(davies): add R dependencies here + } + val mainMethod = userClassLoader.loadClass(args.userClass) + .getMethod("main", classOf[Array[String]]) val userThread = new Thread { override def run() { @@ -446,53 +476,45 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, e.getCause match { case _: InterruptedException => // Reporter thread can interrupt to stop user class - case e: Exception => + case cause: Throwable => + logError("User class threw exception: " + cause, cause) finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_EXCEPTION_USER_CLASS, - "User class threw exception: " + e.getMessage) - // re-throw to get it logged - throw e + "User class threw exception: " + cause) } } } } + userThread.setContextClassLoader(userClassLoader) userThread.setName("Driver") userThread.start() userThread } /** - * Actor that communicates with the driver in client deploy mode. + * An [[RpcEndpoint]] that communicates with the driver's scheduler backend. */ - private class AMActor(driverUrl: String) extends Actor { - var driver: ActorSelection = _ - - override def preStart() = { - logInfo("Listen to driver: " + driverUrl) - driver = context.actorSelection(driverUrl) - // Send a hello message to establish the connection, after which - // we can monitor Lifecycle Events. - driver ! "Hello" - driver ! RegisterClusterManager - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - } + private class AMEndpoint( + override val rpcEnv: RpcEnv, driver: RpcEndpointRef, isClusterMode: Boolean) + extends RpcEndpoint with Logging { - override def receive = { - case x: DisassociatedEvent => - logInfo(s"Driver terminated or disconnected! Shutting down. $x") - finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + override def onStart(): Unit = { + driver.send(RegisterClusterManager(self)) + } + override def receive: PartialFunction[Any, Unit] = { case x: AddWebUIFilter => logInfo(s"Add WebUI Filter. $x") - driver ! x + driver.send(x) + } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RequestExecutors(requestedTotal) => - logInfo(s"Driver requested a total number of $requestedTotal executor(s).") Option(allocator) match { case Some(a) => a.requestTotalExecutors(requestedTotal) case None => logWarning("Container allocator is not ready to request executors yet.") } - sender ! true + context.reply(true) case KillExecutors(executorIds) => logInfo(s"Driver requested to kill executor(s) ${executorIds.mkString(", ")}.") @@ -500,7 +522,16 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, case Some(a) => executorIds.foreach(a.killExecutor) case None => logWarning("Container allocator is not ready to kill executors yet.") } - sender ! true + context.reply(true) + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + logInfo(s"Driver terminated or disconnected! Shutting down. $remoteAddress") + // In cluster mode, do not rely on the disassociated event to exit + // This avoids potentially reporting incorrect exit codes if the driver fails + if (!isClusterMode) { + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + } } } @@ -508,8 +539,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, object ApplicationMaster extends Logging { - val SHUTDOWN_HOOK_PRIORITY: Int = 30 - // exit codes for different causes, no reason behind the values private val EXIT_SUCCESS = 0 private val EXIT_UNCAUGHT_EXCEPTION = 10 @@ -521,7 +550,7 @@ object ApplicationMaster extends Logging { private var master: ApplicationMaster = _ - def main(args: Array[String]) = { + def main(args: Array[String]): Unit = { SignalLogger.register(log) val amArgs = new ApplicationMasterArguments(args) SparkHadoopUtil.get.runAsSparkUser { () => @@ -530,11 +559,11 @@ object ApplicationMaster extends Logging { } } - private[spark] def sparkContextInitialized(sc: SparkContext) = { + private[spark] def sparkContextInitialized(sc: SparkContext): Unit = { master.sparkContextInitialized(sc) } - private[spark] def sparkContextStopped(sc: SparkContext) = { + private[spark] def sparkContextStopped(sc: SparkContext): Boolean = { master.sparkContextStopped(sc) } @@ -546,7 +575,7 @@ object ApplicationMaster extends Logging { */ object ExecutorLauncher { - def main(args: Array[String]) = { + def main(args: Array[String]): Unit = { ApplicationMaster.main(args) } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index d76a63276d75..ae6dc1094d72 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -24,6 +24,9 @@ import collection.mutable.ArrayBuffer class ApplicationMasterArguments(val args: Array[String]) { var userJar: String = null var userClass: String = null + var primaryPyFile: String = null + var primaryRFile: String = null + var pyFiles: String = null var userArgs: Seq[String] = Seq[String]() var executorMemory = 1024 var executorCores = 1 @@ -48,6 +51,18 @@ class ApplicationMasterArguments(val args: Array[String]) { userClass = value args = tail + case ("--primary-py-file") :: value :: tail => + primaryPyFile = value + args = tail + + case ("--primary-r-file") :: value :: tail => + primaryRFile = value + args = tail + + case ("--py-files") :: value :: tail => + pyFiles = value + args = tail + case ("--args" | "--arg") :: value :: tail => userArgsBuffer += value args = tail @@ -69,6 +84,11 @@ class ApplicationMasterArguments(val args: Array[String]) { } } + if (primaryPyFile != null && primaryRFile != null) { + System.err.println("Cannot have primary-py-file and primary-r-file at the same time") + System.exit(-1) + } + userArgs = userArgsBuffer.readOnly } @@ -81,6 +101,10 @@ class ApplicationMasterArguments(val args: Array[String]) { |Options: | --jar JAR_PATH Path to your application's JAR file | --class CLASS_NAME Name of your application's main class + | --primary-py-file A main Python file + | --primary-r-file A main R file + | --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to + | place on the PYTHONPATH for Python apps. | --args ARGS Arguments to be passed to your application's main class. | Multiple invocations are possible, each will be passed in order. | --num-executors NUM Number of executors to start (Default: 2) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index d4eeccf64275..019afbd1a174 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -17,22 +17,29 @@ package org.apache.spark.deploy.yarn +import java.io.{File, FileOutputStream} import java.net.{InetAddress, UnknownHostException, URI, URISyntaxException} import java.nio.ByteBuffer +import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConversions._ -import scala.collection.mutable.{HashMap, ListBuffer, Map} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map} +import scala.reflect.runtime.universe import scala.util.{Try, Success, Failure} import com.google.common.base.Objects +import com.google.common.io.Files import org.apache.hadoop.io.DataOutputBuffer import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission +import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.Master import org.apache.hadoop.mapreduce.MRJobConfig import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.hadoop.security.token.Token import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment @@ -40,6 +47,7 @@ import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.{YarnClient, YarnClientApplication} import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException import org.apache.hadoop.yarn.util.Records import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException} @@ -66,16 +74,12 @@ private[spark] class Client( private val executorMemoryOverhead = args.executorMemoryOverhead // MB private val distCacheMgr = new ClientDistributedCacheManager() private val isClusterMode = args.isClusterMode + private val fireAndForget = isClusterMode && + !sparkConf.getBoolean("spark.yarn.submit.waitAppCompletion", true) def stop(): Unit = yarnClient.stop() - /* ------------------------------------------------------------------------------------- * - | The following methods have much in common in the stable and alpha versions of Client, | - | but cannot be implemented in the parent trait due to subtle API differences across | - | hadoop versions. | - * ------------------------------------------------------------------------------------- */ - /** * Submit an application running our ApplicationMaster to the ResourceManager. * @@ -183,8 +187,7 @@ private[spark] class Client( private[yarn] def copyFileToRemote( destDir: Path, srcPath: Path, - replication: Short, - setPerms: Boolean = false): Path = { + replication: Short): Path = { val destFs = destDir.getFileSystem(hadoopConf) val srcFs = srcPath.getFileSystem(hadoopConf) var destPath = srcPath @@ -193,9 +196,7 @@ private[spark] class Client( logInfo(s"Uploading resource $srcPath -> $destPath") FileUtil.copy(srcFs, srcPath, destFs, destPath, false, hadoopConf) destFs.setReplication(destPath, replication) - if (setPerms) { - destFs.setPermission(destPath, new FsPermission(APP_FILE_PERMISSION)) - } + destFs.setPermission(destPath, new FsPermission(APP_FILE_PERMISSION)) } else { logInfo(s"Source and destination file systems are the same. Not copying $srcPath") } @@ -219,7 +220,12 @@ private[spark] class Client( val fs = FileSystem.get(hadoopConf) val dst = new Path(fs.getHomeDirectory(), appStagingDir) val nns = getNameNodesToAccess(sparkConf) + dst + // Used to keep track of URIs added to the distributed cache. If the same URI is added + // multiple times, YARN will fail to launch containers for the app with an internal + // error. + val distributedUris = new HashSet[String] obtainTokensForNamenodes(nns, hadoopConf, credentials) + obtainTokenForHiveMetastore(hadoopConf, credentials) val replication = sparkConf.getInt("spark.yarn.submit.file.replication", fs.getDefaultReplication(dst)).toShort @@ -236,29 +242,41 @@ private[spark] class Client( "for alternatives.") } + def addDistributedUri(uri: URI): Boolean = { + val uriStr = uri.toString() + if (distributedUris.contains(uriStr)) { + logWarning(s"Resource $uri added multiple times to distributed cache.") + false + } else { + distributedUris += uriStr + true + } + } + /** * Copy the given main resource to the distributed cache if the scheme is not "local". * Otherwise, set the corresponding key in our SparkConf to handle it downstream. - * Each resource is represented by a 4-tuple of: + * Each resource is represented by a 3-tuple of: * (1) destination resource name, * (2) local path to the resource, - * (3) Spark property key to set if the scheme is not local, and - * (4) whether to set permissions for this resource + * (3) Spark property key to set if the scheme is not local */ List( - (SPARK_JAR, sparkJar(sparkConf), CONF_SPARK_JAR, false), - (APP_JAR, args.userJar, CONF_SPARK_USER_JAR, true), - ("log4j.properties", oldLog4jConf.orNull, null, false) - ).foreach { case (destName, _localPath, confKey, setPermissions) => + (SPARK_JAR, sparkJar(sparkConf), CONF_SPARK_JAR), + (APP_JAR, args.userJar, CONF_SPARK_USER_JAR), + ("log4j.properties", oldLog4jConf.orNull, null) + ).foreach { case (destName, _localPath, confKey) => val localPath: String = if (_localPath != null) _localPath.trim() else "" if (!localPath.isEmpty()) { val localURI = new URI(localPath) if (localURI.getScheme != LOCAL_SCHEME) { - val src = getQualifiedLocalPath(localURI, hadoopConf) - val destPath = copyFileToRemote(dst, src, replication, setPermissions) - val destFs = FileSystem.get(destPath.toUri(), hadoopConf) - distCacheMgr.addResource(destFs, hadoopConf, destPath, - localResources, LocalResourceType.FILE, destName, statCache) + if (addDistributedUri(localURI)) { + val src = getQualifiedLocalPath(localURI, hadoopConf) + val destPath = copyFileToRemote(dst, src, replication) + val destFs = FileSystem.get(destPath.toUri(), hadoopConf) + distCacheMgr.addResource(destFs, hadoopConf, destPath, + localResources, LocalResourceType.FILE, destName, statCache) + } } else if (confKey != null) { // If the resource is intended for local use only, handle this downstream // by setting the appropriate property @@ -267,6 +285,13 @@ private[spark] class Client( } } + createConfArchive().foreach { file => + require(addDistributedUri(file.toURI())) + val destPath = copyFileToRemote(dst, new Path(file.toURI()), replication) + distCacheMgr.addResource(fs, hadoopConf, destPath, localResources, LocalResourceType.ARCHIVE, + LOCALIZED_HADOOP_CONF_DIR, statCache, appMasterOnly = true) + } + /** * Do the same for any additional resources passed in through ClientArguments. * Each resource category is represented by a 3-tuple of: @@ -284,13 +309,15 @@ private[spark] class Client( flist.split(',').foreach { file => val localURI = new URI(file.trim()) if (localURI.getScheme != LOCAL_SCHEME) { - val localPath = new Path(localURI) - val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) - val destPath = copyFileToRemote(dst, localPath, replication) - distCacheMgr.addResource( - fs, hadoopConf, destPath, localResources, resType, linkname, statCache) - if (addToClasspath) { - cachedSecondaryJarLinks += linkname + if (addDistributedUri(localURI)) { + val localPath = new Path(localURI) + val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) + val destPath = copyFileToRemote(dst, localPath, replication) + distCacheMgr.addResource( + fs, hadoopConf, destPath, localResources, resType, linkname, statCache) + if (addToClasspath) { + cachedSecondaryJarLinks += linkname + } } } else if (addToClasspath) { // Resource is intended for local use only and should be added to the class path @@ -306,6 +333,57 @@ private[spark] class Client( localResources } + /** + * Create an archive with the Hadoop config files for distribution. + * + * These are only used by the AM, since executors will use the configuration object broadcast by + * the driver. The files are zipped and added to the job as an archive, so that YARN will explode + * it when distributing to the AM. This directory is then added to the classpath of the AM + * process, just to make sure that everybody is using the same default config. + * + * This follows the order of precedence set by the startup scripts, in which HADOOP_CONF_DIR + * shows up in the classpath before YARN_CONF_DIR. + * + * Currently this makes a shallow copy of the conf directory. If there are cases where a + * Hadoop config directory contains subdirectories, this code will have to be fixed. + */ + private def createConfArchive(): Option[File] = { + val hadoopConfFiles = new HashMap[String, File]() + Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR").foreach { envKey => + sys.env.get(envKey).foreach { path => + val dir = new File(path) + if (dir.isDirectory()) { + dir.listFiles().foreach { file => + if (!hadoopConfFiles.contains(file.getName())) { + hadoopConfFiles(file.getName()) = file + } + } + } + } + } + + if (!hadoopConfFiles.isEmpty) { + val hadoopConfArchive = File.createTempFile(LOCALIZED_HADOOP_CONF_DIR, ".zip", + new File(Utils.getLocalDir(sparkConf))) + + val hadoopConfStream = new ZipOutputStream(new FileOutputStream(hadoopConfArchive)) + try { + hadoopConfStream.setLevel(0) + hadoopConfFiles.foreach { case (name, file) => + hadoopConfStream.putNextEntry(new ZipEntry(name)) + Files.copy(file, hadoopConfStream) + hadoopConfStream.closeEntry() + } + } finally { + hadoopConfStream.close() + } + + Some(hadoopConfArchive) + } else { + None + } + } + /** * Set up the environment for launching our ApplicationMaster container. */ @@ -313,7 +391,7 @@ private[spark] class Client( logInfo("Setting up the launch environment for our AM container") val env = new HashMap[String, String]() val extraCp = sparkConf.getOption("spark.driver.extraClassPath") - populateClasspath(args, yarnConf, sparkConf, env, extraCp) + populateClasspath(args, yarnConf, sparkConf, env, true, extraCp) env("SPARK_YARN_MODE") = "true" env("SPARK_YARN_STAGING_DIR") = stagingDir env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName() @@ -400,7 +478,10 @@ private[spark] class Client( // Add Xmx for AM memory javaOpts += "-Xmx" + args.amMemory + "m" - val tmpDir = new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) + val tmpDir = new Path( + YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), + YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR + ) javaOpts += "-Djava.io.tmpdir=" + tmpDir // TODO: Remove once cpuset version is pushed out. @@ -415,6 +496,8 @@ private[spark] class Client( // In our expts, using (default) throughput collector has severe perf ramifications in // multi-tenant machines javaOpts += "-XX:+UseConcMarkSweepGC" + javaOpts += "-XX:MaxTenuringThreshold=31" + javaOpts += "-XX:SurvivorRatio=8" javaOpts += "-XX:+CMSIncrementalMode" javaOpts += "-XX:+CMSIncrementalPacing" javaOpts += "-XX:CMSIncrementalDutyCycleMin=0" @@ -430,10 +513,11 @@ private[spark] class Client( // Include driver-specific java options if we are launching a driver if (isClusterMode) { - sparkConf.getOption("spark.driver.extraJavaOptions") + val driverOpts = sparkConf.getOption("spark.driver.extraJavaOptions") .orElse(sys.env.get("SPARK_JAVA_OPTS")) - .map(Utils.splitCommandString).getOrElse(Seq.empty) - .foreach(opts => javaOpts += opts) + driverOpts.foreach { opts => + javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) + } val libraryPaths = Seq(sys.props.get("spark.driver.extraLibraryPath"), sys.props.get("spark.driver.libraryPath")).flatten if (libraryPaths.nonEmpty) { @@ -455,7 +539,7 @@ private[spark] class Client( val msg = s"$amOptsKey is not allowed to alter memory settings (was '$opts')." throw new SparkException(msg) } - javaOpts ++= Utils.splitCommandString(opts) + javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } } @@ -474,24 +558,50 @@ private[spark] class Client( } else { Nil } + val primaryPyFile = + if (args.primaryPyFile != null) { + Seq("--primary-py-file", args.primaryPyFile) + } else { + Nil + } + val pyFiles = + if (args.pyFiles != null) { + Seq("--py-files", args.pyFiles) + } else { + Nil + } + val primaryRFile = + if (args.primaryRFile != null) { + Seq("--primary-r-file", args.primaryRFile) + } else { + Nil + } val amClass = if (isClusterMode) { Class.forName("org.apache.spark.deploy.yarn.ApplicationMaster").getName } else { Class.forName("org.apache.spark.deploy.yarn.ExecutorLauncher").getName } + if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) { + args.userArgs = ArrayBuffer(args.primaryPyFile, args.pyFiles) ++ args.userArgs + } + if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { + args.userArgs = ArrayBuffer(args.primaryRFile) ++ args.userArgs + } val userArgs = args.userArgs.flatMap { arg => Seq("--arg", YarnSparkHadoopUtil.escapeForShell(arg)) } val amArgs = - Seq(amClass) ++ userClass ++ userJar ++ userArgs ++ - Seq( + Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ pyFiles ++ primaryRFile ++ + userArgs ++ Seq( "--executor-memory", args.executorMemory.toString + "m", "--executor-cores", args.executorCores.toString, "--num-executors ", args.numExecutors.toString) // Command for the ApplicationMaster - val commands = prefixEnv ++ Seq(Environment.JAVA_HOME.$() + "/bin/java", "-server") ++ + val commands = prefixEnv ++ Seq( + YarnSparkHadoopUtil.expandEnvironment(Environment.JAVA_HOME) + "/bin/java", "-server" + ) ++ javaOpts ++ amArgs ++ Seq( "1>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout", @@ -540,36 +650,25 @@ private[spark] class Client( var lastState: YarnApplicationState = null while (true) { Thread.sleep(interval) - val report = getApplicationReport(appId) + val report: ApplicationReport = + try { + getApplicationReport(appId) + } catch { + case e: ApplicationNotFoundException => + logError(s"Application $appId not found.") + return (YarnApplicationState.KILLED, FinalApplicationStatus.KILLED) + } val state = report.getYarnApplicationState if (logApplicationReport) { logInfo(s"Application report for $appId (state: $state)") - val details = Seq[(String, String)]( - ("client token", getClientToken(report)), - ("diagnostics", report.getDiagnostics), - ("ApplicationMaster host", report.getHost), - ("ApplicationMaster RPC port", report.getRpcPort.toString), - ("queue", report.getQueue), - ("start time", report.getStartTime.toString), - ("final status", report.getFinalApplicationStatus.toString), - ("tracking URL", report.getTrackingUrl), - ("user", report.getUser) - ) - - // Use more loggable format if value is null or empty - val formattedDetails = details - .map { case (k, v) => - val newValue = Option(v).filter(_.nonEmpty).getOrElse("N/A") - s"\n\t $k: $newValue" } - .mkString("") // If DEBUG is enabled, log report details every iteration // Otherwise, log them every time the application changes state if (log.isDebugEnabled) { - logDebug(formattedDetails) + logDebug(formatReportDetails(report)) } else if (lastState != state) { - logInfo(formattedDetails) + logInfo(formatReportDetails(report)) } } @@ -590,24 +689,57 @@ private[spark] class Client( throw new SparkException("While loop is depleted! This should never happen...") } + private def formatReportDetails(report: ApplicationReport): String = { + val details = Seq[(String, String)]( + ("client token", getClientToken(report)), + ("diagnostics", report.getDiagnostics), + ("ApplicationMaster host", report.getHost), + ("ApplicationMaster RPC port", report.getRpcPort.toString), + ("queue", report.getQueue), + ("start time", report.getStartTime.toString), + ("final status", report.getFinalApplicationStatus.toString), + ("tracking URL", report.getTrackingUrl), + ("user", report.getUser) + ) + + // Use more loggable format if value is null or empty + details.map { case (k, v) => + val newValue = Option(v).filter(_.nonEmpty).getOrElse("N/A") + s"\n\t $k: $newValue" + }.mkString("") + } + /** - * Submit an application to the ResourceManager and monitor its state. - * This continues until the application has exited for any reason. + * Submit an application to the ResourceManager. + * If set spark.yarn.submit.waitAppCompletion to true, it will stay alive + * reporting the application's status until the application has exited for any reason. + * Otherwise, the client process will exit after submission. * If the application finishes with a failed, killed, or undefined status, * throw an appropriate SparkException. */ def run(): Unit = { - val (yarnApplicationState, finalApplicationStatus) = monitorApplication(submitApplication()) - if (yarnApplicationState == YarnApplicationState.FAILED || - finalApplicationStatus == FinalApplicationStatus.FAILED) { - throw new SparkException("Application finished with failed status") - } - if (yarnApplicationState == YarnApplicationState.KILLED || - finalApplicationStatus == FinalApplicationStatus.KILLED) { - throw new SparkException("Application is killed") - } - if (finalApplicationStatus == FinalApplicationStatus.UNDEFINED) { - throw new SparkException("The final status of application is undefined") + val appId = submitApplication() + if (fireAndForget) { + val report = getApplicationReport(appId) + val state = report.getYarnApplicationState + logInfo(s"Application report for $appId (state: $state)") + logInfo(formatReportDetails(report)) + if (state == YarnApplicationState.FAILED || state == YarnApplicationState.KILLED) { + throw new SparkException(s"Application $appId finished with status: $state") + } + } else { + val (yarnApplicationState, finalApplicationStatus) = monitorApplication(appId) + if (yarnApplicationState == YarnApplicationState.FAILED || + finalApplicationStatus == FinalApplicationStatus.FAILED) { + throw new SparkException(s"Application $appId finished with failed status") + } + if (yarnApplicationState == YarnApplicationState.KILLED || + finalApplicationStatus == FinalApplicationStatus.KILLED) { + throw new SparkException(s"Application $appId is killed") + } + if (finalApplicationStatus == FinalApplicationStatus.UNDEFINED) { + throw new SparkException(s"The final status of application $appId is undefined") + } } } } @@ -660,6 +792,9 @@ object Client extends Logging { // Distribution-defined classpath to add to processes val ENV_DIST_CLASSPATH = "SPARK_DIST_CLASSPATH" + // Subdirectory where the user's hadoop config files will be placed. + val LOCALIZED_HADOOP_CONF_DIR = "__hadoop_conf__" + /** * Find the user-defined Spark jar if configured, or return the jar containing this * class if not. @@ -684,7 +819,7 @@ object Client extends Logging { * Return the path to the given application's staging directory. */ private def getAppStagingDir(appId: ApplicationId): String = { - SPARK_STAGING + Path.SEPARATOR + appId.toString() + Path.SEPARATOR + buildPath(SPARK_STAGING, appId.toString()) } /** @@ -760,56 +895,64 @@ object Client extends Logging { /** * Populate the classpath entry in the given environment map. - * This includes the user jar, Spark jar, and any extra application jars. + * + * User jars are generally not added to the JVM's system classpath; those are handled by the AM + * and executor backend. When the deprecated `spark.yarn.user.classpath.first` is used, user jars + * are included in the system classpath, though. The extra class path and other uploaded files are + * always made available through the system class path. + * + * @param args Client arguments (when starting the AM) or null (when starting executors). */ private[yarn] def populateClasspath( args: ClientArguments, conf: Configuration, sparkConf: SparkConf, env: HashMap[String, String], + isAM: Boolean, extraClassPath: Option[String] = None): Unit = { extraClassPath.foreach(addClasspathEntry(_, env)) - addClasspathEntry(Environment.PWD.$(), env) + addClasspathEntry( + YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env + ) + + if (isAM) { + addClasspathEntry( + YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR + + LOCALIZED_HADOOP_CONF_DIR, env) + } - // Normally the users app.jar is last in case conflicts with spark jars if (sparkConf.getBoolean("spark.yarn.user.classpath.first", false)) { - addUserClasspath(args, sparkConf, env) - addFileToClasspath(sparkJar(sparkConf), SPARK_JAR, env) - populateHadoopClasspath(conf, env) - } else { - addFileToClasspath(sparkJar(sparkConf), SPARK_JAR, env) - populateHadoopClasspath(conf, env) - addUserClasspath(args, sparkConf, env) + val userClassPath = + if (args != null) { + getUserClasspath(Option(args.userJar), Option(args.addJars)) + } else { + getUserClasspath(sparkConf) + } + userClassPath.foreach { x => + addFileToClasspath(x, null, env) + } } - - // Append all jar files under the working directory to the classpath. - addClasspathEntry(Environment.PWD.$() + Path.SEPARATOR + "*", env) + addFileToClasspath(new URI(sparkJar(sparkConf)), SPARK_JAR, env) + populateHadoopClasspath(conf, env) + sys.env.get(ENV_DIST_CLASSPATH).foreach(addClasspathEntry(_, env)) } /** - * Adds the user jars which have local: URIs (or alternate names, such as APP_JAR) explicitly - * to the classpath. + * Returns a list of URIs representing the user classpath. + * + * @param conf Spark configuration. */ - private def addUserClasspath( - args: ClientArguments, - conf: SparkConf, - env: HashMap[String, String]): Unit = { - - // If `args` is not null, we are launching an AM container. - // Otherwise, we are launching executor containers. - val (mainJar, secondaryJars) = - if (args != null) { - (args.userJar, args.addJars) - } else { - (conf.get(CONF_SPARK_USER_JAR, null), conf.get(CONF_SPARK_YARN_SECONDARY_JARS, null)) - } + def getUserClasspath(conf: SparkConf): Array[URI] = { + getUserClasspath(conf.getOption(CONF_SPARK_USER_JAR), + conf.getOption(CONF_SPARK_YARN_SECONDARY_JARS)) + } - addFileToClasspath(mainJar, APP_JAR, env) - if (secondaryJars != null) { - secondaryJars.split(",").filter(_.nonEmpty).foreach { jar => - addFileToClasspath(jar, null, env) - } - } + private def getUserClasspath( + mainJar: Option[String], + secondaryJars: Option[String]): Array[URI] = { + val mainUri = mainJar.orElse(Some(APP_JAR)).map(new URI(_)) + val secondaryUris = secondaryJars.map(_.split(",")).toSeq.flatten.map(new URI(_)) + (mainUri ++ secondaryUris).toArray } /** @@ -820,25 +963,19 @@ object Client extends Logging { * * If not a "local:" file and no alternate name, the environment is not modified. * - * @param path Path to add to classpath (optional). + * @param uri URI to add to classpath (optional). * @param fileName Alternate name for the file (optional). * @param env Map holding the environment variables. */ private def addFileToClasspath( - path: String, + uri: URI, fileName: String, env: HashMap[String, String]): Unit = { - if (path != null) { - scala.util.control.Exception.ignoring(classOf[URISyntaxException]) { - val uri = new URI(path) - if (uri.getScheme == LOCAL_SCHEME) { - addClasspathEntry(uri.getPath, env) - return - } - } - } - if (fileName != null) { - addClasspathEntry(Environment.PWD.$() + Path.SEPARATOR + fileName, env) + if (uri != null && uri.getScheme == LOCAL_SCHEME) { + addClasspathEntry(uri.getPath, env) + } else if (fileName != null) { + addClasspathEntry(buildPath( + YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), fileName), env) } } @@ -889,6 +1026,64 @@ object Client extends Logging { } } + /** + * Obtains token for the Hive metastore and adds them to the credentials. + */ + private def obtainTokenForHiveMetastore(conf: Configuration, credentials: Credentials) { + if (UserGroupInformation.isSecurityEnabled) { + val mirror = universe.runtimeMirror(getClass.getClassLoader) + + try { + val hiveClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.ql.metadata.Hive") + val hive = hiveClass.getMethod("get").invoke(null) + + val hiveConf = hiveClass.getMethod("getConf").invoke(hive) + val hiveConfClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf") + + val hiveConfGet = (param:String) => Option(hiveConfClass + .getMethod("get", classOf[java.lang.String]) + .invoke(hiveConf, param)) + + val metastore_uri = hiveConfGet("hive.metastore.uris") + + // Check for local metastore + if (metastore_uri != None && metastore_uri.get.toString.size > 0) { + val metastore_kerberos_principal_conf_var = mirror.classLoader + .loadClass("org.apache.hadoop.hive.conf.HiveConf$ConfVars") + .getField("METASTORE_KERBEROS_PRINCIPAL").get("varname").toString + + val principal = hiveConfGet(metastore_kerberos_principal_conf_var) + + val username = Option(UserGroupInformation.getCurrentUser().getUserName) + if (principal != None && username != None) { + val tokenStr = hiveClass.getMethod("getDelegationToken", + classOf[java.lang.String], classOf[java.lang.String]) + .invoke(hive, username.get, principal.get).asInstanceOf[java.lang.String] + + val hive2Token = new Token[DelegationTokenIdentifier]() + hive2Token.decodeFromUrlString(tokenStr) + credentials.addToken(new Text("hive.server2.delegation.token"),hive2Token) + logDebug("Added hive.Server2.delegation.token to conf.") + hiveClass.getMethod("closeCurrent").invoke(null) + } else { + logError("Username or principal == NULL") + logError(s"""username=${username.getOrElse("(NULL)")}""") + logError(s"""principal=${principal.getOrElse("(NULL)")}""") + throw new IllegalArgumentException("username and/or principal is equal to null!") + } + } else { + logDebug("HiveMetaStore configured in localmode") + } + } catch { + case e:java.lang.NoSuchMethodException => { logInfo("Hive Method not found " + e); return } + case e:java.lang.ClassNotFoundException => { logInfo("Hive Class not found " + e); return } + case e:Exception => { logError("Unexpected Exception " + e) + throw new RuntimeException("Unexpected exception", e) + } + } + } + } + /** * Return whether the two file systems are the same. */ @@ -934,4 +1129,23 @@ object Client extends Logging { new Path(qualifiedURI) } + /** + * Whether to consider jars provided by the user to have precedence over the Spark jars when + * loading user classes. + */ + def isUserClassPathFirst(conf: SparkConf, isDriver: Boolean): Boolean = { + if (isDriver) { + conf.getBoolean("spark.driver.userClassPathFirst", false) + } else { + conf.getBoolean("spark.executor.userClassPathFirst", false) + } + } + + /** + * Joins all the path components using Path.SEPARATOR. + */ + def buildPath(components: String*): String = { + components.mkString(Path.SEPARATOR) + } + } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index f96b24551227..1423533470fc 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -30,7 +30,10 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) var archives: String = null var userJar: String = null var userClass: String = null - var userArgs: Seq[String] = Seq[String]() + var pyFiles: String = null + var primaryPyFile: String = null + var primaryRFile: String = null + var userArgs: ArrayBuffer[String] = new ArrayBuffer[String]() var executorMemory = 1024 // MB var executorCores = 1 var numExecutors = DEFAULT_NUMBER_EXECUTORS @@ -75,14 +78,23 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) .orElse(sparkConf.getOption("spark.yarn.dist.archives").map(p => Utils.resolveURIs(p))) .orElse(sys.env.get("SPARK_YARN_DIST_ARCHIVES")) .orNull - // If dynamic allocation is enabled, start at the max number of executors + // If dynamic allocation is enabled, start at the configured initial number of executors. + // Default to minExecutors if no initialExecutors is set. if (isDynamicAllocationEnabled) { + val minExecutorsConf = "spark.dynamicAllocation.minExecutors" + val initialExecutorsConf = "spark.dynamicAllocation.initialExecutors" val maxExecutorsConf = "spark.dynamicAllocation.maxExecutors" - if (!sparkConf.contains(maxExecutorsConf)) { + val minNumExecutors = sparkConf.getInt(minExecutorsConf, 0) + val initialNumExecutors = sparkConf.getInt(initialExecutorsConf, minNumExecutors) + val maxNumExecutors = sparkConf.getInt(maxExecutorsConf, Integer.MAX_VALUE) + + // If defined, initial executors must be between min and max + if (initialNumExecutors < minNumExecutors || initialNumExecutors > maxNumExecutors) { throw new IllegalArgumentException( - s"$maxExecutorsConf must be set if dynamic allocation is enabled!") + s"$initialExecutorsConf must be between $minExecutorsConf and $maxNumExecutors!") } - numExecutors = sparkConf.get(maxExecutorsConf).toInt + + numExecutors = initialNumExecutors } } @@ -91,9 +103,13 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) * This is intended to be called only after the provided arguments have been parsed. */ private def validateArgs(): Unit = { - if (numExecutors <= 0) { + if (numExecutors < 0 || (!isDynamicAllocationEnabled && numExecutors == 0)) { throw new IllegalArgumentException( - "You must specify at least 1 executor!\n" + getUsageMessage()) + s""" + |Number of executors was $numExecutors, but must be at least 1 + |(or 0 if dynamic executor allocation is enabled). + |${getUsageMessage()} + """.stripMargin) } if (executorCores < sparkConf.getInt("spark.task.cpus", 1)) { throw new SparkException("Executor cores must not be less than " + @@ -123,7 +139,6 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) } private def parseArgs(inputArgs: List[String]): Unit = { - val userArgsBuffer = new ArrayBuffer[String]() var args = inputArgs while (!args.isEmpty) { @@ -136,11 +151,19 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) userClass = value args = tail + case ("--primary-py-file") :: value :: tail => + primaryPyFile = value + args = tail + + case ("--primary-r-file") :: value :: tail => + primaryRFile = value + args = tail + case ("--args" | "--arg") :: value :: tail => if (args(0) == "--args") { println("--args is deprecated. Use --arg instead.") } - userArgsBuffer += value + userArgs += value args = tail case ("--master-class" | "--am-class") :: value :: tail => @@ -196,6 +219,10 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) addJars = value args = tail + case ("--py-files") :: value :: tail => + pyFiles = value + args = tail + case ("--files") :: value :: tail => files = value args = tail @@ -211,7 +238,10 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) } } - userArgs = userArgsBuffer.readOnly + if (primaryPyFile != null && primaryRFile != null) { + throw new IllegalArgumentException("Cannot have primary-py-file and primary-r-file" + + " at the same time") + } } private def getUsageMessage(unknownParam: List[String] = null): String = { @@ -223,6 +253,8 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) | --jar JAR_PATH Path to your application's JAR file (required in yarn-cluster | mode) | --class CLASS_NAME Name of your application's main class (required) + | --primary-py-file A main Python file + | --primary-r-file A main R file | --arg ARG Argument to be passed to your application's main class. | Multiple invocations are possible, each will be passed in order. | --num-executors NUM Number of executors to start (Default: 2) @@ -235,6 +267,8 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) | 'default') | --addJars jars Comma separated list of local jars that want SparkContext.addJar | to work with. + | --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to + | place on the PYTHONPATH for Python apps. | --files files Comma separated list of files to be distributed with the job. | --archives archives Comma separated list of archives to be distributed with the job. """.stripMargin diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index c537da9f6755..9d04d241dae9 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy.yarn +import java.io.File import java.net.URI import java.nio.ByteBuffer @@ -37,7 +38,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{ConverterUtils, Records} -import org.apache.spark.{SecurityManager, SparkConf, Logging} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.network.util.JavaUtils class ExecutorRunnable( @@ -56,17 +57,17 @@ class ExecutorRunnable( var rpc: YarnRPC = YarnRPC.create(conf) var nmClient: NMClient = _ val yarnConf: YarnConfiguration = new YarnConfiguration(conf) - lazy val env = prepareEnvironment - - def run = { + lazy val env = prepareEnvironment(container) + + override def run(): Unit = { logInfo("Starting Executor Container") nmClient = NMClient.createNMClient() nmClient.init(yarnConf) nmClient.start() - startContainer + startContainer() } - def startContainer = { + def startContainer(): java.util.Map[String, ByteBuffer] = { logInfo("Setting up ContainerLaunchContext") val ctx = Records.newRecord(classOf[ContainerLaunchContext]) @@ -108,7 +109,13 @@ class ExecutorRunnable( } // Send the start request to the ContainerManager - nmClient.startContainer(container, ctx) + try { + nmClient.startContainer(container, ctx) + } catch { + case ex: Exception => + throw new SparkException(s"Exception while starting container ${container.getId}" + + s" on host $hostname", ex) + } } private def prepareCommand( @@ -128,21 +135,25 @@ class ExecutorRunnable( // Set the JVM memory val executorMemoryString = executorMemory + "m" - javaOpts += "-Xms" + executorMemoryString + " -Xmx" + executorMemoryString + " " + javaOpts += "-Xms" + executorMemoryString + javaOpts += "-Xmx" + executorMemoryString // Set extra Java options for the executor, if defined sys.props.get("spark.executor.extraJavaOptions").foreach { opts => - javaOpts += opts + javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } sys.env.get("SPARK_JAVA_OPTS").foreach { opts => - javaOpts += opts + javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } sys.props.get("spark.executor.extraLibraryPath").foreach { p => prefixEnv = Some(Utils.libraryPathEnvPrefix(Seq(p))) } javaOpts += "-Djava.io.tmpdir=" + - new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) + new Path( + YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), + YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR + ) // Certain configs need to be passed here because they are needed before the Executor // registers with the Scheduler and transfers the spark configs. Since the Executor backend @@ -170,18 +181,29 @@ class ExecutorRunnable( // The options are based on // http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use // %20the%20Concurrent%20Low%20Pause%20Collector|outline - javaOpts += " -XX:+UseConcMarkSweepGC " - javaOpts += " -XX:+CMSIncrementalMode " - javaOpts += " -XX:+CMSIncrementalPacing " - javaOpts += " -XX:CMSIncrementalDutyCycleMin=0 " - javaOpts += " -XX:CMSIncrementalDutyCycle=10 " + javaOpts += "-XX:+UseConcMarkSweepGC" + javaOpts += "-XX:+CMSIncrementalMode" + javaOpts += "-XX:+CMSIncrementalPacing" + javaOpts += "-XX:CMSIncrementalDutyCycleMin=0" + javaOpts += "-XX:CMSIncrementalDutyCycle=10" } */ // For log4j configuration to reference javaOpts += ("-Dspark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR) - val commands = prefixEnv ++ Seq(Environment.JAVA_HOME.$() + "/bin/java", + val userClassPath = Client.getUserClasspath(sparkConf).flatMap { uri => + val absPath = + if (new File(uri.getPath()).isAbsolute()) { + uri.getPath() + } else { + Client.buildPath(Environment.PWD.$(), uri.getPath()) + } + Seq("--user-class-path", "file:" + absPath) + }.toSeq + + val commands = prefixEnv ++ Seq( + YarnSparkHadoopUtil.expandEnvironment(Environment.JAVA_HOME) + "/bin/java", "-server", // Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling. // Not killing the task leaves various aspects of the executor and (to some extent) the jvm in @@ -191,11 +213,13 @@ class ExecutorRunnable( "-XX:OnOutOfMemoryError='kill %p'") ++ javaOpts ++ Seq("org.apache.spark.executor.CoarseGrainedExecutorBackend", - masterAddress.toString, - slaveId.toString, - hostname.toString, - executorCores.toString, - appId, + "--driver-url", masterAddress.toString, + "--executor-id", slaveId.toString, + "--hostname", hostname.toString, + "--cores", executorCores.toString, + "--app-id", appId) ++ + userClassPath ++ + Seq( "1>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout", "2>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") @@ -250,10 +274,10 @@ class ExecutorRunnable( localResources } - private def prepareEnvironment: HashMap[String, String] = { + private def prepareEnvironment(container: Container): HashMap[String, String] = { val env = new HashMap[String, String]() val extraCp = sparkConf.getOption("spark.executor.extraClassPath") - Client.populateClasspath(null, yarnConf, sparkConf, env, extraCp) + Client.populateClasspath(null, yarnConf, sparkConf, env, false, extraCp) sparkConf.getExecutorEnv.foreach { case (key, value) => // This assumes each executor environment variable set here is a path @@ -266,6 +290,23 @@ class ExecutorRunnable( YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs) } + // lookup appropriate http scheme for container log urls + val yarnHttpPolicy = yarnConf.get( + YarnConfiguration.YARN_HTTP_POLICY_KEY, + YarnConfiguration.YARN_HTTP_POLICY_DEFAULT + ) + val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://" + + // Add log urls + sys.env.get("SPARK_USER").foreach { user => + val containerId = ConverterUtils.toString(container.getId) + val address = container.getNodeHttpAddress + val baseUrl = s"$httpScheme$address/node/containerlogs/$containerId/$user" + + env("SPARK_LOG_URL_STDERR") = s"$baseUrl/stderr?start=0" + env("SPARK_LOG_URL_STDOUT") = s"$baseUrl/stdout?start=0" + } + System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k, v) => env(k) = v } env } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index d00f29665a58..b8f42dadcb46 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -32,9 +32,12 @@ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.util.RackResolver -import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.log4j.{Level, Logger} + +import org.apache.spark.{SparkEnv, Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.util.AkkaUtils /** * YarnAllocator is charged with requesting containers from the YARN ResourceManager and deciding @@ -60,9 +63,13 @@ private[yarn] class YarnAllocator( import YarnAllocator._ + // RackResolver logs an INFO message whenever it resolves a rack, which is way too often. + if (Logger.getLogger(classOf[RackResolver]).getLevel == null) { + Logger.getLogger(classOf[RackResolver]).setLevel(Level.WARN) + } + // Visible for testing. - val allocatedHostToContainersMap = - new HashMap[String, collection.mutable.Set[ContainerId]] + val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]] val allocatedContainerToHostMap = new HashMap[ContainerId, String] // Containers that we no longer care about. We've either already told the RM to release them or @@ -76,10 +83,11 @@ private[yarn] class YarnAllocator( private var executorIdCounter = 0 @volatile private var numExecutorsFailed = 0 - @volatile private var maxExecutors = args.numExecutors + @volatile private var targetNumExecutors = args.numExecutors // Keep track of which container is running which executor to remove the executors later - private val executorIdToContainer = new HashMap[String, Container] + // Visible for testing. + private[yarn] val executorIdToContainer = new HashMap[String, Container] // Executor memory in MB. protected val executorMemory = args.executorMemory @@ -99,10 +107,12 @@ private[yarn] class YarnAllocator( new ThreadFactoryBuilder().setNameFormat("ContainerLauncher #%d").setDaemon(true).build()) launcherPool.allowCoreThreadTimeOut(true) - private val driverUrl = "akka.tcp://sparkDriver@%s:%s/user/%s".format( + private val driverUrl = AkkaUtils.address( + AkkaUtils.protocol(securityMgr.akkaSSLOptions.enabled), + SparkEnv.driverActorSystemName, sparkConf.get("spark.driver.host"), sparkConf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) // For testing private val launchContainers = sparkConf.getBoolean("spark.yarn.launchContainers", true) @@ -123,10 +133,15 @@ private[yarn] class YarnAllocator( amClient.getMatchingRequests(RM_REQUEST_PRIORITY, location, resource).map(_.size).sum /** - * Request as many executors from the ResourceManager as needed to reach the desired total. + * Request as many executors from the ResourceManager as needed to reach the desired total. If + * the requested total is smaller than the current number of running executors, no executors will + * be killed. */ def requestTotalExecutors(requestedTotal: Int): Unit = synchronized { - maxExecutors = requestedTotal + if (requestedTotal != targetNumExecutors) { + logInfo(s"Driver requested a total number of $requestedTotal executor(s).") + targetNumExecutors = requestedTotal + } } /** @@ -137,8 +152,6 @@ private[yarn] class YarnAllocator( val container = executorIdToContainer.remove(executorId).get internalReleaseContainer(container) numExecutorsRunning -= 1 - maxExecutors -= 1 - assert(maxExecutors >= 0, "Allocator killed more executors than are allocated!") } else { logWarning(s"Attempted to kill unknown executor $executorId!") } @@ -153,15 +166,8 @@ private[yarn] class YarnAllocator( * This must be synchronized because variables read in this method are mutated by other methods. */ def allocateResources(): Unit = synchronized { - val numPendingAllocate = getNumPendingAllocate - val missing = maxExecutors - numPendingAllocate - numExecutorsRunning + updateResourceRequests() - if (missing > 0) { - logInfo(s"Will request $missing executor containers, each with ${resource.getVirtualCores} " + - s"cores and ${resource.getMemory} MB memory including $memoryOverhead MB overhead") - } - - addResourceRequests(missing) val progressIndicator = 0.1f // Poll the ResourceManager. This doubles as a heartbeat if there are no pending container // requests. @@ -191,15 +197,36 @@ private[yarn] class YarnAllocator( } /** - * Request numExecutors additional containers from YARN. Visible for testing. + * Update the set of container requests that we will sync with the RM based on the number of + * executors we have currently running and our target number of executors. + * + * Visible for testing. */ - def addResourceRequests(numExecutors: Int): Unit = { - for (i <- 0 until numExecutors) { - val request = new ContainerRequest(resource, null, null, RM_REQUEST_PRIORITY) - amClient.addContainerRequest(request) - val nodes = request.getNodes - val hostStr = if (nodes == null || nodes.isEmpty) "Any" else nodes.last - logInfo("Container request (host: %s, capability: %s".format(hostStr, resource)) + def updateResourceRequests(): Unit = { + val numPendingAllocate = getNumPendingAllocate + val missing = targetNumExecutors - numPendingAllocate - numExecutorsRunning + + if (missing > 0) { + logInfo(s"Will request $missing executor containers, each with ${resource.getVirtualCores} " + + s"cores and ${resource.getMemory} MB memory including $memoryOverhead MB overhead") + + for (i <- 0 until missing) { + val request = new ContainerRequest(resource, null, null, RM_REQUEST_PRIORITY) + amClient.addContainerRequest(request) + val nodes = request.getNodes + val hostStr = if (nodes == null || nodes.isEmpty) "Any" else nodes.last + logInfo(s"Container request (host: $hostStr, capability: $resource)") + } + } else if (missing < 0) { + val numToCancel = math.min(numPendingAllocate, -missing) + logInfo(s"Canceling requests for $numToCancel executor containers") + + val matchingRequests = amClient.getMatchingRequests(RM_REQUEST_PRIORITY, ANY_HOST, resource) + if (!matchingRequests.isEmpty) { + matchingRequests.head.take(numToCancel).foreach(amClient.removeContainerRequest) + } else { + logWarning("Expected to find pending requests, but found none.") + } } } @@ -256,7 +283,7 @@ private[yarn] class YarnAllocator( * containersToUse or remaining. * * @param allocatedContainer container that was given to us by YARN - * @location resource name, either a node, rack, or * + * @param location resource name, either a node, rack, or * * @param containersToUse list of containers that will be used * @param remaining list of containers that will not be used */ @@ -265,8 +292,14 @@ private[yarn] class YarnAllocator( location: String, containersToUse: ArrayBuffer[Container], remaining: ArrayBuffer[Container]): Unit = { + // SPARK-6050: certain Yarn configurations return a virtual core count that doesn't match the + // request; for example, capacity scheduler + DefaultResourceCalculator. So match on requested + // memory, but use the asked vcore count for matching, effectively disabling matching on vcore + // count. + val matchingResource = Resource.newInstance(allocatedContainer.getResource.getMemory, + resource.getVirtualCores) val matchingRequests = amClient.getMatchingRequests(allocatedContainer.getPriority, location, - allocatedContainer.getResource) + matchingResource) // Match the allocation to a request if (!matchingRequests.isEmpty) { @@ -284,7 +317,7 @@ private[yarn] class YarnAllocator( private def runAllocatedContainers(containersToUse: ArrayBuffer[Container]): Unit = { for (container <- containersToUse) { numExecutorsRunning += 1 - assert(numExecutorsRunning <= maxExecutors) + assert(numExecutorsRunning <= targetNumExecutors) val executorHostname = container.getNodeId.getHost val containerId = container.getId executorIdCounter += 1 @@ -293,6 +326,7 @@ private[yarn] class YarnAllocator( assert(container.getResource.getMemory >= resource.getMemory) logInfo("Launching container %s for on host %s".format(containerId, executorHostname)) + executorIdToContainer(executorId) = container val containerSet = allocatedHostToContainersMap.getOrElseUpdate(executorHostname, new HashSet[ContainerId]) @@ -319,7 +353,8 @@ private[yarn] class YarnAllocator( } } - private def processCompletedContainers(completedContainers: Seq[ContainerStatus]): Unit = { + // Visible for testing. + private[yarn] def processCompletedContainers(completedContainers: Seq[ContainerStatus]): Unit = { for (completedContainer <- completedContainers) { val containerId = completedContainer.getContainerId diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 4bff84612361..5881dc5ffa3a 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -17,22 +17,21 @@ package org.apache.spark.deploy.yarn -import java.lang.{Boolean => JBoolean} import java.io.File -import java.util.{Collections, Set => JSet} import java.util.regex.Matcher import java.util.regex.Pattern -import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable.HashMap +import scala.util.Try import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.api.ApplicationConstants +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.records.{Priority, ApplicationAccessType} -import org.apache.hadoop.yarn.util.RackResolver import org.apache.hadoop.conf.Configuration import org.apache.spark.{SecurityManager, SparkConf} @@ -87,10 +86,10 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { object YarnSparkHadoopUtil { // Additional memory overhead - // 7% was arrived at experimentally. In the interest of minimizing memory waste while covering + // 10% was arrived at experimentally. In the interest of minimizing memory waste while covering // the common cases. Memory overhead tends to grow with container size. - val MEMORY_OVERHEAD_FACTOR = 0.07 + val MEMORY_OVERHEAD_FACTOR = 0.10 val MEMORY_OVERHEAD_MIN = 384 val ANY_HOST = "*" @@ -106,7 +105,7 @@ object YarnSparkHadoopUtil { * If the map already contains this key, append the value to the existing value instead. */ def addPathToEnvironment(env: HashMap[String, String], key: String, value: String): Unit = { - val newValue = if (env.contains(key)) { env(key) + File.pathSeparator + value } else value + val newValue = if (env.contains(key)) { env(key) + getClassPathSeparator + value } else value env.put(key, newValue) } @@ -186,4 +185,30 @@ object YarnSparkHadoopUtil { ) } + /** + * Expand environment variable using Yarn API. + * If environment.$$() is implemented, return the result of it. + * Otherwise, return the result of environment.$() + * Note: $$() is added in Hadoop 2.4. + */ + private lazy val expandMethod = + Try(classOf[Environment].getMethod("$$")) + .getOrElse(classOf[Environment].getMethod("$")) + + def expandEnvironment(environment: Environment): String = + expandMethod.invoke(environment).asInstanceOf[String] + + /** + * Get class path separator using Yarn API. + * If ApplicationConstants.CLASS_PATH_SEPARATOR is implemented, return it. + * Otherwise, return File.pathSeparator + * Note: CLASS_PATH_SEPARATOR is added in Hadoop 2.4. + */ + private lazy val classPathSeparatorField = + Try(classOf[ApplicationConstants].getField("CLASS_PATH_SEPARATOR")) + .getOrElse(classOf[File].getField("pathSeparator")) + + def getClassPathSeparator(): String = { + classPathSeparatorField.get(null).asInstanceOf[String] + } } diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 690f927e938c..99c05329b4d7 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler.cluster import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} +import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException import org.apache.spark.{SparkException, Logging, SparkContext} import org.apache.spark.deploy.yarn.{Client, ClientArguments} @@ -33,7 +34,7 @@ private[spark] class YarnClientSchedulerBackend( private var client: Client = null private var appId: ApplicationId = null - @volatile private var stopping: Boolean = false + private var monitorThread: Thread = null /** * Create a Yarn client to submit an application to the ResourceManager. @@ -56,7 +57,8 @@ private[spark] class YarnClientSchedulerBackend( client = new Client(args, conf) appId = client.submitApplication() waitForApplication() - asyncMonitorApplication() + monitorThread = asyncMonitorApplication() + monitorThread.start() } /** @@ -78,18 +80,12 @@ private[spark] class YarnClientSchedulerBackend( ) // Warn against the following deprecated environment variables: env var -> suggestion val deprecatedEnvVars = Map( - "SPARK_MASTER_MEMORY" -> "SPARK_DRIVER_MEMORY or --driver-memory through spark-submit", "SPARK_WORKER_INSTANCES" -> "SPARK_WORKER_INSTANCES or --num-executors through spark-submit", "SPARK_WORKER_MEMORY" -> "SPARK_EXECUTOR_MEMORY or --executor-memory through spark-submit", "SPARK_WORKER_CORES" -> "SPARK_EXECUTOR_CORES or --executor-cores through spark-submit") - // Do the same for deprecated properties: property -> suggestion - val deprecatedProps = Map("spark.master.memory" -> "--driver-memory through spark-submit") optionTuples.foreach { case (optionName, envVar, sparkProp) => if (sc.getConf.contains(sparkProp)) { extraArgs += (optionName, sc.getConf.get(sparkProp)) - if (deprecatedProps.contains(sparkProp)) { - logWarning(s"NOTE: $sparkProp is deprecated. Use ${deprecatedProps(sparkProp)} instead.") - } } else if (System.getenv(envVar) != null) { extraArgs += (optionName, System.getenv(envVar)) if (deprecatedEnvVars.contains(envVar)) { @@ -128,28 +124,22 @@ private[spark] class YarnClientSchedulerBackend( * If the application has exited for any reason, stop the SparkContext. * This assumes both `client` and `appId` have already been set. */ - private def asyncMonitorApplication(): Unit = { + private def asyncMonitorApplication(): Thread = { assert(client != null && appId != null, "Application has not been submitted yet!") val t = new Thread { override def run() { - while (!stopping) { - val report = client.getApplicationReport(appId) - val state = report.getYarnApplicationState() - if (state == YarnApplicationState.FINISHED || - state == YarnApplicationState.KILLED || - state == YarnApplicationState.FAILED) { - logError(s"Yarn application has already exited with state $state!") - sc.stop() - stopping = true - } - Thread.sleep(1000L) + try { + val (state, _) = client.monitorApplication(appId, logApplicationReport = false) + logError(s"Yarn application has already exited with state $state!") + sc.stop() + } catch { + case e: InterruptedException => logInfo("Interrupting monitor thread") } - Thread.currentThread().interrupt() } } t.setName("Yarn application state monitor") t.setDaemon(true) - t.start() + t } /** @@ -157,7 +147,7 @@ private[spark] class YarnClientSchedulerBackend( */ override def stop() { assert(client != null, "Attempted to stop this scheduler before starting it!") - stopping = true + monitorThread.interrupt() super.stop() client.stop() logInfo("Stopped") diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala index be55d26f1cf6..72ec4d6b34af 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala @@ -17,33 +17,17 @@ package org.apache.spark.scheduler.cluster -import org.apache.hadoop.yarn.util.RackResolver - import org.apache.spark._ import org.apache.spark.deploy.yarn.ApplicationMaster -import org.apache.spark.scheduler.TaskSchedulerImpl -import org.apache.spark.util.Utils /** * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of * ApplicationMaster, etc is done */ -private[spark] class YarnClusterScheduler(sc: SparkContext) extends TaskSchedulerImpl(sc) { +private[spark] class YarnClusterScheduler(sc: SparkContext) extends YarnScheduler(sc) { logInfo("Created YarnClusterScheduler") - // Nothing else for now ... initialize application master : which needs a SparkContext to - // determine how to allocate. - // Note that only the first creation of a SparkContext influences (and ideally, there must be - // only one SparkContext, right ?). Subsequent creations are ignored since executors are already - // allocated by then. - - // By default, rack is unknown - override def getRackForHost(hostPort: String): Option[String] = { - val host = Utils.parseHostPort(hostPort)._1 - Option(RackResolver.resolve(sc.hadoopConfiguration, host).getNetworkLocation) - } - override def postStartHook() { ApplicationMaster.sparkContextInitialized(sc) super.postStartHook() diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala similarity index 77% rename from yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala rename to yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala index 2fa24cc43325..4ebf3af12b38 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala @@ -19,14 +19,18 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.yarn.util.RackResolver +import org.apache.log4j.{Level, Logger} + import org.apache.spark._ import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.util.Utils -/** - * This scheduler launches executors through Yarn - by calling into Client to launch the Spark AM. - */ -private[spark] class YarnClientClusterScheduler(sc: SparkContext) extends TaskSchedulerImpl(sc) { +private[spark] class YarnScheduler(sc: SparkContext) extends TaskSchedulerImpl(sc) { + + // RackResolver logs an INFO message whenever it resolves a rack, which is way too often. + if (Logger.getLogger(classOf[RackResolver]).getLevel == null) { + Logger.getLogger(classOf[RackResolver]).setLevel(Level.WARN) + } // By default, rack is unknown override def getRackForHost(hostPort: String): Option[String] = { diff --git a/yarn/src/test/resources/log4j.properties b/yarn/src/test/resources/log4j.properties index 287c8e356350..6b8a5dbf6373 100644 --- a/yarn/src/test/resources/log4j.properties +++ b/yarn/src/test/resources/log4j.properties @@ -16,7 +16,7 @@ # # Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=INFO, file +log4j.rootCategory=DEBUG, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=true log4j.appender.file.file=target/unit-tests.log @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.apache.hadoop=WARN diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index aad50015b717..a51c2005cb47 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -20,6 +20,11 @@ package org.apache.spark.deploy.yarn import java.io.File import java.net.URI +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ HashMap => MutableHashMap } +import scala.reflect.ClassTag +import scala.util.Try + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.MRJobConfig @@ -28,20 +33,20 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.mockito.Matchers._ import org.mockito.Mockito._ - - -import org.scalatest.FunSuite -import org.scalatest.Matchers - -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ HashMap => MutableHashMap } -import scala.reflect.ClassTag -import scala.util.Try +import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} import org.apache.spark.{SparkException, SparkConf} import org.apache.spark.util.Utils -class ClientSuite extends FunSuite with Matchers { +class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { + + override def beforeAll(): Unit = { + System.setProperty("SPARK_YARN_MODE", "true") + } + + override def afterAll(): Unit = { + System.clearProperty("SPARK_YARN_MODE") + } test("default Yarn application classpath") { Client.getDefaultYarnApplicationClasspath should be(Some(Fixtures.knownDefYarnAppCP)) @@ -84,12 +89,13 @@ class ClientSuite extends FunSuite with Matchers { test("Local jar URIs") { val conf = new Configuration() val sparkConf = new SparkConf().set(Client.CONF_SPARK_JAR, SPARK) + .set("spark.yarn.user.classpath.first", "true") val env = new MutableHashMap[String, String]() val args = new ClientArguments(Array("--jar", USER, "--addJars", ADDED), sparkConf) - Client.populateClasspath(args, conf, sparkConf, env) + Client.populateClasspath(args, conf, sparkConf, env, true) - val cp = env("CLASSPATH").split(File.pathSeparator) + val cp = env("CLASSPATH").split(":|;|") s"$SPARK,$USER,$ADDED".split(",").foreach({ entry => val uri = new URI(entry) if (Client.LOCAL_SCHEME.equals(uri.getScheme())) { @@ -98,8 +104,16 @@ class ClientSuite extends FunSuite with Matchers { cp should not contain (uri.getPath()) } }) - cp should contain (Environment.PWD.$()) - cp should contain (s"${Environment.PWD.$()}${File.separator}*") + val pwdVar = + if (classOf[Environment].getMethods().exists(_.getName == "$$")) { + "{{PWD}}" + } else if (Utils.isWindows) { + "%PWD%" + } else { + Environment.PWD.$() + } + cp should contain(pwdVar) + cp should contain (s"$pwdVar${Path.SEPARATOR}${Client.LOCALIZED_HADOOP_CONF_DIR}") cp should not contain (Client.SPARK_JAR) cp should not contain (Client.APP_JAR) } @@ -111,7 +125,7 @@ class ClientSuite extends FunSuite with Matchers { val client = spy(new Client(args, conf, sparkConf)) doReturn(new Path("/")).when(client).copyFileToRemote(any(classOf[Path]), - any(classOf[Path]), anyShort(), anyBoolean()) + any(classOf[Path]), anyShort()) val tempDir = Utils.createTempDir() try { @@ -221,19 +235,26 @@ class ClientSuite extends FunSuite with Matchers { testCode(conf) } - def newEnv = MutableHashMap[String, String]() + def newEnv: MutableHashMap[String, String] = MutableHashMap[String, String]() - def classpath(env: MutableHashMap[String, String]) = env(Environment.CLASSPATH.name).split(":|;") + def classpath(env: MutableHashMap[String, String]): Array[String] = + env(Environment.CLASSPATH.name).split(":|;|") - def flatten(a: Option[Seq[String]], b: Option[Seq[String]]) = (a ++ b).flatten.toArray + def flatten(a: Option[Seq[String]], b: Option[Seq[String]]): Array[String] = + (a ++ b).flatten.toArray - def getFieldValue[A, B](clazz: Class[_], field: String, defaults: => B)(mapTo: A => B): B = - Try(clazz.getField(field)).map(_.get(null).asInstanceOf[A]).toOption.map(mapTo).getOrElse(defaults) + def getFieldValue[A, B](clazz: Class[_], field: String, defaults: => B)(mapTo: A => B): B = { + Try(clazz.getField(field)) + .map(_.get(null).asInstanceOf[A]) + .toOption + .map(mapTo) + .getOrElse(defaults) + } def getFieldValue2[A: ClassTag, A1: ClassTag, B]( clazz: Class[_], field: String, - defaults: => B)(mapTo: A => B)(mapTo1: A1 => B) : B = { + defaults: => B)(mapTo: A => B)(mapTo1: A1 => B): B = { Try(clazz.getField(field)).map(_.get(null)).map { case v: A => mapTo(v) case v1: A1 => mapTo1(v1) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 024b25f9d336..455f1019d86d 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -79,7 +79,7 @@ class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach } class MockSplitInfo(host: String) extends SplitInfo(null, host, null, 1, null) { - override def equals(other: Any) = false + override def equals(other: Any): Boolean = false } def createAllocator(maxExecutors: Int = 5): YarnAllocator = { @@ -107,8 +107,8 @@ class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach test("single container allocated") { // request a single container and receive it - val handler = createAllocator() - handler.addResourceRequests(1) + val handler = createAllocator(1) + handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) handler.getNumPendingAllocate should be (1) @@ -118,13 +118,15 @@ class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach handler.getNumExecutorsRunning should be (1) handler.allocatedContainerToHostMap.get(container.getId).get should be ("host1") handler.allocatedHostToContainersMap.get("host1").get should contain (container.getId) - rmClient.getMatchingRequests(container.getPriority, "host1", containerResource).size should be (0) + + val size = rmClient.getMatchingRequests(container.getPriority, "host1", containerResource).size + size should be (0) } test("some containers allocated") { // request a few containers and receive some of them - val handler = createAllocator() - handler.addResourceRequests(4) + val handler = createAllocator(4) + handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) handler.getNumPendingAllocate should be (4) @@ -144,7 +146,7 @@ class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach test("receive more containers than requested") { val handler = createAllocator(2) - handler.addResourceRequests(2) + handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) handler.getNumPendingAllocate should be (2) @@ -162,6 +164,72 @@ class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach handler.allocatedHostToContainersMap.contains("host4") should be (false) } + test("decrease total requested executors") { + val handler = createAllocator(4) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (4) + + handler.requestTotalExecutors(3) + handler.updateResourceRequests() + handler.getNumPendingAllocate should be (3) + + val container = createContainer("host1") + handler.handleAllocatedContainers(Array(container)) + + handler.getNumExecutorsRunning should be (1) + handler.allocatedContainerToHostMap.get(container.getId).get should be ("host1") + handler.allocatedHostToContainersMap.get("host1").get should contain (container.getId) + + handler.requestTotalExecutors(2) + handler.updateResourceRequests() + handler.getNumPendingAllocate should be (1) + } + + test("decrease total requested executors to less than currently running") { + val handler = createAllocator(4) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (4) + + handler.requestTotalExecutors(3) + handler.updateResourceRequests() + handler.getNumPendingAllocate should be (3) + + val container1 = createContainer("host1") + val container2 = createContainer("host2") + handler.handleAllocatedContainers(Array(container1, container2)) + + handler.getNumExecutorsRunning should be (2) + + handler.requestTotalExecutors(1) + handler.updateResourceRequests() + handler.getNumPendingAllocate should be (0) + handler.getNumExecutorsRunning should be (2) + } + + test("kill executors") { + val handler = createAllocator(4) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (4) + + val container1 = createContainer("host1") + val container2 = createContainer("host2") + handler.handleAllocatedContainers(Array(container1, container2)) + + handler.requestTotalExecutors(1) + handler.executorIdToContainer.keys.foreach { id => handler.killExecutor(id ) } + + val statuses = Seq(container1, container2).map { c => + ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Finished", 0) + } + handler.updateResourceRequests() + handler.processCompletedContainers(statuses.toSeq) + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (1) + } + test("memory exceeded diagnostic regexes") { val diagnostics = "Container [pid=12465,containerID=container_1412887393566_0003_01_000002] is running " + diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index d79b85e867fc..3877da4120e7 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -17,26 +17,34 @@ package org.apache.spark.deploy.yarn -import java.io.File +import java.io.{File, FileOutputStream, OutputStreamWriter} +import java.util.Properties import java.util.concurrent.TimeUnit import scala.collection.JavaConversions._ +import scala.collection.mutable -import com.google.common.base.Charsets +import com.google.common.base.Charsets.UTF_8 +import com.google.common.io.ByteStreams import com.google.common.io.Files -import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} - import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.server.MiniYARNCluster +import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException} -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException, TestUtils} +import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.scheduler.{SparkListenerJobStart, SparkListener, SparkListenerExecutorAdded} import org.apache.spark.util.Utils +/** + * Integration tests for YARN; these tests use a mini Yarn cluster to run Spark-on-YARN + * applications, and require the Spark assembly to be built before they can be successfully + * run. + */ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers with Logging { - // log4j configuration for the Yarn containers, so that their output is collected - // by Yarn instead of trying to overwrite unit-tests.log. + // log4j configuration for the YARN containers, so that their output is collected + // by YARN instead of trying to overwrite unit-tests.log. private val LOG4J_CONF = """ |log4j.rootCategory=DEBUG, console |log4j.appender.console=org.apache.log4j.ConsoleAppender @@ -45,24 +53,42 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n """.stripMargin + private val TEST_PYFILE = """ + |import sys + |from operator import add + | + |from pyspark import SparkConf , SparkContext + |if __name__ == "__main__": + | if len(sys.argv) != 2: + | print >> sys.stderr, "Usage: test.py [result file]" + | exit(-1) + | sc = SparkContext(conf=SparkConf()) + | status = open(sys.argv[1],'w') + | result = "failure" + | rdd = sc.parallelize(range(10)) + | cnt = rdd.count() + | if cnt == 10: + | result = "success" + | status.write(result) + | status.close() + | sc.stop() + """.stripMargin + private var yarnCluster: MiniYARNCluster = _ private var tempDir: File = _ private var fakeSparkJar: File = _ - private var oldConf: Map[String, String] = _ + private var hadoopConfDir: File = _ + private var logConfDir: File = _ override def beforeAll() { - tempDir = Utils.createTempDir() + super.beforeAll() - val logConfDir = new File(tempDir, "log4j") + tempDir = Utils.createTempDir() + logConfDir = new File(tempDir, "log4j") logConfDir.mkdir() val logConfFile = new File(logConfDir, "log4j.properties") - Files.write(LOG4J_CONF, logConfFile, Charsets.UTF_8) - - val childClasspath = logConfDir.getAbsolutePath() + File.pathSeparator + - sys.props("java.class.path") - - oldConf = sys.props.filter { case (k, v) => k.startsWith("spark.") }.toMap + Files.write(LOG4J_CONF, logConfFile, UTF_8) yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1) yarnCluster.init(new YarnConfiguration()) @@ -93,57 +119,150 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit } logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}") - config.foreach { e => - sys.props += ("spark.hadoop." + e.getKey() -> e.getValue()) - } fakeSparkJar = File.createTempFile("sparkJar", null, tempDir) - sys.props += ("spark.yarn.jar" -> ("local:" + fakeSparkJar.getAbsolutePath())) - sys.props += ("spark.executor.instances" -> "1") - sys.props += ("spark.driver.extraClassPath" -> childClasspath) - sys.props += ("spark.executor.extraClassPath" -> childClasspath) - - super.beforeAll() + hadoopConfDir = new File(tempDir, Client.LOCALIZED_HADOOP_CONF_DIR) + assert(hadoopConfDir.mkdir()) + File.createTempFile("token", ".txt", hadoopConfDir) } override def afterAll() { yarnCluster.stop() - sys.props.retain { case (k, v) => !k.startsWith("spark.") } - sys.props ++= oldConf super.afterAll() } test("run Spark in yarn-client mode") { - var result = File.createTempFile("result", null, tempDir) - YarnClusterDriver.main(Array("yarn-client", result.getAbsolutePath())) - checkResult(result) + testBasicYarnApp(true) } test("run Spark in yarn-cluster mode") { - val main = YarnClusterDriver.getClass.getName().stripSuffix("$") + testBasicYarnApp(false) + } + + test("run Spark in yarn-cluster mode unsuccessfully") { + // Don't provide arguments so the driver will fail. + val exception = intercept[SparkException] { + runSpark(false, mainClassName(YarnClusterDriver.getClass)) + fail("Spark application should have failed.") + } + } + + // Enable this once fix SPARK-6700 + test("run Python application in yarn-cluster mode") { + val primaryPyFile = new File(tempDir, "test.py") + Files.write(TEST_PYFILE, primaryPyFile, UTF_8) + val pyFile = new File(tempDir, "test2.py") + Files.write(TEST_PYFILE, pyFile, UTF_8) var result = File.createTempFile("result", null, tempDir) - val args = Array("--class", main, - "--jar", "file:" + fakeSparkJar.getAbsolutePath(), - "--arg", "yarn-cluster", - "--arg", result.getAbsolutePath(), - "--num-executors", "1") - Client.main(args) + // The sbt assembly does not include pyspark / py4j python dependencies, so we need to + // propagate SPARK_HOME so that those are added to PYTHONPATH. See PythonUtils.scala. + val sparkHome = sys.props("spark.test.home") + val extraConf = Map( + "spark.executorEnv.SPARK_HOME" -> sparkHome, + "spark.yarn.appMasterEnv.SPARK_HOME" -> sparkHome) + + runSpark(false, primaryPyFile.getAbsolutePath(), + sparkArgs = Seq("--py-files", pyFile.getAbsolutePath()), + appArgs = Seq(result.getAbsolutePath()), + extraConf = extraConf) checkResult(result) } - test("run Spark in yarn-cluster mode unsuccessfully") { - val main = YarnClusterDriver.getClass.getName().stripSuffix("$") + test("user class path first in client mode") { + testUseClassPathFirst(true) + } - // Use only one argument so the driver will fail - val args = Array("--class", main, - "--jar", "file:" + fakeSparkJar.getAbsolutePath(), - "--arg", "yarn-cluster", - "--num-executors", "1") - val exception = intercept[SparkException] { - Client.main(args) + test("user class path first in cluster mode") { + testUseClassPathFirst(false) + } + + private def testBasicYarnApp(clientMode: Boolean): Unit = { + var result = File.createTempFile("result", null, tempDir) + runSpark(clientMode, mainClassName(YarnClusterDriver.getClass), + appArgs = Seq(result.getAbsolutePath())) + checkResult(result) + } + + private def testUseClassPathFirst(clientMode: Boolean): Unit = { + // Create a jar file that contains a different version of "test.resource". + val originalJar = TestUtils.createJarWithFiles(Map("test.resource" -> "ORIGINAL"), tempDir) + val userJar = TestUtils.createJarWithFiles(Map("test.resource" -> "OVERRIDDEN"), tempDir) + val driverResult = File.createTempFile("driver", null, tempDir) + val executorResult = File.createTempFile("executor", null, tempDir) + runSpark(clientMode, mainClassName(YarnClasspathTest.getClass), + appArgs = Seq(driverResult.getAbsolutePath(), executorResult.getAbsolutePath()), + extraClassPath = Seq(originalJar.getPath()), + extraJars = Seq("local:" + userJar.getPath()), + extraConf = Map( + "spark.driver.userClassPathFirst" -> "true", + "spark.executor.userClassPathFirst" -> "true")) + checkResult(driverResult, "OVERRIDDEN") + checkResult(executorResult, "OVERRIDDEN") + } + + private def runSpark( + clientMode: Boolean, + klass: String, + appArgs: Seq[String] = Nil, + sparkArgs: Seq[String] = Nil, + extraClassPath: Seq[String] = Nil, + extraJars: Seq[String] = Nil, + extraConf: Map[String, String] = Map()): Unit = { + val master = if (clientMode) "yarn-client" else "yarn-cluster" + val props = new Properties() + + props.setProperty("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath()) + + val childClasspath = logConfDir.getAbsolutePath() + + File.pathSeparator + + sys.props("java.class.path") + + File.pathSeparator + + extraClassPath.mkString(File.pathSeparator) + props.setProperty("spark.driver.extraClassPath", childClasspath) + props.setProperty("spark.executor.extraClassPath", childClasspath) + + // SPARK-4267: make sure java options are propagated correctly. + props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"") + props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"") + + yarnCluster.getConfig().foreach { e => + props.setProperty("spark.hadoop." + e.getKey(), e.getValue()) + } + + sys.props.foreach { case (k, v) => + if (k.startsWith("spark.")) { + props.setProperty(k, v) + } } - assert(Utils.exceptionString(exception).contains("Application finished with failed status")) + + extraConf.foreach { case (k, v) => props.setProperty(k, v) } + + val propsFile = File.createTempFile("spark", ".properties", tempDir) + val writer = new OutputStreamWriter(new FileOutputStream(propsFile), UTF_8) + props.store(writer, "Spark properties.") + writer.close() + + val extraJarArgs = if (!extraJars.isEmpty()) Seq("--jars", extraJars.mkString(",")) else Nil + val mainArgs = + if (klass.endsWith(".py")) { + Seq(klass) + } else { + Seq("--class", klass, fakeSparkJar.getAbsolutePath()) + } + val argv = + Seq( + new File(sys.props("spark.test.home"), "bin/spark-submit").getAbsolutePath(), + "--master", master, + "--num-executors", "1", + "--properties-file", propsFile.getAbsolutePath()) ++ + extraJarArgs ++ + sparkArgs ++ + mainArgs ++ + appArgs + + Utils.executeAndGetOutput(argv, + extraEnvironment = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath())) } /** @@ -152,37 +271,103 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit * the tests enforce that something is written to a file after everything is ok to indicate * that the job succeeded. */ - private def checkResult(result: File) = { - var resultString = Files.toString(result, Charsets.UTF_8) - resultString should be ("success") + private def checkResult(result: File): Unit = { + checkResult(result, "success") + } + + private def checkResult(result: File, expected: String): Unit = { + var resultString = Files.toString(result, UTF_8) + resultString should be (expected) + } + + private def mainClassName(klass: Class[_]): String = { + klass.getName().stripSuffix("$") } } +private[spark] class SaveExecutorInfo extends SparkListener { + val addedExecutorInfos = mutable.Map[String, ExecutorInfo]() + + override def onExecutorAdded(executor: SparkListenerExecutorAdded) { + addedExecutorInfos(executor.executorId) = executor.executorInfo + } +} + private object YarnClusterDriver extends Logging with Matchers { - def main(args: Array[String]) = { - if (args.length != 2) { + val WAIT_TIMEOUT_MILLIS = 10000 + + def main(args: Array[String]): Unit = { + if (args.length != 1) { System.err.println( s""" |Invalid command line: ${args.mkString(" ")} | - |Usage: YarnClusterDriver [master] [result file] + |Usage: YarnClusterDriver [result file] """.stripMargin) System.exit(1) } - val sc = new SparkContext(new SparkConf().setMaster(args(0)) + val sc = new SparkContext(new SparkConf() + .set("spark.extraListeners", classOf[SaveExecutorInfo].getName) .setAppName("yarn \"test app\" 'with quotes' and \\back\\slashes and $dollarSigns")) - val status = new File(args(1)) + val status = new File(args(0)) var result = "failure" try { val data = sc.parallelize(1 to 4, 4).collect().toSet + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) data should be (Set(1, 2, 3, 4)) result = "success" } finally { sc.stop() - Files.write(result, status, Charsets.UTF_8) + Files.write(result, status, UTF_8) + } + + // verify log urls are present + val listeners = sc.listenerBus.findListenersByClass[SaveExecutorInfo] + assert(listeners.size === 1) + val listener = listeners(0) + val executorInfos = listener.addedExecutorInfos.values + assert(executorInfos.nonEmpty) + executorInfos.foreach { info => + assert(info.logUrlMap.nonEmpty) + } + } + +} + +private object YarnClasspathTest { + + def main(args: Array[String]): Unit = { + if (args.length != 2) { + System.err.println( + s""" + |Invalid command line: ${args.mkString(" ")} + | + |Usage: YarnClasspathTest [driver result file] [executor result file] + """.stripMargin) + System.exit(1) + } + + readResource(args(0)) + val sc = new SparkContext(new SparkConf()) + try { + sc.parallelize(Seq(1)).foreach { x => readResource(args(1)) } + } finally { + sc.stop() + } + } + + private def readResource(resultPath: String): Unit = { + var result = "failure" + try { + val ccl = Thread.currentThread().getContextClassLoader() + val resource = ccl.getResourceAsStream("test.resource") + val bytes = ByteStreams.toByteArray(resource) + result = new String(bytes, 0, bytes.length, UTF_8) + } finally { + Files.write(result, new File(resultPath), UTF_8) } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index 2cc5abb3a890..9395316b71ff 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -20,12 +20,15 @@ package org.apache.spark.deploy.yarn import java.io.{File, IOException} import com.google.common.io.{ByteStreams, Files} +import org.apache.hadoop.yarn.api.ApplicationConstants +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.conf.YarnConfiguration import org.scalatest.{FunSuite, Matchers} import org.apache.hadoop.yarn.api.records.ApplicationAccessType import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.util.Utils class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { @@ -43,11 +46,11 @@ class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { logWarning("Cannot execute bash, skipping bash tests.") } - def bashTest(name: String)(fn: => Unit) = + def bashTest(name: String)(fn: => Unit): Unit = if (hasBash) test(name)(fn) else ignore(name)(fn) bashTest("shell script escaping") { - val scriptFile = File.createTempFile("script.", ".sh") + val scriptFile = File.createTempFile("script.", ".sh", Utils.createTempDir()) val args = Array("arg1", "${arg.2}", "\"arg3\"", "'arg4'", "$arg5", "\\arg6") try { val argLine = args.map(a => YarnSparkHadoopUtil.escapeForShell(a)).mkString(" ") @@ -148,4 +151,26 @@ class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { } } + + test("test expandEnvironment result") { + val target = Environment.PWD + if (classOf[Environment].getMethods().exists(_.getName == "$$")) { + YarnSparkHadoopUtil.expandEnvironment(target) should be ("{{" + target + "}}") + } else if (Utils.isWindows) { + YarnSparkHadoopUtil.expandEnvironment(target) should be ("%" + target + "%") + } else { + YarnSparkHadoopUtil.expandEnvironment(target) should be ("$" + target) + } + + } + + test("test getClassPathSeparator result") { + if (classOf[ApplicationConstants].getFields().exists(_.getName == "CLASS_PATH_SEPARATOR")) { + YarnSparkHadoopUtil.getClassPathSeparator() should be ("") + } else if (Utils.isWindows) { + YarnSparkHadoopUtil.getClassPathSeparator() should be (";") + } else { + YarnSparkHadoopUtil.getClassPathSeparator() should be (":") + } + } }