diff --git a/.rat-excludes b/.rat-excludes index 994c7e86f8a91..8f2722cbd001f 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -80,8 +80,6 @@ local-1425081759269/* local-1426533911241/* local-1426633911242/* local-1430917381534/* -local-1430917381535_1 -local-1430917381535_2 DESCRIPTION NAMESPACE test_support/* diff --git a/R/create-docs.sh b/R/create-docs.sh index 6a4687b06ecb9..4194172a2e115 100755 --- a/R/create-docs.sh +++ b/R/create-docs.sh @@ -23,14 +23,14 @@ # After running this script the html docs can be found in # $SPARK_HOME/R/pkg/html -set -o pipefail -set -e - # Figure out where the script is export FWDIR="$(cd "`dirname "$0"`"; pwd)" pushd $FWDIR -# Install the package (this will also generate the Rd files) +# Generate Rd file +Rscript -e 'library(devtools); devtools::document(pkg="./pkg", roclets=c("rd"))' + +# Install the package ./install-dev.sh # Now create HTML files diff --git a/R/install-dev.sh b/R/install-dev.sh index 1edd551f8d243..55ed6f4be1a4a 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -26,20 +26,11 @@ # NOTE(shivaram): Right now we use $SPARK_HOME/R/lib to be the installation directory # to load the SparkR package on the worker nodes. -set -o pipefail -set -e FWDIR="$(cd `dirname $0`; pwd)" LIB_DIR="$FWDIR/lib" mkdir -p $LIB_DIR -pushd $FWDIR - -# Generate Rd files if devtools is installed -Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }' - -# Install SparkR to $LIB_DIR +# Install R R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ - -popd diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 22a4b5bf86ebd..88e1a508f37c4 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -452,7 +452,7 @@ dropTempTable <- function(sqlContext, tableName) { #' df <- read.df(sqlContext, "path/to/file.json", source = "json") #' } -read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { +read.df <- function(sqlContext, path = NULL, source = NULL, ...) { options <- varargsToEnv(...) if (!is.null(path)) { options[['path']] <- path @@ -462,21 +462,15 @@ read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", "org.apache.spark.sql.parquet") } - if (!is.null(schema)) { - stopifnot(class(schema) == "structType") - sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source, - schema$jobj, options) - } else { - sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source, options) - } + sdf <- callJMethod(sqlContext, "load", source, options) dataFrame(sdf) } #' @aliases loadDF #' @export -loadDF <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { - read.df(sqlContext, path, source, schema, ...) +loadDF <- function(sqlContext, path = NULL, source = NULL, ...) { + read.df(sqlContext, path, source, ...) } #' Create an external table diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R index 773b6ecf582d9..ca94f1d4e7fd5 100644 --- a/R/pkg/inst/profile/shell.R +++ b/R/pkg/inst/profile/shell.R @@ -24,7 +24,7 @@ old <- getOption("defaultPackages") options(defaultPackages = c(old, "SparkR")) - sc <- SparkR::sparkR.init() + sc <- SparkR::sparkR.init(Sys.getenv("MASTER", unset = "")) assign("sc", sc, envir=.GlobalEnv) sqlContext <- SparkR::sparkRSQL.init(sc) assign("sqlContext", sqlContext, envir=.GlobalEnv) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 30edfc8a7bd94..d2d82e791e876 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -504,19 +504,6 @@ test_that("read.df() from json file", { df <- read.df(sqlContext, jsonPath, "json") expect_true(inherits(df, "DataFrame")) expect_true(count(df) == 3) - - # Check if we can apply a user defined schema - schema <- structType(structField("name", type = "string"), - structField("age", type = "double")) - - df1 <- read.df(sqlContext, jsonPath, "json", schema) - expect_true(inherits(df1, "DataFrame")) - expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) - - # Run the same with loadDF - df2 <- loadDF(sqlContext, jsonPath, "json", schema) - expect_true(inherits(df2, "DataFrame")) - expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) }) test_that("write.df() as parquet file", { diff --git a/assembly/pom.xml b/assembly/pom.xml index e9c6d26ccddc7..626c8577e31fe 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../pom.xml diff --git a/bagel/pom.xml b/bagel/pom.xml index ed5c37e595a96..132cd433d78a2 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../pom.xml diff --git a/bin/pyspark b/bin/pyspark index f9dbddfa53560..7cb19c51b43a2 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -17,10 +17,24 @@ # limitations under the License. # +# Figure out where Spark is installed export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" source "$SPARK_HOME"/bin/load-spark-env.sh -export _SPARK_CMD_USAGE="Usage: ./bin/pyspark [options]" + +function usage() { + if [ -n "$1" ]; then + echo $1 + fi + echo "Usage: ./bin/pyspark [options]" 1>&2 + "$SPARK_HOME"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + exit $2 +} +export -f usage + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + usage +fi # In Spark <= 1.1, setting IPYTHON=1 would cause the driver to be launched using the `ipython` # executable, while the worker would still be launched using PYSPARK_PYTHON. diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 45e9e3def5121..09b4149c2a439 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -21,7 +21,6 @@ rem Figure out where the Spark framework is installed set SPARK_HOME=%~dp0.. call %SPARK_HOME%\bin\load-spark-env.cmd -set _SPARK_CMD_USAGE=Usage: bin\pyspark.cmd [options] rem Figure out which Python to use. if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( diff --git a/bin/spark-class b/bin/spark-class index 7bb1afe4b44f5..c49d97ce5cf25 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -16,12 +16,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # +set -e # Figure out where Spark is installed export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" . "$SPARK_HOME"/bin/load-spark-env.sh +if [ -z "$1" ]; then + echo "Usage: spark-class []" 1>&2 + exit 1 +fi + # Find the java binary if [ -n "${JAVA_HOME}" ]; then RUNNER="${JAVA_HOME}/bin/java" @@ -92,4 +98,9 @@ CMD=() while IFS= read -d '' -r ARG; do CMD+=("$ARG") done < <("$RUNNER" -cp "$LAUNCH_CLASSPATH" org.apache.spark.launcher.Main "$@") -exec "${CMD[@]}" + +if [ "${CMD[0]}" = "usage" ]; then + "${CMD[@]}" +else + exec "${CMD[@]}" +fi diff --git a/bin/spark-shell b/bin/spark-shell index a6dc863d83fc6..b3761b5e1375b 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -29,7 +29,20 @@ esac set -o posix export FWDIR="$(cd "`dirname "$0"`"/..; pwd)" -export _SPARK_CMD_USAGE="Usage: ./bin/spark-shell [options]" + +usage() { + if [ -n "$1" ]; then + echo "$1" + fi + echo "Usage: ./bin/spark-shell [options]" + "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + exit "$2" +} +export -f usage + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + usage "" 0 +fi # SPARK-4161: scala does not assume use of the java classpath, # so we need to add the "-Dscala.usejavacp=true" flag manually. We diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd index 251309d67f860..00fd30fa38d36 100644 --- a/bin/spark-shell2.cmd +++ b/bin/spark-shell2.cmd @@ -18,7 +18,12 @@ rem limitations under the License. rem set SPARK_HOME=%~dp0.. -set _SPARK_CMD_USAGE=Usage: .\bin\spark-shell.cmd [options] + +echo "%*" | findstr " \<--help\> \<-h\>" >nul +if %ERRORLEVEL% equ 0 ( + call :usage + exit /b 0 +) rem SPARK-4161: scala does not assume use of the java classpath, rem so we need to add the "-Dscala.usejavacp=true" flag manually. We @@ -32,4 +37,16 @@ if "x%SPARK_SUBMIT_OPTS%"=="x" ( set SPARK_SUBMIT_OPTS="%SPARK_SUBMIT_OPTS% -Dscala.usejavacp=true" :run_shell -%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main %* +call %SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main %* +set SPARK_ERROR_LEVEL=%ERRORLEVEL% +if not "x%SPARK_LAUNCHER_USAGE_ERROR%"=="x" ( + call :usage + exit /b 1 +) +exit /b %SPARK_ERROR_LEVEL% + +:usage +echo %SPARK_LAUNCHER_USAGE_ERROR% +echo "Usage: .\bin\spark-shell.cmd [options]" >&2 +call %SPARK_HOME%\bin\spark-submit2.cmd --help 2>&1 | findstr /V "Usage" 1>&2 +goto :eof diff --git a/bin/spark-sql b/bin/spark-sql index 4ea7bc6e39c07..ca1729f4cfcb4 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -17,6 +17,41 @@ # limitations under the License. # +# +# Shell script for starting the Spark SQL CLI + +# Enter posix mode for bash +set -o posix + +# NOTE: This exact class name is matched downstream by SparkSubmit. +# Any changes need to be reflected there. +export CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" + +# Figure out where Spark is installed export FWDIR="$(cd "`dirname "$0"`"/..; pwd)" -export _SPARK_CMD_USAGE="Usage: ./bin/spark-sql [options] [cli option]" -exec "$FWDIR"/bin/spark-submit --class org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver "$@" + +function usage { + if [ -n "$1" ]; then + echo "$1" + fi + echo "Usage: ./bin/spark-sql [options] [cli option]" + pattern="usage" + pattern+="\|Spark assembly has been built with Hive" + pattern+="\|NOTE: SPARK_PREPEND_CLASSES is set" + pattern+="\|Spark Command: " + pattern+="\|--help" + pattern+="\|=======" + + "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + echo + echo "CLI options:" + "$FWDIR"/bin/spark-class "$CLASS" --help 2>&1 | grep -v "$pattern" 1>&2 + exit "$2" +} +export -f usage + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + usage "" 0 +fi + +exec "$FWDIR"/bin/spark-submit --class "$CLASS" "$@" diff --git a/bin/spark-submit b/bin/spark-submit index 255378b0f077c..0e0afe71a0f05 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -22,4 +22,16 @@ SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" # disable randomized hash for string in Python 3.3+ export PYTHONHASHSEED=0 +# Only define a usage function if an upstream script hasn't done so. +if ! type -t usage >/dev/null 2>&1; then + usage() { + if [ -n "$1" ]; then + echo "$1" + fi + "$SPARK_HOME"/bin/spark-class org.apache.spark.deploy.SparkSubmit --help + exit "$2" + } + export -f usage +fi + exec "$SPARK_HOME"/bin/spark-class org.apache.spark.deploy.SparkSubmit "$@" diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd index 651376e526928..d3fc4a5cc3f6e 100644 --- a/bin/spark-submit2.cmd +++ b/bin/spark-submit2.cmd @@ -24,4 +24,15 @@ rem disable randomized hash for string in Python 3.3+ set PYTHONHASHSEED=0 set CLASS=org.apache.spark.deploy.SparkSubmit -%~dp0spark-class2.cmd %CLASS% %* +call %~dp0spark-class2.cmd %CLASS% %* +set SPARK_ERROR_LEVEL=%ERRORLEVEL% +if not "x%SPARK_LAUNCHER_USAGE_ERROR%"=="x" ( + call :usage + exit /b 1 +) +exit /b %SPARK_ERROR_LEVEL% + +:usage +echo %SPARK_LAUNCHER_USAGE_ERROR% +call %SPARK_HOME%\bin\spark-class2.cmd %CLASS% --help +goto :eof diff --git a/bin/sparkR b/bin/sparkR index 464c29f369424..8c918e2b09aef 100755 --- a/bin/sparkR +++ b/bin/sparkR @@ -17,7 +17,23 @@ # limitations under the License. # +# Figure out where Spark is installed export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$SPARK_HOME"/bin/load-spark-env.sh -export _SPARK_CMD_USAGE="Usage: ./bin/sparkR [options]" + +function usage() { + if [ -n "$1" ]; then + echo $1 + fi + echo "Usage: ./bin/sparkR [options]" 1>&2 + "$SPARK_HOME"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + exit $2 +} +export -f usage + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + usage +fi + exec "$SPARK_HOME"/bin/spark-submit sparkr-shell-main "$@" diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index 7f17bc7eea4f5..7de0011a48ca8 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -4,7 +4,7 @@ # divided into instances which correspond to internal components. # Each instance can be configured to report its metrics to one or more sinks. # Accepted values for [instance] are "master", "worker", "executor", "driver", -# and "applications". A wildcard "*" can be used as an instance name, in +# and "applications". A wild card "*" can be used as an instance name, in # which case all instances will inherit the supplied property. # # Within an instance, a "source" specifies a particular set of grouped metrics. @@ -32,7 +32,7 @@ # name (see examples below). # 2. Some sinks involve a polling period. The minimum allowed polling period # is 1 second. -# 3. Wildcard properties can be overridden by more specific properties. +# 3. Wild card properties can be overridden by more specific properties. # For example, master.sink.console.period takes precedence over # *.sink.console.period. # 4. A metrics specific configuration @@ -47,13 +47,6 @@ # instance master and applications. MetricsServlet may not be configured by self. # -## List of available common sources and their properties. - -# org.apache.spark.metrics.source.JvmSource -# Note: Currently, JvmSource is the only available common source -# to add additionaly to an instance, to enable this, -# set the "class" option to its fully qulified class name (see examples below) - ## List of available sinks and their properties. # org.apache.spark.metrics.sink.ConsoleSink diff --git a/core/pom.xml b/core/pom.xml index 40a64beccdc24..a02184222e9f0 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../pom.xml @@ -481,6 +481,29 @@ + + sparkr-docs + + + + org.codehaus.mojo + exec-maven-plugin + + + sparkr-pkg-docs + compile + + exec + + + + + ..${path.separator}R${path.separator}create-docs${script.extension} + + + + + diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 9939103bb0903..9514604752640 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -150,13 +150,6 @@ private[spark] class ExecutorAllocationManager( // Metric source for ExecutorAllocationManager to expose internal status to MetricsSystem. val executorAllocationManagerSource = new ExecutorAllocationManagerSource - // Whether we are still waiting for the initial set of executors to be allocated. - // While this is true, we will not cancel outstanding executor requests. This is - // set to false when: - // (1) a stage is submitted, or - // (2) an executor idle timeout has elapsed. - @volatile private var initializing: Boolean = true - /** * Verify that the settings specified through the config are valid. * If not, throw an appropriate exception. @@ -247,7 +240,6 @@ private[spark] class ExecutorAllocationManager( removeTimes.retain { case (executorId, expireTime) => val expired = now >= expireTime if (expired) { - initializing = false removeExecutor(executorId) } !expired @@ -269,23 +261,15 @@ private[spark] class ExecutorAllocationManager( private def updateAndSyncNumExecutorsTarget(now: Long): Int = synchronized { val maxNeeded = maxNumExecutorsNeeded - if (initializing) { - // Do not change our target while we are still initializing, - // Otherwise the first job may have to ramp up unnecessarily - 0 - } else if (maxNeeded < numExecutorsTarget) { + if (maxNeeded < numExecutorsTarget) { // The target number exceeds the number we actually need, so stop adding new // executors and inform the cluster manager to cancel the extra pending requests val oldNumExecutorsTarget = numExecutorsTarget numExecutorsTarget = math.max(maxNeeded, minNumExecutors) + client.requestTotalExecutors(numExecutorsTarget) numExecutorsToAdd = 1 - - // If the new target has not changed, avoid sending a message to the cluster manager - if (numExecutorsTarget < oldNumExecutorsTarget) { - client.requestTotalExecutors(numExecutorsTarget) - logDebug(s"Lowering target number of executors to $numExecutorsTarget (previously " + - s"$oldNumExecutorsTarget) because not all requested executors are actually needed") - } + logInfo(s"Lowering target number of executors to $numExecutorsTarget because " + + s"not all requests are actually needed (previously $oldNumExecutorsTarget)") numExecutorsTarget - oldNumExecutorsTarget } else if (addTime != NOT_SET && now >= addTime) { val delta = addExecutors(maxNeeded) @@ -493,7 +477,6 @@ private[spark] class ExecutorAllocationManager( private var numRunningTasks: Int = _ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { - initializing = false val stageId = stageSubmitted.stageInfo.stageId val numTasks = stageSubmitted.stageInfo.numTasks allocationManager.synchronized { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index a0eae774268ed..8cf4d58847d8e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -82,13 +82,13 @@ object SparkSubmit { private val CLASS_NOT_FOUND_EXIT_STATUS = 101 // Exposed for testing - private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode) + private[spark] var exitFn: () => Unit = () => System.exit(1) private[spark] var printStream: PrintStream = System.err private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str) private[spark] def printErrorAndExit(str: String): Unit = { printStream.println("Error: " + str) printStream.println("Run with --help for usage help or --verbose for debug output") - exitFn(1) + exitFn() } private[spark] def printVersionAndExit(): Unit = { printStream.println("""Welcome to @@ -99,7 +99,7 @@ object SparkSubmit { /_/ """.format(SPARK_VERSION)) printStream.println("Type --help for more information.") - exitFn(0) + exitFn() } def main(args: Array[String]): Unit = { @@ -160,7 +160,7 @@ object SparkSubmit { // detect exceptions with empty stack traces here, and treat them differently. if (e.getStackTrace().length == 0) { printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") - exitFn(1) + exitFn() } else { throw e } @@ -425,6 +425,7 @@ object SparkSubmit { // Yarn client only OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"), OptionAssigner(args.numExecutors, YARN, CLIENT, sysProp = "spark.executor.instances"), + OptionAssigner(args.executorCores, YARN, CLIENT, sysProp = "spark.executor.cores"), OptionAssigner(args.files, YARN, CLIENT, sysProp = "spark.yarn.dist.files"), OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"), OptionAssigner(args.principal, YARN, CLIENT, sysProp = "spark.yarn.principal"), @@ -445,7 +446,7 @@ object SparkSubmit { OptionAssigner(args.keytab, YARN, CLUSTER, clOption = "--keytab"), // Other options - OptionAssigner(args.executorCores, STANDALONE | YARN, ALL_DEPLOY_MODES, + OptionAssigner(args.executorCores, STANDALONE, ALL_DEPLOY_MODES, sysProp = "spark.executor.cores"), OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES, sysProp = "spark.executor.memory"), @@ -699,7 +700,7 @@ object SparkSubmit { /** * Return whether the given main class represents a sql shell. */ - private[deploy] def isSqlShell(mainClass: String): Boolean = { + private def isSqlShell(mainClass: String): Boolean = { mainClass == "org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index b7429a901e162..cc6a7bd9f4119 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -17,15 +17,12 @@ package org.apache.spark.deploy -import java.io.{ByteArrayOutputStream, PrintStream} -import java.lang.reflect.InvocationTargetException import java.net.URI import java.util.{List => JList} import java.util.jar.JarFile import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.io.Source import org.apache.spark.deploy.SparkSubmitAction._ import org.apache.spark.launcher.SparkSubmitArgumentsParser @@ -415,9 +412,6 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case VERSION => SparkSubmit.printVersionAndExit() - case USAGE_ERROR => - printUsageAndExit(1) - case _ => throw new IllegalArgumentException(s"Unexpected argument '$opt'.") } @@ -455,14 +449,11 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S if (unknownParam != null) { outStream.println("Unknown/unsupported param " + unknownParam) } - val command = sys.env.get("_SPARK_CMD_USAGE").getOrElse( + outStream.println( """Usage: spark-submit [options] [app arguments] |Usage: spark-submit --kill [submission ID] --master [spark://...] - |Usage: spark-submit --status [submission ID] --master [spark://...]""".stripMargin) - outStream.println(command) - - outStream.println( - """ + |Usage: spark-submit --status [submission ID] --master [spark://...] + | |Options: | --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local. | --deploy-mode DEPLOY_MODE Whether to launch the driver program locally ("client") or @@ -534,65 +525,6 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | delegation tokens periodically. """.stripMargin ) - - if (SparkSubmit.isSqlShell(mainClass)) { - outStream.println("CLI options:") - outStream.println(getSqlShellOptions()) - } - - SparkSubmit.exitFn(exitCode) + SparkSubmit.exitFn() } - - /** - * Run the Spark SQL CLI main class with the "--help" option and catch its output. Then filter - * the results to remove unwanted lines. - * - * Since the CLI will call `System.exit()`, we install a security manager to prevent that call - * from working, and restore the original one afterwards. - */ - private def getSqlShellOptions(): String = { - val currentOut = System.out - val currentErr = System.err - val currentSm = System.getSecurityManager() - try { - val out = new ByteArrayOutputStream() - val stream = new PrintStream(out) - System.setOut(stream) - System.setErr(stream) - - val sm = new SecurityManager() { - override def checkExit(status: Int): Unit = { - throw new SecurityException() - } - - override def checkPermission(perm: java.security.Permission): Unit = {} - } - System.setSecurityManager(sm) - - try { - Class.forName(mainClass).getMethod("main", classOf[Array[String]]) - .invoke(null, Array(HELP)) - } catch { - case e: InvocationTargetException => - // Ignore SecurityException, since we throw it above. - if (!e.getCause().isInstanceOf[SecurityException]) { - throw e - } - } - - stream.flush() - - // Get the output and discard any unnecessary lines from it. - Source.fromString(new String(out.toByteArray())).getLines - .filter { line => - !line.startsWith("log4j") && !line.startsWith("usage") - } - .mkString("\n") - } finally { - System.setSecurityManager(currentSm) - System.setOut(currentOut) - System.setErr(currentErr) - } - } - } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index 5f5e0fe1c34d7..298a8201960d1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -17,9 +17,6 @@ package org.apache.spark.deploy.history -import java.util.zip.ZipOutputStream - -import org.apache.spark.SparkException import org.apache.spark.ui.SparkUI private[spark] case class ApplicationAttemptInfo( @@ -65,12 +62,4 @@ private[history] abstract class ApplicationHistoryProvider { */ def getConfig(): Map[String, String] = Map() - /** - * Writes out the event logs to the output stream provided. The logs will be compressed into a - * single zip file and written out. - * @throws SparkException if the logs for the app id cannot be found. - */ - @throws(classOf[SparkException]) - def writeEventLogs(appId: String, attemptId: Option[String], zipStream: ZipOutputStream): Unit - } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 5427a88f32ffd..45c2be34c8680 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -17,18 +17,16 @@ package org.apache.spark.deploy.history -import java.io.{BufferedInputStream, FileNotFoundException, InputStream, IOException, OutputStream} +import java.io.{BufferedInputStream, FileNotFoundException, IOException, InputStream} import java.util.concurrent.{ExecutorService, Executors, TimeUnit} -import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.mutable -import com.google.common.io.ByteStreams import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.fs.permission.AccessControlException -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec import org.apache.spark.scheduler._ @@ -61,8 +59,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) .map { d => Utils.resolveURI(d).toString } .getOrElse(DEFAULT_LOG_DIR) - private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - private val fs = Utils.getHadoopFileSystem(logDir, hadoopConf) + private val fs = Utils.getHadoopFileSystem(logDir, SparkHadoopUtil.get.newConfiguration(conf)) // Used by check event thread and clean log thread. // Scheduled thread pool size must be one, otherwise it will have concurrent issues about fs @@ -222,58 +219,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - override def writeEventLogs( - appId: String, - attemptId: Option[String], - zipStream: ZipOutputStream): Unit = { - - /** - * This method compresses the files passed in, and writes the compressed data out into the - * [[OutputStream]] passed in. Each file is written as a new [[ZipEntry]] with its name being - * the name of the file being compressed. - */ - def zipFileToStream(file: Path, entryName: String, outputStream: ZipOutputStream): Unit = { - val fs = FileSystem.get(hadoopConf) - val inputStream = fs.open(file, 1 * 1024 * 1024) // 1MB Buffer - try { - outputStream.putNextEntry(new ZipEntry(entryName)) - ByteStreams.copy(inputStream, outputStream) - outputStream.closeEntry() - } finally { - inputStream.close() - } - } - - applications.get(appId) match { - case Some(appInfo) => - try { - // If no attempt is specified, or there is no attemptId for attempts, return all attempts - appInfo.attempts.filter { attempt => - attempt.attemptId.isEmpty || attemptId.isEmpty || attempt.attemptId.get == attemptId.get - }.foreach { attempt => - val logPath = new Path(logDir, attempt.logPath) - // If this is a legacy directory, then add the directory to the zipStream and add - // each file to that directory. - if (isLegacyLogDirectory(fs.getFileStatus(logPath))) { - val files = fs.listStatus(logPath) - zipStream.putNextEntry(new ZipEntry(attempt.logPath + "/")) - zipStream.closeEntry() - files.foreach { file => - val path = file.getPath - zipFileToStream(path, attempt.logPath + Path.SEPARATOR + path.getName, zipStream) - } - } else { - zipFileToStream(new Path(logDir, attempt.logPath), attempt.logPath, zipStream) - } - } - } finally { - zipStream.close() - } - case None => throw new SparkException(s"Logs for $appId not found.") - } - } - - /** * Replay the log files in the list and merge the list of old applications with new ones */ diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 10638afb74900..5a0eb585a9049 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -18,7 +18,6 @@ package org.apache.spark.deploy.history import java.util.NoSuchElementException -import java.util.zip.ZipOutputStream import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import com.google.common.cache._ @@ -174,13 +173,6 @@ class HistoryServer( getApplicationList().iterator.map(ApplicationsListResource.appHistoryInfoToPublicAppInfo) } - override def writeEventLogs( - appId: String, - attemptId: Option[String], - zipStream: ZipOutputStream): Unit = { - provider.writeEventLogs(appId, attemptId, zipStream) - } - /** * Returns the provider configuration to show in the listing page. * diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala index e8ef60bd5428a..be8560d10fc62 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala @@ -68,7 +68,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") retryHeaders, retryRow, Iterable.apply(driverState.description.retryState)) val content =

Driver state information for driver id {driverId}

- Back to Drivers + Back to Drivers

Driver state: {driverState.state}

diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 5a1d06eb87db9..dc2bee6f2bdca 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -17,8 +17,6 @@ package org.apache.spark.deploy.worker.ui -import java.io.File -import java.net.URI import javax.servlet.http.HttpServletRequest import scala.xml.Node @@ -137,13 +135,6 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with return ("Error: Log type must be one of " + supportedLogTypes.mkString(", "), 0, 0, 0) } - // Verify that the normalized path of the log directory is in the working directory - val normalizedUri = new URI(logDirectory).normalize() - val normalizedLogDir = new File(normalizedUri.getPath) - if (!Utils.isInDirectory(workDir, normalizedLogDir)) { - return ("Error: invalid log directory " + logDirectory, 0, 0, 0) - } - try { val files = RollingFileAppender.getSortedRolledOverFiles(logDirectory, logType) logDebug(s"Sorted log files of type $logType in $logDirectory:\n${files.mkString("\n")}") @@ -159,7 +150,7 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with offset } } - val endIndex = math.min(startIndex + byteLength, totalLength) + val endIndex = math.min(startIndex + totalLength, totalLength) logDebug(s"Getting log from $startIndex to $endIndex") val logText = Utils.offsetBytes(files, startIndex, endIndex) logDebug(s"Got log of length ${logText.length} bytes") diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 8ae76c5f72f2e..2ab41ba488ff6 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -43,5 +43,5 @@ package org.apache package object spark { // For package docs only - val SPARK_VERSION = "1.5.0-SNAPSHOT" + val SPARK_VERSION = "1.4.0-SNAPSHOT" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 82455b0426a5d..673cd0e19eba2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -46,7 +46,7 @@ import org.apache.spark.util.{Clock, SystemClock, Utils} * * @param sched the TaskSchedulerImpl associated with the TaskSetManager * @param taskSet the TaskSet to manage scheduling for - * @param maxTaskFailures if any particular task fails this number of times, the entire + * @param maxTaskFailures if any particular task fails more than this number of times, the entire * task set will be aborted */ private[spark] class TaskSetManager( diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 7c7f70d8a193b..fcad959540f5a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -103,7 +103,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp case None => // Ignoring the update since we don't know about the executor. logWarning(s"Ignored task status update ($taskId state $state) " + - s"from unknown executor with ID $executorId") + "from unknown executor $sender with ID $executorId") } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index 50b6ba67e9931..f73c742732dec 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -16,7 +16,6 @@ */ package org.apache.spark.status.api.v1 -import java.util.zip.ZipOutputStream import javax.servlet.ServletContext import javax.ws.rs._ import javax.ws.rs.core.{Context, Response} @@ -165,18 +164,6 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { } } - @Path("applications/{appId}/logs") - def getEventLogs( - @PathParam("appId") appId: String): EventLogDownloadResource = { - new EventLogDownloadResource(uiRoot, appId, None) - } - - @Path("applications/{appId}/{attemptId}/logs") - def getEventLogs( - @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): EventLogDownloadResource = { - new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) - } } private[spark] object ApiRootResource { @@ -206,17 +193,6 @@ private[spark] trait UIRoot { def getSparkUI(appKey: String): Option[SparkUI] def getApplicationInfoList: Iterator[ApplicationInfo] - /** - * Write the event logs for the given app to the [[ZipOutputStream]] instance. If attemptId is - * [[None]], event logs for all attempts of this application will be written out. - */ - def writeEventLogs(appId: String, attemptId: Option[String], zipStream: ZipOutputStream): Unit = { - Response.serverError() - .entity("Event logs are only available through the history server.") - .status(Response.Status.SERVICE_UNAVAILABLE) - .build() - } - /** * Get the spark UI with the given appID, and apply a function * to it. If there is no such app, throw an appropriate exception diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala deleted file mode 100644 index 22e21f0c62a29..0000000000000 --- a/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.status.api.v1 - -import java.io.OutputStream -import java.util.zip.ZipOutputStream -import javax.ws.rs.{GET, Produces} -import javax.ws.rs.core.{MediaType, Response, StreamingOutput} - -import scala.util.control.NonFatal - -import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.deploy.SparkHadoopUtil - -@Produces(Array(MediaType.APPLICATION_OCTET_STREAM)) -private[v1] class EventLogDownloadResource( - val uIRoot: UIRoot, - val appId: String, - val attemptId: Option[String]) extends Logging { - val conf = SparkHadoopUtil.get.newConfiguration(new SparkConf) - - @GET - def getEventLogs(): Response = { - try { - val fileName = { - attemptId match { - case Some(id) => s"eventLogs-$appId-$id.zip" - case None => s"eventLogs-$appId.zip" - } - } - - val stream = new StreamingOutput { - override def write(output: OutputStream): Unit = { - val zipStream = new ZipOutputStream(output) - try { - uIRoot.writeEventLogs(appId, attemptId, zipStream) - } finally { - zipStream.close() - } - - } - } - - Response.ok(stream) - .header("Content-Disposition", s"attachment; filename=$fileName") - .header("Content-Type", MediaType.APPLICATION_OCTET_STREAM) - .build() - } catch { - case NonFatal(e) => - Response.serverError() - .entity(s"Event logs are not available for app: $appId.") - .status(Response.Status.SERVICE_UNAVAILABLE) - .build() - } - } -} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 1d31fce4c697b..f39e961772c46 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -17,12 +17,8 @@ package org.apache.spark.ui.jobs -import java.util.concurrent.TimeoutException - import scala.collection.mutable.{HashMap, HashSet, ListBuffer} -import com.google.common.annotations.VisibleForTesting - import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics @@ -530,30 +526,4 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { override def onApplicationStart(appStarted: SparkListenerApplicationStart) { startTime = appStarted.time } - - /** - * For testing only. Wait until at least `numExecutors` executors are up, or throw - * `TimeoutException` if the waiting time elapsed before `numExecutors` executors up. - * - * @param numExecutors the number of executors to wait at least - * @param timeout time to wait in milliseconds - */ - @VisibleForTesting - private[spark] def waitUntilExecutorsUp(numExecutors: Int, timeout: Long): Unit = { - val finishTime = System.currentTimeMillis() + timeout - while (System.currentTimeMillis() < finishTime) { - val numBlockManagers = synchronized { - blockManagerIds.size - } - if (numBlockManagers >= numExecutors + 1) { - // Need to count the block manager in driver - return - } - // Sleep rather than using wait/notify, because this is used only for testing and wait/notify - // add overhead in the general case. - Thread.sleep(10) - } - throw new TimeoutException( - s"Can't find $numExecutors executors before $timeout milliseconds elapsed") - } } diff --git a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala index 61b5a4cecddce..1861d38640102 100644 --- a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala @@ -120,22 +120,21 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri /** * For testing only. Wait until there are no more events in the queue, or until the specified - * time has elapsed. Throw `TimeoutException` if the specified time elapsed before the queue - * emptied. + * time has elapsed. Return true if the queue has emptied and false is the specified time + * elapsed before the queue emptied. */ @VisibleForTesting - @throws(classOf[TimeoutException]) - def waitUntilEmpty(timeoutMillis: Long): Unit = { + def waitUntilEmpty(timeoutMillis: Int): Boolean = { val finishTime = System.currentTimeMillis + timeoutMillis while (!queueIsEmpty) { if (System.currentTimeMillis > finishTime) { - throw new TimeoutException( - s"The event queue is not empty after $timeoutMillis milliseconds") + return false } /* Sleep rather than using wait/notify, because this is used only for testing and * wait/notify add overhead in the general case. */ Thread.sleep(10) } + true } /** diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 5f132410540fd..693e1a0a3d5f0 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2227,22 +2227,6 @@ private[spark] object Utils extends Logging { } } - /** - * Return whether the specified file is a parent directory of the child file. - */ - def isInDirectory(parent: File, child: File): Boolean = { - if (child == null || parent == null) { - return false - } - if (!child.exists() || !parent.exists() || !parent.isDirectory()) { - return false - } - if (parent.equals(child)) { - return true - } - isInDirectory(parent, child.getParentFile) - } - } private [util] class SparkShutdownHookManager { diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 64e7102e3654c..1501111a06655 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -20,8 +20,6 @@ package org.apache.spark.util.collection import scala.reflect._ import com.google.common.hash.Hashing -import org.apache.spark.annotation.Private - /** * A simple, fast hash set optimized for non-null insertion-only use case, where keys are never * removed. @@ -39,7 +37,7 @@ import org.apache.spark.annotation.Private * It uses quadratic probing with a power-of-2 hash table size, which is guaranteed * to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing). */ -@Private +private[spark] class OpenHashSet[@specialized(Long, Int) T: ClassTag]( initialCapacity: Int, loadFactor: Double) @@ -112,14 +110,6 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( rehashIfNeeded(k, grow, move) } - def union(other: OpenHashSet[T]): OpenHashSet[T] = { - val iterator = other.iterator - while (iterator.hasNext) { - add(iterator.next()) - } - this - } - /** * Add an element to the set. This one differs from add in that it doesn't trigger rehashing. * The caller is responsible for calling rehashIfNeeded. diff --git a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json index d575bf2f284b9..ce4fe80b66aa5 100644 --- a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json @@ -7,22 +7,6 @@ "sparkUser" : "irashid", "completed" : true } ] -}, { - "id" : "local-1430917381535", - "name" : "Spark shell", - "attempts" : [ { - "attemptId" : "2", - "startTime" : "2015-05-06T13:03:00.893GMT", - "endTime" : "2015-05-06T13:03:00.950GMT", - "sparkUser" : "irashid", - "completed" : true - }, { - "attemptId" : "1", - "startTime" : "2015-05-06T13:03:00.880GMT", - "endTime" : "2015-05-06T13:03:00.890GMT", - "sparkUser" : "irashid", - "completed" : true - } ] }, { "id" : "local-1426533911241", "name" : "Spark shell", diff --git a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json index d575bf2f284b9..ce4fe80b66aa5 100644 --- a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json @@ -7,22 +7,6 @@ "sparkUser" : "irashid", "completed" : true } ] -}, { - "id" : "local-1430917381535", - "name" : "Spark shell", - "attempts" : [ { - "attemptId" : "2", - "startTime" : "2015-05-06T13:03:00.893GMT", - "endTime" : "2015-05-06T13:03:00.950GMT", - "sparkUser" : "irashid", - "completed" : true - }, { - "attemptId" : "1", - "startTime" : "2015-05-06T13:03:00.880GMT", - "endTime" : "2015-05-06T13:03:00.890GMT", - "sparkUser" : "irashid", - "completed" : true - } ] }, { "id" : "local-1426533911241", "name" : "Spark shell", diff --git a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json index 15c2de8ef99ea..dca86fe5f7e6a 100644 --- a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json @@ -7,22 +7,6 @@ "sparkUser" : "irashid", "completed" : true } ] -}, { - "id" : "local-1430917381535", - "name" : "Spark shell", - "attempts" : [ { - "attemptId" : "2", - "startTime" : "2015-05-06T13:03:00.893GMT", - "endTime" : "2015-05-06T13:03:00.950GMT", - "sparkUser" : "irashid", - "completed" : true - }, { - "attemptId" : "1", - "startTime" : "2015-05-06T13:03:00.880GMT", - "endTime" : "2015-05-06T13:03:00.890GMT", - "sparkUser" : "irashid", - "completed" : true - } ] }, { "id" : "local-1426533911241", "name" : "Spark shell", @@ -40,14 +24,12 @@ "completed" : true } ] }, { - "id": "local-1425081759269", - "name": "Spark shell", - "attempts": [ - { - "startTime": "2015-02-28T00:02:38.277GMT", - "endTime": "2015-02-28T00:02:46.912GMT", - "sparkUser": "irashid", - "completed": true - } - ] + "id" : "local-1425081759269", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2015-02-28T00:02:38.277GMT", + "endTime" : "2015-02-28T00:02:46.912GMT", + "sparkUser" : "irashid", + "completed" : true + } ] } ] \ No newline at end of file diff --git a/core/src/test/resources/spark-events/local-1430917381535_1 b/core/src/test/resources/spark-events/local-1430917381535_1 deleted file mode 100644 index d5a1303344825..0000000000000 --- a/core/src/test/resources/spark-events/local-1430917381535_1 +++ /dev/null @@ -1,5 +0,0 @@ -{"Event":"SparkListenerLogStart","Spark Version":"1.4.0-SNAPSHOT"} -{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"driver","Host":"localhost","Port":61103},"Maximum Memory":278019440,"Timestamp":1430917380880} -{"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"Java Home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre","Java Version":"1.8.0_25 (Oracle Corporation)","Scala Version":"version 2.10.4"},"Spark Properties":{"spark.driver.host":"192.168.1.102","spark.eventLog.enabled":"true","spark.driver.port":"61101","spark.repl.class.uri":"http://192.168.1.102:61100","spark.jars":"","spark.app.name":"Spark shell","spark.scheduler.mode":"FIFO","spark.executor.id":"driver","spark.master":"local[*]","spark.eventLog.dir":"/Users/irashid/github/kraps/core/src/test/resources/spark-events","spark.fileserver.uri":"http://192.168.1.102:61102","spark.tachyonStore.folderName":"spark-aaaf41b3-d1dd-447f-8951-acf51490758b","spark.app.id":"local-1430917381534"},"System Properties":{"java.io.tmpdir":"/var/folders/36/m29jw1z95qv4ywb1c4n0rz000000gp/T/","line.separator":"\n","path.separator":":","sun.management.compiler":"HotSpot 64-Bit Tiered Compilers","SPARK_SUBMIT":"true","sun.cpu.endian":"little","java.specification.version":"1.8","java.vm.specification.name":"Java Virtual Machine Specification","java.vendor":"Oracle Corporation","java.vm.specification.version":"1.8","user.home":"/Users/irashid","file.encoding.pkg":"sun.io","sun.nio.ch.bugLevel":"","ftp.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","sun.arch.data.model":"64","sun.boot.library.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib","user.dir":"/Users/irashid/github/spark","java.library.path":"/Users/irashid/Library/Java/Extensions:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java:.","sun.cpu.isalist":"","os.arch":"x86_64","java.vm.version":"25.25-b02","java.endorsed.dirs":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/endorsed","java.runtime.version":"1.8.0_25-b17","java.vm.info":"mixed mode","java.ext.dirs":"/Users/irashid/Library/Java/Extensions:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/ext:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java","java.runtime.name":"Java(TM) SE Runtime Environment","file.separator":"/","java.class.version":"52.0","scala.usejavacp":"true","java.specification.name":"Java Platform API Specification","sun.boot.class.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/resources.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/rt.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/sunrsasign.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/jsse.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/jce.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/charsets.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/jfr.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/classes","file.encoding":"UTF-8","user.timezone":"America/Chicago","java.specification.vendor":"Oracle Corporation","sun.java.launcher":"SUN_STANDARD","os.version":"10.9.5","sun.os.patch.level":"unknown","gopherProxySet":"false","java.vm.specification.vendor":"Oracle Corporation","user.country":"US","sun.jnu.encoding":"UTF-8","http.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","user.language":"en","socksNonProxyHosts":"local|*.local|169.254/16|*.169.254/16","java.vendor.url":"http://java.oracle.com/","java.awt.printerjob":"sun.lwawt.macosx.CPrinterJob","java.awt.graphicsenv":"sun.awt.CGraphicsEnvironment","awt.toolkit":"sun.lwawt.macosx.LWCToolkit","os.name":"Mac OS X","java.vm.vendor":"Oracle Corporation","java.vendor.url.bug":"http://bugreport.sun.com/bugreport/","user.name":"irashid","java.vm.name":"Java HotSpot(TM) 64-Bit Server VM","sun.java.command":"org.apache.spark.deploy.SparkSubmit --conf spark.eventLog.enabled=true --conf spark.eventLog.dir=/Users/irashid/github/kraps/core/src/test/resources/spark-events --class org.apache.spark.repl.Main spark-shell","java.home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre","java.version":"1.8.0_25","sun.io.unicode.encoding":"UnicodeBig"},"Classpath Entries":{"/etc/hadoop":"System Classpath","/Users/irashid/github/spark/lib_managed/jars/datanucleus-rdbms-3.2.9.jar":"System Classpath","/Users/irashid/github/spark/conf/":"System Classpath","/Users/irashid/github/spark/assembly/target/scala-2.10/spark-assembly-1.4.0-SNAPSHOT-hadoop2.5.0.jar":"System Classpath","/Users/irashid/github/spark/lib_managed/jars/datanucleus-core-3.2.10.jar":"System Classpath","/Users/irashid/github/spark/lib_managed/jars/datanucleus-api-jdo-3.2.6.jar":"System Classpath"}} -{"Event":"SparkListenerApplicationStart","App Name":"Spark shell","App ID":"local-1430917381535","Timestamp":1430917380880,"User":"irashid","App Attempt ID":"1"} -{"Event":"SparkListenerApplicationEnd","Timestamp":1430917380890} \ No newline at end of file diff --git a/core/src/test/resources/spark-events/local-1430917381535_2 b/core/src/test/resources/spark-events/local-1430917381535_2 deleted file mode 100644 index abb637a22e1e3..0000000000000 --- a/core/src/test/resources/spark-events/local-1430917381535_2 +++ /dev/null @@ -1,5 +0,0 @@ -{"Event":"SparkListenerLogStart","Spark Version":"1.4.0-SNAPSHOT"} -{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"driver","Host":"localhost","Port":61103},"Maximum Memory":278019440,"Timestamp":1430917380893} -{"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"Java Home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre","Java Version":"1.8.0_25 (Oracle Corporation)","Scala Version":"version 2.10.4"},"Spark Properties":{"spark.driver.host":"192.168.1.102","spark.eventLog.enabled":"true","spark.driver.port":"61101","spark.repl.class.uri":"http://192.168.1.102:61100","spark.jars":"","spark.app.name":"Spark shell","spark.scheduler.mode":"FIFO","spark.executor.id":"driver","spark.master":"local[*]","spark.eventLog.dir":"/Users/irashid/github/kraps/core/src/test/resources/spark-events","spark.fileserver.uri":"http://192.168.1.102:61102","spark.tachyonStore.folderName":"spark-aaaf41b3-d1dd-447f-8951-acf51490758b","spark.app.id":"local-1430917381534"},"System Properties":{"java.io.tmpdir":"/var/folders/36/m29jw1z95qv4ywb1c4n0rz000000gp/T/","line.separator":"\n","path.separator":":","sun.management.compiler":"HotSpot 64-Bit Tiered Compilers","SPARK_SUBMIT":"true","sun.cpu.endian":"little","java.specification.version":"1.8","java.vm.specification.name":"Java Virtual Machine Specification","java.vendor":"Oracle Corporation","java.vm.specification.version":"1.8","user.home":"/Users/irashid","file.encoding.pkg":"sun.io","sun.nio.ch.bugLevel":"","ftp.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","sun.arch.data.model":"64","sun.boot.library.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib","user.dir":"/Users/irashid/github/spark","java.library.path":"/Users/irashid/Library/Java/Extensions:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java:.","sun.cpu.isalist":"","os.arch":"x86_64","java.vm.version":"25.25-b02","java.endorsed.dirs":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/endorsed","java.runtime.version":"1.8.0_25-b17","java.vm.info":"mixed mode","java.ext.dirs":"/Users/irashid/Library/Java/Extensions:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/ext:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java","java.runtime.name":"Java(TM) SE Runtime Environment","file.separator":"/","java.class.version":"52.0","scala.usejavacp":"true","java.specification.name":"Java Platform API Specification","sun.boot.class.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/resources.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/rt.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/sunrsasign.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/jsse.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/jce.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/charsets.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/lib/jfr.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre/classes","file.encoding":"UTF-8","user.timezone":"America/Chicago","java.specification.vendor":"Oracle Corporation","sun.java.launcher":"SUN_STANDARD","os.version":"10.9.5","sun.os.patch.level":"unknown","gopherProxySet":"false","java.vm.specification.vendor":"Oracle Corporation","user.country":"US","sun.jnu.encoding":"UTF-8","http.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","user.language":"en","socksNonProxyHosts":"local|*.local|169.254/16|*.169.254/16","java.vendor.url":"http://java.oracle.com/","java.awt.printerjob":"sun.lwawt.macosx.CPrinterJob","java.awt.graphicsenv":"sun.awt.CGraphicsEnvironment","awt.toolkit":"sun.lwawt.macosx.LWCToolkit","os.name":"Mac OS X","java.vm.vendor":"Oracle Corporation","java.vendor.url.bug":"http://bugreport.sun.com/bugreport/","user.name":"irashid","java.vm.name":"Java HotSpot(TM) 64-Bit Server VM","sun.java.command":"org.apache.spark.deploy.SparkSubmit --conf spark.eventLog.enabled=true --conf spark.eventLog.dir=/Users/irashid/github/kraps/core/src/test/resources/spark-events --class org.apache.spark.repl.Main spark-shell","java.home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home/jre","java.version":"1.8.0_25","sun.io.unicode.encoding":"UnicodeBig"},"Classpath Entries":{"/etc/hadoop":"System Classpath","/Users/irashid/github/spark/lib_managed/jars/datanucleus-rdbms-3.2.9.jar":"System Classpath","/Users/irashid/github/spark/conf/":"System Classpath","/Users/irashid/github/spark/assembly/target/scala-2.10/spark-assembly-1.4.0-SNAPSHOT-hadoop2.5.0.jar":"System Classpath","/Users/irashid/github/spark/lib_managed/jars/datanucleus-core-3.2.10.jar":"System Classpath","/Users/irashid/github/spark/lib_managed/jars/datanucleus-api-jdo-3.2.6.jar":"System Classpath"}} -{"Event":"SparkListenerApplicationStart","App Name":"Spark shell","App ID":"local-1430917381535","Timestamp":1430917380893,"User":"irashid","App Attempt ID":"2"} -{"Event":"SparkListenerApplicationEnd","Timestamp":1430917380950} \ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 803e1831bb269..1c2b681f0b843 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -90,7 +90,7 @@ class ExecutorAllocationManagerSuite } test("add executors") { - sc = createSparkContext(1, 10, 1) + sc = createSparkContext(1, 10) val manager = sc.executorAllocationManager.get sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) @@ -135,7 +135,7 @@ class ExecutorAllocationManagerSuite } test("add executors capped by num pending tasks") { - sc = createSparkContext(0, 10, 0) + sc = createSparkContext(0, 10) val manager = sc.executorAllocationManager.get sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 5))) @@ -186,7 +186,7 @@ class ExecutorAllocationManagerSuite } test("cancel pending executors when no longer needed") { - sc = createSparkContext(0, 10, 0) + sc = createSparkContext(0, 10) val manager = sc.executorAllocationManager.get sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(2, 5))) @@ -213,7 +213,7 @@ class ExecutorAllocationManagerSuite } test("remove executors") { - sc = createSparkContext(5, 10, 5) + sc = createSparkContext(5, 10) val manager = sc.executorAllocationManager.get (1 to 10).map(_.toString).foreach { id => onExecutorAdded(manager, id) } @@ -263,7 +263,7 @@ class ExecutorAllocationManagerSuite } test ("interleaving add and remove") { - sc = createSparkContext(5, 10, 5) + sc = createSparkContext(5, 10) val manager = sc.executorAllocationManager.get sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) @@ -331,7 +331,7 @@ class ExecutorAllocationManagerSuite } test("starting/canceling add timer") { - sc = createSparkContext(2, 10, 2) + sc = createSparkContext(2, 10) val clock = new ManualClock(8888L) val manager = sc.executorAllocationManager.get manager.setClock(clock) @@ -363,7 +363,7 @@ class ExecutorAllocationManagerSuite } test("starting/canceling remove timers") { - sc = createSparkContext(2, 10, 2) + sc = createSparkContext(2, 10) val clock = new ManualClock(14444L) val manager = sc.executorAllocationManager.get manager.setClock(clock) @@ -410,7 +410,7 @@ class ExecutorAllocationManagerSuite } test("mock polling loop with no events") { - sc = createSparkContext(0, 20, 0) + sc = createSparkContext(0, 20) val manager = sc.executorAllocationManager.get val clock = new ManualClock(2020L) manager.setClock(clock) @@ -436,7 +436,7 @@ class ExecutorAllocationManagerSuite } test("mock polling loop add behavior") { - sc = createSparkContext(0, 20, 0) + sc = createSparkContext(0, 20) val clock = new ManualClock(2020L) val manager = sc.executorAllocationManager.get manager.setClock(clock) @@ -486,7 +486,7 @@ class ExecutorAllocationManagerSuite } test("mock polling loop remove behavior") { - sc = createSparkContext(1, 20, 1) + sc = createSparkContext(1, 20) val clock = new ManualClock(2020L) val manager = sc.executorAllocationManager.get manager.setClock(clock) @@ -547,7 +547,7 @@ class ExecutorAllocationManagerSuite } test("listeners trigger add executors correctly") { - sc = createSparkContext(2, 10, 2) + sc = createSparkContext(2, 10) val manager = sc.executorAllocationManager.get assert(addTime(manager) === NOT_SET) @@ -577,7 +577,7 @@ class ExecutorAllocationManagerSuite } test("listeners trigger remove executors correctly") { - sc = createSparkContext(2, 10, 2) + sc = createSparkContext(2, 10) val manager = sc.executorAllocationManager.get assert(removeTimes(manager).isEmpty) @@ -608,7 +608,7 @@ class ExecutorAllocationManagerSuite } test("listeners trigger add and remove executor callbacks correctly") { - sc = createSparkContext(2, 10, 2) + sc = createSparkContext(2, 10) val manager = sc.executorAllocationManager.get assert(executorIds(manager).isEmpty) assert(removeTimes(manager).isEmpty) @@ -641,7 +641,7 @@ class ExecutorAllocationManagerSuite } test("SPARK-4951: call onTaskStart before onBlockManagerAdded") { - sc = createSparkContext(2, 10, 2) + sc = createSparkContext(2, 10) val manager = sc.executorAllocationManager.get assert(executorIds(manager).isEmpty) assert(removeTimes(manager).isEmpty) @@ -677,7 +677,7 @@ class ExecutorAllocationManagerSuite } test("avoid ramp up when target < running executors") { - sc = createSparkContext(0, 100000, 0) + sc = createSparkContext(0, 100000) val manager = sc.executorAllocationManager.get val stage1 = createStageInfo(0, 1000) sc.listenerBus.postToAll(SparkListenerStageSubmitted(stage1)) @@ -701,67 +701,13 @@ class ExecutorAllocationManagerSuite assert(numExecutorsTarget(manager) === 16) } - test("avoid ramp down initial executors until first job is submitted") { - sc = createSparkContext(2, 5, 3) - val manager = sc.executorAllocationManager.get - val clock = new ManualClock(10000L) - manager.setClock(clock) - - // Verify the initial number of executors - assert(numExecutorsTarget(manager) === 3) - schedule(manager) - // Verify whether the initial number of executors is kept with no pending tasks - assert(numExecutorsTarget(manager) === 3) - - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 2))) - clock.advance(100L) - - assert(maxNumExecutorsNeeded(manager) === 2) - schedule(manager) - - // Verify that current number of executors should be ramp down when first job is submitted - assert(numExecutorsTarget(manager) === 2) - } - - test("avoid ramp down initial executors until idle executor is timeout") { - sc = createSparkContext(2, 5, 3) - val manager = sc.executorAllocationManager.get - val clock = new ManualClock(10000L) - manager.setClock(clock) - - // Verify the initial number of executors - assert(numExecutorsTarget(manager) === 3) - schedule(manager) - // Verify the initial number of executors is kept when no pending tasks - assert(numExecutorsTarget(manager) === 3) - (0 until 3).foreach { i => - onExecutorAdded(manager, s"executor-$i") - } - - clock.advance(executorIdleTimeout * 1000) - - assert(maxNumExecutorsNeeded(manager) === 0) - schedule(manager) - // Verify executor is timeout but numExecutorsTarget is not recalculated - assert(numExecutorsTarget(manager) === 3) - - // Schedule again to recalculate the numExecutorsTarget after executor is timeout - schedule(manager) - // Verify that current number of executors should be ramp down when executor is timeout - assert(numExecutorsTarget(manager) === 2) - } - - private def createSparkContext( - minExecutors: Int = 1, - maxExecutors: Int = 5, - initialExecutors: Int = 1): SparkContext = { + private def createSparkContext(minExecutors: Int = 1, maxExecutors: Int = 5): SparkContext = { val conf = new SparkConf() .setMaster("local") .setAppName("test-executor-allocation-manager") .set("spark.dynamicAllocation.enabled", "true") .set("spark.dynamicAllocation.minExecutors", minExecutors.toString) .set("spark.dynamicAllocation.maxExecutors", maxExecutors.toString) - .set("spark.dynamicAllocation.initialExecutors", initialExecutors.toString) .set("spark.dynamicAllocation.schedulerBacklogTimeout", s"${schedulerBacklogTimeout.toString}s") .set("spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", @@ -845,10 +791,6 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { manager invokePrivate _schedule() } - private def maxNumExecutorsNeeded(manager: ExecutorAllocationManager): Int = { - manager invokePrivate _maxNumExecutorsNeeded() - } - private def addExecutors(manager: ExecutorAllocationManager): Int = { val maxNumExecutorsNeeded = manager invokePrivate _maxNumExecutorsNeeded() manager invokePrivate _addExecutors(maxNumExecutorsNeeded) diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 140012226fdbb..bac6fdbcdc976 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -55,14 +55,6 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { sc.env.blockManager.externalShuffleServiceEnabled should equal(true) sc.env.blockManager.shuffleClient.getClass should equal(classOf[ExternalShuffleClient]) - // In a slow machine, one slave may register hundreds of milliseconds ahead of the other one. - // If we don't wait for all slaves, it's possible that only one executor runs all jobs. Then - // all shuffle blocks will be in this executor, ShuffleBlockFetcherIterator will directly fetch - // local blocks from the local BlockManager and won't send requests to ExternalShuffleService. - // In this case, we won't receive FetchFailed. And it will make this test fail. - // Therefore, we should wait until all slaves are up - sc.jobProgressListener.waitUntilExecutorsUp(2, 10000) - val rdd = sc.parallelize(0 until 1000, 10).map(i => (i, 1)).reduceByKey(_ + _) rdd.count() diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index c054c718075f8..c05e8bb6538ba 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.broadcast +import scala.concurrent.duration._ import scala.util.Random import org.scalatest.Assertions +import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.io.SnappyCompressionCodec @@ -310,7 +312,13 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { val _sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf) // Wait until all salves are up - _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 10000) + eventually(timeout(10.seconds), interval(10.milliseconds)) { + _sc.jobProgressListener.synchronized { + val numBlockManagers = _sc.jobProgressListener.blockManagerIds.size + assert(numBlockManagers == numSlaves + 1, + s"Expect ${numSlaves + 1} block managers, but was ${numBlockManagers}") + } + } _sc } else { new SparkContext("local", "test", broadcastConf) diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala index ddc92814c0acf..c215b0582889f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -41,7 +41,7 @@ class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { // Trigger a job so that executors get added sc.parallelize(1 to 100, 4).map(_.toString).count() - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) listener.addedExecutorInfos.values.foreach { info => assert(info.logUrlMap.nonEmpty) // Browse to each URL to check that it's valid @@ -71,7 +71,7 @@ class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { // Trigger a job so that executors get added sc.parallelize(1 to 100, 4).map(_.toString).count() - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) val listeners = sc.listenerBus.findListenersByClass[SaveExecutorInfo] assert(listeners.size === 1) val listener = listeners(0) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 46ea28d0f18f6..46369457f000a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -62,7 +62,7 @@ class SparkSubmitSuite SparkSubmit.printStream = printStream @volatile var exitedCleanly = false - SparkSubmit.exitFn = (_) => exitedCleanly = true + SparkSubmit.exitFn = () => exitedCleanly = true val thread = new Thread { override def run() = try { diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 09075eeb539aa..0f6933df9e6bc 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -17,16 +17,12 @@ package org.apache.spark.deploy.history -import java.io.{BufferedOutputStream, ByteArrayInputStream, ByteArrayOutputStream, File, - FileOutputStream, OutputStreamWriter} +import java.io.{BufferedOutputStream, File, FileOutputStream, OutputStreamWriter} import java.net.URI import java.util.concurrent.TimeUnit -import java.util.zip.{ZipInputStream, ZipOutputStream} import scala.io.Source -import com.google.common.base.Charsets -import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.fs.Path import org.json4s.jackson.JsonMethods._ import org.scalatest.BeforeAndAfter @@ -339,40 +335,6 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc assert(!log2.exists()) } - test("Event log copy") { - val provider = new FsHistoryProvider(createTestConf()) - val logs = (1 to 2).map { i => - val log = newLogFile("downloadApp1", Some(s"attempt$i"), inProgress = false) - writeFile(log, true, None, - SparkListenerApplicationStart( - "downloadApp1", Some("downloadApp1"), 5000 * i, "test", Some(s"attempt$i")), - SparkListenerApplicationEnd(5001 * i) - ) - log - } - provider.checkForLogs() - - (1 to 2).foreach { i => - val underlyingStream = new ByteArrayOutputStream() - val outputStream = new ZipOutputStream(underlyingStream) - provider.writeEventLogs("downloadApp1", Some(s"attempt$i"), outputStream) - outputStream.close() - val inputStream = new ZipInputStream(new ByteArrayInputStream(underlyingStream.toByteArray)) - var totalEntries = 0 - var entry = inputStream.getNextEntry - entry should not be null - while (entry != null) { - val actual = new String(ByteStreams.toByteArray(inputStream), Charsets.UTF_8) - val expected = Files.toString(logs.find(_.getName == entry.getName).get, Charsets.UTF_8) - actual should be (expected) - totalEntries += 1 - entry = inputStream.getNextEntry - } - totalEntries should be (1) - inputStream.close() - } - } - /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index e5b5e1bb65337..14f2d1a5894b8 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -16,13 +16,10 @@ */ package org.apache.spark.deploy.history -import java.io.{File, FileInputStream, FileWriter, InputStream, IOException} +import java.io.{File, FileInputStream, FileWriter, IOException} import java.net.{HttpURLConnection, URL} -import java.util.zip.ZipInputStream import javax.servlet.http.{HttpServletRequest, HttpServletResponse} -import com.google.common.base.Charsets -import com.google.common.io.{ByteStreams, Files} import org.apache.commons.io.{FileUtils, IOUtils} import org.mockito.Mockito.when import org.scalatest.{BeforeAndAfter, Matchers} @@ -150,70 +147,6 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers } } - test("download all logs for app with multiple attempts") { - doDownloadTest("local-1430917381535", None) - } - - test("download one log for app with multiple attempts") { - (1 to 2).foreach { attemptId => doDownloadTest("local-1430917381535", Some(attemptId)) } - } - - test("download legacy logs - all attempts") { - doDownloadTest("local-1426533911241", None, legacy = true) - } - - test("download legacy logs - single attempts") { - (1 to 2). foreach { - attemptId => doDownloadTest("local-1426533911241", Some(attemptId), legacy = true) - } - } - - // Test that the files are downloaded correctly, and validate them. - def doDownloadTest(appId: String, attemptId: Option[Int], legacy: Boolean = false): Unit = { - - val url = attemptId match { - case Some(id) => - new URL(s"${generateURL(s"applications/$appId")}/$id/logs") - case None => - new URL(s"${generateURL(s"applications/$appId")}/logs") - } - - val (code, inputStream, error) = HistoryServerSuite.connectAndGetInputStream(url) - code should be (HttpServletResponse.SC_OK) - inputStream should not be None - error should be (None) - - val zipStream = new ZipInputStream(inputStream.get) - var entry = zipStream.getNextEntry - entry should not be null - val totalFiles = { - if (legacy) { - attemptId.map { x => 3 }.getOrElse(6) - } else { - attemptId.map { x => 1 }.getOrElse(2) - } - } - var filesCompared = 0 - while (entry != null) { - if (!entry.isDirectory) { - val expectedFile = { - if (legacy) { - val splits = entry.getName.split("/") - new File(new File(logDir, splits(0)), splits(1)) - } else { - new File(logDir, entry.getName) - } - } - val expected = Files.toString(expectedFile, Charsets.UTF_8) - val actual = new String(ByteStreams.toByteArray(zipStream), Charsets.UTF_8) - actual should be (expected) - filesCompared += 1 - } - entry = zipStream.getNextEntry - } - filesCompared should be (totalFiles) - } - test("response codes on bad paths") { val badAppId = getContentAndCode("applications/foobar") badAppId._1 should be (HttpServletResponse.SC_NOT_FOUND) @@ -269,11 +202,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers } def getUrl(path: String): String = { - HistoryServerSuite.getUrl(generateURL(path)) - } - - def generateURL(path: String): URL = { - new URL(s"http://localhost:$port/api/v1/$path") + HistoryServerSuite.getUrl(new URL(s"http://localhost:$port/api/v1/$path")) } def generateExpectation(name: String, path: String): Unit = { @@ -304,18 +233,13 @@ object HistoryServerSuite { } def getContentAndCode(url: URL): (Int, Option[String], Option[String]) = { - val (code, in, errString) = connectAndGetInputStream(url) - val inString = in.map(IOUtils.toString) - (code, inString, errString) - } - - def connectAndGetInputStream(url: URL): (Int, Option[InputStream], Option[String]) = { val connection = url.openConnection().asInstanceOf[HttpURLConnection] connection.setRequestMethod("GET") connection.connect() val code = connection.getResponseCode() - val inStream = try { - Option(connection.getInputStream()) + val inString = try { + val in = Option(connection.getInputStream()) + in.map(IOUtils.toString) } catch { case io: IOException => None } @@ -325,7 +249,7 @@ object HistoryServerSuite { } catch { case io: IOException => None } - (code, inStream, errString) + (code, inString, errString) } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ui/LogPageSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ui/LogPageSuite.scala index 72eaffb416981..572360ddb95d4 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ui/LogPageSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ui/LogPageSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.worker.ui import java.io.{File, FileWriter} -import org.mockito.Mockito.{mock, when} +import org.mockito.Mockito.mock import org.scalatest.PrivateMethodTester import org.apache.spark.SparkFunSuite @@ -28,47 +28,33 @@ class LogPageSuite extends SparkFunSuite with PrivateMethodTester { test("get logs simple") { val webui = mock(classOf[WorkerWebUI]) - val tmpDir = new File(sys.props("java.io.tmpdir")) - val workDir = new File(tmpDir, "work-dir") - workDir.mkdir() - when(webui.workDir).thenReturn(workDir) val logPage = new LogPage(webui) // Prepare some fake log files to read later val out = "some stdout here" val err = "some stderr here" - val tmpOut = new File(workDir, "stdout") - val tmpErr = new File(workDir, "stderr") - val tmpErrBad = new File(tmpDir, "stderr") // outside the working directory - val tmpOutBad = new File(tmpDir, "stdout") - val tmpRand = new File(workDir, "random") + val tmpDir = new File(sys.props("java.io.tmpdir")) + val tmpOut = new File(tmpDir, "stdout") + val tmpErr = new File(tmpDir, "stderr") + val tmpRand = new File(tmpDir, "random") write(tmpOut, out) write(tmpErr, err) - write(tmpOutBad, out) - write(tmpErrBad, err) write(tmpRand, "1 6 4 5 2 7 8") // Get the logs. All log types other than "stderr" or "stdout" will be rejected val getLog = PrivateMethod[(String, Long, Long, Long)]('getLog) val (stdout, _, _, _) = - logPage invokePrivate getLog(workDir.getAbsolutePath, "stdout", None, 100) + logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100) val (stderr, _, _, _) = - logPage invokePrivate getLog(workDir.getAbsolutePath, "stderr", None, 100) + logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100) val (error1, _, _, _) = - logPage invokePrivate getLog(workDir.getAbsolutePath, "random", None, 100) + logPage invokePrivate getLog(tmpDir.getAbsolutePath, "random", None, 100) val (error2, _, _, _) = - logPage invokePrivate getLog(workDir.getAbsolutePath, "does-not-exist.txt", None, 100) - // These files exist, but live outside the working directory - val (error3, _, _, _) = - logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100) - val (error4, _, _, _) = - logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100) + logPage invokePrivate getLog(tmpDir.getAbsolutePath, "does-not-exist.txt", None, 100) assert(stdout === out) assert(stderr === err) - assert(error1.startsWith("Error: Log type must be one of ")) - assert(error2.startsWith("Error: Log type must be one of ")) - assert(error3.startsWith("Error: invalid log directory")) - assert(error4.startsWith("Error: invalid log directory")) + assert(error1.startsWith("Error")) + assert(error2.startsWith("Error")) } /** Write the specified string to the file. */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 47b2868753c0e..bfcf918e06162 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -254,7 +254,7 @@ class DAGSchedulerSuite test("[SPARK-3353] parent stage should have lower stage id") { sparkListener.stageByOrderOfExecution.clear() sc.parallelize(1 to 10).map(x => (x, x)).reduceByKey(_ + _, 4).count() - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.stageByOrderOfExecution.length === 2) assert(sparkListener.stageByOrderOfExecution(0) < sparkListener.stageByOrderOfExecution(1)) } @@ -389,7 +389,7 @@ class DAGSchedulerSuite submit(unserializableRdd, Array(0)) assert(failure.getMessage.startsWith( "Job aborted due to stage failure: Task not serializable:")) - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.contains(0)) assert(sparkListener.failedStages.size === 1) assertDataStructuresEmpty() @@ -399,7 +399,7 @@ class DAGSchedulerSuite submit(new MyRDD(sc, 1, Nil), Array(0)) failed(taskSets(0), "some failure") assert(failure.getMessage === "Job aborted due to stage failure: some failure") - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.contains(0)) assert(sparkListener.failedStages.size === 1) assertDataStructuresEmpty() @@ -410,7 +410,7 @@ class DAGSchedulerSuite val jobId = submit(rdd, Array(0)) cancel(jobId) assert(failure.getMessage === s"Job $jobId cancelled ") - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.contains(0)) assert(sparkListener.failedStages.size === 1) assertDataStructuresEmpty() @@ -462,7 +462,7 @@ class DAGSchedulerSuite assert(results === Map(0 -> 42)) assertDataStructuresEmpty() - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.isEmpty) assert(sparkListener.successfulStages.contains(0)) } @@ -531,7 +531,7 @@ class DAGSchedulerSuite Map[Long, Any](), createFakeTaskInfo(), null)) - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.contains(1)) // The second ResultTask fails, with a fetch failure for the output from the second mapper. @@ -543,7 +543,7 @@ class DAGSchedulerSuite createFakeTaskInfo(), null)) // The SparkListener should not receive redundant failure events. - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.size == 1) } @@ -592,7 +592,7 @@ class DAGSchedulerSuite // Listener bus should get told about the map stage failing, but not the reduce stage // (since the reduce stage hasn't been started yet). - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.toSet === Set(0)) assertDataStructuresEmpty() @@ -643,7 +643,7 @@ class DAGSchedulerSuite assert(cancelledStages.toSet === Set(0, 2)) // Make sure the listeners got told about both failed stages. - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.successfulStages.isEmpty) assert(sparkListener.failedStages.toSet === Set(0, 2)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 651295b7344c5..06fb909bf5419 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -47,7 +47,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match // Starting listener bus should flush all buffered events bus.start(sc) - bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(counter.count === 5) // After listener bus has stopped, posting events should not increment counter @@ -131,7 +131,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match rdd2.setName("Target RDD") rdd2.count() - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) listener.stageInfos.size should be {1} val (stageInfo, taskInfoMetrics) = listener.stageInfos.head @@ -156,7 +156,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match rdd3.setName("Trois") rdd1.count() - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) listener.stageInfos.size should be {1} val stageInfo1 = listener.stageInfos.keys.find(_.stageId == 0).get stageInfo1.rddInfos.size should be {1} // ParallelCollectionRDD @@ -165,7 +165,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match listener.stageInfos.clear() rdd2.count() - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) listener.stageInfos.size should be {1} val stageInfo2 = listener.stageInfos.keys.find(_.stageId == 1).get stageInfo2.rddInfos.size should be {3} // ParallelCollectionRDD, FilteredRDD, MappedRDD @@ -174,7 +174,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match listener.stageInfos.clear() rdd3.count() - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) listener.stageInfos.size should be {2} // Shuffle map stage + result stage val stageInfo3 = listener.stageInfos.keys.find(_.stageId == 3).get stageInfo3.rddInfos.size should be {1} // ShuffledRDD @@ -190,7 +190,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val rdd2 = rdd1.map(_.toString) sc.runJob(rdd2, (items: Iterator[String]) => items.size, Seq(0, 1), true) - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) listener.stageInfos.size should be {1} val (stageInfo, _) = listener.stageInfos.head @@ -214,7 +214,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val d = sc.parallelize(0 to 1e4.toInt, 64).map(w) d.count() - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) listener.stageInfos.size should be (1) val d2 = d.map { i => w(i) -> i * 2 }.setName("shuffle input 1") @@ -225,7 +225,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match d4.setName("A Cogroup") d4.collectAsMap() - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) listener.stageInfos.size should be (4) listener.stageInfos.foreach { case (stageInfo, taskInfoMetrics) => /** @@ -281,7 +281,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match .reduce { case (x, y) => x } assert(result === 1.to(akkaFrameSize).toArray) - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) val TASK_INDEX = 0 assert(listener.startedTasks.contains(TASK_INDEX)) assert(listener.startedGettingResultTasks.contains(TASK_INDEX)) @@ -297,7 +297,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val result = sc.parallelize(Seq(1), 1).map(2 * _).reduce { case (x, y) => x } assert(result === 2) - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) val TASK_INDEX = 0 assert(listener.startedTasks.contains(TASK_INDEX)) assert(listener.startedGettingResultTasks.isEmpty) @@ -352,7 +352,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match // Post events to all listeners, and wait until the queue is drained (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } - bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) // The exception should be caught, and the event should be propagated to other listeners assert(bus.listenerThreadIsAlive) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala index d97fba00976d2..c7f179e1483a5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.scheduler -import scala.collection.mutable +import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite} import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite} -import org.apache.spark.scheduler.cluster.ExecutorInfo +import scala.collection.mutable /** * Unit tests for SparkListener that require a local cluster. @@ -41,16 +41,12 @@ class SparkListenerWithClusterSuite extends SparkFunSuite with LocalSparkContext val listener = new SaveExecutorInfo sc.addSparkListener(listener) - // This test will check if the number of executors received by "SparkListener" is same as the - // number of all executors, so we need to wait until all executors are up - sc.jobProgressListener.waitUntilExecutorsUp(2, 10000) - val rdd1 = sc.parallelize(1 to 100, 4) val rdd2 = rdd1.map(_.toString) rdd2.setName("Target RDD") rdd2.count() - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(listener.addedExecutorInfo.size == 2) assert(listener.addedExecutorInfo("0").totalCores == 1) assert(listener.addedExecutorInfo("1").totalCores == 1) diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index a61ea3918f46a..a867cf83dc3f1 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -608,69 +608,4 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { manager.runAll() assert(output.toList === List(4, 3, 2)) } - - test("isInDirectory") { - val tmpDir = new File(sys.props("java.io.tmpdir")) - val parentDir = new File(tmpDir, "parent-dir") - val childDir1 = new File(parentDir, "child-dir-1") - val childDir1b = new File(parentDir, "child-dir-1b") - val childFile1 = new File(parentDir, "child-file-1.txt") - val childDir2 = new File(childDir1, "child-dir-2") - val childDir2b = new File(childDir1, "child-dir-2b") - val childFile2 = new File(childDir1, "child-file-2.txt") - val childFile3 = new File(childDir2, "child-file-3.txt") - val nullFile: File = null - parentDir.mkdir() - childDir1.mkdir() - childDir1b.mkdir() - childDir2.mkdir() - childDir2b.mkdir() - childFile1.createNewFile() - childFile2.createNewFile() - childFile3.createNewFile() - - // Identity - assert(Utils.isInDirectory(parentDir, parentDir)) - assert(Utils.isInDirectory(childDir1, childDir1)) - assert(Utils.isInDirectory(childDir2, childDir2)) - - // Valid ancestor-descendant pairs - assert(Utils.isInDirectory(parentDir, childDir1)) - assert(Utils.isInDirectory(parentDir, childFile1)) - assert(Utils.isInDirectory(parentDir, childDir2)) - assert(Utils.isInDirectory(parentDir, childFile2)) - assert(Utils.isInDirectory(parentDir, childFile3)) - assert(Utils.isInDirectory(childDir1, childDir2)) - assert(Utils.isInDirectory(childDir1, childFile2)) - assert(Utils.isInDirectory(childDir1, childFile3)) - assert(Utils.isInDirectory(childDir2, childFile3)) - - // Inverted ancestor-descendant pairs should fail - assert(!Utils.isInDirectory(childDir1, parentDir)) - assert(!Utils.isInDirectory(childDir2, parentDir)) - assert(!Utils.isInDirectory(childDir2, childDir1)) - assert(!Utils.isInDirectory(childFile1, parentDir)) - assert(!Utils.isInDirectory(childFile2, parentDir)) - assert(!Utils.isInDirectory(childFile3, parentDir)) - assert(!Utils.isInDirectory(childFile2, childDir1)) - assert(!Utils.isInDirectory(childFile3, childDir1)) - assert(!Utils.isInDirectory(childFile3, childDir2)) - - // Non-existent files or directories should fail - assert(!Utils.isInDirectory(parentDir, new File(parentDir, "one.txt"))) - assert(!Utils.isInDirectory(parentDir, new File(parentDir, "one/two.txt"))) - assert(!Utils.isInDirectory(parentDir, new File(parentDir, "one/two/three.txt"))) - - // Siblings should fail - assert(!Utils.isInDirectory(childDir1, childDir1b)) - assert(!Utils.isInDirectory(childDir1, childFile1)) - assert(!Utils.isInDirectory(childDir2, childDir2b)) - assert(!Utils.isInDirectory(childDir2, childFile2)) - - // Null files should fail without throwing NPE - assert(!Utils.isInDirectory(parentDir, nullFile)) - assert(!Utils.isInDirectory(childFile3, nullFile)) - assert(!Utils.isInDirectory(nullFile, parentDir)) - assert(!Utils.isInDirectory(nullFile, childFile3)) - } } diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 54274a83f6d66..0b14a618e755c 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -228,14 +228,14 @@ if [[ ! "$@" =~ --skip-package ]]; then # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds # share the same Zinc server. - make_binary_release "hadoop1" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver" "3030" & - make_binary_release "hadoop1-scala2.11" "-Psparkr -Phadoop-1 -Phive -Dscala-2.11" "3031" & - make_binary_release "cdh4" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & - make_binary_release "hadoop2.3" "-Psparkr -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & - make_binary_release "hadoop2.4" "-Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & - make_binary_release "mapr3" "-Pmapr3 -Psparkr -Phive -Phive-thriftserver" "3035" & - make_binary_release "mapr4" "-Pmapr4 -Psparkr -Pyarn -Phive -Phive-thriftserver" "3036" & - make_binary_release "hadoop2.4-without-hive" "-Psparkr -Phadoop-2.4 -Pyarn" "3037" & + make_binary_release "hadoop1" "-Psparkr -Psparkr-docs -Phadoop-1 -Phive -Phive-thriftserver" "3030" & + make_binary_release "hadoop1-scala2.11" "-Psparkr -Psparkr-docs -Phadoop-1 -Phive -Dscala-2.11" "3031" & + make_binary_release "cdh4" "-Psparkr -Psparkr-docs -Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & + make_binary_release "hadoop2.3" "-Psparkr -Psparkr-docs -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & + make_binary_release "hadoop2.4" "-Psparkr -Psparkr-docs -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & + make_binary_release "mapr3" "-Pmapr3 -Psparkr -Psparkr-docs -Phive -Phive-thriftserver" "3035" & + make_binary_release "mapr4" "-Pmapr4 -Psparkr -Psparkr-docs -Pyarn -Phive -Phive-thriftserver" "3036" & + make_binary_release "hadoop2.4-without-hive" "-Psparkr -Psparkr-docs -Phadoop-2.4 -Pyarn" "3037" & wait rm -rf spark-$RELEASE_VERSION-bin-*/ diff --git a/dev/run-tests b/dev/run-tests index d178e2a4601ea..7dd8d31fd44e3 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -80,19 +80,18 @@ export SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Pkinesis-asl" # Only run Hive tests if there are SQL changes. # Partial solution for SPARK-1455. if [ -n "$AMPLAB_JENKINS" ]; then - target_branch="$ghprbTargetBranch" - git fetch origin "$target_branch":"$target_branch" + git fetch origin master:master # AMP_JENKINS_PRB indicates if the current build is a pull request build. if [ -n "$AMP_JENKINS_PRB" ]; then # It is a pull request build. sql_diffs=$( - git diff --name-only "$target_branch" \ + git diff --name-only master \ | grep -e "^sql/" -e "^bin/spark-sql" -e "^sbin/start-thriftserver.sh" ) non_sql_diffs=$( - git diff --name-only "$target_branch" \ + git diff --name-only master \ | grep -v -e "^sql/" -e "^bin/spark-sql" -e "^sbin/start-thriftserver.sh" ) diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 641b0ff3c4be4..8b2a44fd72ba5 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -47,9 +47,7 @@ COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}" # GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :( SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}" -# format: http://linux.die.net/man/1/timeout -# must be less than the timeout configured on Jenkins (currently 180m) -TESTS_TIMEOUT="175m" +TESTS_TIMEOUT="150m" # format: http://linux.die.net/man/1/timeout # Array to capture all tests to run on the pull request. These tests are held under the #+ dev/tests/ directory. @@ -193,7 +191,7 @@ done test_result="$?" if [ "$test_result" -eq "124" ]; then - fail_message="**[Test build ${BUILD_DISPLAY_NAME} timed out](${BUILD_URL}console)** \ + fail_message="**[Test build ${BUILD_DISPLAY_NAME} timed out](${BUILD_URL}consoleFull)** \ for PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL}) \ after a configured wait of \`${TESTS_TIMEOUT}\`." @@ -233,7 +231,7 @@ done # post end message { result_message="\ - [Test build ${BUILD_DISPLAY_NAME} has finished](${BUILD_URL}console) for \ + [Test build ${BUILD_DISPLAY_NAME} has finished](${BUILD_URL}consoleFull) for \ PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." result_message="${result_message}\n${test_result_note}" diff --git a/docs/_config.yml b/docs/_config.yml index c0e031a83ba9c..b22b627f09007 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,8 +14,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 1.5.0-SNAPSHOT -SPARK_VERSION_SHORT: 1.5.0 +SPARK_VERSION: 1.4.0-SNAPSHOT +SPARK_VERSION_SHORT: 1.4.0 SCALA_BINARY_VERSION: "2.10" SCALA_VERSION: "2.10.4" MESOS_VERSION: 0.21.0 diff --git a/docs/monitoring.md b/docs/monitoring.md index bcf885fe4e681..e75018499003a 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -228,14 +228,6 @@ for a running application, at `http://localhost:4040/api/v1`. /applications/[app-id]/storage/rdd/[rdd-id] Details for the storage status of a given RDD - - /applications/[app-id]/logs - Download the event logs for all attempts of the given application as a zip file - - - /applications/[app-id]/[attempt-id]/logs - Download the event logs for the specified attempt of the given application as a zip file - When running on Yarn, each application has multiple attempts, so `[app-id]` is actually diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index cde5830c733e0..282ea75e1e785 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1785,13 +1785,6 @@ that these options will be deprecated in future release as more optimizations ar Configures the number of partitions to use when shuffling data for joins or aggregations. - - spark.sql.planner.externalSort - false - - When true, performs sorts spilling to disk as needed otherwise sort each partition in memory. - - # Distributed SQL Engine diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index d6d5605948a5a..64714f0b799fc 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -29,7 +29,7 @@ Next, we discuss how to use this approach in your streaming application. [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]) You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala).
import org.apache.spark.streaming.kafka.*; @@ -39,7 +39,7 @@ Next, we discuss how to use this approach in your streaming application. [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]); You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java).
@@ -105,7 +105,7 @@ Next, we discuss how to use this approach in your streaming application. streamingContext, [map of Kafka parameters], [set of topics to consume]) See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala).
import org.apache.spark.streaming.kafka.*; @@ -116,7 +116,7 @@ Next, we discuss how to use this approach in your streaming application. [map of Kafka parameters], [set of topics to consume]); See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java).
@@ -153,4 +153,4 @@ Next, we discuss how to use this approach in your streaming application. Another thing to note is that since this approach does not use Receivers, the standard receiver-related (that is, [configurations](configuration.html) of the form `spark.streaming.receiver.*` ) will not apply to the input DStreams created by this approach (will apply to other input DStreams though). Instead, use the [configurations](configuration.html) `spark.streaming.kafka.*`. An important one is `spark.streaming.kafka.maxRatePerPartition` which is the maximum rate at which each Kafka partition will be read by this direct API. -3. **Deploying:** Similar to the first approach, you can package `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR and the launch the application using `spark-submit`. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. +3. **Deploying:** Similar to the first approach, you can package `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR and the launch the application using `spark-submit`. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. \ No newline at end of file diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 84629cb9a0ca0..ee0904c9e5d54 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -219,8 +219,7 @@ def parse_args(): "(default: %default).") parser.add_option( "--hadoop-major-version", default="1", - help="Major version of Hadoop. Valid options are 1 (Hadoop 1.0.4), 2 (CDH 4.2.0), yarn " + - "(Hadoop 2.4.0) (default: %default)") + help="Major version of Hadoop (default: %default)") parser.add_option( "-D", metavar="[ADDRESS:]PORT", dest="proxy_port", help="Use SSH dynamic port forwarding to create a SOCKS proxy at " + @@ -272,8 +271,7 @@ def parse_args(): help="Launch fresh slaves, but use an existing stopped master if possible") parser.add_option( "--worker-instances", type="int", default=1, - help="Number of instances per worker: variable SPARK_WORKER_INSTANCES. Not used if YARN " + - "is used as Hadoop major version (default: %default)") + help="Number of instances per worker: variable SPARK_WORKER_INSTANCES (default: %default)") parser.add_option( "--master-opts", type="string", default="", help="Extra options to give to master through SPARK_MASTER_OPTS variable " + @@ -763,10 +761,6 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): if opts.ganglia: modules.append('ganglia') - # Clear SPARK_WORKER_INSTANCES if running on YARN - if opts.hadoop_major_version == "yarn": - opts.worker_instances = "" - # NOTE: We should clone the repository before running deploy_files to # prevent ec2-variables.sh from being overwritten print("Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format( @@ -1004,7 +998,6 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): master_addresses = [get_dns_name(i, opts.private_ips) for i in master_nodes] slave_addresses = [get_dns_name(i, opts.private_ips) for i in slave_nodes] - worker_instances_str = "%d" % opts.worker_instances if opts.worker_instances else "" template_vars = { "master_list": '\n'.join(master_addresses), "active_master": active_master, @@ -1018,7 +1011,7 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): "spark_version": spark_v, "tachyon_version": tachyon_v, "hadoop_major_version": opts.hadoop_major_version, - "spark_worker_instances": worker_instances_str, + "spark_worker_instances": "%d" % opts.worker_instances, "spark_master_opts": opts.master_opts } diff --git a/examples/pom.xml b/examples/pom.xml index e6884b09dca94..e4efee7b5e647 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../pom.xml diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py index e1fd85b082c08..96ddac761d698 100644 --- a/examples/src/main/python/parquet_inputformat.py +++ b/examples/src/main/python/parquet_inputformat.py @@ -51,7 +51,7 @@ parquet_rdd = sc.newAPIHadoopFile( path, - 'org.apache.parquet.avro.AvroParquetInputFormat', + 'parquet.avro.AvroParquetInputFormat', 'java.lang.Void', 'org.apache.avro.generic.IndexedRecord', valueConverter='org.apache.spark.examples.pythonconverters.IndexedRecordToJavaConverter') diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 7a7dccc3d0922..71f2b6fe18bd1 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 14f7daaf417e0..a345c03582ad6 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml index 8059c443827ef..0b79f47647f6b 100644 --- a/external/kafka-assembly/pom.xml +++ b/external/kafka-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index ded863bd985e8..5734d55bf4784 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 0e41e5781784b..7d102e10ab60f 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 178ae8de13b57..d28e3e1846d70 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 37bfd10d43663..9998c11c85171 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index f138251748c9e..4351a8a12fe21 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index c6f60bc907438..25847a1b33d9c 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml @@ -40,13 +40,6 @@ spark-streaming_${scala.binary.version} ${project.version} - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - test-jar - test - org.apache.spark spark-streaming_${scala.binary.version} diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index 478d0019a25f0..e14bbae4a9b6e 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 853dea9a7795e..28b41228feb3d 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../pom.xml diff --git a/launcher/pom.xml b/launcher/pom.xml index 48dd0d5f9106b..cc177d23dff77 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../pom.xml diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index 62492f9baf3bb..929b29a49ed70 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -53,33 +53,21 @@ public static void main(String[] argsArray) throws Exception { List args = new ArrayList(Arrays.asList(argsArray)); String className = args.remove(0); - boolean printLaunchCommand = !isEmpty(System.getenv("SPARK_PRINT_LAUNCH_COMMAND")); + boolean printLaunchCommand; + boolean printUsage; AbstractCommandBuilder builder; - if (className.equals("org.apache.spark.deploy.SparkSubmit")) { - try { + try { + if (className.equals("org.apache.spark.deploy.SparkSubmit")) { builder = new SparkSubmitCommandBuilder(args); - } catch (IllegalArgumentException e) { - printLaunchCommand = false; - System.err.println("Error: " + e.getMessage()); - System.err.println(); - - MainClassOptionParser parser = new MainClassOptionParser(); - try { - parser.parse(args); - } catch (Exception ignored) { - // Ignore parsing exceptions. - } - - List help = new ArrayList(); - if (parser.className != null) { - help.add(parser.CLASS); - help.add(parser.className); - } - help.add(parser.USAGE_ERROR); - builder = new SparkSubmitCommandBuilder(help); + } else { + builder = new SparkClassCommandBuilder(className, args); } - } else { - builder = new SparkClassCommandBuilder(className, args); + printLaunchCommand = !isEmpty(System.getenv("SPARK_PRINT_LAUNCH_COMMAND")); + printUsage = false; + } catch (IllegalArgumentException e) { + builder = new UsageCommandBuilder(e.getMessage()); + printLaunchCommand = false; + printUsage = true; } Map env = new HashMap(); @@ -90,7 +78,13 @@ public static void main(String[] argsArray) throws Exception { } if (isWindows()) { - System.out.println(prepareWindowsCommand(cmd, env)); + // When printing the usage message, we can't use "cmd /v" since that prevents the env + // variable from being seen in the caller script. So do not call prepareWindowsCommand(). + if (printUsage) { + System.out.println(join(" ", cmd)); + } else { + System.out.println(prepareWindowsCommand(cmd, env)); + } } else { // In bash, use NULL as the arg separator since it cannot be used in an argument. List bashCmd = prepareBashCommand(cmd, env); @@ -141,30 +135,33 @@ private static List prepareBashCommand(List cmd, Map extra) { - + public List buildCommand(Map env) { + if (isWindows()) { + return Arrays.asList("set", "SPARK_LAUNCHER_USAGE_ERROR=" + message); + } else { + return Arrays.asList("usage", message, "1"); + } } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 3e5a2820b6c11..7d387d406edae 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -77,7 +77,6 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { } private final List sparkArgs; - private final boolean printHelp; /** * Controls whether mixing spark-submit arguments with app arguments is allowed. This is needed @@ -88,11 +87,10 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { SparkSubmitCommandBuilder() { this.sparkArgs = new ArrayList(); - this.printHelp = false; } SparkSubmitCommandBuilder(List args) { - this.sparkArgs = new ArrayList(); + this(); List submitArgs = args; if (args.size() > 0 && args.get(0).equals(PYSPARK_SHELL)) { this.allowsMixedArguments = true; @@ -106,16 +104,14 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { this.allowsMixedArguments = false; } - OptionParser parser = new OptionParser(); - parser.parse(submitArgs); - this.printHelp = parser.helpRequested; + new OptionParser().parse(submitArgs); } @Override public List buildCommand(Map env) throws IOException { - if (PYSPARK_SHELL_RESOURCE.equals(appResource) && !printHelp) { + if (PYSPARK_SHELL_RESOURCE.equals(appResource)) { return buildPySparkShellCommand(env); - } else if (SPARKR_SHELL_RESOURCE.equals(appResource) && !printHelp) { + } else if (SPARKR_SHELL_RESOURCE.equals(appResource)) { return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); @@ -315,8 +311,6 @@ private boolean isThriftServer(String mainClass) { private class OptionParser extends SparkSubmitOptionParser { - boolean helpRequested = false; - @Override protected boolean handle(String opt, String value) { if (opt.equals(MASTER)) { @@ -347,9 +341,6 @@ protected boolean handle(String opt, String value) { allowsMixedArguments = true; appResource = specialClasses.get(value); } - } else if (opt.equals(HELP) || opt.equals(USAGE_ERROR)) { - helpRequested = true; - sparkArgs.add(opt); } else { sparkArgs.add(opt); if (value != null) { @@ -369,7 +360,6 @@ protected boolean handleUnknown(String opt) { appArgs.add(opt); return true; } else { - checkArgument(!opt.startsWith("-"), "Unrecognized option: %s", opt); sparkArgs.add(opt); return false; } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java index b88bba883ac65..229000087688f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java @@ -61,7 +61,6 @@ class SparkSubmitOptionParser { // Options that do not take arguments. protected final String HELP = "--help"; protected final String SUPERVISE = "--supervise"; - protected final String USAGE_ERROR = "--usage-error"; protected final String VERBOSE = "--verbose"; protected final String VERSION = "--version"; @@ -121,7 +120,6 @@ class SparkSubmitOptionParser { final String[][] switches = { { HELP, "-h" }, { SUPERVISE }, - { USAGE_ERROR }, { VERBOSE, "-v" }, { VERSION }, }; diff --git a/mllib/pom.xml b/mllib/pom.xml index b16058ddc203a..65c647a91d192 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../pom.xml diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index f4e250757560a..a2dc8a8b960c5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -88,9 +88,6 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod /** * :: Experimental :: * Model fitted by [[StringIndexer]]. - * NOTE: During transformation, if the input column does not exist, - * [[StringIndexerModel.transform]] would return the input dataset unmodified. - * This is a temporary fix for the case when target labels do not exist during prediction. */ @Experimental class StringIndexerModel private[ml] ( @@ -115,12 +112,6 @@ class StringIndexerModel private[ml] ( def setOutputCol(value: String): this.type = set(outputCol, value) override def transform(dataset: DataFrame): DataFrame = { - if (!dataset.schema.fieldNames.contains($(inputCol))) { - logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " + - "Skip StringIndexerModel.") - return dataset - } - val indexer = udf { label: String => if (labelToIndex.contains(label)) { labelToIndex(label) @@ -137,11 +128,6 @@ class StringIndexerModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { - if (schema.fieldNames.contains($(inputCol))) { - validateAndTransformSchema(schema) - } else { - // If the input column does not exist during transformation, we skip StringIndexerModel. - schema - } + validateAndTransformSchema(schema) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index ba94d6a3a80a9..473488dce9b0d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -69,10 +69,14 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali } } - /** Creates a param pair with the given value (for Java). */ + /** + * Creates a param pair with the given value (for Java). + */ def w(value: T): ParamPair[T] = this -> value - /** Creates a param pair with the given value (for Scala). */ + /** + * Creates a param pair with the given value (for Scala). + */ def ->(value: T): ParamPair[T] = ParamPair(this, value) override final def toString: String = s"${parent}__$name" @@ -186,7 +190,6 @@ class DoubleParam(parent: String, name: String, doc: String, isValid: Double => def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) - /** Creates a param pair with the given value (for Java). */ override def w(value: Double): ParamPair[Double] = super.w(value) } @@ -206,7 +209,6 @@ class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolea def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) - /** Creates a param pair with the given value (for Java). */ override def w(value: Int): ParamPair[Int] = super.w(value) } @@ -226,7 +228,6 @@ class FloatParam(parent: String, name: String, doc: String, isValid: Float => Bo def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) - /** Creates a param pair with the given value (for Java). */ override def w(value: Float): ParamPair[Float] = super.w(value) } @@ -246,7 +247,6 @@ class LongParam(parent: String, name: String, doc: String, isValid: Long => Bool def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) - /** Creates a param pair with the given value (for Java). */ override def w(value: Long): ParamPair[Long] = super.w(value) } @@ -260,7 +260,6 @@ class BooleanParam(parent: String, name: String, doc: String) // No need for isV def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) - /** Creates a param pair with the given value (for Java). */ override def w(value: Boolean): ParamPair[Boolean] = super.w(value) } @@ -275,6 +274,8 @@ class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array def this(parent: Params, name: String, doc: String) = this(parent, name, doc, ParamValidators.alwaysTrue) + override def w(value: Array[String]): ParamPair[Array[String]] = super.w(value) + /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray) } @@ -290,9 +291,10 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array def this(parent: Params, name: String, doc: String) = this(parent, name, doc, ParamValidators.alwaysTrue) + override def w(value: Array[Double]): ParamPair[Array[Double]] = super.w(value) + /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ - def w(value: java.util.List[java.lang.Double]): ParamPair[Array[Double]] = - w(value.asScala.map(_.asInstanceOf[Double]).toArray) + def w(value: java.util.List[Double]): ParamPair[Array[Double]] = w(value.asScala.toArray) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index cb29392e8bc63..6434b64aed15d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -135,7 +135,7 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM logInfo(s"Best set of parameters:\n${epm(bestIndex)}") logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] - copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this)) + copyValues(new CrossValidatorModel(uid, bestModel).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -158,8 +158,7 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM @Experimental class CrossValidatorModel private[ml] ( override val uid: String, - val bestModel: Model[_], - val avgMetrics: Array[Double]) + val bestModel: Model[_]) extends Model[CrossValidatorModel] with CrossValidatorParams { override def validateParams(): Unit = { @@ -176,10 +175,7 @@ class CrossValidatorModel private[ml] ( } override def copy(extra: ParamMap): CrossValidatorModel = { - val copied = new CrossValidatorModel( - uid, - bestModel.copy(extra).asInstanceOf[Model[_]], - avgMetrics.clone()) + val copied = new CrossValidatorModel(uid, bestModel.copy(extra).asInstanceOf[Model[_]]) copyValues(copied, extra) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index fc509d2ba1470..70b0e40948e51 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -22,7 +22,6 @@ import scala.collection.mutable.IndexedSeq import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, Vector => BV} import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, Matrices, Vector, Vectors} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLUtils @@ -189,9 +188,6 @@ class GaussianMixture private ( new GaussianMixtureModel(weights, gaussians) } - /** Java-friendly version of [[run()]] */ - def run(data: JavaRDD[Vector]): GaussianMixtureModel = run(data.rdd) - /** Average of dense breeze vectors */ private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = { val v = BDV.zeros[Double](x(0).length) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index cb807c8038101..5fc2cb1b62d33 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -25,7 +25,6 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.{MLUtils, Loader, Saveable} @@ -47,7 +46,7 @@ import org.apache.spark.sql.{SQLContext, Row} @Experimental class GaussianMixtureModel( val weights: Array[Double], - val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable { + val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable{ require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match") @@ -66,10 +65,6 @@ class GaussianMixtureModel( responsibilityMatrix.map(r => r.indexOf(r.max)) } - /** Java-friendly version of [[predict()]] */ - def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = - predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]] - /** * Given the input vectors, return the membership value of each vector * to all mixture components. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 974b26924dfb8..6cf26445f20a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -20,7 +20,6 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum} import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.JavaPairRDD import org.apache.spark.graphx.{VertexId, EdgeContext, Graph} import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix} import org.apache.spark.rdd.RDD @@ -346,11 +345,6 @@ class DistributedLDAModel private ( } } - /** Java-friendly version of [[topicDistributions]] */ - def javaTopicDistributions: JavaPairRDD[java.lang.Long, Vector] = { - JavaPairRDD.fromRDD(topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]]) - } - // TODO: // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index d9b34cec64894..c21e4fe7dc9b6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -21,10 +21,8 @@ import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.JavaSparkContext._ import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStream} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -236,9 +234,6 @@ class StreamingKMeans( } } - /** Java-friendly version of `trainOn`. */ - def trainOn(data: JavaDStream[Vector]): Unit = trainOn(data.dstream) - /** * Use the clustering model to make predictions on batches of data from a DStream. * @@ -250,11 +245,6 @@ class StreamingKMeans( data.map(model.predict) } - /** Java-friendly version of `predictOn`. */ - def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Integer] = { - JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Integer]]) - } - /** * Use the model to make predictions on the values of a DStream and carry over its keys. * @@ -267,14 +257,6 @@ class StreamingKMeans( data.mapValues(model.predict) } - /** Java-friendly version of `predictOnValues`. */ - def predictOnValues[K]( - data: JavaPairDStream[K, Vector]): JavaPairDStream[K, java.lang.Integer] = { - implicit val tag = fakeClassTag[K] - JavaPairDStream.fromPairDStream( - predictOnValues(data.dstream).asInstanceOf[DStream[(K, java.lang.Integer)]]) - } - /** Check whether cluster centers have been initialized. */ private[this] def assertInitialized(): Unit = { if (model.clusterCenters == null) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index 900007ec6bc74..b3fad0c52d655 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -18,7 +18,6 @@ package org.apache.spark.mllib.stat import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.linalg.{Matrix, Vector} import org.apache.spark.mllib.regression.LabeledPoint @@ -81,10 +80,6 @@ object Statistics { */ def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) - /** Java-friendly version of [[corr()]] */ - def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double]): Double = - corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]]) - /** * Compute the correlation for the input RDDs using the specified method. * Methods currently supported: `pearson` (default), `spearman`. @@ -101,10 +96,6 @@ object Statistics { */ def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) - /** Java-friendly version of [[corr()]] */ - def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double], method: String): Double = - corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]], method) - /** * Conduct Pearson's chi-squared goodness of fit test of the observed data against the * expected distribution. diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java similarity index 95% rename from mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java rename to mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java index 55787f8606d48..640d2ec55e4e7 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.mllib.classification; +package org.apache.spark.ml.classification; import java.io.Serializable; import java.util.List; @@ -28,6 +28,7 @@ import org.junit.Test; import org.apache.spark.SparkConf; +import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java index 9890155e9f865..e7df10dfa63ac 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java @@ -50,7 +50,6 @@ public void testParams() { testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a"); Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0); Assert.assertEquals(testParams.getMyStringParam(), "a"); - Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[] {1.0, 2.0}, 0.0); } @Test diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java index ff5929235ac2c..947ae3a2ce06f 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java @@ -51,8 +51,7 @@ public String uid() { public int getMyIntParam() { return (Integer)getOrDefault(myIntParam_); } public JavaTestParams setMyIntParam(int value) { - set(myIntParam_, value); - return this; + set(myIntParam_, value); return this; } private DoubleParam myDoubleParam_; @@ -61,8 +60,7 @@ public JavaTestParams setMyIntParam(int value) { public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam_); } public JavaTestParams setMyDoubleParam(double value) { - set(myDoubleParam_, value); - return this; + set(myDoubleParam_, value); return this; } private Param myStringParam_; @@ -71,18 +69,7 @@ public JavaTestParams setMyDoubleParam(double value) { public String getMyStringParam() { return getOrDefault(myStringParam_); } public JavaTestParams setMyStringParam(String value) { - set(myStringParam_, value); - return this; - } - - private DoubleArrayParam myDoubleArrayParam_; - public DoubleArrayParam myDoubleArrayParam() { return myDoubleArrayParam_; } - - public double[] getMyDoubleArrayParam() { return getOrDefault(myDoubleArrayParam_); } - - public JavaTestParams setMyDoubleArrayParam(double[] value) { - set(myDoubleArrayParam_, value); - return this; + set(myStringParam_, value); return this; } private void init() { @@ -92,14 +79,8 @@ private void init() { List validStrings = Lists.newArrayList("a", "b"); myStringParam_ = new Param(this, "myStringParam", "this is a string param", ParamValidators.inArray(validStrings)); - myDoubleArrayParam_ = - new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param"); - - setDefault(myIntParam(), 1); - setDefault(myIntParam().w(1)); - setDefault(myDoubleParam(), 0.5); + setDefault(myIntParam_, 1); + setDefault(myDoubleParam_, 0.5); setDefault(myIntParam().w(1), myDoubleParam().w(0.5)); - setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0}); - setDefault(myDoubleArrayParam().w(new double[] {1.0, 2.0})); } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java deleted file mode 100644 index 467a7a69e8f30..0000000000000 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.clustering; - -import java.io.Serializable; -import java.util.List; - -import com.google.common.collect.Lists; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; - -import static org.junit.Assert.assertEquals; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; - -public class JavaGaussianMixtureSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaGaussianMixture"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } - - @Test - public void runGaussianMixture() { - List points = Lists.newArrayList( - Vectors.dense(1.0, 2.0, 6.0), - Vectors.dense(1.0, 3.0, 0.0), - Vectors.dense(1.0, 4.0, 6.0) - ); - - JavaRDD data = sc.parallelize(points, 2); - GaussianMixtureModel model = new GaussianMixture().setK(2).setMaxIterations(1).setSeed(1234) - .run(data); - assertEquals(model.gaussians().length, 2); - JavaRDD predictions = model.predict(data); - predictions.first(); - } -} diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index 581c033f08ebe..96c2da169961f 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -107,10 +107,6 @@ public void distributedLDAModel() { // Check: log probabilities assert(model.logLikelihood() < 0.0); assert(model.logPrior() < 0.0); - - // Check: topic distributions - JavaPairRDD topicDistributions = model.javaTopicDistributions(); - assertEquals(topicDistributions.count(), corpus.count()); } @Test diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java deleted file mode 100644 index 3b0e879eec77f..0000000000000 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.clustering; - -import java.io.Serializable; -import java.util.List; - -import scala.Tuple2; - -import com.google.common.collect.Lists; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; - -import static org.apache.spark.streaming.JavaTestUtils.*; - -import org.apache.spark.SparkConf; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.streaming.Duration; -import org.apache.spark.streaming.api.java.JavaDStream; -import org.apache.spark.streaming.api.java.JavaPairDStream; -import org.apache.spark.streaming.api.java.JavaStreamingContext; - -public class JavaStreamingKMeansSuite implements Serializable { - - protected transient JavaStreamingContext ssc; - - @Before - public void setUp() { - SparkConf conf = new SparkConf() - .setMaster("local[2]") - .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); - ssc = new JavaStreamingContext(conf, new Duration(1000)); - ssc.checkpoint("checkpoint"); - } - - @After - public void tearDown() { - ssc.stop(); - ssc = null; - } - - @Test - @SuppressWarnings("unchecked") - public void javaAPI() { - List trainingBatch = Lists.newArrayList( - Vectors.dense(1.0), - Vectors.dense(0.0)); - JavaDStream training = - attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2); - List> testBatch = Lists.newArrayList( - new Tuple2(10, Vectors.dense(1.0)), - new Tuple2(11, Vectors.dense(0.0))); - JavaPairDStream test = JavaPairDStream.fromJavaDStream( - attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2)); - StreamingKMeans skmeans = new StreamingKMeans() - .setK(1) - .setDecayFactor(1.0) - .setInitialCenters(new Vector[]{Vectors.dense(1.0)}, new double[]{0.0}); - skmeans.trainOn(training); - JavaPairDStream prediction = skmeans.predictOnValues(test); - attachTestOutputStream(prediction.count()); - runStreams(ssc, 2, 2); - } -} diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java deleted file mode 100644 index 62f7f26b7c98f..0000000000000 --- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.stat; - -import java.io.Serializable; - -import com.google.common.collect.Lists; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; - -import static org.junit.Assert.assertEquals; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; - -public class JavaStatisticsSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaStatistics"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } - - @Test - public void testCorr() { - JavaRDD x = sc.parallelize(Lists.newArrayList(1.0, 2.0, 3.0, 4.0)); - JavaRDD y = sc.parallelize(Lists.newArrayList(1.1, 2.2, 3.1, 4.3)); - - Double corr1 = Statistics.corr(x, y); - Double corr2 = Statistics.corr(x, y, "pearson"); - // Check default method - assertEquals(corr1, corr2); - } -} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 5f557e16e5150..cbf1e8ddcb48a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -60,12 +60,4 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) assert(output === expected) } - - test("StringIndexerModel should keep silent if the input column does not exist.") { - val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c")) - .setInputCol("label") - .setOutputCol("labelIndex") - val df = sqlContext.range(0L, 10L) - assert(indexerModel.transform(df).eq(df)) - } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 9b3619f0046ea..5ba469c7b10a0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -56,7 +56,6 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) - assert(cvModel.avgMetrics.length === lrParamMaps.length) } test("validateParams should check estimatorParamMaps") { diff --git a/network/common/pom.xml b/network/common/pom.xml index a85e0a66f4a30..0c3147761cfc5 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index 4b5bfcb6f04bc..7dc7c65825e34 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml index a99f7c4392d3d..1e2e9c80af6cc 100644 --- a/network/yarn/pom.xml +++ b/network/yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/pom.xml b/pom.xml index e28d4b9fc2b17..711edf9efad2b 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ @@ -118,6 +118,7 @@ 2.3.4-spark 1.6 spark + 2.0.1 0.21.1 shaded-protobuf 1.7.10 @@ -136,7 +137,7 @@ 0.13.1 10.10.1.1 - 1.7.0 + 1.6.0rc3 1.2.4 8.1.14.v20131031 3.0.0.v201112011016 @@ -268,18 +269,6 @@ false - - - spark-1.4-staging - Spark 1.4 RC4 Staging Repository - https://repository.apache.org/content/repositories/orgapachespark-1112 - - true - - - false - - @@ -1080,13 +1069,13 @@ - org.apache.parquet + com.twitter parquet-column ${parquet.version} ${parquet.deps.scope} - org.apache.parquet + com.twitter parquet-hadoop ${parquet.version} ${parquet.deps.scope} @@ -1216,6 +1205,15 @@ -target ${java.version} + + + + org.scalamacros + paradise_${scala.version} + ${scala.macros.version} + + @@ -1254,7 +1252,6 @@ ${test.java.home} - test true ${spark.test.home} 1 @@ -1287,7 +1284,6 @@ ${test.java.home} - test true ${spark.test.home} 1 @@ -1430,8 +1426,6 @@ 2.3 false - - false diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 5812b72f0aa78..dde92949fa175 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -91,8 +91,7 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - // TODO: Change this once Spark 1.4.0 is released - val previousSparkVersion = "1.4.0-rc4" + val previousSparkVersion = "1.3.0" val fullId = "spark-" + projectRef.project + "_2.10" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 73e4bfd78e577..8da72b3fa7cdb 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -34,22 +34,6 @@ import com.typesafe.tools.mima.core.ProblemFilters._ object MimaExcludes { def excludes(version: String) = version match { - case v if v.startsWith("1.5") => - Seq( - MimaBuild.excludeSparkPackage("deploy"), - // These are needed if checking against the sbt build, since they are part of - // the maven-generated artifacts in 1.3. - excludePackage("org.spark-project.jetty"), - MimaBuild.excludeSparkPackage("unused"), - // JavaRDDLike is not meant to be extended by user programs - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.partitioner"), - // Mima false positive (was a private[spark] class) - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.util.collection.PairIterator"), - // SQL execution is considered private. - excludePackage("org.apache.spark.sql.execution") - ) case v if v.startsWith("1.4") => Seq( MimaBuild.excludeSparkPackage("deploy"), diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index ef3a175bac209..9a849639233bc 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -178,6 +178,9 @@ object SparkBuild extends PomBuild { /* Enable unidoc only for the root spark project */ enable(Unidoc.settings)(spark) + /* Catalyst macro settings */ + enable(Catalyst.settings)(catalyst) + /* Spark SQL Core console settings */ enable(SQL.settings)(sql) @@ -272,6 +275,14 @@ object OldDeps { ) } +object Catalyst { + lazy val settings = Seq( + addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full), + // Quasiquotes break compiling scala doc... + // TODO: Investigate fixing this. + sources in (Compile, doc) ~= (_ filter (_.getName contains "codegen"))) +} + object SQL { lazy val settings = Seq( initialCommands in console := @@ -504,7 +515,6 @@ object TestSettings { javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", - javaOptions in Test += "-Dderby.system.durability=test", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") .map { case (k,v) => s"-D$k=$v" }.toSeq, javaOptions in Test += "-ea", diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 44d90f1437bc9..aeb7ad4f2f83e 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -324,12 +324,10 @@ def stop(self): with SparkContext._lock: SparkContext._active_spark_context = None - def range(self, start, end=None, step=1, numSlices=None): + def range(self, start, end, step=1, numSlices=None): """ Create a new RDD of int containing elements from `start` to `end` - (exclusive), increased by `step` every element. Can be called the same - way as python's built-in range() function. If called with a single argument, - the argument is interpreted as `end`, and `start` is set to 0. + (exclusive), increased by `step` every element. :param start: the start value :param end: the end value (exclusive) @@ -337,17 +335,9 @@ def range(self, start, end=None, step=1, numSlices=None): :param numSlices: the number of partitions of the new RDD :return: An RDD of int - >>> sc.range(5).collect() - [0, 1, 2, 3, 4] - >>> sc.range(2, 4).collect() - [2, 3] >>> sc.range(1, 7, 2).collect() [1, 3, 5] """ - if end is None: - end = start - start = 0 - return self.parallelize(xrange(start, end, step), numSlices) def parallelize(self, c, numSlices=None): diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 599c9ac5794a2..9fdf43c3e6eb5 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -131,7 +131,7 @@ def udf(self): return UDFRegistration(self) @since(1.4) - def range(self, start, end=None, step=1, numPartitions=None): + def range(self, start, end, step=1, numPartitions=None): """ Create a :class:`DataFrame` with single LongType column named `id`, containing elements in a range from `start` to `end` (exclusive) with @@ -145,20 +145,10 @@ def range(self, start, end=None, step=1, numPartitions=None): >>> sqlContext.range(1, 7, 2).collect() [Row(id=1), Row(id=3), Row(id=5)] - - If only one argument is specified, it will be used as the end value. - - >>> sqlContext.range(3).collect() - [Row(id=0), Row(id=1), Row(id=2)] """ if numPartitions is None: numPartitions = self._sc.defaultParallelism - - if end is None: - jdf = self._ssql_ctx.range(0, int(start), int(step), int(numPartitions)) - else: - jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions)) - + jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions)) return DataFrame(jdf, self) @ignore_unicode_prefix diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 902504df5b11b..7673153abe0e2 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -616,19 +616,7 @@ def describe(self, *cols): | min| 2| | max| 5| +-------+---+ - >>> df.describe(['age', 'name']).show() - +-------+---+-----+ - |summary|age| name| - +-------+---+-----+ - | count| 2| 2| - | mean|3.5| null| - | stddev|1.5| null| - | min| 2|Alice| - | max| 5| Bob| - +-------+---+-----+ """ - if len(cols) == 1 and isinstance(cols[0], list): - cols = cols[0] jdf = self._jdf.describe(self._jseq(cols)) return DataFrame(jdf, self.sql_ctx) @@ -1201,30 +1189,15 @@ def withColumnRenamed(self, existing, new): @since(1.4) @ignore_unicode_prefix - def drop(self, col): + def drop(self, colName): """Returns a new :class:`DataFrame` that drops the specified column. - :param col: a string name of the column to drop, or a - :class:`Column` to drop. + :param colName: string, name of the column to drop. >>> df.drop('age').collect() [Row(name=u'Alice'), Row(name=u'Bob')] - - >>> df.drop(df.age).collect() - [Row(name=u'Alice'), Row(name=u'Bob')] - - >>> df.join(df2, df.name == df2.name, 'inner').drop(df.name).collect() - [Row(age=5, height=85, name=u'Bob')] - - >>> df.join(df2, df.name == df2.name, 'inner').drop(df2.name).collect() - [Row(age=5, name=u'Bob', height=85)] """ - if isinstance(col, basestring): - jdf = self._jdf.drop(col) - elif isinstance(col, Column): - jdf = self._jdf.drop(col._jc) - else: - raise TypeError("col should be a string or a Column") + jdf = self._jdf.drop(colName) return DataFrame(jdf, self.sql_ctx) @since(1.3) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a6fce50c76c2b..6e498f0af0af5 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -131,8 +131,6 @@ def test_range(self): self.assertEqual(self.sqlCtx.range(1, 1).count(), 0) self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1) self.assertEqual(self.sqlCtx.range(0, 1 << 40, 1 << 39).count(), 2) - self.assertEqual(self.sqlCtx.range(-2).count(), 0) - self.assertEqual(self.sqlCtx.range(3).count(), 3) def test_explode(self): from pyspark.sql.functions import explode diff --git a/repl/pom.xml b/repl/pom.xml index 85f7bc8ac1024..6e5cb7f77e1df 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index f4b1cc3a4ffe7..d9e1cdb84bb27 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml @@ -36,6 +36,10 @@ + + org.scala-lang + scala-compiler + org.scala-lang scala-reflect @@ -63,11 +67,6 @@ scalacheck_${scala.binary.version} test - - org.codehaus.janino - janino - 2.7.8 - target/scala-${scala.binary.version}/classes @@ -109,6 +108,13 @@ !scala-2.11 + + + org.scalamacros + quasiquotes_${scala.binary.version} + ${scala.macros.version} + + diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index ec97fe603c44f..bb546b3086b33 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,25 +17,23 @@ package org.apache.spark.sql.catalyst.expressions; -import javax.annotation.Nullable; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; -import java.util.Set; - +import scala.collection.Map; import scala.collection.Seq; import scala.collection.mutable.ArraySeq; +import javax.annotation.Nullable; +import java.math.BigDecimal; +import java.sql.Date; +import java.util.*; + import org.apache.spark.sql.Row; -import org.apache.spark.sql.BaseMutableRow; import org.apache.spark.sql.types.DataType; +import static org.apache.spark.sql.types.DataTypes.*; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.UTF8String; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.bitset.BitSetMethods; -import static org.apache.spark.sql.types.DataTypes.*; - /** * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. * @@ -51,7 +49,7 @@ * * Instances of `UnsafeRow` act as pointers to row data stored in this format. */ -public final class UnsafeRow extends BaseMutableRow { +public final class UnsafeRow implements MutableRow { private Object baseObject; private long baseOffset; @@ -229,11 +227,21 @@ public int size() { return numFields; } + @Override + public int length() { + return size(); + } + @Override public StructType schema() { return schema; } + @Override + public Object apply(int i) { + return get(i); + } + @Override public Object get(int i) { assertIndexIsValid(i); @@ -331,7 +339,60 @@ public String getString(int i) { return getUTF8String(i).toString(); } + @Override + public BigDecimal getDecimal(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Date getDate(int i) { + throw new UnsupportedOperationException(); + } + @Override + public Seq getSeq(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public List getList(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Map getMap(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public scala.collection.immutable.Map getValuesMap(Seq fieldNames) { + throw new UnsupportedOperationException(); + } + + @Override + public java.util.Map getJavaMap(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Row getStruct(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public T getAs(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public T getAs(String fieldName) { + throw new UnsupportedOperationException(); + } + + @Override + public int fieldIndex(String name) { + throw new UnsupportedOperationException(); + } @Override public Row copy() { @@ -351,4 +412,24 @@ public Seq toSeq() { } return values; } + + @Override + public String toString() { + return mkString("[", ",", "]"); + } + + @Override + public String mkString() { + return toSeq().mkString(); + } + + @Override + public String mkString(String sep) { + return toSeq().mkString(sep); + } + + @Override + public String mkString(String start, String sep, String end) { + return toSeq().mkString(start, sep, end); + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java deleted file mode 100644 index acec2bf4520f2..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseMutableRow.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql; - -import org.apache.spark.sql.catalyst.expressions.MutableRow; - -public abstract class BaseMutableRow extends BaseRow implements MutableRow { - - @Override - public void update(int ordinal, Object value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setInt(int ordinal, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setLong(int ordinal, long value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setDouble(int ordinal, double value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setBoolean(int ordinal, boolean value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setShort(int ordinal, short value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setByte(int ordinal, byte value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setFloat(int ordinal, float value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setString(int ordinal, String value) { - throw new UnsupportedOperationException(); - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java deleted file mode 100644 index d138b43a3482b..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql; - -import java.math.BigDecimal; -import java.sql.Date; -import java.util.List; - -import scala.collection.Seq; -import scala.collection.mutable.ArraySeq; - -import org.apache.spark.sql.catalyst.expressions.GenericRow; -import org.apache.spark.sql.types.StructType; - -public abstract class BaseRow implements Row { - - @Override - final public int length() { - return size(); - } - - @Override - public boolean anyNull() { - final int n = size(); - for (int i=0; i < n; i++) { - if (isNullAt(i)) { - return true; - } - } - return false; - } - - @Override - public StructType schema() { throw new UnsupportedOperationException(); } - - @Override - final public Object apply(int i) { - return get(i); - } - - @Override - public int getInt(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public long getLong(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public float getFloat(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public double getDouble(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public byte getByte(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public short getShort(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean getBoolean(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public String getString(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public BigDecimal getDecimal(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Date getDate(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Seq getSeq(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public List getList(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public scala.collection.Map getMap(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public scala.collection.immutable.Map getValuesMap(Seq fieldNames) { - throw new UnsupportedOperationException(); - } - - @Override - public java.util.Map getJavaMap(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Row getStruct(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public T getAs(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public T getAs(String fieldName) { - throw new UnsupportedOperationException(); - } - - @Override - public int fieldIndex(String name) { - throw new UnsupportedOperationException(); - } - - @Override - public Row copy() { - final int n = size(); - Object[] arr = new Object[n]; - for (int i = 0; i < n; i++) { - arr[i] = get(i); - } - return new GenericRow(arr); - } - - @Override - public Seq toSeq() { - final int n = size(); - final ArraySeq values = new ArraySeq(n); - for (int i = 0; i < n; i++) { - values.update(i, get(i)); - } - return values; - } - - @Override - public String toString() { - return mkString("[", ",", "]"); - } - - @Override - public String mkString() { - return toSeq().mkString(); - } - - @Override - public String mkString(String sep) { - return toSeq().mkString(sep); - } - - @Override - public String mkString(String start, String sep, String end) { - return toSeq().mkString(start, sep, end); - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5883d938b676d..bc17169f35a46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -235,8 +235,9 @@ class Analyzer( } /** - * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from - * a logical plan node's children. + * Replaces [[UnresolvedAttribute]]s with concrete + * [[catalyst.expressions.AttributeReference AttributeReferences]] from a logical plan node's + * children. */ object ResolveReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { @@ -454,7 +455,7 @@ class Analyzer( } /** - * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. + * Replaces [[UnresolvedFunction]]s with concrete [[catalyst.expressions.Expression Expressions]]. */ object ResolveFunctions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -845,8 +846,9 @@ class Analyzer( } /** - * Removes [[Subquery]] operators from the plan. Subqueries are only required to provide - * scoping information for attributes and can be removed once analysis is complete. + * Removes [[catalyst.plans.logical.Subquery Subquery]] operators from the plan. Subqueries are + * only required to provide scoping information for attributes and can be removed once analysis is + * complete. */ object EliminateSubQueries extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 9b8a08a88dcb0..b064600e94fac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -130,7 +130,7 @@ trait HiveTypeCoercion { * the appropriate numeric equivalent. */ object ConvertNaNs extends Rule[LogicalPlan] { - private val StringNaN = Literal("NaN") + private val stringNaN = Literal("NaN") def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressions { @@ -138,20 +138,20 @@ trait HiveTypeCoercion { case e if !e.childrenResolved => e /* Double Conversions */ - case b @ BinaryExpression(StringNaN, right @ DoubleType()) => - b.makeCopy(Array(Literal(Double.NaN), right)) - case b @ BinaryExpression(left @ DoubleType(), StringNaN) => - b.makeCopy(Array(left, Literal(Double.NaN))) + case b: BinaryExpression if b.left == stringNaN && b.right.dataType == DoubleType => + b.makeCopy(Array(b.right, Literal(Double.NaN))) + case b: BinaryExpression if b.left.dataType == DoubleType && b.right == stringNaN => + b.makeCopy(Array(Literal(Double.NaN), b.left)) + case b: BinaryExpression if b.left == stringNaN && b.right == stringNaN => + b.makeCopy(Array(Literal(Double.NaN), b.left)) /* Float Conversions */ - case b @ BinaryExpression(StringNaN, right @ FloatType()) => - b.makeCopy(Array(Literal(Float.NaN), right)) - case b @ BinaryExpression(left @ FloatType(), StringNaN) => - b.makeCopy(Array(left, Literal(Float.NaN))) - - /* Use float NaN by default to avoid unnecessary type widening */ - case b @ BinaryExpression(left @ StringNaN, StringNaN) => - b.makeCopy(Array(left, Literal(Float.NaN))) + case b: BinaryExpression if b.left == stringNaN && b.right.dataType == FloatType => + b.makeCopy(Array(b.right, Literal(Float.NaN))) + case b: BinaryExpression if b.left.dataType == FloatType && b.right == stringNaN => + b.makeCopy(Array(Literal(Float.NaN), b.left)) + case b: BinaryExpression if b.left == stringNaN && b.right == stringNaN => + b.makeCopy(Array(Literal(Float.NaN), b.left)) } } } @@ -184,25 +184,21 @@ trait HiveTypeCoercion { case u @ Union(left, right) if u.childrenResolved && !u.resolved => val castedInput = left.output.zip(right.output).map { // When a string is found on one side, make the other side a string too. - case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType => - (lhs, Alias(Cast(rhs, StringType), rhs.name)()) - case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType => - (Alias(Cast(lhs, StringType), lhs.name)(), rhs) - - case (lhs, rhs) if lhs.dataType != rhs.dataType => - logDebug(s"Resolving mismatched union input ${lhs.dataType}, ${rhs.dataType}") - findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType => + case (l, r) if l.dataType == StringType && r.dataType != StringType => + (l, Alias(Cast(r, StringType), r.name)()) + case (l, r) if l.dataType != StringType && r.dataType == StringType => + (Alias(Cast(l, StringType), l.name)(), r) + + case (l, r) if l.dataType != r.dataType => + logDebug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") + findTightestCommonTypeOfTwo(l.dataType, r.dataType).map { widestType => val newLeft = - if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, widestType), lhs.name)() + if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)() val newRight = - if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, widestType), rhs.name)() + if (r.dataType == widestType) r else Alias(Cast(r, widestType), r.name)() (newLeft, newRight) - }.getOrElse { - // If there is no applicable conversion, leave expression unchanged. - (lhs, rhs) - } - + }.getOrElse((l, r)) // If there is no applicable conversion, leave expression unchanged. case other => other } @@ -231,10 +227,12 @@ trait HiveTypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case b @ BinaryExpression(left, right) if left.dataType != right.dataType => - findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType => - val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) - val newRight = if (right.dataType == widestType) right else Cast(right, widestType) + case b: BinaryExpression if b.left.dataType != b.right.dataType => + findTightestCommonTypeOfTwo(b.left.dataType, b.right.dataType).map { widestType => + val newLeft = + if (b.left.dataType == widestType) b.left else Cast(b.left, widestType) + val newRight = + if (b.right.dataType == widestType) b.right else Cast(b.right, widestType) b.makeCopy(Array(newLeft, newRight)) }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. } @@ -249,42 +247,57 @@ trait HiveTypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a @ BinaryArithmetic(left @ StringType(), r) => - a.makeCopy(Array(Cast(left, DoubleType), r)) - case a @ BinaryArithmetic(left, right @ StringType()) => - a.makeCopy(Array(left, Cast(right, DoubleType))) + case a: BinaryArithmetic if a.left.dataType == StringType => + a.makeCopy(Array(Cast(a.left, DoubleType), a.right)) + case a: BinaryArithmetic if a.right.dataType == StringType => + a.makeCopy(Array(a.left, Cast(a.right, DoubleType))) // we should cast all timestamp/date/string compare into string compare - case p @ BinaryComparison(left @ StringType(), right @ DateType()) => - p.makeCopy(Array(left, Cast(right, StringType))) - case p @ BinaryComparison(left @ DateType(), right @ StringType()) => - p.makeCopy(Array(Cast(left, StringType), right)) - case p @ BinaryComparison(left @ StringType(), right @ TimestampType()) => - p.makeCopy(Array(Cast(left, TimestampType), right)) - case p @ BinaryComparison(left @ TimestampType(), right @ StringType()) => - p.makeCopy(Array(left, Cast(right, TimestampType))) - case p @ BinaryComparison(left @ TimestampType(), right @ DateType()) => - p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) - case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) => - p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) - - case p @ BinaryComparison(left @ StringType(), right) if right.dataType != StringType => - p.makeCopy(Array(Cast(left, DoubleType), right)) - case p @ BinaryComparison(left, right @ StringType()) if left.dataType != StringType => - p.makeCopy(Array(left, Cast(right, DoubleType))) - - case i @ In(a @ DateType(), b) if b.forall(_.dataType == StringType) => + case p: BinaryComparison if p.left.dataType == StringType && + p.right.dataType == DateType => + p.makeCopy(Array(p.left, Cast(p.right, StringType))) + case p: BinaryComparison if p.left.dataType == DateType && + p.right.dataType == StringType => + p.makeCopy(Array(Cast(p.left, StringType), p.right)) + case p: BinaryComparison if p.left.dataType == StringType && + p.right.dataType == TimestampType => + p.makeCopy(Array(Cast(p.left, TimestampType), p.right)) + case p: BinaryComparison if p.left.dataType == TimestampType && + p.right.dataType == StringType => + p.makeCopy(Array(p.left, Cast(p.right, TimestampType))) + case p: BinaryComparison if p.left.dataType == TimestampType && + p.right.dataType == DateType => + p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType))) + case p: BinaryComparison if p.left.dataType == DateType && + p.right.dataType == TimestampType => + p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType))) + + case p: BinaryComparison if p.left.dataType == StringType && + p.right.dataType != StringType => + p.makeCopy(Array(Cast(p.left, DoubleType), p.right)) + case p: BinaryComparison if p.left.dataType != StringType && + p.right.dataType == StringType => + p.makeCopy(Array(p.left, Cast(p.right, DoubleType))) + + case i @ In(a, b) if a.dataType == DateType && + b.forall(_.dataType == StringType) => i.makeCopy(Array(Cast(a, StringType), b)) - case i @ In(a @ TimestampType(), b) if b.forall(_.dataType == StringType) => + case i @ In(a, b) if a.dataType == TimestampType && + b.forall(_.dataType == StringType) => i.makeCopy(Array(a, b.map(Cast(_, TimestampType)))) - case i @ In(a @ DateType(), b) if b.forall(_.dataType == TimestampType) => + case i @ In(a, b) if a.dataType == DateType && + b.forall(_.dataType == TimestampType) => i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) - case i @ In(a @ TimestampType(), b) if b.forall(_.dataType == DateType) => + case i @ In(a, b) if a.dataType == TimestampType && + b.forall(_.dataType == DateType) => i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) - case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) - case Average(e @ StringType()) => Average(Cast(e, DoubleType)) - case Sqrt(e @ StringType()) => Sqrt(Cast(e, DoubleType)) + case Sum(e) if e.dataType == StringType => + Sum(Cast(e, DoubleType)) + case Average(e) if e.dataType == StringType => + Average(Cast(e, DoubleType)) + case Sqrt(e) if e.dataType == StringType => + Sqrt(Cast(e, DoubleType)) } } @@ -366,22 +379,22 @@ trait HiveTypeCoercion { // fix decimal precision for union case u @ Union(left, right) if u.childrenResolved && !u.resolved => val castedInput = left.output.zip(right.output).map { - case (lhs, rhs) if lhs.dataType != rhs.dataType => - (lhs.dataType, rhs.dataType) match { + case (l, r) if l.dataType != r.dataType => + (l.dataType, r.dataType) match { case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) => // Union decimals with precision/scale p1/s2 and p2/s2 will be promoted to // DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)) val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2)) - (Alias(Cast(lhs, fixedType), lhs.name)(), Alias(Cast(rhs, fixedType), rhs.name)()) + (Alias(Cast(l, fixedType), l.name)(), Alias(Cast(r, fixedType), r.name)()) case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => - (Alias(Cast(lhs, intTypeToFixed(t)), lhs.name)(), rhs) + (Alias(Cast(l, intTypeToFixed(t)), l.name)(), r) case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => - (lhs, Alias(Cast(rhs, intTypeToFixed(t)), rhs.name)()) + (l, Alias(Cast(r, intTypeToFixed(t)), r.name)()) case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) => - (Alias(Cast(lhs, floatTypeToFixed(t)), lhs.name)(), rhs) + (Alias(Cast(l, floatTypeToFixed(t)), l.name)(), r) case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) => - (lhs, Alias(Cast(rhs, floatTypeToFixed(t)), rhs.name)()) - case _ => (lhs, rhs) + (l, Alias(Cast(r, floatTypeToFixed(t)), r.name)()) + case _ => (l, r) } case other => other } @@ -454,16 +467,16 @@ trait HiveTypeCoercion { // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles - case b @ BinaryExpression(left, right) if left.dataType != right.dataType => - (left.dataType, right.dataType) match { + case b: BinaryExpression if b.left.dataType != b.right.dataType => + (b.left.dataType, b.right.dataType) match { case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => - b.makeCopy(Array(Cast(left, intTypeToFixed(t)), right)) + b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right)) case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => - b.makeCopy(Array(left, Cast(right, intTypeToFixed(t)))) + b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t)))) case (t, DecimalType.Fixed(p, s)) if isFloat(t) => - b.makeCopy(Array(left, Cast(right, DoubleType))) + b.makeCopy(Array(b.left, Cast(b.right, DoubleType))) case (DecimalType.Fixed(p, s), t) if isFloat(t) => - b.makeCopy(Array(Cast(left, DoubleType), right)) + b.makeCopy(Array(Cast(b.left, DoubleType), b.right)) case _ => b } @@ -512,31 +525,31 @@ trait HiveTypeCoercion { // all other cases are considered as false. // We may simplify the expression if one side is literal numeric values - case EqualTo(left @ BooleanType(), Literal(value, _: NumericType)) - if trueValues.contains(value) => left - case EqualTo(left @ BooleanType(), Literal(value, _: NumericType)) - if falseValues.contains(value) => Not(left) - case EqualTo(Literal(value, _: NumericType), right @ BooleanType()) - if trueValues.contains(value) => right - case EqualTo(Literal(value, _: NumericType), right @ BooleanType()) - if falseValues.contains(value) => Not(right) - case EqualNullSafe(left @ BooleanType(), Literal(value, _: NumericType)) - if trueValues.contains(value) => And(IsNotNull(left), left) - case EqualNullSafe(left @ BooleanType(), Literal(value, _: NumericType)) - if falseValues.contains(value) => And(IsNotNull(left), Not(left)) - case EqualNullSafe(Literal(value, _: NumericType), right @ BooleanType()) - if trueValues.contains(value) => And(IsNotNull(right), right) - case EqualNullSafe(Literal(value, _: NumericType), right @ BooleanType()) - if falseValues.contains(value) => And(IsNotNull(right), Not(right)) - - case EqualTo(left @ BooleanType(), right @ NumericType()) => - transform(left , right) - case EqualTo(left @ NumericType(), right @ BooleanType()) => - transform(right, left) - case EqualNullSafe(left @ BooleanType(), right @ NumericType()) => - transformNullSafe(left, right) - case EqualNullSafe(left @ NumericType(), right @ BooleanType()) => - transformNullSafe(right, left) + case EqualTo(l @ BooleanType(), Literal(value, _: NumericType)) + if trueValues.contains(value) => l + case EqualTo(l @ BooleanType(), Literal(value, _: NumericType)) + if falseValues.contains(value) => Not(l) + case EqualTo(Literal(value, _: NumericType), r @ BooleanType()) + if trueValues.contains(value) => r + case EqualTo(Literal(value, _: NumericType), r @ BooleanType()) + if falseValues.contains(value) => Not(r) + case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType)) + if trueValues.contains(value) => And(IsNotNull(l), l) + case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType)) + if falseValues.contains(value) => And(IsNotNull(l), Not(l)) + case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType()) + if trueValues.contains(value) => And(IsNotNull(r), r) + case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType()) + if falseValues.contains(value) => And(IsNotNull(r), Not(r)) + + case EqualTo(l @ BooleanType(), r @ NumericType()) => + transform(l , r) + case EqualTo(l @ NumericType(), r @ BooleanType()) => + transform(r, l) + case EqualNullSafe(l @ BooleanType(), r @ NumericType()) => + transformNullSafe(l, r) + case EqualNullSafe(l @ NumericType(), r @ BooleanType()) => + transformNullSafe(r, l) } } @@ -617,7 +630,7 @@ trait HiveTypeCoercion { case d: Divide if d.dataType == DoubleType => d case d: Divide if d.dataType.isInstanceOf[DecimalType] => d - case Divide(left, right) => Divide(Cast(left, DoubleType), Cast(right, DoubleType)) + case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index b2b9d1a5e1581..3cf851aec15ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -118,10 +118,6 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express override def toString: String = s"($left $symbol $right)" } -private[sql] object BinaryExpression { - def unapply(e: BinaryExpression): Option[(Expression, Expression)] = Some((e.left, e.right)) -} - abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { self: Product => } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index a3770f998d94d..2ac53f8f6613f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -118,10 +118,6 @@ abstract class BinaryArithmetic extends BinaryExpression { sys.error(s"BinaryArithmetics must override either eval or evalInternal") } -private[sql] object BinaryArithmetic { - def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right)) -} - case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index cd604121b7dd9..36964af68dd8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -17,11 +17,9 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import scala.collection.mutable -import scala.language.existentials +import com.google.common.cache.{CacheLoader, CacheBuilder} -import com.google.common.cache.{CacheBuilder, CacheLoader} -import org.codehaus.janino.ClassBodyEvaluator +import scala.language.existentials import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions @@ -38,15 +36,23 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] * expressions. */ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + import scala.tools.reflect.ToolBox + + protected val toolBox = runtimeMirror(getClass.getClassLoader).mkToolBox() - protected val rowType = classOf[Row].getName - protected val stringType = classOf[UTF8String].getName - protected val decimalType = classOf[Decimal].getName - protected val exprType = classOf[Expression].getName - protected val mutableRowType = classOf[MutableRow].getName - protected val genericMutableRowType = classOf[GenericMutableRow].getName + protected val rowType = typeOf[Row] + protected val mutableRowType = typeOf[MutableRow] + protected val genericRowType = typeOf[GenericRow] + protected val genericMutableRowType = typeOf[GenericMutableRow] + + protected val projectionType = typeOf[Projection] + protected val mutableProjectionType = typeOf[MutableProjection] private val curId = new java.util.concurrent.atomic.AtomicInteger() + private val javaSeparator = "$" /** * Can be flipped on manually in the console to add (expensive) expression evaluation trace code. @@ -68,20 +74,6 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin /** Binds an input expression to a given input schema */ protected def bind(in: InType, inputSchema: Seq[Attribute]): InType - /** - * Compile the Java source code into a Java class, using Janino. - * - * It will track the time used to compile - */ - protected def compile(code: String): Class[_] = { - val startTime = System.nanoTime() - val clazz = new ClassBodyEvaluator(code).getClazz() - val endTime = System.nanoTime() - def timeMs: Double = (endTime - startTime).toDouble / 1000000 - logDebug(s"Compiled Java code (${code.size} bytes) in $timeMs ms") - clazz - } - /** * A cache of generated classes. * @@ -95,7 +87,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin .maximumSize(1000) .build( new CacheLoader[InType, OutType]() { - override def load(in: InType): OutType = { + override def load(in: InType): OutType = globalLock.synchronized { val startTime = System.nanoTime() val result = create(in) val endTime = System.nanoTime() @@ -118,8 +110,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` * function.) */ - protected def freshName(prefix: String): String = { - s"$prefix${curId.getAndIncrement}" + protected def freshName(prefix: String): TermName = { + newTermName(s"$prefix$javaSeparator${curId.getAndIncrement}") } /** @@ -133,51 +125,32 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * @param objectTerm A possibly boxed version of the result of evaluating this expression. */ protected case class EvaluatedExpression( - code: String, - nullTerm: String, - primitiveTerm: String, - objectTerm: String) - - /** - * A context for codegen, which is used to bookkeeping the expressions those are not supported - * by codegen, then they are evaluated directly. The unsupported expression is appended at the - * end of `references`, the position of it is kept in the code, used to access and evaluate it. - */ - protected class CodeGenContext { - /** - * Holding all the expressions those do not support codegen, will be evaluated directly. - */ - val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]() - } - - /** - * Create a new codegen context for expression evaluator, used to store those - * expressions that don't support codegen - */ - def newCodeGenContext(): CodeGenContext = { - new CodeGenContext() - } + code: Seq[Tree], + nullTerm: TermName, + primitiveTerm: TermName, + objectTerm: TermName) /** * Given an expression tree returns an [[EvaluatedExpression]], which contains Scala trees that * can be used to determine the result of evaluating the expression on an input row. */ - def expressionEvaluator(e: Expression, ctx: CodeGenContext): EvaluatedExpression = { + def expressionEvaluator(e: Expression): EvaluatedExpression = { val primitiveTerm = freshName("primitiveTerm") val nullTerm = freshName("nullTerm") val objectTerm = freshName("objectTerm") implicit class Evaluate1(e: Expression) { - def castOrNull(f: String => String, dataType: DataType): String = { - val eval = expressionEvaluator(e, ctx) - eval.code + - s""" - boolean $nullTerm = ${eval.nullTerm}; - ${primitiveForType(dataType)} $primitiveTerm = ${defaultPrimitive(dataType)}; - if (!$nullTerm) { - $primitiveTerm = ${f(eval.primitiveTerm)}; - } - """ + def castOrNull(f: TermName => Tree, dataType: DataType): Seq[Tree] = { + val eval = expressionEvaluator(e) + eval.code ++ + q""" + val $nullTerm = ${eval.nullTerm} + val $primitiveTerm = + if($nullTerm) + ${defaultPrimitive(dataType)} + else + ${f(eval.primitiveTerm)} + """.children } } @@ -190,505 +163,529 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * * @param f a function from two primitive term names to a tree that evaluates them. */ - def evaluate(f: (String, String) => String): String = + def evaluate(f: (TermName, TermName) => Tree): Seq[Tree] = evaluateAs(expressions._1.dataType)(f) - def evaluateAs(resultType: DataType)(f: (String, String) => String): String = { + def evaluateAs(resultType: DataType)(f: (TermName, TermName) => Tree): Seq[Tree] = { // TODO: Right now some timestamp tests fail if we enforce this... if (expressions._1.dataType != expressions._2.dataType) { log.warn(s"${expressions._1.dataType} != ${expressions._2.dataType}") } - val eval1 = expressionEvaluator(expressions._1, ctx) - val eval2 = expressionEvaluator(expressions._2, ctx) + val eval1 = expressionEvaluator(expressions._1) + val eval2 = expressionEvaluator(expressions._2) val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm) - eval1.code + eval2.code + - s""" - boolean $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm}; - ${primitiveForType(resultType)} $primitiveTerm = ${defaultPrimitive(resultType)}; - if(!$nullTerm) { - $primitiveTerm = (${primitiveForType(resultType)})($resultCode); - } - """ + eval1.code ++ eval2.code ++ + q""" + val $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm} + val $primitiveTerm: ${termForType(resultType)} = + if($nullTerm) { + ${defaultPrimitive(resultType)} + } else { + $resultCode.asInstanceOf[${termForType(resultType)}] + } + """.children : Seq[Tree] } } - val inputTuple = "i" + val inputTuple = newTermName(s"i") // TODO: Skip generation of null handling code when expression are not nullable. - val primitiveEvaluation: PartialFunction[Expression, String] = { + val primitiveEvaluation: PartialFunction[Expression, Seq[Tree]] = { case b @ BoundReference(ordinal, dataType, nullable) => - s""" - final boolean $nullTerm = $inputTuple.isNullAt($ordinal); - final ${primitiveForType(dataType)} $primitiveTerm = $nullTerm ? - ${defaultPrimitive(dataType)} : (${getColumn(inputTuple, dataType, ordinal)}); - """ + val nullValue = q"$inputTuple.isNullAt($ordinal)" + q""" + val $nullTerm: Boolean = $nullValue + val $primitiveTerm: ${termForType(dataType)} = + if($nullTerm) + ${defaultPrimitive(dataType)} + else + ${getColumn(inputTuple, dataType, ordinal)} + """.children case expressions.Literal(null, dataType) => - s""" - final boolean $nullTerm = true; - ${primitiveForType(dataType)} $primitiveTerm = ${defaultPrimitive(dataType)}; - """ - - case expressions.Literal(value: UTF8String, StringType) => - val arr = s"new byte[]{${value.getBytes.map(_.toString).mkString(", ")}}" - s""" - final boolean $nullTerm = false; - ${stringType} $primitiveTerm = - new ${stringType}().set(${arr}); - """ - - case expressions.Literal(value, FloatType) => - s""" - final boolean $nullTerm = false; - float $primitiveTerm = ${value}f; - """ - - case expressions.Literal(value, dt @ DecimalType()) => - s""" - final boolean $nullTerm = false; - ${primitiveForType(dt)} $primitiveTerm = new ${primitiveForType(dt)}().set($value); - """ - - case expressions.Literal(value, dataType) => - s""" - final boolean $nullTerm = false; - ${primitiveForType(dataType)} $primitiveTerm = $value; - """ - - case Cast(child @ BinaryType(), StringType) => - child.castOrNull(c => - s"new ${stringType}().set($c)", - StringType) + q""" + val $nullTerm = true + val $primitiveTerm: ${termForType(dataType)} = null.asInstanceOf[${termForType(dataType)}] + """.children + + case expressions.Literal(value: Boolean, dataType) => + q""" + val $nullTerm = ${value == null} + val $primitiveTerm: ${termForType(dataType)} = $value + """.children + + case expressions.Literal(value: UTF8String, dataType) => + q""" + val $nullTerm = ${value == null} + val $primitiveTerm: ${termForType(dataType)} = + org.apache.spark.sql.types.UTF8String(${value.getBytes}) + """.children + + case expressions.Literal(value: Int, dataType) => + q""" + val $nullTerm = ${value == null} + val $primitiveTerm: ${termForType(dataType)} = $value + """.children + + case expressions.Literal(value: Long, dataType) => + q""" + val $nullTerm = ${value == null} + val $primitiveTerm: ${termForType(dataType)} = $value + """.children + + case Cast(e @ BinaryType(), StringType) => + val eval = expressionEvaluator(e) + eval.code ++ + q""" + val $nullTerm = ${eval.nullTerm} + val $primitiveTerm = + if($nullTerm) + ${defaultPrimitive(StringType)} + else + org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) + """.children case Cast(child @ DateType(), StringType) => child.castOrNull(c => - s"""new ${stringType}().set( + q"""org.apache.spark.sql.types.UTF8String( org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""", StringType) - case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - child.castOrNull(c => s"(${primitiveForType(dt)})($c?1:0)", dt) + case Cast(child @ NumericType(), IntegerType) => + child.castOrNull(c => q"$c.toInt", IntegerType) - case Cast(child @ DecimalType(), IntegerType) => - child.castOrNull(c => s"($c).toInt()", IntegerType) + case Cast(child @ NumericType(), LongType) => + child.castOrNull(c => q"$c.toLong", LongType) - case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - child.castOrNull(c => s"($c).to${termForType(dt)}()", dt) + case Cast(child @ NumericType(), DoubleType) => + child.castOrNull(c => q"$c.toDouble", DoubleType) - case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - child.castOrNull(c => s"(${primitiveForType(dt)})($c)", dt) + case Cast(child @ NumericType(), FloatType) => + child.castOrNull(c => q"$c.toFloat", FloatType) // Special handling required for timestamps in hive test cases since the toString function // does not match the expected output. case Cast(e, StringType) if e.dataType != TimestampType => - e.castOrNull(c => - s"new ${stringType}().set(String.valueOf($c))", - StringType) + val eval = expressionEvaluator(e) + eval.code ++ + q""" + val $nullTerm = ${eval.nullTerm} + val $primitiveTerm = + if($nullTerm) + ${defaultPrimitive(StringType)} + else + org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString) + """.children case EqualTo(e1 @ BinaryType(), e2 @ BinaryType()) => (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => - s"java.util.Arrays.equals((byte[])$eval1, (byte[])$eval2)" + q""" + java.util.Arrays.equals($eval1.asInstanceOf[Array[Byte]], + $eval2.asInstanceOf[Array[Byte]]) + """ } case EqualTo(e1, e2) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 == $eval2" } + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" } + + /* TODO: Fix null semantics. + case In(e1, list) if !list.exists(!_.isInstanceOf[expressions.Literal]) => + val eval = expressionEvaluator(e1) + + val checks = list.map { + case expressions.Literal(v: String, dataType) => + q"if(${eval.primitiveTerm} == $v) return true" + case expressions.Literal(v: Int, dataType) => + q"if(${eval.primitiveTerm} == $v) return true" + } + + val funcName = newTermName(s"isIn${curId.getAndIncrement()}") + + q""" + def $funcName: Boolean = { + ..${eval.code} + if(${eval.nullTerm}) return false + ..$checks + return false + } + val $nullTerm = false + val $primitiveTerm = $funcName + """.children + */ case GreaterThan(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 > $eval2" } + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 > $eval2" } case GreaterThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 >= $eval2" } + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 >= $eval2" } case LessThan(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 < $eval2" } + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 < $eval2" } case LessThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 <= $eval2" } + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 <= $eval2" } case And(e1, e2) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - s""" - ${eval1.code} - boolean $nullTerm = false; - boolean $primitiveTerm = false; - - if (!${eval1.nullTerm} && !${eval1.primitiveTerm}) { + val eval1 = expressionEvaluator(e1) + val eval2 = expressionEvaluator(e2) + + q""" + ..${eval1.code} + var $nullTerm = false + var $primitiveTerm: ${termForType(BooleanType)} = false + + if (!${eval1.nullTerm} && ${eval1.primitiveTerm} == false) { } else { - ${eval2.code} - if (!${eval2.nullTerm} && !${eval2.primitiveTerm}) { + ..${eval2.code} + if (!${eval2.nullTerm} && ${eval2.primitiveTerm} == false) { } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { - $primitiveTerm = true; + $primitiveTerm = true } else { - $nullTerm = true; + $nullTerm = true } } - """ + """.children case Or(e1, e2) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) + val eval1 = expressionEvaluator(e1) + val eval2 = expressionEvaluator(e2) - s""" - ${eval1.code} - boolean $nullTerm = false; - boolean $primitiveTerm = false; + q""" + ..${eval1.code} + var $nullTerm = false + var $primitiveTerm: ${termForType(BooleanType)} = false if (!${eval1.nullTerm} && ${eval1.primitiveTerm}) { - $primitiveTerm = true; + $primitiveTerm = true } else { - ${eval2.code} + ..${eval2.code} if (!${eval2.nullTerm} && ${eval2.primitiveTerm}) { - $primitiveTerm = true; + $primitiveTerm = true } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { - $primitiveTerm = false; + $primitiveTerm = false } else { - $nullTerm = true; + $nullTerm = true } } - """ + """.children case Not(child) => // Uh, bad function name... - child.castOrNull(c => s"!$c", BooleanType) - - case Add(e1 @ DecimalType(), e2 @ DecimalType()) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$plus($eval2)" } - case Subtract(e1 @ DecimalType(), e2 @ DecimalType()) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$minus($eval2)" } - case Multiply(e1 @ DecimalType(), e2 @ DecimalType()) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$times($eval2)" } - case Divide(e1 @ DecimalType(), e2 @ DecimalType()) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = null; - if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm}.isZero()) { - $nullTerm = true; - } else { - $primitiveTerm = ${eval1.primitiveTerm}.$$div${eval2.primitiveTerm}); - } - """ - case Remainder(e1 @ DecimalType(), e2 @ DecimalType()) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = 0; - if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm}.isZero()) { - $nullTerm = true; - } else { - $primitiveTerm = ${eval1.primitiveTerm}.remainder(${eval2.primitiveTerm}); - } - """ - - case Add(e1, e2) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 + $eval2" } - case Subtract(e1, e2) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 - $eval2" } - case Multiply(e1, e2) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 * $eval2" } + child.castOrNull(c => q"!$c", BooleanType) + + case Add(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 + $eval2" } + case Subtract(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 - $eval2" } + case Multiply(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 * $eval2" } case Divide(e1, e2) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = 0; - if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm} == 0) { - $nullTerm = true; - } else { - $primitiveTerm = ${eval1.primitiveTerm} / ${eval2.primitiveTerm}; + val eval1 = expressionEvaluator(e1) + val eval2 = expressionEvaluator(e2) + + eval1.code ++ eval2.code ++ + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(e1.dataType)} = 0 + + if (${eval1.nullTerm} || ${eval2.nullTerm} ) { + $nullTerm = true + } else if (${eval2.primitiveTerm} == 0) + $nullTerm = true + else { + $primitiveTerm = ${eval1.primitiveTerm} / ${eval2.primitiveTerm} } - """ + """.children + case Remainder(e1, e2) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = 0; - if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm} == 0) { - $nullTerm = true; - } else { - $primitiveTerm = ${eval1.primitiveTerm} % ${eval2.primitiveTerm}; + val eval1 = expressionEvaluator(e1) + val eval2 = expressionEvaluator(e2) + + eval1.code ++ eval2.code ++ + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(e1.dataType)} = 0 + + if (${eval1.nullTerm} || ${eval2.nullTerm} ) { + $nullTerm = true + } else if (${eval2.primitiveTerm} == 0) + $nullTerm = true + else { + $nullTerm = false + $primitiveTerm = ${eval1.primitiveTerm} % ${eval2.primitiveTerm} } - """ + """.children case IsNotNull(e) => - val eval = expressionEvaluator(e, ctx) - s""" - ${eval.code} - boolean $nullTerm = false; - boolean $primitiveTerm = !${eval.nullTerm}; - """ + val eval = expressionEvaluator(e) + q""" + ..${eval.code} + var $nullTerm = false + var $primitiveTerm: ${termForType(BooleanType)} = !${eval.nullTerm} + """.children case IsNull(e) => - val eval = expressionEvaluator(e, ctx) - s""" - ${eval.code} - boolean $nullTerm = false; - boolean $primitiveTerm = ${eval.nullTerm}; - """ - - case e @ Coalesce(children) => - s""" - boolean $nullTerm = true; - ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)}; - """ + + val eval = expressionEvaluator(e) + q""" + ..${eval.code} + var $nullTerm = false + var $primitiveTerm: ${termForType(BooleanType)} = ${eval.nullTerm} + """.children + + case c @ Coalesce(children) => + q""" + var $nullTerm = true + var $primitiveTerm: ${termForType(c.dataType)} = ${defaultPrimitive(c.dataType)} + """.children ++ children.map { c => - val eval = expressionEvaluator(c, ctx) - s""" + val eval = expressionEvaluator(c) + q""" if($nullTerm) { - ${eval.code} + ..${eval.code} if(!${eval.nullTerm}) { - $nullTerm = false; - $primitiveTerm = ${eval.primitiveTerm}; + $nullTerm = false + $primitiveTerm = ${eval.primitiveTerm} } } """ - }.mkString("\n") + } - case e @ expressions.If(condition, trueValue, falseValue) => - val condEval = expressionEvaluator(condition, ctx) - val trueEval = expressionEvaluator(trueValue, ctx) - val falseEval = expressionEvaluator(falseValue, ctx) + case i @ expressions.If(condition, trueValue, falseValue) => + val condEval = expressionEvaluator(condition) + val trueEval = expressionEvaluator(trueValue) + val falseEval = expressionEvaluator(falseValue) - s""" - boolean $nullTerm = false; - ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)}; - ${condEval.code} + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(i.dataType)} = ${defaultPrimitive(i.dataType)} + ..${condEval.code} if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) { - ${trueEval.code} - $nullTerm = ${trueEval.nullTerm}; - $primitiveTerm = ${trueEval.primitiveTerm}; + ..${trueEval.code} + $nullTerm = ${trueEval.nullTerm} + $primitiveTerm = ${trueEval.primitiveTerm} } else { - ${falseEval.code} - $nullTerm = ${falseEval.nullTerm}; - $primitiveTerm = ${falseEval.primitiveTerm}; + ..${falseEval.code} + $nullTerm = ${falseEval.nullTerm} + $primitiveTerm = ${falseEval.primitiveTerm} } - """ + """.children case NewSet(elementType) => - s""" - boolean $nullTerm = false; - ${hashSetForType(elementType)} $primitiveTerm = new ${hashSetForType(elementType)}(); - """ + q""" + val $nullTerm = false + val $primitiveTerm = new ${hashSetForType(elementType)}() + """.children case AddItemToSet(item, set) => - val itemEval = expressionEvaluator(item, ctx) - val setEval = expressionEvaluator(set, ctx) + val itemEval = expressionEvaluator(item) + val setEval = expressionEvaluator(set) val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType - val htype = hashSetForType(elementType) - itemEval.code + setEval.code + - s""" - if (!${itemEval.nullTerm} && !${setEval.nullTerm}) { - (($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm}); + itemEval.code ++ setEval.code ++ + q""" + if (!${itemEval.nullTerm}) { + ${setEval.primitiveTerm} + .asInstanceOf[${hashSetForType(elementType)}] + .add(${itemEval.primitiveTerm}) } - boolean $nullTerm = false; - ${htype} $primitiveTerm = ($htype)${setEval.primitiveTerm}; - """ + + val $nullTerm = false + val $primitiveTerm = ${setEval.primitiveTerm} + """.children case CombineSets(left, right) => - val leftEval = expressionEvaluator(left, ctx) - val rightEval = expressionEvaluator(right, ctx) + val leftEval = expressionEvaluator(left) + val rightEval = expressionEvaluator(right) val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType - val htype = hashSetForType(elementType) - leftEval.code + rightEval.code + - s""" - boolean $nullTerm = false; - ${htype} $primitiveTerm = - (${htype})${leftEval.primitiveTerm}; - $primitiveTerm.union((${htype})${rightEval.primitiveTerm}); - """ + leftEval.code ++ rightEval.code ++ + q""" + val $nullTerm = false + var $primitiveTerm: ${hashSetForType(elementType)} = null + + { + val leftSet = ${leftEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}] + val rightSet = ${rightEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}] + val iterator = rightSet.iterator + while (iterator.hasNext) { + leftSet.add(iterator.next()) + } + $primitiveTerm = leftSet + } + """.children - case MaxOf(e1, e2) if !e1.dataType.isInstanceOf[DecimalType] => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) + case MaxOf(e1, e2) => + val eval1 = expressionEvaluator(e1) + val eval2 = expressionEvaluator(e2) - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = ${defaultPrimitive(e1.dataType)}; + eval1.code ++ eval2.code ++ + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)} if (${eval1.nullTerm}) { - $nullTerm = ${eval2.nullTerm}; - $primitiveTerm = ${eval2.primitiveTerm}; + $nullTerm = ${eval2.nullTerm} + $primitiveTerm = ${eval2.primitiveTerm} } else if (${eval2.nullTerm}) { - $nullTerm = ${eval1.nullTerm}; - $primitiveTerm = ${eval1.primitiveTerm}; + $nullTerm = ${eval1.nullTerm} + $primitiveTerm = ${eval1.primitiveTerm} } else { if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) { - $primitiveTerm = ${eval1.primitiveTerm}; + $primitiveTerm = ${eval1.primitiveTerm} } else { - $primitiveTerm = ${eval2.primitiveTerm}; + $primitiveTerm = ${eval2.primitiveTerm} } } - """ + """.children - case MinOf(e1, e2) if !e1.dataType.isInstanceOf[DecimalType] => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) + case MinOf(e1, e2) => + val eval1 = expressionEvaluator(e1) + val eval2 = expressionEvaluator(e2) - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = ${defaultPrimitive(e1.dataType)}; + eval1.code ++ eval2.code ++ + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)} if (${eval1.nullTerm}) { - $nullTerm = ${eval2.nullTerm}; - $primitiveTerm = ${eval2.primitiveTerm}; + $nullTerm = ${eval2.nullTerm} + $primitiveTerm = ${eval2.primitiveTerm} } else if (${eval2.nullTerm}) { - $nullTerm = ${eval1.nullTerm}; - $primitiveTerm = ${eval1.primitiveTerm}; + $nullTerm = ${eval1.nullTerm} + $primitiveTerm = ${eval1.primitiveTerm} } else { if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) { - $primitiveTerm = ${eval1.primitiveTerm}; + $primitiveTerm = ${eval1.primitiveTerm} } else { - $primitiveTerm = ${eval2.primitiveTerm}; + $primitiveTerm = ${eval2.primitiveTerm} } } - """ + """.children case UnscaledValue(child) => - val childEval = expressionEvaluator(child, ctx) - - childEval.code + - s""" - boolean $nullTerm = ${childEval.nullTerm}; - long $primitiveTerm = $nullTerm ? -1 : ${childEval.primitiveTerm}.toUnscaledLong(); - """ + val childEval = expressionEvaluator(child) + + childEval.code ++ + q""" + var $nullTerm = ${childEval.nullTerm} + var $primitiveTerm: Long = if (!$nullTerm) { + ${childEval.primitiveTerm}.toUnscaledLong + } else { + ${defaultPrimitive(LongType)} + } + """.children case MakeDecimal(child, precision, scale) => - val eval = expressionEvaluator(child, ctx) + val childEval = expressionEvaluator(child) - eval.code + - s""" - boolean $nullTerm = ${eval.nullTerm}; - org.apache.spark.sql.types.Decimal $primitiveTerm = ${defaultPrimitive(DecimalType())}; + childEval.code ++ + q""" + var $nullTerm = ${childEval.nullTerm} + var $primitiveTerm: org.apache.spark.sql.types.Decimal = + ${defaultPrimitive(DecimalType())} if (!$nullTerm) { - $primitiveTerm = new org.apache.spark.sql.types.Decimal(); - $primitiveTerm = $primitiveTerm.setOrNull(${eval.primitiveTerm}, $precision, $scale); - $nullTerm = $primitiveTerm == null; + $primitiveTerm = new org.apache.spark.sql.types.Decimal() + $primitiveTerm = $primitiveTerm.setOrNull(${childEval.primitiveTerm}, $precision, $scale) + $nullTerm = $primitiveTerm == null } - """ + """.children } // If there was no match in the partial function above, we fall back on calling the interpreted // expression evaluator. - val code: String = + val code: Seq[Tree] = primitiveEvaluation.lift.apply(e).getOrElse { - logError(s"No rules to generate $e") - ctx.references += e - s""" - /* expression: ${e} */ - Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i); - boolean $nullTerm = $objectTerm == null; - ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)}; - if (!$nullTerm) $primitiveTerm = (${termForType(e.dataType)})$objectTerm; - """ + log.debug(s"No rules to generate $e") + val tree = reify { e } + q""" + val $objectTerm = $tree.eval(i) + val $nullTerm = $objectTerm == null + val $primitiveTerm = $objectTerm.asInstanceOf[${termForType(e.dataType)}] + """.children + } + + // Only inject debugging code if debugging is turned on. + val debugCode = + if (debugLogging) { + val localLogger = log + val localLoggerTree = reify { localLogger } + q""" + $localLoggerTree.debug( + ${e.toString} + ": " + (if ($nullTerm) "null" else $primitiveTerm.toString)) + """ :: Nil + } else { + Nil } - EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm) + EvaluatedExpression(code ++ debugCode, nullTerm, primitiveTerm, objectTerm) } - protected def getColumn(inputRow: String, dataType: DataType, ordinal: Int) = { + protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = { dataType match { - case StringType => s"(${stringType})$inputRow.apply($ordinal)" - case dt: DataType if isNativeType(dt) => s"$inputRow.${accessorForType(dt)}($ordinal)" - case _ => s"(${termForType(dataType)})$inputRow.apply($ordinal)" + case StringType => q"$inputRow($ordinal).asInstanceOf[org.apache.spark.sql.types.UTF8String]" + case dt: DataType if isNativeType(dt) => q"$inputRow.${accessorForType(dt)}($ordinal)" + case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]" } } protected def setColumn( - destinationRow: String, + destinationRow: TermName, dataType: DataType, ordinal: Int, - value: String): String = { + value: TermName) = { dataType match { - case StringType => s"$destinationRow.update($ordinal, $value)" + case StringType => q"$destinationRow.update($ordinal, $value)" case dt: DataType if isNativeType(dt) => - s"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" - case _ => s"$destinationRow.update($ordinal, $value)" + q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" + case _ => q"$destinationRow.update($ordinal, $value)" } } - protected def accessorForType(dt: DataType) = dt match { - case IntegerType => "getInt" - case other => s"get${termForType(dt)}" - } - - protected def mutatorForType(dt: DataType) = dt match { - case IntegerType => "setInt" - case other => s"set${termForType(dt)}" - } + protected def accessorForType(dt: DataType) = newTermName(s"get${primitiveForType(dt)}") + protected def mutatorForType(dt: DataType) = newTermName(s"set${primitiveForType(dt)}") - protected def hashSetForType(dt: DataType): String = dt match { - case IntegerType => classOf[IntegerHashSet].getName - case LongType => classOf[LongHashSet].getName + protected def hashSetForType(dt: DataType) = dt match { + case IntegerType => typeOf[IntegerHashSet] + case LongType => typeOf[LongHashSet] case unsupportedType => sys.error(s"Code generation not support for hashset of type $unsupportedType") } - protected def primitiveForType(dt: DataType): String = dt match { - case IntegerType => "int" - case LongType => "long" - case ShortType => "short" - case ByteType => "byte" - case DoubleType => "double" - case FloatType => "float" - case BooleanType => "boolean" - case dt: DecimalType => decimalType - case BinaryType => "byte[]" - case StringType => stringType - case DateType => "int" - case TimestampType => "java.sql.Timestamp" - case _ => "Object" - } - - protected def defaultPrimitive(dt: DataType): String = dt match { - case BooleanType => "false" - case FloatType => "-1.0f" - case ShortType => "-1" - case LongType => "-1" - case ByteType => "-1" - case DoubleType => "-1.0" - case IntegerType => "-1" - case DateType => "-1" - case dt: DecimalType => "null" - case StringType => "null" - case _ => "null" - } - - protected def termForType(dt: DataType): String = dt match { - case IntegerType => "Integer" + protected def primitiveForType(dt: DataType) = dt match { + case IntegerType => "Int" case LongType => "Long" case ShortType => "Short" case ByteType => "Byte" case DoubleType => "Double" case FloatType => "Float" case BooleanType => "Boolean" - case dt: DecimalType => decimalType - case BinaryType => "byte[]" - case StringType => stringType - case DateType => "Integer" - case TimestampType => "java.sql.Timestamp" - case _ => "Object" + case StringType => "org.apache.spark.sql.types.UTF8String" + } + + protected def defaultPrimitive(dt: DataType) = dt match { + case BooleanType => ru.Literal(Constant(false)) + case FloatType => ru.Literal(Constant(-1.0.toFloat)) + case StringType => q"""org.apache.spark.sql.types.UTF8String("")""" + case ShortType => ru.Literal(Constant(-1.toShort)) + case LongType => ru.Literal(Constant(-1L)) + case ByteType => ru.Literal(Constant(-1.toByte)) + case DoubleType => ru.Literal(Constant(-1.toDouble)) + case DecimalType() => q"org.apache.spark.sql.types.Decimal(-1)" + case IntegerType => ru.Literal(Constant(-1)) + case DateType => ru.Literal(Constant(-1)) + case _ => ru.Literal(Constant(null)) + } + + protected def termForType(dt: DataType) = dt match { + case n: AtomicType => n.tag + case _ => typeTag[Any] } /** * List of data types that have special accessors and setters in [[Row]]. */ protected val nativeTypes = - Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType) + Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) /** * Returns true if the data type has a special accessor and setter in [[Row]]. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 638b53fe0fe2f..840260703ab74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -19,14 +19,15 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ -// MutableProjection is not accessible in Java -abstract class BaseMutableProjection extends MutableProjection {} - /** * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new * input [[Row]] for a fixed set of [[Expression Expressions]]. */ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => MutableProjection] { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + val mutableRowName = newTermName("mutableRow") protected def canonicalize(in: Seq[Expression]): Seq[Expression] = in.map(ExpressionCanonicalizer.execute) @@ -35,61 +36,41 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu in.map(BindReferences.bindReference(_, inputSchema)) protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { - val ctx = newCodeGenContext() - val projectionCode = expressions.zipWithIndex.map { case (e, i) => - val evaluationCode = expressionEvaluator(e, ctx) - evaluationCode.code + - s""" - if(${evaluationCode.nullTerm}) - mutableRow.setNullAt($i); - else - ${setColumn("mutableRow", e.dataType, i, evaluationCode.primitiveTerm)}; - """ - }.mkString("\n") - val code = s""" - import org.apache.spark.sql.Row; - - public SpecificProjection generate($exprType[] expr) { - return new SpecificProjection(expr); - } - - class SpecificProjection extends ${classOf[BaseMutableProjection].getName} { - - private $exprType[] expressions = null; - private $mutableRowType mutableRow = null; - - public SpecificProjection($exprType[] expr) { - expressions = expr; - mutableRow = new $genericMutableRowType(${expressions.size}); - } - - public ${classOf[BaseMutableProjection].getName} target($mutableRowType row) { - mutableRow = row; - return this; - } + val projectionCode = expressions.zipWithIndex.flatMap { case (e, i) => + val evaluationCode = expressionEvaluator(e) + + evaluationCode.code :+ + q""" + if(${evaluationCode.nullTerm}) + mutableRow.setNullAt($i) + else + ${setColumn(mutableRowName, e.dataType, i, evaluationCode.primitiveTerm)} + """ + } - /* Provide immutable access to the last projected row. */ - public Row currentValue() { - return mutableRow; - } + val code = + q""" + () => { new $mutableProjectionType { - public Object apply(Object _i) { - Row i = (Row) _i; - $projectionCode + private[this] var $mutableRowName: $mutableRowType = + new $genericMutableRowType(${expressions.size}) - return mutableRow; - } - } - """ + def target(row: $mutableRowType): $mutableProjectionType = { + $mutableRowName = row + this + } + /* Provide immutable access to the last projected row. */ + def currentValue: $rowType = mutableRow - logDebug(s"code for ${expressions.mkString(",")}:\n$code") + def apply(i: $rowType): $rowType = { + ..$projectionCode + mutableRow + } + } } + """ - val c = compile(code) - // fetch the only one method `generate(Expression[])` - val m = c.getDeclaredMethods()(0) - () => { - m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[BaseMutableProjection] - } + log.debug(s"code for ${expressions.mkString(",")}:\n$code") + toolBox.eval(code).asInstanceOf[() => MutableProjection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 0ff840dab393c..b129c0d898bb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -18,29 +18,18 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.Logging -import org.apache.spark.annotation.Private -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{BinaryType, NumericType} - -/** - * Inherits some default implementation for Java from `Ordering[Row]` - */ -@Private -class BaseOrdering extends Ordering[Row] { - def compare(a: Row, b: Row): Int = { - throw new UnsupportedOperationException - } -} +import org.apache.spark.sql.types.{BinaryType, StringType, NumericType} /** * Generates bytecode for an [[Ordering]] of [[Row Rows]] for a given set of * [[Expression Expressions]]. */ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] with Logging { + import scala.reflect.runtime.{universe => ru} import scala.reflect.runtime.universe._ - protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = + protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder]) protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] = @@ -49,90 +38,73 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit protected def create(ordering: Seq[SortOrder]): Ordering[Row] = { val a = newTermName("a") val b = newTermName("b") - val ctx = newCodeGenContext() - val comparisons = ordering.zipWithIndex.map { case (order, i) => - val evalA = expressionEvaluator(order.child, ctx) - val evalB = expressionEvaluator(order.child, ctx) - val asc = order.direction == Ascending + val evalA = expressionEvaluator(order.child) + val evalB = expressionEvaluator(order.child) + val compare = order.child.dataType match { case BinaryType => - s""" - { - byte[] x = ${if (asc) evalA.primitiveTerm else evalB.primitiveTerm}; - byte[] y = ${if (!asc) evalB.primitiveTerm else evalA.primitiveTerm}; - int j = 0; - while (j < x.length && j < y.length) { - if (x[j] != y[j]) return x[j] - y[j]; - j = j + 1; - } - int d = x.length - y.length; - if (d != 0) { - return d; - } - }""" + q""" + val x = ${if (order.direction == Ascending) evalA.primitiveTerm else evalB.primitiveTerm} + val y = ${if (order.direction != Ascending) evalB.primitiveTerm else evalA.primitiveTerm} + var i = 0 + while (i < x.length && i < y.length) { + val res = x(i).compareTo(y(i)) + if (res != 0) return res + i = i+1 + } + return x.length - y.length + """ case _: NumericType => - s""" - if (${evalA.primitiveTerm} != ${evalB.primitiveTerm}) { - if (${evalA.primitiveTerm} > ${evalB.primitiveTerm}) { - return ${if (asc) "1" else "-1"}; - } else { - return ${if (asc) "-1" else "1"}; - } - }""" - case _ => - s""" - int comp = ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm}); - if (comp != 0) { - return ${if (asc) "comp" else "-comp"}; - }""" - } - - s""" - i = $a; - ${evalA.code} - i = $b; - ${evalB.code} - if (${evalA.nullTerm} && ${evalB.nullTerm}) { - // Nothing - } else if (${evalA.nullTerm}) { - return ${if (order.direction == Ascending) "-1" else "1"}; - } else if (${evalB.nullTerm}) { - return ${if (order.direction == Ascending) "1" else "-1"}; + q""" + val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm} + if(comp != 0) { + return ${if (order.direction == Ascending) q"comp.toInt" else q"-comp.toInt"} + } + """ + case StringType => + if (order.direction == Ascending) { + q"""return ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm})""" } else { - $compare + q"""return ${evalB.primitiveTerm}.compare(${evalA.primitiveTerm})""" } - """ - }.mkString("\n") - - val code = s""" - import org.apache.spark.sql.Row; - - public SpecificOrdering generate($exprType[] expr) { - return new SpecificOrdering(expr); } - class SpecificOrdering extends ${typeOf[BaseOrdering]} { - - private $exprType[] expressions = null; - - public SpecificOrdering($exprType[] expr) { - expressions = expr; + q""" + i = $a + ..${evalA.code} + i = $b + ..${evalB.code} + if (${evalA.nullTerm} && ${evalB.nullTerm}) { + // Nothing + } else if (${evalA.nullTerm}) { + return ${if (order.direction == Ascending) q"-1" else q"1"} + } else if (${evalB.nullTerm}) { + return ${if (order.direction == Ascending) q"1" else q"-1"} + } else { + $compare } + """ + } - @Override - public int compare(Row a, Row b) { - Row i = null; // Holds current row being evaluated. - $comparisons - return 0; + val q"class $orderingName extends $orderingType { ..$body }" = reify { + class SpecificOrdering extends Ordering[Row] { + val o = ordering + } + }.tree.children.head + + val code = q""" + class $orderingName extends $orderingType { + ..$body + def compare(a: $rowType, b: $rowType): Int = { + var i: $rowType = null // Holds current row being evaluated. + ..$comparisons + return 0 } - }""" - + } + new $orderingName() + """ logDebug(s"Generated Ordering: $code") - - val c = compile(code) - // fetch the only one method `generate(Expression[])` - val m = c.getDeclaredMethods()(0) - m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[BaseOrdering] + toolBox.eval(code).asInstanceOf[Ordering[Row]] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index fb18769f00da3..40e163024360e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -19,17 +19,12 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ -/** - * Interface for generated predicate - */ -abstract class Predicate { - def eval(r: Row): Boolean -} - /** * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[Row]]. */ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in) @@ -37,34 +32,17 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { BindReferences.bindReference(in, inputSchema) protected def create(predicate: Expression): ((Row) => Boolean) = { - val ctx = newCodeGenContext() - val eval = expressionEvaluator(predicate, ctx) - val code = s""" - import org.apache.spark.sql.Row; + val cEval = expressionEvaluator(predicate) - public SpecificPredicate generate($exprType[] expr) { - return new SpecificPredicate(expr); - } - - class SpecificPredicate extends ${classOf[Predicate].getName} { - private final $exprType[] expressions; - public SpecificPredicate($exprType[] expr) { - expressions = expr; - } - - @Override - public boolean eval(Row i) { - ${eval.code} - return !${eval.nullTerm} && ${eval.primitiveTerm}; + val code = + q""" + (i: $rowType) => { + ..${cEval.code} + if (${cEval.nullTerm}) false else ${cEval.primitiveTerm} } - }""" - - logDebug(s"Generated predicate '$predicate':\n$code") + """ - val c = compile(code) - // fetch the only one method `generate(Expression[])` - val m = c.getDeclaredMethods()(0) - val p = m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[Predicate] - (r: Row) => p.eval(r) + log.debug(s"Generated predicate '$predicate':\n$code") + toolBox.eval(code).asInstanceOf[Row => Boolean] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index d5be1fc12e0f0..31c63a79ebc8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -17,14 +17,9 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.sql.BaseMutableRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -/** - * Java can not access Projection (in package object) - */ -abstract class BaseProject extends Projection {} /** * Generates bytecode that produces a new [[Row]] object based on a fixed set of input @@ -32,6 +27,7 @@ abstract class BaseProject extends Projection {} * generated based on the output types of the [[Expression]] to avoid boxing of primitive values. */ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { + import scala.reflect.runtime.{universe => ru} import scala.reflect.runtime.universe._ protected def canonicalize(in: Seq[Expression]): Seq[Expression] = @@ -42,183 +38,201 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { // Make Mutablility optional... protected def create(expressions: Seq[Expression]): Projection = { - val ctx = newCodeGenContext() - val columns = expressions.zipWithIndex.map { - case (e, i) => - s"private ${primitiveForType(e.dataType)} c$i = ${defaultPrimitive(e.dataType)};\n" - }.mkString("\n ") + val tupleLength = ru.Literal(Constant(expressions.length)) + val lengthDef = q"final val length = $tupleLength" + + /* TODO: Configurable... + val nullFunctions = + q""" + private final val nullSet = new org.apache.spark.util.collection.BitSet(length) + final def setNullAt(i: Int) = nullSet.set(i) + final def isNullAt(i: Int) = nullSet.get(i) + """ + */ + + val nullFunctions = + q""" + private[this] var nullBits = new Array[Boolean](${expressions.size}) + override def setNullAt(i: Int) = { nullBits(i) = true } + override def isNullAt(i: Int) = nullBits(i) + """.children - val initColumns = expressions.zipWithIndex.map { + val tupleElements = expressions.zipWithIndex.flatMap { case (e, i) => - val eval = expressionEvaluator(e, ctx) - s""" + val elementName = newTermName(s"c$i") + val evaluatedExpression = expressionEvaluator(e) + val iLit = ru.Literal(Constant(i)) + + q""" + var ${newTermName(s"c$i")}: ${termForType(e.dataType)} = _ { - // column$i - ${eval.code} - nullBits[$i] = ${eval.nullTerm}; - if(!${eval.nullTerm}) { - c$i = ${eval.primitiveTerm}; + ..${evaluatedExpression.code} + if(${evaluatedExpression.nullTerm}) + setNullAt($iLit) + else { + nullBits($iLit) = false + $elementName = ${evaluatedExpression.primitiveTerm} } } - """ - }.mkString("\n") + """.children : Seq[Tree] + } - val getCases = (0 until expressions.size).map { i => - s"case $i: return c$i;" - }.mkString("\n ") + val accessorFailure = q"""scala.sys.error("Invalid ordinal:" + i)""" + val applyFunction = { + val cases = (0 until expressions.size).map { i => + val ordinal = ru.Literal(Constant(i)) + val elementName = newTermName(s"c$i") + val iLit = ru.Literal(Constant(i)) - val updateCases = expressions.zipWithIndex.map { case (e, i) => - s"case $i: { c$i = (${termForType(e.dataType)})value; return;}" - }.mkString("\n ") + q"if(i == $ordinal) { if(isNullAt($i)) return null else return $elementName }" + } + q"override def apply(i: Int): Any = { ..$cases; $accessorFailure }" + } + + val updateFunction = { + val cases = expressions.zipWithIndex.map {case (e, i) => + val ordinal = ru.Literal(Constant(i)) + val elementName = newTermName(s"c$i") + val iLit = ru.Literal(Constant(i)) + + q""" + if(i == $ordinal) { + if(value == null) { + setNullAt(i) + } else { + nullBits(i) = false + $elementName = value.asInstanceOf[${termForType(e.dataType)}] + } + return + }""" + } + q"override def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }" + } val specificAccessorFunctions = nativeTypes.map { dataType => - val cases = expressions.zipWithIndex.map { - case (e, i) if e.dataType == dataType => - s"case $i: return c$i;" - case _ => "" - }.mkString("\n ") - if (cases.count(_ != '\n') > 0) { - s""" - @Override - public ${primitiveForType(dataType)} ${accessorForType(dataType)}(int i) { - if (isNullAt(i)) { - return ${defaultPrimitive(dataType)}; - } - switch (i) { - $cases - } - return ${defaultPrimitive(dataType)}; - }""" - } else { - "" + val ifStatements = expressions.zipWithIndex.flatMap { + // getString() is not used by expressions + case (e, i) if e.dataType == dataType && dataType != StringType => + val elementName = newTermName(s"c$i") + // TODO: The string of ifs gets pretty inefficient as the row grows in size. + // TODO: Optional null checks? + q"if(i == $i) return $elementName" :: Nil + case _ => Nil } - }.mkString("\n") + dataType match { + // Row() need this interface to compile + case StringType => + q""" + override def getString(i: Int): String = { + $accessorFailure + }""" + case other => + q""" + override def ${accessorForType(dataType)}(i: Int): ${termForType(dataType)} = { + ..$ifStatements; + $accessorFailure + }""" + } + } val specificMutatorFunctions = nativeTypes.map { dataType => - val cases = expressions.zipWithIndex.map { - case (e, i) if e.dataType == dataType => - s"case $i: { c$i = value; return; }" - case _ => "" - }.mkString("\n") - if (cases.count(_ != '\n') > 0) { - s""" - @Override - public void ${mutatorForType(dataType)}(int i, ${primitiveForType(dataType)} value) { - nullBits[i] = false; - switch (i) { - $cases - } - }""" - } else { - "" + val ifStatements = expressions.zipWithIndex.flatMap { + // setString() is not used by expressions + case (e, i) if e.dataType == dataType && dataType != StringType => + val elementName = newTermName(s"c$i") + // TODO: The string of ifs gets pretty inefficient as the row grows in size. + // TODO: Optional null checks? + q"if(i == $i) { nullBits($i) = false; $elementName = value; return }" :: Nil + case _ => Nil + } + dataType match { + case StringType => + // MutableRow() need this interface to compile + q""" + override def setString(i: Int, value: String) { + $accessorFailure + }""" + case other => + q""" + override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}) { + ..$ifStatements; + $accessorFailure + }""" } - }.mkString("\n") + } val hashValues = expressions.zipWithIndex.map { case (e, i) => - val col = newTermName(s"c$i") + val elementName = newTermName(s"c$i") val nonNull = e.dataType match { - case BooleanType => s"$col ? 0 : 1" - case ByteType | ShortType | IntegerType | DateType => s"$col" - case LongType => s"$col ^ ($col >>> 32)" - case FloatType => s"Float.floatToIntBits($col)" + case BooleanType => q"if ($elementName) 0 else 1" + case ByteType | ShortType | IntegerType => q"$elementName.toInt" + case LongType => q"($elementName ^ ($elementName >>> 32)).toInt" + case FloatType => q"java.lang.Float.floatToIntBits($elementName)" case DoubleType => - s"Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32)" - case _ => s"$col.hashCode()" + q"{ val b = java.lang.Double.doubleToLongBits($elementName); (b ^ (b >>>32)).toInt }" + case _ => q"$elementName.hashCode" } - s"isNullAt($i) ? 0 : ($nonNull)" + q"if (isNullAt($i)) 0 else $nonNull" } - val hashUpdates: String = hashValues.map( v => - s""" - result *= 37; result += $v;""" - ).mkString("\n") + val hashUpdates: Seq[Tree] = hashValues.map(v => q"""result = 37 * result + $v""": Tree) - val columnChecks = expressions.zipWithIndex.map { case (e, i) => - s""" - if (isNullAt($i) != row.isNullAt($i) || !isNullAt($i) && !get($i).equals(row.get($i))) { - return false; - } + val hashCodeFunction = + q""" + override def hashCode(): Int = { + var result: Int = 37 + ..$hashUpdates + result + } """ - }.mkString("\n") - val code = s""" - import org.apache.spark.sql.Row; - - public SpecificProjection generate($exprType[] expr) { - return new SpecificProjection(expr); + val columnChecks = (0 until expressions.size).map { i => + val elementName = newTermName(s"c$i") + q"if (this.$elementName != specificType.$elementName) return false" } - class SpecificProjection extends ${typeOf[BaseProject]} { - private $exprType[] expressions = null; - - public SpecificProjection($exprType[] expr) { - expressions = expr; - } + val equalsFunction = + q""" + override def equals(other: Any): Boolean = other match { + case specificType: SpecificRow => + ..$columnChecks + return true + case other => super.equals(other) + } + """ - @Override - public Object apply(Object r) { - return new SpecificRow(expressions, (Row) r); - } + val allColumns = (0 until expressions.size).map { i => + val iLit = ru.Literal(Constant(i)) + q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }" } - final class SpecificRow extends ${typeOf[BaseMutableRow]} { - - $columns - - public SpecificRow($exprType[] expressions, Row i) { - $initColumns - } - - public int size() { return ${expressions.length};} - private boolean[] nullBits = new boolean[${expressions.length}]; - public void setNullAt(int i) { nullBits[i] = true; } - public boolean isNullAt(int i) { return nullBits[i]; } - - public Object get(int i) { - if (isNullAt(i)) return null; - switch (i) { - $getCases - } - return null; - } - public void update(int i, Object value) { - if (value == null) { - setNullAt(i); - return; - } - nullBits[i] = false; - switch (i) { - $updateCases - } - } - $specificAccessorFunctions - $specificMutatorFunctions - - @Override - public int hashCode() { - int result = 37; - $hashUpdates - return result; + val copyFunction = + q"override def copy() = new $genericRowType(Array[Any](..$allColumns))" + + val toSeqFunction = + q"override def toSeq: Seq[Any] = Seq(..$allColumns)" + + val classBody = + nullFunctions ++ ( + lengthDef +: + applyFunction +: + updateFunction +: + equalsFunction +: + hashCodeFunction +: + copyFunction +: + toSeqFunction +: + (tupleElements ++ specificAccessorFunctions ++ specificMutatorFunctions)) + + val code = q""" + final class SpecificRow(i: $rowType) extends $mutableRowType { + ..$classBody } - @Override - public boolean equals(Object other) { - if (other instanceof Row) { - Row row = (Row) other; - if (row.length() != size()) return false; - $columnChecks - return true; - } - return super.equals(other); - } - } + new $projectionType { def apply(r: $rowType) = new SpecificRow(r) } """ - logDebug(s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${code}") - - val c = compile(code) - // fetch the only one method `generate(Expression[])` - val m = c.getDeclaredMethods()(0) - m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[Projection] + log.debug( + s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${toolBox.typeCheck(code)}") + toolBox.eval(code).asInstanceOf[Projection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala index 7f1b12cdd5800..528e38a50a740 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala @@ -27,6 +27,12 @@ import org.apache.spark.util.Utils */ package object codegen { + /** + * A lock to protect invoking the scala compiler at runtime, since it is not thread safe in Scala + * 2.10. + */ + protected[codegen] val globalLock = org.apache.spark.sql.catalyst.ScalaReflectionLock + /** Canonicalizes an expression so those that differ only by names can reuse the same code. */ object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] { val batches = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 58273b166fe91..807021d50e8e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -202,8 +202,9 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { sys.error(s"BinaryComparisons must override either eval or evalInternal") } -private[sql] object BinaryComparison { - def unapply(e: BinaryComparison): Option[(Expression, Expression)] = Some((e.left, e.right)) +object BinaryComparison { + def unapply(b: BinaryComparison): Option[(Expression, Expression)] = + Some((b.left, b.right)) } case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c16f08d389955..5c6379b8d44b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -36,8 +36,6 @@ object DefaultOptimizer extends Optimizer { // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: - Batch("Distinct", FixedPoint(100), - ReplaceDistinctWithAggregate) :: Batch("Operator Reordering", FixedPoint(100), UnionPushdown, CombineFilters, @@ -266,7 +264,7 @@ object NullPropagation extends Rule[LogicalPlan] { if (newChildren.length == 0) { Literal.create(null, e.dataType) } else if (newChildren.length == 1) { - newChildren.head + newChildren(0) } else { Coalesce(newChildren) } @@ -280,18 +278,21 @@ object NullPropagation extends Rule[LogicalPlan] { case e: MinOf => e // Put exceptional cases above if any - case e @ BinaryArithmetic(Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ BinaryArithmetic(_, Literal(null, _)) => Literal.create(null, e.dataType) - - case e @ BinaryComparison(Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ BinaryComparison(_, Literal(null, _)) => Literal.create(null, e.dataType) - + case e: BinaryArithmetic => e.children match { + case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) + case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) + case _ => e + } + case e: BinaryComparison => e.children match { + case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) + case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) + case _ => e + } case e: StringRegexExpression => e.children match { case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e } - case e: StringComparison => e.children match { case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) @@ -695,15 +696,3 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] { LocalRelation(projectList.map(_.toAttribute), data.map(projection)) } } - -/** - * Replaces logical [[Distinct]] operator with an [[Aggregate]] operator. - * {{{ - * SELECT DISTINCT f1, f2 FROM t ==> SELECT f1, f2 FROM t GROUP BY f1, f2 - * }}} - */ -object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Distinct(child) => Aggregate(child.output, child.output, child) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index e77e5c27b687a..33a9e55a47dee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -339,9 +339,6 @@ case class Sample( override def output: Seq[Attribute] = child.output } -/** - * Returns a new logical plan that dedups input rows. - */ case class Distinct(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 5df528770ca6e..b6927485f42bf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -344,7 +344,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation("abdef" cast TimestampType, null) checkEvaluation("12.65" cast DecimalType.Unlimited, Decimal(12.65)) - checkEvaluation(Literal(1) cast LongType, 1.toLong) + checkEvaluation(Literal(1) cast LongType, 1) checkEvaluation(Cast(Literal(1000) cast TimestampType, LongType), 1.toLong) checkEvaluation(Cast(Literal(-1200) cast TimestampType, LongType), -2.toLong) checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) @@ -363,16 +363,13 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(Cast("abdef" cast BinaryType, StringType), "abdef") checkEvaluation(Cast(Cast(Cast(Cast( - Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), - 5.toLong) + Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), 5) checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast - ByteType, TimestampType), DecimalType.Unlimited), LongType), StringType), ShortType), - 0.toShort) + ByteType, TimestampType), DecimalType.Unlimited), LongType), StringType), ShortType), 0) checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast TimestampType, ByteType), DecimalType.Unlimited), LongType), StringType), ShortType), null) checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast - DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType), - 0.toShort) + DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType), 0) checkEvaluation(Literal(true) cast IntegerType, 1) checkEvaluation(Literal(false) cast IntegerType, 0) checkEvaluation(Literal(true) cast StringType, "true") @@ -512,9 +509,9 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { val seconds = millis * 1000 + 2 val ts = new Timestamp(millis) val tss = new Timestamp(seconds) - checkEvaluation(Cast(ts, ShortType), 15.toShort) + checkEvaluation(Cast(ts, ShortType), 15) checkEvaluation(Cast(ts, IntegerType), 15) - checkEvaluation(Cast(ts, LongType), 15.toLong) + checkEvaluation(Cast(ts, LongType), 15) checkEvaluation(Cast(ts, FloatType), 15.002f) checkEvaluation(Cast(ts, DoubleType), 15.002) checkEvaluation(Cast(Cast(tss, ShortType), TimestampType), ts) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala index 8cfd853afa35f..d7c437095e395 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala @@ -32,12 +32,11 @@ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() } catch { case e: Throwable => - val ctx = GenerateProjection.newCodeGenContext() - val evaluated = GenerateProjection.expressionEvaluator(expression, ctx) + val evaluated = GenerateProjection.expressionEvaluator(expression) fail( s""" |Code generation of $expression failed: - |${evaluated.code} + |${evaluated.code.mkString("\n")} |$e """.stripMargin) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala index 9ab1f7d7ad0db..a40324b008e16 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala @@ -28,8 +28,7 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { - val ctx = GenerateProjection.newCodeGenContext() - lazy val evaluated = GenerateProjection.expressionEvaluator(expression, ctx) + lazy val evaluated = GenerateProjection.expressionEvaluator(expression) val plan = try { GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) @@ -38,7 +37,7 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { fail( s""" |Code generation of $expression failed: - |${evaluated.code} + |${evaluated.code.mkString("\n")} |$e """.stripMargin) } @@ -50,7 +49,7 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { s""" |Mismatched hashCodes for values: $actual, $expectedRow |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} - |${evaluated.code} + |${evaluated.code.mkString("\n")} """.stripMargin) } if (actual != expectedRow) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala deleted file mode 100644 index df29a62ff0e15..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LocalRelation, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.RuleExecutor - -class ReplaceDistinctWithAggregateSuite extends PlanTest { - - object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("ProjectCollapsing", Once, ReplaceDistinctWithAggregate) :: Nil - } - - test("replace distinct with aggregate") { - val input = LocalRelation('a.int, 'b.int) - - val query = Distinct(input) - val optimized = Optimize.execute(query.analyze) - - val correctAnswer = Aggregate(input.output, input.output, input) - - comparePlans(optimized, correctAnswer) - } -} diff --git a/sql/core/pom.xml b/sql/core/pom.xml index ed75475a87067..8210c552603ea 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml @@ -61,11 +61,11 @@ test - org.apache.parquet + com.twitter parquet-column - org.apache.parquet + com.twitter parquet-hadoop diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 4a224153e1a37..034d887901975 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1082,22 +1082,6 @@ class DataFrame private[sql]( } } - /** - * Returns a new [[DataFrame]] with a column dropped. - * This version of drop accepts a Column rather than a name. - * This is a no-op if the DataFrame doesn't have a column - * with an equivalent expression. - * @group dfops - * @since 1.4.1 - */ - def drop(col: Column): DataFrame = { - val attrs = this.logicalPlan.output - val colsAfterDrop = attrs.filter { attr => - attr != col.expr - }.map(attr => Column(attr)) - select(colsAfterDrop : _*) - } - /** * Returns a new [[DataFrame]] that contains only the unique rows from this [[DataFrame]]. * This is an alias for `distinct`. @@ -1311,7 +1295,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - override def distinct: DataFrame = dropDuplicates() + override def distinct: DataFrame = Distinct(logicalPlan) /** * @group basic diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index ddb54025baa24..91e6385dec81b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -705,18 +705,7 @@ class SQLContext(@transient val sparkContext: SparkContext) /** * :: Experimental :: * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements - * in an range from 0 to `end` (exclusive) with step value 1. - * - * @since 1.4.1 - * @group dataframe - */ - @Experimental - def range(end: Long): DataFrame = range(0, end) - - /** - * :: Experimental :: - * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements - * in an range from `start` to `end` (exclusive) with step value 1. + * in an range from `start` to `end`(exclusive) with step value 1. * * @since 1.4.0 * @group dataframe @@ -731,7 +720,7 @@ class SQLContext(@transient val sparkContext: SparkContext) /** * :: Experimental :: * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements - * in an range from `start` to `end` (exclusive) with an step value, with partition number + * in an range from `start` to `end`(exclusive) with an step value, with partition number * specified. * * @since 1.4.0 @@ -916,11 +905,6 @@ class SQLContext(@transient val sparkContext: SparkContext) tlSession.remove() } - protected[sql] def setSession(session: SQLSession): Unit = { - detachSession() - tlSession.set(session) - } - protected[sql] class SQLSession { // Note that this is a lazy val so we can override the default value in subclasses. protected[sql] lazy val conf: SQLConf = new SQLConf diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 43b62f0e822f8..604f3124e23ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -139,19 +139,4 @@ private[r] object SQLUtils { case "ignore" => SaveMode.Ignore } } - - def loadDF( - sqlContext: SQLContext, - source: String, - options: java.util.Map[String, String]): DataFrame = { - sqlContext.read.format(source).options(options).load() - } - - def loadDF( - sqlContext: SQLContext, - source: String, - schema: StructType, - options: java.util.Map[String, String]): DataFrame = { - sqlContext.read.format(source).schema(schema).options(options).load() - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7a1331a39151a..d0a1ad00560d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -284,8 +284,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case r: RunnableCommand => ExecutedCommand(r) :: Nil case logical.Distinct(child) => - throw new IllegalStateException( - "logical distinct operator should have been replaced by aggregate in the optimizer") + execution.Distinct(partial = false, + execution.Distinct(partial = true, planLater(child))) :: Nil case logical.Repartition(numPartitions, shuffle, child) => execution.Repartition(numPartitions, shuffle, planLater(child)) :: Nil case logical.SortPartitions(sortExprs, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index fb42072f9d5a7..a30ade86441ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -230,6 +230,37 @@ case class ExternalSort( override def outputOrdering: Seq[SortOrder] = sortOrder } +/** + * :: DeveloperApi :: + * Computes the set of distinct input rows using a HashSet. + * @param partial when true the distinct operation is performed partially, per partition, without + * shuffling the data. + * @param child the input query plan. + */ +@DeveloperApi +case class Distinct(partial: Boolean, child: SparkPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output + + override def requiredChildDistribution: Seq[Distribution] = + if (partial) UnspecifiedDistribution :: Nil else ClusteredDistribution(child.output) :: Nil + + protected override def doExecute(): RDD[Row] = { + child.execute().mapPartitions { iter => + val hashSet = new scala.collection.mutable.HashSet[Row]() + + var currentRow: Row = null + while (iter.hasNext) { + currentRow = iter.next() + if (!hashSet.contains(currentRow)) { + hashSet.add(currentRow.copy()) + } + } + + hashSet.iterator + } + } +} + /** * :: DeveloperApi :: * Return a new RDD that has exactly `numPartitions` partitions. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala index 62c4e92ebec68..f5ce2718bec4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala @@ -21,9 +21,9 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter -import org.apache.parquet.Log -import org.apache.parquet.hadoop.util.ContextUtil -import org.apache.parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat} +import parquet.Log +import parquet.hadoop.util.ContextUtil +import parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat} private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) extends ParquetOutputCommitter(outputPath, context) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 85c2ce740fe52..caa9f045537d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -23,9 +23,9 @@ import java.util.{TimeZone, Calendar} import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap} import jodd.datetime.JDateTime -import org.apache.parquet.column.Dictionary -import org.apache.parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} -import org.apache.parquet.schema.MessageType +import parquet.column.Dictionary +import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} +import parquet.schema.MessageType import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.parquet.CatalystConverter.FieldType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index 88ae88e9684c8..f0f4e7d147e75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -21,11 +21,11 @@ import java.nio.ByteBuffer import com.google.common.io.BaseEncoding import org.apache.hadoop.conf.Configuration -import org.apache.parquet.filter2.compat.FilterCompat -import org.apache.parquet.filter2.compat.FilterCompat._ -import org.apache.parquet.filter2.predicate.FilterApi._ -import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} -import org.apache.parquet.io.api.Binary +import parquet.filter2.compat.FilterCompat +import parquet.filter2.compat.FilterCompat._ +import parquet.filter2.predicate.FilterApi._ +import parquet.filter2.predicate.{FilterApi, FilterPredicate} +import parquet.io.api.Binary import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 09088ee91106c..fcb9513ab66f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -24,9 +24,9 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.permission.FsAction import org.apache.spark.sql.types.{StructType, DataType} -import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat} -import org.apache.parquet.hadoop.metadata.CompressionCodecName -import org.apache.parquet.schema.MessageType +import parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat} +import parquet.hadoop.metadata.CompressionCodecName +import parquet.schema.MessageType import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException} @@ -107,7 +107,7 @@ private[sql] object ParquetRelation { // // Therefore we need to force the class to be loaded. // This should really be resolved by Parquet. - Class.forName(classOf[org.apache.parquet.Log].getName) + Class.forName(classOf[parquet.Log].getName) // Note: Logger.getLogger("parquet") has a default logger // that appends to Console which needs to be cleared. @@ -127,7 +127,7 @@ private[sql] object ParquetRelation { type RowType = org.apache.spark.sql.catalyst.expressions.GenericMutableRow // The compression type - type CompressionType = org.apache.parquet.hadoop.metadata.CompressionCodecName + type CompressionType = parquet.hadoop.metadata.CompressionCodecName // The parquet compression short names val shortParquetCompressionCodecNames = Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 1e694f2feabee..cb7ae246d0d75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -33,13 +33,13 @@ import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat => NewFileOutputFormat} -import org.apache.parquet.hadoop._ -import org.apache.parquet.hadoop.api.ReadSupport.ReadContext -import org.apache.parquet.hadoop.api.{InitContext, ReadSupport} -import org.apache.parquet.hadoop.metadata.GlobalMetaData -import org.apache.parquet.hadoop.util.ContextUtil -import org.apache.parquet.io.ParquetDecodingException -import org.apache.parquet.schema.MessageType +import parquet.hadoop._ +import parquet.hadoop.api.ReadSupport.ReadContext +import parquet.hadoop.api.{InitContext, ReadSupport} +import parquet.hadoop.metadata.GlobalMetaData +import parquet.hadoop.util.ContextUtil +import parquet.io.ParquetDecodingException +import parquet.schema.MessageType import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mapred.SparkHadoopMapRedUtil @@ -78,7 +78,7 @@ private[sql] case class ParquetTableScan( }.toArray protected override def doExecute(): RDD[Row] = { - import org.apache.parquet.filter2.compat.FilterCompat.FilterPredicateCompat + import parquet.filter2.compat.FilterCompat.FilterPredicateCompat val sc = sqlContext.sparkContext val job = new Job(sc.hadoopConfiguration) @@ -136,7 +136,7 @@ private[sql] case class ParquetTableScan( baseRDD.mapPartitionsWithInputSplit { case (split, iter) => val partValue = "([^=]+)=([^=]+)".r val partValues = - split.asInstanceOf[org.apache.parquet.hadoop.ParquetInputSplit] + split.asInstanceOf[parquet.hadoop.ParquetInputSplit] .getPath .toString .split("/") @@ -378,7 +378,7 @@ private[sql] case class InsertIntoParquetTable( * to imported ones. */ private[parquet] class AppendingParquetOutputFormat(offset: Int) - extends org.apache.parquet.hadoop.ParquetOutputFormat[Row] { + extends parquet.hadoop.ParquetOutputFormat[Row] { // override to accept existing directories as valid output directory override def checkOutputSpecs(job: JobContext): Unit = {} var committer: OutputCommitter = null @@ -431,7 +431,7 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) * RecordFilter we want to use. */ private[parquet] class FilteringParquetRowInputFormat - extends org.apache.parquet.hadoop.ParquetInputFormat[Row] with Logging { + extends parquet.hadoop.ParquetInputFormat[Row] with Logging { private var fileStatuses = Map.empty[Path, FileStatus] @@ -439,7 +439,7 @@ private[parquet] class FilteringParquetRowInputFormat inputSplit: InputSplit, taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = { - import org.apache.parquet.filter2.compat.FilterCompat.NoOpFilter + import parquet.filter2.compat.FilterCompat.NoOpFilter val readSupport: ReadSupport[Row] = new RowReadSupport() @@ -501,7 +501,7 @@ private[parquet] class FilteringParquetRowInputFormat globalMetaData = new GlobalMetaData(globalMetaData.getSchema, mergedMetadata, globalMetaData.getCreatedBy) - val readContext = ParquetInputFormat.getReadSupportInstance(configuration).init( + val readContext = getReadSupport(configuration).init( new InitContext(configuration, globalMetaData.getKeyValueMetaData, globalMetaData.getSchema)) @@ -531,8 +531,8 @@ private[parquet] class FilteringParquetRowInputFormat minSplitSize: JLong, readContext: ReadContext): JList[ParquetInputSplit] = { - import org.apache.parquet.filter2.compat.FilterCompat.Filter - import org.apache.parquet.filter2.compat.RowGroupFilter + import parquet.filter2.compat.FilterCompat.Filter + import parquet.filter2.compat.RowGroupFilter import org.apache.spark.sql.parquet.FilteringParquetRowInputFormat.blockLocationCache @@ -547,7 +547,7 @@ private[parquet] class FilteringParquetRowInputFormat // https://github.com/apache/incubator-parquet-mr/pull/17 // is resolved val generateSplits = - Class.forName("org.apache.parquet.hadoop.ClientSideMetadataSplitStrategy") + Class.forName("parquet.hadoop.ClientSideMetadataSplitStrategy") .getDeclaredMethods.find(_.getName == "generateSplits").getOrElse( sys.error(s"Failed to reflectively invoke ClientSideMetadataSplitStrategy.generateSplits")) generateSplits.setAccessible(true) @@ -612,7 +612,7 @@ private[parquet] class FilteringParquetRowInputFormat // https://github.com/apache/incubator-parquet-mr/pull/17 // is resolved val generateSplits = - Class.forName("org.apache.parquet.hadoop.TaskSideMetadataSplitStrategy") + Class.forName("parquet.hadoop.TaskSideMetadataSplitStrategy") .getDeclaredMethods.find(_.getName == "generateTaskSideMDSplits").getOrElse( sys.error( s"Failed to reflectively invoke TaskSideMetadataSplitStrategy.generateTaskSideMDSplits")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 89db408b1c382..70a220cc43ab9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql.parquet import java.util.{HashMap => JHashMap} import org.apache.hadoop.conf.Configuration -import org.apache.parquet.column.ParquetProperties -import org.apache.parquet.hadoop.ParquetOutputFormat -import org.apache.parquet.hadoop.api.ReadSupport.ReadContext -import org.apache.parquet.hadoop.api.{ReadSupport, WriteSupport} -import org.apache.parquet.io.api._ -import org.apache.parquet.schema.MessageType +import parquet.column.ParquetProperties +import parquet.hadoop.ParquetOutputFormat +import parquet.hadoop.api.ReadSupport.ReadContext +import parquet.hadoop.api.{ReadSupport, WriteSupport} +import parquet.io.api._ +import parquet.schema.MessageType import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index ba2a35b74ef82..6698b19c7477d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -19,25 +19,26 @@ package org.apache.spark.sql.parquet import java.io.IOException -import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job -import org.apache.parquet.format.converter.ParquetMetadataConverter -import org.apache.parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} -import org.apache.parquet.hadoop.util.ContextUtil -import org.apache.parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} -import org.apache.parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} -import org.apache.parquet.schema.Type.Repetition -import org.apache.parquet.schema.{ConversionPatterns, DecimalMetadata, GroupType => ParquetGroupType, MessageType, OriginalType => ParquetOriginalType, PrimitiveType => ParquetPrimitiveType, Type => ParquetType, Types => ParquetTypes} +import parquet.format.converter.ParquetMetadataConverter +import parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} +import parquet.hadoop.util.ContextUtil +import parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} +import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} +import parquet.schema.Type.Repetition +import parquet.schema.{ConversionPatterns, DecimalMetadata, GroupType => ParquetGroupType, MessageType, OriginalType => ParquetOriginalType, PrimitiveType => ParquetPrimitiveType, Type => ParquetType, Types => ParquetTypes} -import org.apache.spark.Logging -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.types._ +import org.apache.spark.{Logging, SparkException} +// Implicits +import scala.collection.JavaConversions._ /** A class representing Parquet info fields we care about, for passing back to Parquet */ private[parquet] case class ParquetTypeInfo( @@ -72,12 +73,13 @@ private[parquet] object ParquetTypesConverter extends Logging { case ParquetPrimitiveTypeName.INT96 if int96AsTimestamp => TimestampType case ParquetPrimitiveTypeName.INT96 => // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? - throw new AnalysisException("Potential loss of precision: cannot convert INT96") + sys.error("Potential loss of precision: cannot convert INT96") case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY if (originalType == ParquetOriginalType.DECIMAL && decimalInfo.getPrecision <= 18) => // TODO: for now, our reader only supports decimals that fit in a Long DecimalType(decimalInfo.getPrecision, decimalInfo.getScale) - case _ => throw new AnalysisException(s"Unsupported parquet datatype $parquetType") + case _ => sys.error( + s"Unsupported parquet datatype $parquetType") } } @@ -369,7 +371,7 @@ private[parquet] object ParquetTypesConverter extends Logging { parquetKeyType, parquetValueType) } - case _ => throw new AnalysisException(s"Unsupported datatype $ctype") + case _ => sys.error(s"Unsupported datatype $ctype") } } } @@ -401,7 +403,7 @@ private[parquet] object ParquetTypesConverter extends Logging { def convertFromString(string: String): Seq[Attribute] = { Try(DataType.fromJson(string)).getOrElse(DataType.fromCaseClassString(string)) match { case s: StructType => s.toAttributes - case other => throw new AnalysisException(s"Can convert $string to row") + case other => sys.error(s"Can convert $string to row") } } @@ -409,8 +411,8 @@ private[parquet] object ParquetTypesConverter extends Logging { // ,;{}()\n\t= and space character are special characters in Parquet schema schema.map(_.name).foreach { name => if (name.matches(".*[ ,;{}()\n\t=].*")) { - throw new AnalysisException( - s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=". + sys.error( + s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\n\t=". |Please use alias to rename it. """.stripMargin.split("\n").mkString(" ")) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 5dda440240e60..824ae36968c32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -29,17 +29,16 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat -import org.apache.parquet.filter2.predicate.FilterApi -import org.apache.parquet.hadoop._ -import org.apache.parquet.hadoop.metadata.CompressionCodecName -import org.apache.parquet.hadoop.util.ContextUtil +import parquet.filter2.predicate.FilterApi +import parquet.hadoop._ +import parquet.hadoop.metadata.CompressionCodecName +import parquet.hadoop.util.ContextUtil import org.apache.spark.{Partition => SparkPartition, SerializableWritable, Logging, SparkException} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.{Row, SQLConf, SQLContext} @@ -84,7 +83,7 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext case partFilePattern(id) => id.toInt case name if name.startsWith("_") => 0 case name if name.startsWith(".") => 0 - case name => throw new AnalysisException( + case name => sys.error( s"Trying to write Parquet files to directory $outputPath, " + s"but found items with illegal name '$name'.") }.reduceOption(_ max _).getOrElse(0) @@ -381,12 +380,11 @@ private[sql] class ParquetRelation2( // time-consuming. if (dataSchema == null) { dataSchema = { - val dataSchema0 = maybeDataSchema - .orElse(readSchema()) - .orElse(maybeMetastoreSchema) - .getOrElse(throw new AnalysisException( - s"Failed to discover schema of Parquet file(s) in the following location(s):\n" + - paths.mkString("\n\t"))) + val dataSchema0 = + maybeDataSchema + .orElse(readSchema()) + .orElse(maybeMetastoreSchema) + .getOrElse(sys.error("Failed to get the schema.")) // If this Parquet relation is converted from a Hive Metastore table, must reconcile case // case insensitivity issue and possible schema mismatch (probably caused by schema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala index 4d5ed211ad0c0..70bcca7526aae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.parquet.timestamp import java.nio.{ByteBuffer, ByteOrder} -import org.apache.parquet.Preconditions -import org.apache.parquet.io.api.{Binary, RecordConsumer} +import parquet.Preconditions +import parquet.io.api.{Binary, RecordConsumer} private[parquet] class NanoTime extends Serializable { private var julianDay = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index e9932c09107db..71f016b1f14de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter, FileOutputFormat} -import org.apache.parquet.hadoop.util.ContextUtil +import parquet.hadoop.util.ContextUtil import org.apache.spark._ import org.apache.spark.mapred.SparkHadoopMapRedUtil diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties index 12fb128149d32..28e90b9520b2c 100644 --- a/sql/core/src/test/resources/log4j.properties +++ b/sql/core/src/test/resources/log4j.properties @@ -36,11 +36,11 @@ log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %t %p %c{1}: %m%n log4j.appender.FA.Threshold = INFO # Some packages are noisy for no good reason. -log4j.additivity.org.apache.parquet.hadoop.ParquetRecordReader=false -log4j.logger.org.apache.parquet.hadoop.ParquetRecordReader=OFF +log4j.additivity.parquet.hadoop.ParquetRecordReader=false +log4j.logger.parquet.hadoop.ParquetRecordReader=OFF -log4j.additivity.org.apache.parquet.hadoop.ParquetOutputCommitter=false -log4j.logger.org.apache.parquet.hadoop.ParquetOutputCommitter=OFF +log4j.additivity.parquet.hadoop.ParquetOutputCommitter=false +log4j.logger.parquet.hadoop.ParquetOutputCommitter=OFF log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false log4j.logger.org.apache.hadoop.hive.serde2.lazy.LazyStruct=OFF @@ -52,5 +52,5 @@ log4j.additivity.hive.ql.metadata.Hive=false log4j.logger.hive.ql.metadata.Hive=OFF # Parquet related logging -log4j.logger.org.apache.parquet.hadoop=WARN +log4j.logger.parquet.hadoop=WARN log4j.logger.org.apache.spark.sql.parquet=INFO diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 72e60d9aa75cb..0772e5e187425 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -25,6 +25,8 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.storage.{RDDBlockId, StorageLevel} case class BigData(s: String) @@ -32,12 +34,8 @@ case class BigData(s: String) class CachedTableSuite extends QueryTest { TestData // Load test tables. - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.sql - def rddIdOf(tableName: String): Int = { - val executedPlan = ctx.table(tableName).queryExecution.executedPlan + val executedPlan = table(tableName).queryExecution.executedPlan executedPlan.collect { case InMemoryColumnarTableScan(_, _, relation) => relation.cachedColumnBuffers.id @@ -47,47 +45,47 @@ class CachedTableSuite extends QueryTest { } def isMaterialized(rddId: Int): Boolean = { - ctx.sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty + sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty } test("cache temp table") { testData.select('key).registerTempTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable"), 0) - ctx.cacheTable("tempTable") + cacheTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - ctx.uncacheTable("tempTable") + uncacheTable("tempTable") } test("unpersist an uncached table will not raise exception") { - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == cacheManager.lookupCachedData(testData)) testData.persist() - assert(None != ctx.cacheManager.lookupCachedData(testData)) + assert(None != cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == cacheManager.lookupCachedData(testData)) } test("cache table as select") { sql("CACHE TABLE tempTable AS SELECT key FROM testData") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - ctx.uncacheTable("tempTable") + uncacheTable("tempTable") } test("uncaching temp table") { testData.select('key).registerTempTable("tempTable1") testData.select('key).registerTempTable("tempTable2") - ctx.cacheTable("tempTable1") + cacheTable("tempTable1") assertCached(sql("SELECT COUNT(*) FROM tempTable1")) assertCached(sql("SELECT COUNT(*) FROM tempTable2")) // Is this valid? - ctx.uncacheTable("tempTable2") + uncacheTable("tempTable2") // Should this be cached? assertCached(sql("SELECT COUNT(*) FROM tempTable1"), 0) @@ -95,103 +93,103 @@ class CachedTableSuite extends QueryTest { test("too big for memory") { val data = "*" * 10000 - ctx.sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() + sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() .registerTempTable("bigData") - ctx.table("bigData").persist(StorageLevel.MEMORY_AND_DISK) - assert(ctx.table("bigData").count() === 200000L) - ctx.table("bigData").unpersist(blocking = true) + table("bigData").persist(StorageLevel.MEMORY_AND_DISK) + assert(table("bigData").count() === 200000L) + table("bigData").unpersist(blocking = true) } test("calling .cache() should use in-memory columnar caching") { - ctx.table("testData").cache() - assertCached(ctx.table("testData")) - ctx.table("testData").unpersist(blocking = true) + table("testData").cache() + assertCached(table("testData")) + table("testData").unpersist(blocking = true) } test("calling .unpersist() should drop in-memory columnar cache") { - ctx.table("testData").cache() - ctx.table("testData").count() - ctx.table("testData").unpersist(blocking = true) - assertCached(ctx.table("testData"), 0) + table("testData").cache() + table("testData").count() + table("testData").unpersist(blocking = true) + assertCached(table("testData"), 0) } test("isCached") { - ctx.cacheTable("testData") + cacheTable("testData") - assertCached(ctx.table("testData")) - assert(ctx.table("testData").queryExecution.withCachedData match { + assertCached(table("testData")) + assert(table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => true case _ => false }) - ctx.uncacheTable("testData") - assert(!ctx.isCached("testData")) - assert(ctx.table("testData").queryExecution.withCachedData match { + uncacheTable("testData") + assert(!isCached("testData")) + assert(table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => false case _ => true }) } test("SPARK-1669: cacheTable should be idempotent") { - assume(!ctx.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) + assume(!table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) - ctx.cacheTable("testData") - assertCached(ctx.table("testData")) + cacheTable("testData") + assertCached(table("testData")) assertResult(1, "InMemoryRelation not found, testData should have been cached") { - ctx.table("testData").queryExecution.withCachedData.collect { + table("testData").queryExecution.withCachedData.collect { case r: InMemoryRelation => r }.size } - ctx.cacheTable("testData") + cacheTable("testData") assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { - ctx.table("testData").queryExecution.withCachedData.collect { + table("testData").queryExecution.withCachedData.collect { case r @ InMemoryRelation(_, _, _, _, _: InMemoryColumnarTableScan, _) => r }.size } - ctx.uncacheTable("testData") + uncacheTable("testData") } test("read from cached table and uncache") { - ctx.cacheTable("testData") - checkAnswer(ctx.table("testData"), testData.collect().toSeq) - assertCached(ctx.table("testData")) + cacheTable("testData") + checkAnswer(table("testData"), testData.collect().toSeq) + assertCached(table("testData")) - ctx.uncacheTable("testData") - checkAnswer(ctx.table("testData"), testData.collect().toSeq) - assertCached(ctx.table("testData"), 0) + uncacheTable("testData") + checkAnswer(table("testData"), testData.collect().toSeq) + assertCached(table("testData"), 0) } test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { - ctx.uncacheTable("testData") + uncacheTable("testData") } } test("SELECT star from cached table") { sql("SELECT * FROM testData").registerTempTable("selectStar") - ctx.cacheTable("selectStar") + cacheTable("selectStar") checkAnswer( sql("SELECT * FROM selectStar WHERE key = 1"), Seq(Row(1, "1"))) - ctx.uncacheTable("selectStar") + uncacheTable("selectStar") } test("Self-join cached") { val unCachedAnswer = sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() - ctx.cacheTable("testData") + cacheTable("testData") checkAnswer( sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), unCachedAnswer.toSeq) - ctx.uncacheTable("testData") + uncacheTable("testData") } test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") { sql("CACHE TABLE testData") - assertCached(ctx.table("testData")) + assertCached(table("testData")) val rddId = rddIdOf("testData") assert( @@ -199,7 +197,7 @@ class CachedTableSuite extends QueryTest { "Eagerly cached in-memory table should have already been materialized") sql("UNCACHE TABLE testData") - assert(!ctx.isCached("testData"), "Table 'testData' should not be cached") + assert(!isCached("testData"), "Table 'testData' should not be cached") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") @@ -208,14 +206,14 @@ class CachedTableSuite extends QueryTest { test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") - assertCached(ctx.table("testCacheTable")) + assertCached(table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - ctx.uncacheTable("testCacheTable") + uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -223,14 +221,14 @@ class CachedTableSuite extends QueryTest { test("CACHE TABLE tableName AS SELECT ...") { sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10") - assertCached(ctx.table("testCacheTable")) + assertCached(table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - ctx.uncacheTable("testCacheTable") + uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -238,7 +236,7 @@ class CachedTableSuite extends QueryTest { test("CACHE LAZY TABLE tableName") { sql("CACHE LAZY TABLE testData") - assertCached(ctx.table("testData")) + assertCached(table("testData")) val rddId = rddIdOf("testData") assert( @@ -250,7 +248,7 @@ class CachedTableSuite extends QueryTest { isMaterialized(rddId), "Lazily cached in-memory table should have been materialized") - ctx.uncacheTable("testData") + uncacheTable("testData") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -258,7 +256,7 @@ class CachedTableSuite extends QueryTest { test("InMemoryRelation statistics") { sql("CACHE TABLE testData") - ctx.table("testData").queryExecution.withCachedData.collect { + table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum assert(cached.statistics.sizeInBytes === actualSizeInBytes) @@ -267,38 +265,38 @@ class CachedTableSuite extends QueryTest { test("Drops temporary table") { testData.select('key).registerTempTable("t1") - ctx.table("t1") - ctx.dropTempTable("t1") - assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found")) + table("t1") + dropTempTable("t1") + assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found")) } test("Drops cached temporary table") { testData.select('key).registerTempTable("t1") testData.select('key).registerTempTable("t2") - ctx.cacheTable("t1") + cacheTable("t1") - assert(ctx.isCached("t1")) - assert(ctx.isCached("t2")) + assert(isCached("t1")) + assert(isCached("t2")) - ctx.dropTempTable("t1") - assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found")) - assert(!ctx.isCached("t2")) + dropTempTable("t1") + assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found")) + assert(!isCached("t2")) } test("Clear all cache") { sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - ctx.cacheTable("t1") - ctx.cacheTable("t2") - ctx.clearCache() - assert(ctx.cacheManager.isEmpty) + cacheTable("t1") + cacheTable("t2") + clearCache() + assert(cacheManager.isEmpty) sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - ctx.cacheTable("t1") - ctx.cacheTable("t2") + cacheTable("t1") + cacheTable("t2") sql("Clear CACHE") - assert(ctx.cacheManager.isEmpty) + assert(cacheManager.isEmpty) } test("Clear accumulators when uncacheTable to prevent memory leaking") { @@ -307,8 +305,8 @@ class CachedTableSuite extends QueryTest { Accumulators.synchronized { val accsSize = Accumulators.originals.size - ctx.cacheTable("t1") - ctx.cacheTable("t2") + cacheTable("t1") + cacheTable("t2") assert((accsSize + 2) == Accumulators.originals.size) } @@ -319,8 +317,8 @@ class CachedTableSuite extends QueryTest { Accumulators.synchronized { val accsSize = Accumulators.originals.size - ctx.uncacheTable("t1") - ctx.uncacheTable("t2") + uncacheTable("t1") + uncacheTable("t2") assert((accsSize - 2) == Accumulators.originals.size) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 4f5484f1368d1..bfba379d9a518 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -21,14 +21,13 @@ import org.scalatest.Matchers._ import org.apache.spark.sql.execution.Project import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ class ColumnExpressionSuite extends QueryTest { import org.apache.spark.sql.TestData._ - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - test("alias") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") assert(df.select(df("a").as("b")).columns.head === "b") @@ -214,7 +213,7 @@ class ColumnExpressionSuite extends QueryTest { } test("!==") { - val nullData = ctx.createDataFrame(ctx.sparkContext.parallelize( + val nullData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize( Row(1, 1) :: Row(1, 2) :: Row(1, null) :: @@ -275,7 +274,7 @@ class ColumnExpressionSuite extends QueryTest { } test("between") { - val testData = ctx.sparkContext.parallelize( + val testData = TestSQLContext.sparkContext.parallelize( (0, 1, 2) :: (1, 2, 3) :: (2, 1, 0) :: @@ -288,7 +287,7 @@ class ColumnExpressionSuite extends QueryTest { checkAnswer(testData.filter($"a".between($"b", $"c")), expectAnswer) } - val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize( + val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize( Row(false, false) :: Row(false, true) :: Row(true, false) :: @@ -414,7 +413,7 @@ class ColumnExpressionSuite extends QueryTest { test("monotonicallyIncreasingId") { // Make sure we have 2 partitions, each with 2 records. - val df = ctx.sparkContext.parallelize(1 to 2, 2).mapPartitions { iter => + val df = TestSQLContext.sparkContext.parallelize(1 to 2, 2).mapPartitions { iter => Iterator(Tuple1(1), Tuple1(2)) }.toDF("a") checkAnswer( @@ -424,7 +423,7 @@ class ColumnExpressionSuite extends QueryTest { } test("sparkPartitionId") { - val df = ctx.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b") + val df = TestSQLContext.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b") checkAnswer( df.select(sparkPartitionId()), Row(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 790b405c72697..232f05c00918f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types.DecimalType class DataFrameAggregateSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - test("groupBy") { checkAnswer( testData2.groupBy("a").agg(sum($"b")), @@ -68,12 +67,12 @@ class DataFrameAggregateSuite extends QueryTest { Seq(Row(1, 3), Row(2, 3), Row(3, 3)) ) - ctx.conf.setConf("spark.sql.retainGroupColumns", "false") + TestSQLContext.conf.setConf("spark.sql.retainGroupColumns", "false") checkAnswer( testData2.groupBy("a").agg(sum($"b")), Seq(Row(3), Row(3), Row(3)) ) - ctx.conf.setConf("spark.sql.retainGroupColumns", "true") + TestSQLContext.conf.setConf("spark.sql.retainGroupColumns", "true") } test("agg without groups") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 53c2befb73702..b1e0faa310b68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ /** @@ -26,9 +27,6 @@ import org.apache.spark.sql.types._ */ class DataFrameFunctionsSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - test("array with column name") { val df = Seq((0, 1)).toDF("a", "b") val row = df.select(array("a", "b")).first() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala index fbb30706a4943..2d2367d6e7292 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -17,14 +17,15 @@ package org.apache.spark.sql -class DataFrameImplicitsSuite extends QueryTest { +import org.apache.spark.sql.test.TestSQLContext.{sparkContext => sc} +import org.apache.spark.sql.test.TestSQLContext.implicits._ + - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameImplicitsSuite extends QueryTest { test("RDD of tuples") { checkAnswer( - ctx.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), + sc.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), (1 to 10).map(i => Row(i, i.toString))) } @@ -36,19 +37,19 @@ class DataFrameImplicitsSuite extends QueryTest { test("RDD[Int]") { checkAnswer( - ctx.sparkContext.parallelize(1 to 10).toDF("intCol"), + sc.parallelize(1 to 10).toDF("intCol"), (1 to 10).map(i => Row(i))) } test("RDD[Long]") { checkAnswer( - ctx.sparkContext.parallelize(1L to 10L).toDF("longCol"), + sc.parallelize(1L to 10L).toDF("longCol"), (1L to 10L).map(i => Row(i))) } test("RDD[String]") { checkAnswer( - ctx.sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"), + sc.parallelize(1 to 10).map(_.toString).toDF("stringCol"), (1 to 10).map(i => Row(i.toString))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 051d13e9a544f..787f3f175fea2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ -class DataFrameJoinSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameJoinSuite extends QueryTest { test("join - join using") { val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") @@ -49,8 +49,7 @@ class DataFrameJoinSuite extends QueryTest { checkAnswer( df1.join(df2, $"df1.key" === $"df2.key"), - ctx.sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") - .collect().toSeq) + sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq) } test("join - using aliases after self join") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 495701d4f616c..41b4f02e6a294 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -19,11 +19,10 @@ package org.apache.spark.sql import scala.collection.JavaConversions._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ -class DataFrameNaFunctionsSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameNaFunctionsSuite extends QueryTest { def createDF(): DataFrame = { Seq[(String, java.lang.Integer, java.lang.Double)]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 0d3ff899dad72..438f479459dfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.implicits._ class DataFrameStatSuite extends SparkFunSuite { - private val sqlCtx = org.apache.spark.sql.test.TestSQLContext - import sqlCtx.implicits._ - - private def toLetter(i: Int): String = (i + 97).toChar.toString + val sqlCtx = TestSQLContext + def toLetter(i: Int): String = (i + 97).toChar.toString test("pearson correlation") { val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index bb8621abe64ad..a4fd1058afce5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -21,19 +21,17 @@ import scala.language.postfixOps import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint} +import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext} +import org.apache.spark.sql.test.TestSQLContext.implicits._ class DataFrameSuite extends QueryTest { import org.apache.spark.sql.TestData._ - lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - test("analysis error should be eagerly reported") { - val oldSetting = ctx.conf.dataFrameEagerAnalysis + val oldSetting = TestSQLContext.conf.dataFrameEagerAnalysis // Eager analysis. - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") + TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") intercept[Exception] { testData.select('nonExistentName) } intercept[Exception] { @@ -47,11 +45,11 @@ class DataFrameSuite extends QueryTest { } // No more eager analysis once the flag is turned off - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false") + TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false") testData.select('nonExistentName) // Set the flag back to original value before this test. - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) + TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) } test("dataframe toString") { @@ -69,12 +67,12 @@ class DataFrameSuite extends QueryTest { } test("invalid plan toString, debug mode") { - val oldSetting = ctx.conf.dataFrameEagerAnalysis - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") + val oldSetting = TestSQLContext.conf.dataFrameEagerAnalysis + TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") // Turn on debug mode so we can see invalid query plans. import org.apache.spark.sql.execution.debug._ - ctx.debug() + TestSQLContext.debug() val badPlan = testData.select('badColumn) @@ -83,7 +81,7 @@ class DataFrameSuite extends QueryTest { badPlan.toString) // Set the flag back to original value before this test. - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) + TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) } test("access complex data") { @@ -99,8 +97,8 @@ class DataFrameSuite extends QueryTest { } test("empty data frame") { - assert(ctx.emptyDataFrame.columns.toSeq === Seq.empty[String]) - assert(ctx.emptyDataFrame.count() === 0) + assert(TestSQLContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) + assert(TestSQLContext.emptyDataFrame.count() === 0) } test("head and take") { @@ -313,7 +311,7 @@ class DataFrameSuite extends QueryTest { } test("replace column using withColumn") { - val df2 = ctx.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") + val df2 = TestSQLContext.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") val df3 = df2.withColumn("x", df2("x") + 1) checkAnswer( df3.select("x"), @@ -336,51 +334,6 @@ class DataFrameSuite extends QueryTest { assert(df.schema.map(_.name) === Seq("key", "value")) } - test("drop column using drop with column reference") { - val col = testData("key") - val df = testData.drop(col) - checkAnswer( - df, - testData.collect().map(x => Row(x.getString(1))).toSeq) - assert(df.schema.map(_.name) === Seq("value")) - } - - test("drop unknown column (no-op) with column reference") { - val col = Column("random") - val df = testData.drop(col) - checkAnswer( - df, - testData.collect().toSeq) - assert(df.schema.map(_.name) === Seq("key", "value")) - } - - test("drop unknown column with same name (no-op) with column reference") { - val col = Column("key") - val df = testData.drop(col) - checkAnswer( - df, - testData.collect().toSeq) - assert(df.schema.map(_.name) === Seq("key", "value")) - } - - test("drop column after join with duplicate columns using column reference") { - val newSalary = salary.withColumnRenamed("personId", "id") - val col = newSalary("id") - // this join will result in duplicate "id" columns - val joinedDf = person.join(newSalary, - person("id") === newSalary("id"), "inner") - // remove only the "id" column that was associated with newSalary - val df = joinedDf.drop(col) - checkAnswer( - df, - joinedDf.collect().map { - case Row(id: Int, name: String, age: Int, idToDrop: Int, salary: Double) => - Row(id, name, age, salary) - }.toSeq) - assert(df.schema.map(_.name) === Seq("id", "name", "age", "salary")) - assert(df("id") == person("id")) - } - test("withColumnRenamed") { val df = testData.toDF().withColumn("newCol", col("key") + 1) .withColumnRenamed("value", "valueRenamed") @@ -394,7 +347,7 @@ class DataFrameSuite extends QueryTest { test("randomSplit") { val n = 600 - val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = TestSQLContext.sparkContext.parallelize(1 to n, 2).toDF("id") for (seed <- 1 to 5) { val splits = data.randomSplit(Array[Double](1, 2, 3), seed) assert(splits.length == 3, "wrong number of splits") @@ -489,22 +442,19 @@ class DataFrameSuite extends QueryTest { } test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { - val rowRDD = ctx.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) + val rowRDD = TestSQLContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) - val df = ctx.createDataFrame(rowRDD, schema) + val df = TestSQLContext.createDataFrame(rowRDD, schema) df.rdd.collect() } test("SPARK-6899") { - val originalValue = ctx.conf.codegenEnabled - ctx.setConf(SQLConf.CODEGEN_ENABLED, "true") - try{ - checkAnswer( - decimalData.agg(avg('a)), - Row(new java.math.BigDecimal(2.0))) - } finally { - ctx.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) - } + val originalValue = TestSQLContext.conf.codegenEnabled + TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, "true") + checkAnswer( + decimalData.agg(avg('a)), + Row(new java.math.BigDecimal(2.0))) + TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) } test("SPARK-7133: Implement struct, array, and map field accessor") { @@ -515,14 +465,14 @@ class DataFrameSuite extends QueryTest { } test("SPARK-7551: support backticks for DataFrame attribute resolution") { - val df = ctx.read.json(ctx.sparkContext.makeRDD( + val df = TestSQLContext.read.json(TestSQLContext.sparkContext.makeRDD( """{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil)) checkAnswer( df.select(df("`a.b`.c.`d..e`.`f`")), Row(1) ) - val df2 = ctx.read.json(ctx.sparkContext.makeRDD( + val df2 = TestSQLContext.read.json(TestSQLContext.sparkContext.makeRDD( """{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil)) checkAnswer( df2.select(df2("`a b`.c.d e.f")), @@ -542,7 +492,7 @@ class DataFrameSuite extends QueryTest { } test("SPARK-7324 dropDuplicates") { - val testData = ctx.sparkContext.parallelize( + val testData = TestSQLContext.sparkContext.parallelize( (2, 1, 2) :: (1, 1, 1) :: (1, 2, 1) :: (2, 1, 2) :: (2, 2, 2) :: (2, 2, 1) :: @@ -590,49 +540,41 @@ class DataFrameSuite extends QueryTest { test("SPARK-7150 range api") { // numSlice is greater than length - val res1 = ctx.range(0, 10, 1, 15).select("id") + val res1 = TestSQLContext.range(0, 10, 1, 15).select("id") assert(res1.count == 10) assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - val res2 = ctx.range(3, 15, 3, 2).select("id") + val res2 = TestSQLContext.range(3, 15, 3, 2).select("id") assert(res2.count == 4) assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) - val res3 = ctx.range(1, -2).select("id") + val res3 = TestSQLContext.range(1, -2).select("id") assert(res3.count == 0) // start is positive, end is negative, step is negative - val res4 = ctx.range(1, -2, -2, 6).select("id") + val res4 = TestSQLContext.range(1, -2, -2, 6).select("id") assert(res4.count == 2) assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0))) // start, end, step are negative - val res5 = ctx.range(-3, -8, -2, 1).select("id") + val res5 = TestSQLContext.range(-3, -8, -2, 1).select("id") assert(res5.count == 3) assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15))) // start, end are negative, step is positive - val res6 = ctx.range(-8, -4, 2, 1).select("id") + val res6 = TestSQLContext.range(-8, -4, 2, 1).select("id") assert(res6.count == 2) assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14))) - val res7 = ctx.range(-10, -9, -20, 1).select("id") + val res7 = TestSQLContext.range(-10, -9, -20, 1).select("id") assert(res7.count == 0) - val res8 = ctx.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") + val res8 = TestSQLContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") assert(res8.count == 3) assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) - val res9 = ctx.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") + val res9 = TestSQLContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") assert(res9.count == 2) assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) - - // only end provided as argument - val res10 = ctx.range(10).select("id") - assert(res10.count == 10) - assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - - val res11 = ctx.range(-1).select("id") - assert(res11.count == 0) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index ffd26c4f5a7c2..407c789657834 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -20,28 +20,27 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Ensures tables are loaded. TestData - lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.logicalPlanToSparkQuery - test("equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan - val planned = ctx.planner.HashJoin(join) + val planned = planner.HashJoin(join) assert(planned.size === 1) } def assertJoin(sqlString: String, c: Class[_]): Any = { - val df = ctx.sql(sqlString) + val df = sql(sqlString) val physical = df.queryExecution.sparkPlan val operators = physical.collect { case j: ShuffledHashJoin => j @@ -62,9 +61,9 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("join operator selection") { - ctx.cacheManager.clearCache() + cacheManager.clearCache() - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled + val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), @@ -95,22 +94,22 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } try { - ctx.conf.setConf("spark.sql.planner.sortMergeJoin", "true") + conf.setConf("spark.sql.planner.sortMergeJoin", "true") Seq( ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } finally { - ctx.conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) } } test("broadcasted hash join operator selection") { - ctx.cacheManager.clearCache() - ctx.sql("CACHE TABLE testData") + cacheManager.clearCache() + sql("CACHE TABLE testData") - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled + val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), @@ -118,7 +117,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } try { - ctx.conf.setConf("spark.sql.planner.sortMergeJoin", "true") + conf.setConf("spark.sql.planner.sortMergeJoin", "true") Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", @@ -127,17 +126,17 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } finally { - ctx.conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) } - ctx.sql("UNCACHE TABLE testData") + sql("UNCACHE TABLE testData") } test("multiple-key equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan - val planned = ctx.planner.HashJoin(join) + val planned = planner.HashJoin(join) assert(planned.size === 1) } @@ -242,7 +241,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Make sure we are choosing left.outputPartitioning as the // outputPartitioning for the outer join operator. checkAnswer( - ctx.sql( + sql( """ |SELECT l.N, count(*) |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) @@ -256,7 +255,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(6, 1) :: Nil) checkAnswer( - ctx.sql( + sql( """ |SELECT r.a, count(*) |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) @@ -302,7 +301,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Make sure we are choosing right.outputPartitioning as the // outputPartitioning for the outer join operator. checkAnswer( - ctx.sql( + sql( """ |SELECT l.a, count(*) |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -311,7 +310,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 6)) checkAnswer( - ctx.sql( + sql( """ |SELECT r.N, count(*) |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -363,7 +362,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator. checkAnswer( - ctx.sql( + sql( """ |SELECT l.a, count(*) |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -372,7 +371,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 10)) checkAnswer( - ctx.sql( + sql( """ |SELECT r.N, count(*) |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -387,7 +386,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 4) :: Nil) checkAnswer( - ctx.sql( + sql( """ |SELECT l.N, count(*) |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) @@ -402,7 +401,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 4) :: Nil) checkAnswer( - ctx.sql( + sql( """ |SELECT r.a, count(*) |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) @@ -412,11 +411,11 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("broadcasted left semi join operator selection") { - ctx.cacheManager.clearCache() - ctx.sql("CACHE TABLE testData") - val tmp = ctx.conf.autoBroadcastJoinThreshold + cacheManager.clearCache() + sql("CACHE TABLE testData") + val tmp = conf.autoBroadcastJoinThreshold - ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000") + sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000") Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[BroadcastLeftSemiJoinHash]) @@ -424,7 +423,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case (query, joinClass) => assertJoin(query, joinClass) } - ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") + sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) @@ -432,12 +431,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case (query, joinClass) => assertJoin(query, joinClass) } - ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp.toString) - ctx.sql("UNCACHE TABLE testData") + setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp.toString) + sql("UNCACHE TABLE testData") } test("left semi join") { - val df = ctx.sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") + val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") checkAnswer(df, Row(1, 1) :: Row(1, 2) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index 2089660c52bf7..3ce97c3fffdb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -19,47 +19,49 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfter +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} class ListTablesSuite extends QueryTest with BeforeAndAfter { - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ + import org.apache.spark.sql.test.TestSQLContext.implicits._ - private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value") + val df = + sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value") before { df.registerTempTable("ListTablesSuiteTable") } after { - ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + catalog.unregisterTable(Seq("ListTablesSuiteTable")) } test("get all tables") { checkAnswer( - ctx.tables().filter("tableName = 'ListTablesSuiteTable'"), + tables().filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) checkAnswer( - ctx.sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), + sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) - assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + catalog.unregisterTable(Seq("ListTablesSuiteTable")) + assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } test("getting all Tables with a database name has no impact on returned table names") { checkAnswer( - ctx.tables("DB").filter("tableName = 'ListTablesSuiteTable'"), + tables("DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) checkAnswer( - ctx.sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), + sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) - assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + catalog.unregisterTable(Seq("ListTablesSuiteTable")) + assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } test("query the returned DataFrame of tables") { @@ -67,20 +69,19 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter { StructField("tableName", StringType, false) :: StructField("isTemporary", BooleanType, false) :: Nil) - Seq(ctx.tables(), ctx.sql("SHOW TABLes")).foreach { + Seq(tables(), sql("SHOW TABLes")).foreach { case tableDF => assert(expectedSchema === tableDF.schema) tableDF.registerTempTable("tables") checkAnswer( - ctx.sql( - "SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"), + sql("SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"), Row(true, "ListTablesSuiteTable") ) checkAnswer( - ctx.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), + tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), Row("tables", true)) - ctx.dropTempTable("tables") + dropTempTable("tables") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 0a38af2b4c889..dd68965444f5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -17,29 +17,36 @@ package org.apache.spark.sql -import org.apache.spark.sql.functions._ - +import java.lang.{Double => JavaDouble} -private object MathExpressionsTestData { - case class DoubleData(a: java.lang.Double, b: java.lang.Double) - case class NullDoubles(a: java.lang.Double) +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.implicits._ + +private[this] object MathExpressionsTestData { + + case class DoubleData(a: JavaDouble, b: JavaDouble) + val doubleData = TestSQLContext.sparkContext.parallelize( + (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1))).toDF() + + val nnDoubleData = TestSQLContext.sparkContext.parallelize( + (1 to 10).map(i => DoubleData(i * 0.1, i * -0.1))).toDF() + + case class NullDoubles(a: JavaDouble) + val nullDoubles = + TestSQLContext.sparkContext.parallelize( + NullDoubles(1.0) :: + NullDoubles(2.0) :: + NullDoubles(3.0) :: + NullDoubles(null) :: Nil + ).toDF() } class MathExpressionsSuite extends QueryTest { import MathExpressionsTestData._ - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - - private lazy val doubleData = (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1)).toDF() - - private lazy val nnDoubleData = (1 to 10).map(i => DoubleData(i * 0.1, i * -0.1)).toDF() - - private lazy val nullDoubles = - Seq(NullDoubles(1.0), NullDoubles(2.0), NullDoubles(3.0), NullDoubles(null)).toDF() - - private def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T]( + def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T]( c: Column => Column, f: T => T): Unit = { checkAnswer( @@ -58,8 +65,7 @@ class MathExpressionsSuite extends QueryTest { ) } - private def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = - { + def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = { checkAnswer( nnDoubleData.select(c('a)), (1 to 10).map(n => Row(f(n * 0.1))) @@ -83,7 +89,7 @@ class MathExpressionsSuite extends QueryTest { ) } - private def testTwoToOneMathFunction( + def testTwoToOneMathFunction( c: (Column, Column) => Column, d: (Column, Double) => Column, f: (Double, Double) => Double): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index d84b57af9c882..513ac915dcb2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -21,13 +21,12 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ class RowSuite extends SparkFunSuite { - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - test("create row") { val expected = new GenericMutableRow(4) expected.update(0, 2147483647) @@ -57,7 +56,7 @@ class RowSuite extends SparkFunSuite { test("serialize w/ kryo") { val row = Seq((1, Seq(1), Map(1 -> 1), BigDecimal(1))).toDF().first() - val serializer = new SparkSqlSerializer(ctx.sparkContext.getConf) + val serializer = new SparkSqlSerializer(TestSQLContext.sparkContext.getConf) val instance = serializer.newInstance() val ser = instance.serialize(row) val de = instance.deserialize(ser).asInstanceOf[Row] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 76d0dd1744a41..3a5f071e2f7cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,64 +17,67 @@ package org.apache.spark.sql +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test._ -class SQLConfSuite extends QueryTest { +/* Implicits */ +import TestSQLContext._ - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class SQLConfSuite extends QueryTest { - private val testKey = "test.key.0" - private val testVal = "test.val.0" + val testKey = "test.key.0" + val testVal = "test.val.0" test("propagate from spark conf") { // We create a new context here to avoid order dependence with other tests that might call // clear(). - val newContext = new SQLContext(ctx.sparkContext) - assert(newContext.getConf("spark.sql.testkey", "false") === "true") + val newContext = new SQLContext(TestSQLContext.sparkContext) + assert(newContext.getConf("spark.sql.testkey", "false") == "true") } test("programmatic ways of basic setting and getting") { - ctx.conf.clear() - assert(ctx.getAllConfs.size === 0) + conf.clear() + assert(getAllConfs.size === 0) - ctx.setConf(testKey, testVal) - assert(ctx.getConf(testKey) === testVal) - assert(ctx.getConf(testKey, testVal + "_") === testVal) - assert(ctx.getAllConfs.contains(testKey)) + setConf(testKey, testVal) + assert(getConf(testKey) == testVal) + assert(getConf(testKey, testVal + "_") == testVal) + assert(getAllConfs.contains(testKey)) // Tests SQLConf as accessed from a SQLContext is mutable after // the latter is initialized, unlike SparkConf inside a SparkContext. - assert(ctx.getConf(testKey) == testVal) - assert(ctx.getConf(testKey, testVal + "_") === testVal) - assert(ctx.getAllConfs.contains(testKey)) + assert(TestSQLContext.getConf(testKey) == testVal) + assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal) + assert(TestSQLContext.getAllConfs.contains(testKey)) - ctx.conf.clear() + conf.clear() } test("parse SQL set commands") { - ctx.conf.clear() - ctx.sql(s"set $testKey=$testVal") - assert(ctx.getConf(testKey, testVal + "_") === testVal) - assert(ctx.getConf(testKey, testVal + "_") === testVal) + conf.clear() + sql(s"set $testKey=$testVal") + assert(getConf(testKey, testVal + "_") == testVal) + assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal) - ctx.sql("set some.property=20") - assert(ctx.getConf("some.property", "0") === "20") - ctx.sql("set some.property = 40") - assert(ctx.getConf("some.property", "0") === "40") + sql("set some.property=20") + assert(getConf("some.property", "0") == "20") + sql("set some.property = 40") + assert(getConf("some.property", "0") == "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" - ctx.sql(s"set $key=$vs") - assert(ctx.getConf(key, "0") === vs) + sql(s"set $key=$vs") + assert(getConf(key, "0") == vs) - ctx.sql(s"set $key=") - assert(ctx.getConf(key, "0") === "") + sql(s"set $key=") + assert(getConf(key, "0") == "") - ctx.conf.clear() + conf.clear() } test("deprecated property") { - ctx.conf.clear() - ctx.sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") - assert(ctx.getConf(SQLConf.SHUFFLE_PARTITIONS) === "10") + conf.clear() + sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") + assert(getConf(SQLConf.SHUFFLE_PARTITIONS) == "10") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index c8d8796568a41..797d123b48668 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -20,29 +20,31 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.TestSQLContext class SQLContextSuite extends SparkFunSuite with BeforeAndAfterAll { - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + private val testSqlContext = TestSQLContext + private val testSparkContext = TestSQLContext.sparkContext override def afterAll(): Unit = { - SQLContext.setLastInstantiatedContext(ctx) + SQLContext.setLastInstantiatedContext(testSqlContext) } test("getOrCreate instantiates SQLContext") { SQLContext.clearLastInstantiatedContext() - val sqlContext = SQLContext.getOrCreate(ctx.sparkContext) + val sqlContext = SQLContext.getOrCreate(testSparkContext) assert(sqlContext != null, "SQLContext.getOrCreate returned null") - assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext), + assert(SQLContext.getOrCreate(testSparkContext).eq(sqlContext), "SQLContext created by SQLContext.getOrCreate not returned by SQLContext.getOrCreate") } test("getOrCreate gets last explicitly instantiated SQLContext") { SQLContext.clearLastInstantiatedContext() - val sqlContext = new SQLContext(ctx.sparkContext) - assert(SQLContext.getOrCreate(ctx.sparkContext) != null, + val sqlContext = new SQLContext(testSparkContext) + assert(SQLContext.getOrCreate(testSparkContext) != null, "SQLContext.getOrCreate after explicitly created SQLContext returned null") - assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext), + assert(SQLContext.getOrCreate(testSparkContext).eq(sqlContext), "SQLContext.getOrCreate after explicitly created SQLContext did not return the context") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5babc4332cc77..63f7d314fb699 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -24,7 +24,9 @@ import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext} +import org.apache.spark.sql.test.TestSQLContext.{udf => _, _} + import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ @@ -34,9 +36,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { // Make sure the tables are loaded. TestData - val sqlContext = org.apache.spark.sql.test.TestSQLContext + val sqlContext = TestSQLContext import sqlContext.implicits._ - import sqlContext.sql test("SPARK-6743: no columns from cache") { Seq( @@ -45,7 +46,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { (43, 81, 24) ).toDF("a", "b", "c").registerTempTable("cachedData") - sqlContext.cacheTable("cachedData") + cacheTable("cachedData") checkAnswer( sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"), Row(0) :: Row(81) :: Nil) @@ -93,14 +94,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SQL Dialect Switching to a new SQL parser") { - val newContext = new SQLContext(sqlContext.sparkContext) + val newContext = new SQLContext(TestSQLContext.sparkContext) newContext.setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) assert(newContext.getSQLDialect().getClass === classOf[MyDialect]) assert(newContext.sql("SELECT 1").collect() === Array(Row(1))) } test("SQL Dialect Switch to an invalid parser with alias") { - val newContext = new SQLContext(sqlContext.sparkContext) + val newContext = new SQLContext(TestSQLContext.sparkContext) newContext.sql("SET spark.sql.dialect=MyTestClass") intercept[DialectException] { newContext.sql("SELECT 1") @@ -117,8 +118,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("grouping on nested fields") { - sqlContext.read.json(sqlContext.sparkContext.parallelize( - """{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) + read.json(sparkContext.parallelize("""{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) .registerTempTable("rows") checkAnswer( @@ -135,9 +135,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-6201 IN type conversion") { - sqlContext.read.json( - sqlContext.sparkContext.parallelize( - Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) + read.json( + sparkContext.parallelize(Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) .registerTempTable("d") checkAnswer( @@ -158,12 +157,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("aggregation with codegen") { - val originalValue = sqlContext.conf.codegenEnabled - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true") + val originalValue = conf.codegenEnabled + setConf(SQLConf.CODEGEN_ENABLED, "true") // Prepare a table that we can group some rows. - sqlContext.table("testData") - .unionAll(sqlContext.table("testData")) - .unionAll(sqlContext.table("testData")) + table("testData") + .unionAll(table("testData")) + .unionAll(table("testData")) .registerTempTable("testData3x") def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { @@ -185,79 +184,77 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { checkAnswer(df, expectedResults) } - try { - // Just to group rows. - testCodeGen( - "SELECT key FROM testData3x GROUP BY key", - (1 to 100).map(Row(_))) - // COUNT - testCodeGen( - "SELECT key, count(value) FROM testData3x GROUP BY key", - (1 to 100).map(i => Row(i, 3))) - testCodeGen( - "SELECT count(key) FROM testData3x", - Row(300) :: Nil) - // COUNT DISTINCT ON int - testCodeGen( - "SELECT value, count(distinct key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, 1))) - testCodeGen( - "SELECT count(distinct key) FROM testData3x", - Row(100) :: Nil) - // SUM - testCodeGen( - "SELECT value, sum(key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, 3 * i))) - testCodeGen( - "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x", - Row(5050 * 3, 5050 * 3.0) :: Nil) - // AVERAGE - testCodeGen( - "SELECT value, avg(key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, i))) - testCodeGen( - "SELECT avg(key) FROM testData3x", - Row(50.5) :: Nil) - // MAX - testCodeGen( - "SELECT value, max(key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, i))) - testCodeGen( - "SELECT max(key) FROM testData3x", - Row(100) :: Nil) - // MIN - testCodeGen( - "SELECT value, min(key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, i))) - testCodeGen( - "SELECT min(key) FROM testData3x", - Row(1) :: Nil) - // Some combinations. - testCodeGen( - """ - |SELECT - | value, - | sum(key), - | max(key), - | min(key), - | avg(key), - | count(key), - | count(distinct key) - |FROM testData3x - |GROUP BY value - """.stripMargin, - (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1))) - testCodeGen( - "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x", - Row(100, 1, 50.5, 300, 100) :: Nil) - // Aggregate with Code generation handling all null values - testCodeGen( - "SELECT sum('a'), avg('a'), count(null) FROM testData", - Row(0, null, 0) :: Nil) - } finally { - sqlContext.dropTempTable("testData3x") - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) - } + // Just to group rows. + testCodeGen( + "SELECT key FROM testData3x GROUP BY key", + (1 to 100).map(Row(_))) + // COUNT + testCodeGen( + "SELECT key, count(value) FROM testData3x GROUP BY key", + (1 to 100).map(i => Row(i, 3))) + testCodeGen( + "SELECT count(key) FROM testData3x", + Row(300) :: Nil) + // COUNT DISTINCT ON int + testCodeGen( + "SELECT value, count(distinct key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 1))) + testCodeGen( + "SELECT count(distinct key) FROM testData3x", + Row(100) :: Nil) + // SUM + testCodeGen( + "SELECT value, sum(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 3 * i))) + testCodeGen( + "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x", + Row(5050 * 3, 5050 * 3.0) :: Nil) + // AVERAGE + testCodeGen( + "SELECT value, avg(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT avg(key) FROM testData3x", + Row(50.5) :: Nil) + // MAX + testCodeGen( + "SELECT value, max(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT max(key) FROM testData3x", + Row(100) :: Nil) + // MIN + testCodeGen( + "SELECT value, min(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT min(key) FROM testData3x", + Row(1) :: Nil) + // Some combinations. + testCodeGen( + """ + |SELECT + | value, + | sum(key), + | max(key), + | min(key), + | avg(key), + | count(key), + | count(distinct key) + |FROM testData3x + |GROUP BY value + """.stripMargin, + (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1))) + testCodeGen( + "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x", + Row(100, 1, 50.5, 300, 100) :: Nil) + // Aggregate with Code generation handling all null values + testCodeGen( + "SELECT sum('a'), avg('a'), count(null) FROM testData", + Row(0, null, 0) :: Nil) + + dropTempTable("testData3x") + setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) } test("Add Parser of SQL COALESCE()") { @@ -448,43 +445,37 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("sorting") { - val before = sqlContext.conf.externalSortEnabled - sqlContext.setConf(SQLConf.EXTERNAL_SORT, "false") + val before = conf.externalSortEnabled + setConf(SQLConf.EXTERNAL_SORT, "false") sortTest() - sqlContext.setConf(SQLConf.EXTERNAL_SORT, before.toString) + setConf(SQLConf.EXTERNAL_SORT, before.toString) } test("external sorting") { - val before = sqlContext.conf.externalSortEnabled - sqlContext.setConf(SQLConf.EXTERNAL_SORT, "true") + val before = conf.externalSortEnabled + setConf(SQLConf.EXTERNAL_SORT, "true") sortTest() - sqlContext.setConf(SQLConf.EXTERNAL_SORT, before.toString) + setConf(SQLConf.EXTERNAL_SORT, before.toString) } test("SPARK-6927 sorting with codegen on") { - val externalbefore = sqlContext.conf.externalSortEnabled - val codegenbefore = sqlContext.conf.codegenEnabled - sqlContext.setConf(SQLConf.EXTERNAL_SORT, "false") - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true") - try{ - sortTest() - } finally { - sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) - } + val externalbefore = conf.externalSortEnabled + val codegenbefore = conf.codegenEnabled + setConf(SQLConf.EXTERNAL_SORT, "false") + setConf(SQLConf.CODEGEN_ENABLED, "true") + sortTest() + setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) + setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) } test("SPARK-6927 external sorting with codegen on") { - val externalbefore = sqlContext.conf.externalSortEnabled - val codegenbefore = sqlContext.conf.codegenEnabled - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true") - sqlContext.setConf(SQLConf.EXTERNAL_SORT, "true") - try { - sortTest() - } finally { - sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) - } + val externalbefore = conf.externalSortEnabled + val codegenbefore = conf.codegenEnabled + setConf(SQLConf.CODEGEN_ENABLED, "true") + setConf(SQLConf.EXTERNAL_SORT, "true") + sortTest() + setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) + setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) } test("limit") { @@ -517,8 +508,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("Allow only a single WITH clause per query") { intercept[RuntimeException] { - sql( - "with q1 as (select * from testData) with q2 as (select * from q1) select * from q2") + sql("with q1 as (select * from testData) with q2 as (select * from q1) select * from q2") } } @@ -865,7 +855,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SET commands semantics using sql()") { - sqlContext.conf.clear() + conf.clear() val testKey = "test.key.0" val testVal = "test.val.0" val nonexistentKey = "nonexistent" @@ -897,17 +887,17 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { sql(s"SET $nonexistentKey"), Row(s"$nonexistentKey=") ) - sqlContext.conf.clear() + conf.clear() } test("SET commands with illegal or inappropriate argument") { - sqlContext.conf.clear() + conf.clear() // Set negative mapred.reduce.tasks for automatically determing // the number of reducers is not supported intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-1")) intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-01")) intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-2")) - sqlContext.conf.clear() + conf.clear() } test("apply schema") { @@ -925,7 +915,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(values(0).toInt, values(1), values(2).toBoolean, v4) } - val df1 = sqlContext.createDataFrame(rowRDD1, schema1) + val df1 = createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") checkAnswer( sql("SELECT * FROM applySchema1"), @@ -955,7 +945,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df2 = sqlContext.createDataFrame(rowRDD2, schema2) + val df2 = createDataFrame(rowRDD2, schema2) df2.registerTempTable("applySchema2") checkAnswer( sql("SELECT * FROM applySchema2"), @@ -980,7 +970,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4)) } - val df3 = sqlContext.createDataFrame(rowRDD3, schema2) + val df3 = createDataFrame(rowRDD3, schema2) df3.registerTempTable("applySchema3") checkAnswer( @@ -1025,7 +1015,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { .build() val schemaWithMeta = new StructType(Array( schema("id"), schema("name").copy(metadata = metadata), schema("age"))) - val personWithMeta = sqlContext.createDataFrame(person.rdd, schemaWithMeta) + val personWithMeta = createDataFrame(person.rdd, schemaWithMeta) def validateMetadata(rdd: DataFrame): Unit = { assert(rdd.schema("name").metadata.getString(docKey) == docValue) } @@ -1040,7 +1030,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-3371 Renaming a function expression with group by gives error") { - sqlContext.udf.register("len", (s: String) => s.length) + TestSQLContext.udf.register("len", (s: String) => s.length) checkAnswer( sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), Row(1)) @@ -1221,9 +1211,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-3483 Special chars in column names") { - val data = sqlContext.sparkContext.parallelize( + val data = sparkContext.parallelize( Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) - sqlContext.read.json(data).registerTempTable("records") + read.json(data).registerTempTable("records") sql("SELECT `key?number1`, `key.number2` FROM records") } @@ -1264,15 +1254,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-4322 Grouping field with struct field as sub expression") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) - .registerTempTable("data") + read.json(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)).registerTempTable("data") checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) - sqlContext.dropTempTable("data") + dropTempTable("data") - sqlContext.read.json( - sqlContext.sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") + read.json(sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) - sqlContext.dropTempTable("data") + dropTempTable("data") } test("SPARK-4432 Fix attribute reference resolution error when using ORDER BY") { @@ -1291,10 +1279,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("Supporting relational operator '<=>' in Spark SQL") { val nullCheckData1 = TestData(1, "1") :: TestData(2, null) :: Nil - val rdd1 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) + val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) rdd1.toDF().registerTempTable("nulldata1") val nullCheckData2 = TestData(1, "1") :: TestData(2, null) :: Nil - val rdd2 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) + val rdd2 = sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) rdd2.toDF().registerTempTable("nulldata2") checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " + "nulldata2 on nulldata1.value <=> nulldata2.value"), @@ -1303,23 +1291,22 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("Multi-column COUNT(DISTINCT ...)") { val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil - val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) + val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("distinctData") checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2)) } test("SPARK-4699 case sensitivity SQL query") { - sqlContext.setConf(SQLConf.CASE_SENSITIVE, "false") + setConf(SQLConf.CASE_SENSITIVE, "false") val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil - val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) + val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("testTable1") checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1")) - sqlContext.setConf(SQLConf.CASE_SENSITIVE, "true") + setConf(SQLConf.CASE_SENSITIVE, "true") } test("SPARK-6145: ORDER BY test for nested fields") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD( - """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) + read.json(sparkContext.makeRDD("""{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) .registerTempTable("nestedOrder") checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1)) @@ -1331,14 +1318,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-6145: special cases") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD( + read.json(sparkContext.makeRDD( """{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t") checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1)) checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1)) } test("SPARK-6898: complete support for special chars in column names") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD( + read.json(sparkContext.makeRDD( """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) .registerTempTable("t") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index ece3d6fdf2af5..d2ede39f0a5f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.test.TestSQLContext._ case class ReflectData( stringField: String, @@ -74,15 +75,15 @@ case class ComplexReflectData( class ScalaReflectionRelationSuite extends SparkFunSuite { - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ + import org.apache.spark.sql.test.TestSQLContext.implicits._ test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1, 2, 3)) - Seq(data).toDF().registerTempTable("reflectData") + val rdd = sparkContext.parallelize(data :: Nil) + rdd.toDF().registerTempTable("reflectData") - assert(ctx.sql("SELECT * FROM reflectData").collect().head === + assert(sql("SELECT * FROM reflectData").collect().head === Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3))) @@ -90,26 +91,27 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { test("query case class RDD with nulls") { val data = NullReflectData(null, null, null, null, null, null, null) - Seq(data).toDF().registerTempTable("reflectNullData") + val rdd = sparkContext.parallelize(data :: Nil) + rdd.toDF().registerTempTable("reflectNullData") - assert(ctx.sql("SELECT * FROM reflectNullData").collect().head === - Row.fromSeq(Seq.fill(7)(null))) + assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } test("query case class RDD with Nones") { val data = OptionalReflectData(None, None, None, None, None, None, None) - Seq(data).toDF().registerTempTable("reflectOptionalData") + val rdd = sparkContext.parallelize(data :: Nil) + rdd.toDF().registerTempTable("reflectOptionalData") - assert(ctx.sql("SELECT * FROM reflectOptionalData").collect().head === + assert(sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } // Equality is broken for Arrays, so we test that separately. test("query binary data") { - Seq(ReflectBinary(Array[Byte](1))).toDF().registerTempTable("reflectBinary") + val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil) + rdd.toDF().registerTempTable("reflectBinary") - val result = ctx.sql("SELECT data FROM reflectBinary") - .collect().head(0).asInstanceOf[Array[Byte]] + val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]] assert(result.toSeq === Seq[Byte](1)) } @@ -125,9 +127,10 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { Map(10 -> 100L, 20 -> 200L), Map(10 -> Some(100L), 20 -> Some(200L), 30 -> None), Nested(None, "abc"))) + val rdd = sparkContext.parallelize(data :: Nil) + rdd.toDF().registerTempTable("reflectComplexData") - Seq(data).toDF().registerTempTable("reflectComplexData") - assert(ctx.sql("SELECT * FROM reflectComplexData").collect().head === + assert(sql("SELECT * FROM reflectComplexData").collect().head === new GenericRow(Array[Any]( Seq(1, 2, 3), Seq(1, 2, null), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index e55c9e460b791..1e8cde606b67b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.sql.test.TestSQLContext class SerializationSuite extends SparkFunSuite { - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - test("[SPARK-5235] SQLContext should be serializable") { - val sqlContext = new SQLContext(ctx.sparkContext) + val sqlContext = new SQLContext(TestSQLContext.sparkContext) new JavaSerializer(new SparkConf()).newInstance().serialize(sqlContext) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 064c040d2b771..1a9ba66416b21 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,41 +17,43 @@ package org.apache.spark.sql +import org.apache.spark.sql.test._ + +/* Implicits */ +import TestSQLContext._ +import TestSQLContext.implicits._ case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - test("Simple UDF") { - ctx.udf.register("strLenScala", (_: String).length) - assert(ctx.sql("SELECT strLenScala('test')").head().getInt(0) === 4) + udf.register("strLenScala", (_: String).length) + assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) } test("ZeroArgument UDF") { - ctx.udf.register("random0", () => { Math.random()}) - assert(ctx.sql("SELECT random0()").head().getDouble(0) >= 0.0) + udf.register("random0", () => { Math.random()}) + assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { - ctx.udf.register("strLenScala", (_: String).length + (_: Int)) - assert(ctx.sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) + udf.register("strLenScala", (_: String).length + (_: Int)) + assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } test("struct UDF") { - ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) + udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) val result = - ctx.sql("SELECT returnStruct('test', 'test2') as ret") + sql("SELECT returnStruct('test', 'test2') as ret") .select($"ret.f1").head().getString(0) assert(result === "test") } test("udf that is transformed") { - ctx.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) + udf.register("makeStruct", (x: Int, y: Int) => (x, y)) // 1 + 1 is constant folded causing a transformation. - assert(ctx.sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) + assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 45c9f06941c10..dc2d43a197f40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +import java.io.File + +import org.apache.spark.util.Utils + import scala.beans.{BeanInfo, BeanProperty} import com.clearspring.analytics.stream.cardinality.HyperLogLog @@ -24,11 +28,12 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql} +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashSet - @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { override def equals(other: Any): Boolean = other match { @@ -67,13 +72,11 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { } class UserDefinedTypeSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - - private lazy val pointsRDD = Seq( + val points = Seq( MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), - MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))).toDF() + MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) + val pointsRDD = sparkContext.parallelize(points).toDF() + test("register user type: MyDenseVector for MyLabeledPoint") { val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v } @@ -91,10 +94,10 @@ class UserDefinedTypeSuite extends QueryTest { } test("UDTs and UDFs") { - ctx.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) + TestSQLContext.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) pointsRDD.registerTempTable("points") checkAnswer( - ctx.sql("SELECT testType(features) from points"), + sql("SELECT testType(features) from points"), Seq(Row(true), Row(true))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index fa3b8144c086e..055453e688e73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -21,6 +21,8 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{QueryTest, TestData} import org.apache.spark.storage.StorageLevel.MEMORY_ONLY @@ -29,12 +31,8 @@ class InMemoryColumnarQuerySuite extends QueryTest { // Make sure the tables are loaded. TestData - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.{logicalPlanToSparkQuery, sql} - test("simple columnar query") { - val plan = ctx.executePlan(testData.logicalPlan).executedPlan + val plan = executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -42,16 +40,16 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("default size avoids broadcast") { // TODO: Improve this test when we have better statistics - ctx.sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) + sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) .toDF().registerTempTable("sizeTst") - ctx.cacheTable("sizeTst") + cacheTable("sizeTst") assert( - ctx.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > - ctx.conf.autoBroadcastJoinThreshold) + table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > + conf.autoBroadcastJoinThreshold) } test("projection") { - val plan = ctx.executePlan(testData.select('value, 'key).logicalPlan).executedPlan + val plan = executePlan(testData.select('value, 'key).logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().map { @@ -60,7 +58,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { } test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { - val plan = ctx.executePlan(testData.logicalPlan).executedPlan + val plan = executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -72,7 +70,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM repeatedData"), repeatedData.collect().toSeq.map(Row.fromTuple)) - ctx.cacheTable("repeatedData") + cacheTable("repeatedData") checkAnswer( sql("SELECT * FROM repeatedData"), @@ -84,7 +82,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM nullableRepeatedData"), nullableRepeatedData.collect().toSeq.map(Row.fromTuple)) - ctx.cacheTable("nullableRepeatedData") + cacheTable("nullableRepeatedData") checkAnswer( sql("SELECT * FROM nullableRepeatedData"), @@ -96,7 +94,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT time FROM timestamps"), timestamps.collect().toSeq.map(Row.fromTuple)) - ctx.cacheTable("timestamps") + cacheTable("timestamps") checkAnswer( sql("SELECT time FROM timestamps"), @@ -108,7 +106,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM withEmptyParts"), withEmptyParts.collect().toSeq.map(Row.fromTuple)) - ctx.cacheTable("withEmptyParts") + cacheTable("withEmptyParts") checkAnswer( sql("SELECT * FROM withEmptyParts"), @@ -157,7 +155,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { // Create a RDD for the schema val rdd = - ctx.sparkContext.parallelize((1 to 100), 10).map { i => + sparkContext.parallelize((1 to 100), 10).map { i => Row( s"str${i}: test cache.", s"binary${i}: test cache.".getBytes("UTF-8"), @@ -177,18 +175,18 @@ class InMemoryColumnarQuerySuite extends QueryTest { (0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, Row((i - 0.25).toFloat, Seq(true, false, null))) } - ctx.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") + createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") // Cache the table. sql("cache table InMemoryCache_different_data_types") // Make sure the table is indeed cached. - val tableScan = ctx.table("InMemoryCache_different_data_types").queryExecution.executedPlan + val tableScan = table("InMemoryCache_different_data_types").queryExecution.executedPlan assert( - ctx.isCached("InMemoryCache_different_data_types"), + isCached("InMemoryCache_different_data_types"), "InMemoryCache_different_data_types should be cached.") // Issue a query and check the results. checkAnswer( sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"), - ctx.table("InMemoryCache_different_data_types").collect()) - ctx.dropTempTable("InMemoryCache_different_data_types") + table("InMemoryCache_different_data_types").collect()) + dropTempTable("InMemoryCache_different_data_types") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 6545c6b314a4c..cda1b0992e36f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -21,42 +21,40 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - - private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize - private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning + val originalColumnBatchSize = conf.columnBatchSize + val originalInMemoryPartitionPruning = conf.inMemoryPartitionPruning override protected def beforeAll(): Unit = { // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch - ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, "10") + setConf(SQLConf.COLUMN_BATCH_SIZE, "10") - val pruningData = ctx.sparkContext.makeRDD((1 to 100).map { key => + val pruningData = sparkContext.makeRDD((1 to 100).map { key => val string = if (((key - 1) / 10) % 2 == 0) null else key.toString TestData(key, string) }, 5).toDF() pruningData.registerTempTable("pruningData") // Enable in-memory partition pruning - ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") + setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") // Enable in-memory table scan accumulators - ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") + setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") } override protected def afterAll(): Unit = { - ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString) - ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString) + setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString) + setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString) } before { - ctx.cacheTable("pruningData") + cacheTable("pruningData") } after { - ctx.uncacheTable("pruningData") + uncacheTable("pruningData") } // Comparisons @@ -110,7 +108,7 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi expectedQueryResult: => Seq[Int]): Unit = { test(query) { - val df = ctx.sql(query) + val df = sql(query) val queryExecution = df.queryExecution assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 7931854db27c1..e20c66cb2f1d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -21,11 +21,13 @@ import java.math.BigDecimal import java.sql.DriverManager import java.util.{Calendar, GregorianCalendar, Properties} -import org.h2.jdbc.JdbcSQLException -import org.scalatest.BeforeAndAfter - import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test._ import org.apache.spark.sql.types._ +import org.h2.jdbc.JdbcSQLException +import org.scalatest.BeforeAndAfter +import TestSQLContext._ +import TestSQLContext.implicits._ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb0" @@ -35,16 +37,12 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) val testH2Dialect = new JdbcDialect { - override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2") + def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = Some(StringType) } - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.sql - before { Class.forName("org.h2.Driver") // Extra properties that will be specified for our database. We need these to test @@ -255,26 +253,26 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { } test("Basic API") { - assert(ctx.read.jdbc( + assert(TestSQLContext.read.jdbc( urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3) } test("Basic API with FetchSize") { val properties = new Properties properties.setProperty("fetchSize", "2") - assert(ctx.read.jdbc( + assert(TestSQLContext.read.jdbc( urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3) } test("Partitioning via JDBCPartitioningInfo API") { assert( - ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) + TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) .collect().length === 3) } test("Partitioning via list-of-where-clauses API") { val parts = Array[String]("THEID < 2", "THEID >= 2") - assert(ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) + assert(TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) .collect().length === 3) } @@ -330,9 +328,9 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { } test("test DATE types") { - val rows = ctx.read.jdbc( + val rows = TestSQLContext.read.jdbc( urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() - val cachedRows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val cachedRows = TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) .cache().collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) assert(rows(1).getAs[java.sql.Date](1) === null) @@ -340,8 +338,9 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { } test("test DATE types in cache") { - val rows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() - ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val rows = + TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() + TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) .cache().registerTempTable("mycached_date") val cachedRows = sql("select * from mycached_date").collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) @@ -349,7 +348,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { } test("test types for null value") { - val rows = ctx.read.jdbc( + val rows = TestSQLContext.read.jdbc( urlWithUserAndPass, "TEST.NULLTYPES", new Properties).collect() assert((0 to 14).forall(i => rows(0).isNullAt(i))) } @@ -396,8 +395,10 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { test("Remap types via JdbcDialects") { JdbcDialects.registerDialect(testH2Dialect) - val df = ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) - assert(df.schema.filter(_.dataType != org.apache.spark.sql.types.StringType).isEmpty) + val df = TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) + assert(df.schema.filter( + _.dataType != org.apache.spark.sql.types.StringType + ).isEmpty) val rows = df.collect() assert(rows(0).get(0).isInstanceOf[String]) assert(rows(0).get(1).isInstanceOf[String]) @@ -418,7 +419,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { test("Aggregated dialects") { val agg = new AggregatedDialect(List(new JdbcDialect { - override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:") + def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = if (sqlType % 2 == 0) { @@ -429,8 +430,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { }, testH2Dialect)) assert(agg.canHandle("jdbc:h2:xxx")) assert(!agg.canHandle("jdbc:h2")) - assert(agg.getCatalystType(0, "", 1, null) === Some(LongType)) - assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) + assert(agg.getCatalystType(0, "", 1, null) == Some(LongType)) + assert(agg.getCatalystType(1, "", 1, null) == Some(StringType)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index d949ef42267ec..2de8c1a6098e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -24,6 +24,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{SaveMode, Row} +import org.apache.spark.sql.test._ import org.apache.spark.sql.types._ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { @@ -36,10 +37,6 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { properties.setProperty("password", "testPass") properties.setProperty("rowId", "false") - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.sql - before { Class.forName("org.h2.Driver") conn = DriverManager.getConnection(url) @@ -57,14 +54,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { "create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() conn1.commit() - ctx.sql( + TestSQLContext.sql( s""" |CREATE TEMPORARY TABLE PEOPLE |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - ctx.sql( + TestSQLContext.sql( s""" |CREATE TEMPORARY TABLE PEOPLE1 |USING org.apache.spark.sql.jdbc @@ -77,64 +74,66 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { conn1.close() } - private lazy val sc = ctx.sparkContext + val sc = TestSQLContext.sparkContext - private lazy val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222)) - private lazy val arr1x2 = Array[Row](Row.apply("fred", 3)) - private lazy val schema2 = StructType( + val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222)) + val arr1x2 = Array[Row](Row.apply("fred", 3)) + val schema2 = StructType( StructField("name", StringType) :: StructField("id", IntegerType) :: Nil) - private lazy val arr2x3 = Array[Row](Row.apply("dave", 42, 1), Row.apply("mary", 222, 2)) - private lazy val schema3 = StructType( + val arr2x3 = Array[Row](Row.apply("dave", 42, 1), Row.apply("mary", 222, 2)) + val schema3 = StructType( StructField("name", StringType) :: StructField("id", IntegerType) :: StructField("seq", IntegerType) :: Nil) test("Basic CREATE") { - val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) + val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties) - assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) - assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) + assert(2 == TestSQLContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) + assert(2 == + TestSQLContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) } test("CREATE with overwrite") { - val df = ctx.createDataFrame(sc.parallelize(arr2x3), schema3) - val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3) + val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) df.write.jdbc(url1, "TEST.DROPTEST", properties) - assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(3 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + assert(2 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(3 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.DROPTEST", properties) - assert(1 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + assert(1 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(2 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) } test("CREATE then INSERT to append") { - val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) + val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) df.write.jdbc(url, "TEST.APPENDTEST", new Properties) df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties) - assert(3 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) - assert(2 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) + assert(3 == TestSQLContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) + assert(2 == + TestSQLContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) } test("CREATE then INSERT to truncate") { - val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) + val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) df.write.jdbc(url1, "TEST.TRUNCATETEST", properties) df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.TRUNCATETEST", properties) - assert(1 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) - assert(2 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) + assert(1 == TestSQLContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) + assert(2 == TestSQLContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) } test("Incompatible INSERT to append") { - val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = ctx.createDataFrame(sc.parallelize(arr2x3), schema3) + val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) + val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3) df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties) intercept[org.apache.spark.SparkException] { @@ -143,15 +142,15 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { } test("INSERT to JDBC Datasource") { - ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + TestSQLContext.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } test("INSERT to JDBC Datasource with overwrite") { - ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - ctx.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + TestSQLContext.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + TestSQLContext.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") + assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index d889c7be17ce7..f8d62f9e7e02b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -23,19 +23,21 @@ import java.sql.{Date, Timestamp} import com.fasterxml.jackson.core.JsonFactory import org.scalactic.Tolerance._ -import org.apache.spark.sql.{QueryTest, Row, SQLConf} import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.json.InferSchema.compatibleType import org.apache.spark.sql.sources.LogicalRelation +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.{QueryTest, Row, SQLConf} import org.apache.spark.util.Utils -class JsonSuite extends QueryTest with TestJsonData { +class JsonSuite extends QueryTest { + import org.apache.spark.sql.json.TestJsonData._ - protected lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.sql - import ctx.implicits._ + TestJsonData test("Type promotion") { def checkTypePromotion(expected: Any, actual: Any) { @@ -212,7 +214,7 @@ class JsonSuite extends QueryTest with TestJsonData { } test("Complex field and type inferring with null in sampling") { - val jsonDF = ctx.read.json(jsonNullStruct) + val jsonDF = read.json(jsonNullStruct) val expectedSchema = StructType( StructField("headers", StructType( StructField("Charset", StringType, true) :: @@ -231,7 +233,7 @@ class JsonSuite extends QueryTest with TestJsonData { } test("Primitive field and type inferring") { - val jsonDF = ctx.read.json(primitiveFieldAndType) + val jsonDF = read.json(primitiveFieldAndType) val expectedSchema = StructType( StructField("bigInteger", DecimalType.Unlimited, true) :: @@ -259,7 +261,7 @@ class JsonSuite extends QueryTest with TestJsonData { } test("Complex field and type inferring") { - val jsonDF = ctx.read.json(complexFieldAndType1) + val jsonDF = read.json(complexFieldAndType1) val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: @@ -358,7 +360,7 @@ class JsonSuite extends QueryTest with TestJsonData { } test("GetField operation on complex data type") { - val jsonDF = ctx.read.json(complexFieldAndType1) + val jsonDF = read.json(complexFieldAndType1) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -374,7 +376,7 @@ class JsonSuite extends QueryTest with TestJsonData { } test("Type conflict in primitive field values") { - val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict) + val jsonDF = read.json(primitiveFieldValueTypeConflict) val expectedSchema = StructType( StructField("num_bool", StringType, true) :: @@ -448,7 +450,7 @@ class JsonSuite extends QueryTest with TestJsonData { } ignore("Type conflict in primitive field values (Ignored)") { - val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict) + val jsonDF = read.json(primitiveFieldValueTypeConflict) jsonDF.registerTempTable("jsonTable") // Right now, the analyzer does not promote strings in a boolean expression. @@ -501,7 +503,7 @@ class JsonSuite extends QueryTest with TestJsonData { } test("Type conflict in complex field values") { - val jsonDF = ctx.read.json(complexFieldValueTypeConflict) + val jsonDF = read.json(complexFieldValueTypeConflict) val expectedSchema = StructType( StructField("array", ArrayType(LongType, true), true) :: @@ -525,7 +527,7 @@ class JsonSuite extends QueryTest with TestJsonData { } test("Type conflict in array elements") { - val jsonDF = ctx.read.json(arrayElementTypeConflict) + val jsonDF = read.json(arrayElementTypeConflict) val expectedSchema = StructType( StructField("array1", ArrayType(StringType, true), true) :: @@ -553,7 +555,7 @@ class JsonSuite extends QueryTest with TestJsonData { } test("Handling missing fields") { - val jsonDF = ctx.read.json(missingFields) + val jsonDF = read.json(missingFields) val expectedSchema = StructType( StructField("a", BooleanType, true) :: @@ -572,9 +574,8 @@ class JsonSuite extends QueryTest with TestJsonData { val dir = Utils.createTempDir() dir.delete() val path = dir.getCanonicalPath - ctx.sparkContext.parallelize(1 to 100) - .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) - val jsonDF = ctx.read.option("samplingRatio", "0.49").json(path) + sparkContext.parallelize(1 to 100).map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) + val jsonDF = read.option("samplingRatio", "0.49").json(path) val analyzed = jsonDF.queryExecution.analyzed assert( @@ -589,7 +590,7 @@ class JsonSuite extends QueryTest with TestJsonData { val schema = StructType(StructField("a", LongType, true) :: Nil) val logicalRelation = - ctx.read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation] + read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation] val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation] assert(relationWithSchema.path === Some(path)) assert(relationWithSchema.schema === schema) @@ -601,7 +602,7 @@ class JsonSuite extends QueryTest with TestJsonData { dir.delete() val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = ctx.read.json(path) + val jsonDF = read.json(path) val expectedSchema = StructType( StructField("bigInteger", DecimalType.Unlimited, true) :: @@ -670,7 +671,7 @@ class JsonSuite extends QueryTest with TestJsonData { StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) - val jsonDF1 = ctx.read.schema(schema).json(path) + val jsonDF1 = read.schema(schema).json(path) assert(schema === jsonDF1.schema) @@ -687,7 +688,7 @@ class JsonSuite extends QueryTest with TestJsonData { "this is a simple string.") ) - val jsonDF2 = ctx.read.schema(schema).json(primitiveFieldAndType) + val jsonDF2 = read.schema(schema).json(primitiveFieldAndType) assert(schema === jsonDF2.schema) @@ -708,7 +709,7 @@ class JsonSuite extends QueryTest with TestJsonData { test("Applying schemas with MapType") { val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - val jsonWithSimpleMap = ctx.read.schema(schemaWithSimpleMap).json(mapType1) + val jsonWithSimpleMap = read.schema(schemaWithSimpleMap).json(mapType1) jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap") @@ -736,7 +737,7 @@ class JsonSuite extends QueryTest with TestJsonData { val schemaWithComplexMap = StructType( StructField("map", MapType(StringType, innerStruct, true), false) :: Nil) - val jsonWithComplexMap = ctx.read.schema(schemaWithComplexMap).json(mapType2) + val jsonWithComplexMap = read.schema(schemaWithComplexMap).json(mapType2) jsonWithComplexMap.registerTempTable("jsonWithComplexMap") @@ -762,7 +763,7 @@ class JsonSuite extends QueryTest with TestJsonData { } test("SPARK-2096 Correctly parse dot notations") { - val jsonDF = ctx.read.json(complexFieldAndType2) + val jsonDF = read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -780,7 +781,7 @@ class JsonSuite extends QueryTest with TestJsonData { } test("SPARK-3390 Complex arrays") { - val jsonDF = ctx.read.json(complexFieldAndType2) + val jsonDF = read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -803,7 +804,7 @@ class JsonSuite extends QueryTest with TestJsonData { } test("SPARK-3308 Read top level JSON arrays") { - val jsonDF = ctx.read.json(jsonArray) + val jsonDF = read.json(jsonArray) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -821,10 +822,10 @@ class JsonSuite extends QueryTest with TestJsonData { test("Corrupt records") { // Test if we can query corrupt records. - val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") + val oldColumnNameOfCorruptRecord = TestSQLContext.conf.columnNameOfCorruptRecord + TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") - val jsonDF = ctx.read.json(corruptRecords) + val jsonDF = read.json(corruptRecords) jsonDF.registerTempTable("jsonTable") val schema = StructType( @@ -874,11 +875,11 @@ class JsonSuite extends QueryTest with TestJsonData { Row("]") :: Nil ) - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) } test("SPARK-4068: nulls in arrays") { - val jsonDF = ctx.read.json(nullsInArrays) + val jsonDF = read.json(nullsInArrays) jsonDF.registerTempTable("jsonTable") val schema = StructType( @@ -924,7 +925,7 @@ class JsonSuite extends QueryTest with TestJsonData { Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5) } - val df1 = ctx.createDataFrame(rowRDD1, schema1) + val df1 = createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") val df2 = df1.toDF val result = df2.toJSON.collect() @@ -947,7 +948,7 @@ class JsonSuite extends QueryTest with TestJsonData { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df3 = ctx.createDataFrame(rowRDD2, schema2) + val df3 = createDataFrame(rowRDD2, schema2) df3.registerTempTable("applySchema2") val df4 = df3.toDF val result2 = df4.toJSON.collect() @@ -955,8 +956,8 @@ class JsonSuite extends QueryTest with TestJsonData { assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") - val jsonDF = ctx.read.json(primitiveFieldAndType) - val primTable = ctx.read.json(jsonDF.toJSON) + val jsonDF = read.json(primitiveFieldAndType) + val primTable = read.json(jsonDF.toJSON) primTable.registerTempTable("primativeTable") checkAnswer( sql("select * from primativeTable"), @@ -968,8 +969,8 @@ class JsonSuite extends QueryTest with TestJsonData { "this is a simple string.") ) - val complexJsonDF = ctx.read.json(complexFieldAndType1) - val compTable = ctx.read.json(complexJsonDF.toJSON) + val complexJsonDF = read.json(complexFieldAndType1) + val compTable = read.json(complexJsonDF.toJSON) compTable.registerTempTable("complexTable") // Access elements of a primitive array. checkAnswer( @@ -1073,29 +1074,29 @@ class JsonSuite extends QueryTest with TestJsonData { } test("SPARK-7565 MapType in JsonRDD") { - val useStreaming = ctx.getConf(SQLConf.USE_JACKSON_STREAMING_API, "true") - val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") + val useStreaming = getConf(SQLConf.USE_JACKSON_STREAMING_API, "true") + val oldColumnNameOfCorruptRecord = TestSQLContext.conf.columnNameOfCorruptRecord + TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) try{ for (useStreaming <- List("true", "false")) { - ctx.setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) + setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) val temp = Utils.createTempDir().getPath - val df = ctx.read.schema(schemaWithSimpleMap).json(mapType1) + val df = read.schema(schemaWithSimpleMap).json(mapType1) df.write.mode("overwrite").parquet(temp) // order of MapType is not defined - assert(ctx.read.parquet(temp).count() == 5) + assert(read.parquet(temp).count() == 5) - val df2 = ctx.read.json(corruptRecords) + val df2 = read.json(corruptRecords) df2.write.mode("overwrite").parquet(temp) - checkAnswer(ctx.read.parquet(temp), df2.collect()) + checkAnswer(read.parquet(temp), df2.collect()) } } finally { - ctx.setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) + setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index b6a6a8dc6a63c..47a97a49daabb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -17,15 +17,12 @@ package org.apache.spark.sql.json -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.test.TestSQLContext -trait TestJsonData { +object TestJsonData { - protected def ctx: SQLContext - - def primitiveFieldAndType: RDD[String] = - ctx.sparkContext.parallelize( + val primitiveFieldAndType = + TestSQLContext.sparkContext.parallelize( """{"string":"this is a simple string.", "integer":10, "long":21474836470, @@ -35,8 +32,8 @@ trait TestJsonData { "null":null }""" :: Nil) - def primitiveFieldValueTypeConflict: RDD[String] = - ctx.sparkContext.parallelize( + val primitiveFieldValueTypeConflict = + TestSQLContext.sparkContext.parallelize( """{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1, "num_bool":true, "num_str":13.1, "str_bool":"str1"}""" :: """{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null, @@ -46,15 +43,15 @@ trait TestJsonData { """{"num_num_1":21474836570, "num_num_2":1.1, "num_num_3": 21474836470, "num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil) - def jsonNullStruct: RDD[String] = - ctx.sparkContext.parallelize( + val jsonNullStruct = + TestSQLContext.sparkContext.parallelize( """{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":{}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":""}""" :: """{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil) - def complexFieldValueTypeConflict: RDD[String] = - ctx.sparkContext.parallelize( + val complexFieldValueTypeConflict = + TestSQLContext.sparkContext.parallelize( """{"num_struct":11, "str_array":[1, 2, 3], "array":[], "struct_array":[], "struct": {}}""" :: """{"num_struct":{"field":false}, "str_array":null, @@ -64,23 +61,23 @@ trait TestJsonData { """{"num_struct":{}, "str_array":["str1", "str2", 33], "array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil) - def arrayElementTypeConflict: RDD[String] = - ctx.sparkContext.parallelize( + val arrayElementTypeConflict = + TestSQLContext.sparkContext.parallelize( """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}], "array2": [{"field":214748364700}, {"field":1}]}""" :: """{"array3": [{"field":"str"}, {"field":1}]}""" :: """{"array3": [1, 2, 3]}""" :: Nil) - def missingFields: RDD[String] = - ctx.sparkContext.parallelize( + val missingFields = + TestSQLContext.sparkContext.parallelize( """{"a":true}""" :: """{"b":21474836470}""" :: """{"c":[33, 44]}""" :: """{"d":{"field":true}}""" :: """{"e":"str"}""" :: Nil) - def complexFieldAndType1: RDD[String] = - ctx.sparkContext.parallelize( + val complexFieldAndType1 = + TestSQLContext.sparkContext.parallelize( """{"struct":{"field1": true, "field2": 92233720368547758070}, "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]}, "arrayOfString":["str1", "str2"], @@ -95,8 +92,8 @@ trait TestJsonData { "arrayOfArray2":[[1, 2, 3], [1.1, 2.1, 3.1]] }""" :: Nil) - def complexFieldAndType2: RDD[String] = - ctx.sparkContext.parallelize( + val complexFieldAndType2 = + TestSQLContext.sparkContext.parallelize( """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], "complexArrayOfStruct": [ { @@ -149,16 +146,16 @@ trait TestJsonData { ]] }""" :: Nil) - def mapType1: RDD[String] = - ctx.sparkContext.parallelize( + val mapType1 = + TestSQLContext.sparkContext.parallelize( """{"map": {"a": 1}}""" :: """{"map": {"b": 2}}""" :: """{"map": {"c": 3}}""" :: """{"map": {"c": 1, "d": 4}}""" :: """{"map": {"e": null}}""" :: Nil) - def mapType2: RDD[String] = - ctx.sparkContext.parallelize( + val mapType2 = + TestSQLContext.sparkContext.parallelize( """{"map": {"a": {"field1": [1, 2, 3, null]}}}""" :: """{"map": {"b": {"field2": 2}}}""" :: """{"map": {"c": {"field1": [], "field2": 4}}}""" :: @@ -166,22 +163,22 @@ trait TestJsonData { """{"map": {"e": null}}""" :: """{"map": {"f": {"field1": null}}}""" :: Nil) - def nullsInArrays: RDD[String] = - ctx.sparkContext.parallelize( + val nullsInArrays = + TestSQLContext.sparkContext.parallelize( """{"field1":[[null], [[["Test"]]]]}""" :: """{"field2":[null, [{"Test":1}]]}""" :: """{"field3":[[null], [{"Test":"2"}]]}""" :: """{"field4":[[null, [1,2,3]]]}""" :: Nil) - def jsonArray: RDD[String] = - ctx.sparkContext.parallelize( + val jsonArray = + TestSQLContext.sparkContext.parallelize( """[{"a":"str_a_1"}]""" :: """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """[]""" :: Nil) - def corruptRecords: RDD[String] = - ctx.sparkContext.parallelize( + val corruptRecords = + TestSQLContext.sparkContext.parallelize( """{""" :: """""" :: """{"a":1, b:2}""" :: @@ -189,5 +186,6 @@ trait TestJsonData { """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """]""" :: Nil) - def empty: RDD[String] = ctx.sparkContext.parallelize(Seq[String]()) + val empty = + TestSQLContext.sparkContext.parallelize(Seq[String]()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 17f5f9a491e6b..bdc2ebabc5e9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -18,13 +18,14 @@ package org.apache.spark.sql.parquet import org.scalatest.BeforeAndAfterAll -import org.apache.parquet.filter2.predicate.Operators._ -import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} +import parquet.filter2.predicate.Operators._ +import parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.sources.LogicalRelation +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf} @@ -41,7 +42,7 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf} * data type is nullable. */ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext + val sqlContext = TestSQLContext private def checkFilterPredicate( df: DataFrame, @@ -311,7 +312,7 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { } class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { - lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi + val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") @@ -340,7 +341,7 @@ class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeA } class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { - lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi + val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 2b6a27032e637..dd48bb350f26d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -24,18 +24,21 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.scalatest.BeforeAndAfterAll -import org.apache.parquet.example.data.simple.SimpleGroup -import org.apache.parquet.example.data.{Group, GroupWriter} -import org.apache.parquet.hadoop.api.WriteSupport -import org.apache.parquet.hadoop.api.WriteSupport.WriteContext -import org.apache.parquet.hadoop.metadata.{ParquetMetadata, FileMetaData, CompressionCodecName} -import org.apache.parquet.hadoop.{Footer, ParquetFileWriter, ParquetWriter} -import org.apache.parquet.io.api.RecordConsumer -import org.apache.parquet.schema.{MessageType, MessageTypeParser} +import parquet.example.data.simple.SimpleGroup +import parquet.example.data.{Group, GroupWriter} +import parquet.hadoop.api.WriteSupport +import parquet.hadoop.api.WriteSupport.WriteContext +import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData, CompressionCodecName} +import parquet.hadoop.{Footer, ParquetFileWriter, ParquetWriter} +import parquet.io.api.RecordConsumer +import parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf, SaveMode} @@ -63,8 +66,9 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS * A test suite that tests basic Parquet I/O. */ class ParquetIOSuiteBase extends QueryTest with ParquetTest { - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ + val sqlContext = TestSQLContext + + import sqlContext.implicits.localSeqToDataFrameHolder /** * Writes `data` to a Parquet file, reads it back and check file contents. @@ -100,7 +104,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { test("fixed-length decimals") { def makeDecimalRDD(decimal: DecimalType): DataFrame = - sqlContext.sparkContext + sparkContext .parallelize(0 to 1000) .map(i => Tuple1(i / 100.0)) .toDF() @@ -111,7 +115,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withTempPath { dir => val data = makeDecimalRDD(DecimalType(precision, scale)) data.write.parquet(dir.getCanonicalPath) - checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq) + checkAnswer(read.parquet(dir.getCanonicalPath), data.collect().toSeq) } } @@ -119,7 +123,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { intercept[Throwable] { withTempPath { dir => makeDecimalRDD(DecimalType(19, 10)).write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).collect() + read.parquet(dir.getCanonicalPath).collect() } } @@ -127,14 +131,14 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { intercept[Throwable] { withTempPath { dir => makeDecimalRDD(DecimalType.Unlimited).write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).collect() + read.parquet(dir.getCanonicalPath).collect() } } } test("date type") { def makeDateRDD(): DataFrame = - sqlContext.sparkContext + sparkContext .parallelize(0 to 1000) .map(i => Tuple1(DateUtils.toJavaDate(i))) .toDF() @@ -143,7 +147,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withTempPath { dir => val data = makeDateRDD() data.write.parquet(dir.getCanonicalPath) - checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq) + checkAnswer(read.parquet(dir.getCanonicalPath), data.collect().toSeq) } } @@ -232,7 +236,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { def checkCompressionCodec(codec: CompressionCodecName): Unit = { withSQLConf(SQLConf.PARQUET_COMPRESSION -> codec.name()) { withParquetFile(data) { path => - assertResult(sqlContext.conf.parquetCompressionCodec.toUpperCase) { + assertResult(conf.parquetCompressionCodec.toUpperCase) { compressionCodecFor(path) } } @@ -240,7 +244,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } // Checks default compression codec - checkCompressionCodec(CompressionCodecName.fromConf(sqlContext.conf.parquetCompressionCodec)) + checkCompressionCodec(CompressionCodecName.fromConf(conf.parquetCompressionCodec)) checkCompressionCodec(CompressionCodecName.UNCOMPRESSED) checkCompressionCodec(CompressionCodecName.GZIP) @@ -279,7 +283,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withTempDir { dir => val path = new Path(dir.toURI.toString, "part-r-0.parquet") makeRawParquetFile(path) - checkAnswer(sqlContext.read.parquet(path.toString), (0 until 10).map { i => + checkAnswer(read.parquet(path.toString), (0 until 10).map { i => Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) }) } @@ -308,7 +312,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withParquetFile((1 to 10).map(i => (i, i.toString))) { file => val newData = (11 to 20).map(i => (i, i.toString)) newData.toDF().write.format("parquet").mode(SaveMode.Overwrite).save(file) - checkAnswer(sqlContext.read.parquet(file), newData.map(Row.fromTuple)) + checkAnswer(read.parquet(file), newData.map(Row.fromTuple)) } } @@ -317,7 +321,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withParquetFile(data) { file => val newData = (11 to 20).map(i => (i, i.toString)) newData.toDF().write.format("parquet").mode(SaveMode.Ignore).save(file) - checkAnswer(sqlContext.read.parquet(file), data.map(Row.fromTuple)) + checkAnswer(read.parquet(file), data.map(Row.fromTuple)) } } @@ -337,7 +341,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withParquetFile(data) { file => val newData = (11 to 20).map(i => (i, i.toString)) newData.toDF().write.format("parquet").mode(SaveMode.Append).save(file) - checkAnswer(sqlContext.read.parquet(file), (data ++ newData).map(Row.fromTuple)) + checkAnswer(read.parquet(file), (data ++ newData).map(Row.fromTuple)) } } @@ -365,11 +369,11 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { val path = new Path(location.getCanonicalPath) ParquetFileWriter.writeMetadataFile( - sqlContext.sparkContext.hadoopConfiguration, + sparkContext.hadoopConfiguration, path, new Footer(path, new ParquetMetadata(fileMetadata, Nil)) :: Nil) - assertResult(sqlContext.read.parquet(path.toString).schema) { + assertResult(read.parquet(path.toString).schema) { StructType( StructField("a", BooleanType, nullable = false) :: StructField("b", IntegerType, nullable = false) :: @@ -396,13 +400,13 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } finally { configuration.set("spark.sql.parquet.output.committer.class", - "org.apache.parquet.hadoop.ParquetOutputCommitter") + "parquet.hadoop.ParquetOutputCommitter") } } } class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { - private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi + val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") @@ -426,7 +430,7 @@ class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterA } class ParquetDataSourceOffIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { - private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi + val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index 8979a0a210a42..3b29979452ad9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -14,7 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.sql.parquet import java.io.File @@ -29,6 +28,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.sources.PartitioningUtils._ import org.apache.spark.sql.sources.{LogicalRelation, Partition, PartitionSpec} +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, QueryTest, Row, SQLContext} @@ -39,10 +39,10 @@ case class ParquetData(intField: Int, stringField: String) case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { + override val sqlContext: SQLContext = TestSQLContext - override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext + import sqlContext._ import sqlContext.implicits._ - import sqlContext.sql val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__" @@ -190,7 +190,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { // Introduce _temporary dir to the base dir the robustness of the schema discovery process. new File(base.getCanonicalPath, "_temporary").mkdir() - sqlContext.read.parquet(base.getCanonicalPath).registerTempTable("t") + println("load the partitioned table") + read.parquet(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -237,7 +238,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - sqlContext.read.parquet(base.getCanonicalPath).registerTempTable("t") + read.parquet(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -285,7 +286,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - val parquetRelation = sqlContext.read.format("parquet").load(base.getCanonicalPath) + val parquetRelation = read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath) parquetRelation.registerTempTable("t") withTempTable("t") { @@ -325,7 +326,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - val parquetRelation = sqlContext.read.format("parquet").load(base.getCanonicalPath) + val parquetRelation = read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath) parquetRelation.registerTempTable("t") withTempTable("t") { @@ -357,7 +358,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { (1 to 10).map(i => (i, i.toString)).toDF("intField", "stringField"), makePartitionDir(base, defaultPartitionName, "pi" -> 2)) - sqlContext.read.format("parquet").load(base.getCanonicalPath).registerTempTable("t") + read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -370,7 +371,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { test("SPARK-7749 Non-partitioned table should have empty partition spec") { withTempPath { dir => (1 to 10).map(i => (i, i.toString)).toDF("a", "b").write.parquet(dir.getCanonicalPath) - val queryExecution = sqlContext.read.parquet(dir.getCanonicalPath).queryExecution + val queryExecution = read.parquet(dir.getCanonicalPath).queryExecution queryExecution.analyzed.collectFirst { case LogicalRelation(relation: ParquetRelation2) => assert(relation.partitionSpec === PartitionSpec.emptySpec) @@ -384,7 +385,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { withTempPath { dir => val df = Seq("/", "[]", "?").zipWithIndex.map(_.swap).toDF("i", "s") df.write.format("parquet").partitionBy("s").save(dir.getCanonicalPath) - checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), df.collect()) + checkAnswer(read.parquet(dir.getCanonicalPath), df.collect()) } } @@ -424,12 +425,12 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { } val schema = StructType(partitionColumns :+ StructField(s"i", StringType)) - val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(row :: Nil), schema) + val df = createDataFrame(sparkContext.parallelize(row :: Nil), schema) withTempPath { dir => df.write.format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) val fields = schema.map(f => Column(f.name).cast(f.dataType)) - checkAnswer(sqlContext.read.load(dir.toString).select(fields: _*), row) + checkAnswer(read.load(dir.toString).select(fields: _*), row) } } @@ -445,7 +446,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { Files.touch(new File(s"${dir.getCanonicalPath}/b=1", ".DS_Store")) Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar")) - checkAnswer(sqlContext.read.format("parquet").load(dir.getCanonicalPath), df) + checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index de0107a361815..304936fb2be8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -22,14 +22,14 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.types._ import org.apache.spark.sql.{SQLConf, QueryTest} import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext._ /** * A test suite that tests various Parquet queries. */ class ParquetQuerySuiteBase extends QueryTest with ParquetTest { - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ - import sqlContext.sql + val sqlContext = TestSQLContext test("simple select queries") { withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { @@ -40,22 +40,22 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest { test("appending") { val data = (0 until 10).map(i => (i, i.toString)) - sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { sql("INSERT INTO TABLE t SELECT * FROM tmp") - checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple)) + checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) } - sqlContext.catalog.unregisterTable(Seq("tmp")) + catalog.unregisterTable(Seq("tmp")) } test("overwriting") { val data = (0 until 10).map(i => (i, i.toString)) - sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") - checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple)) + checkAnswer(table("t"), data.map(Row.fromTuple)) } - sqlContext.catalog.unregisterTable(Seq("tmp")) + catalog.unregisterTable(Seq("tmp")) } test("self-join") { @@ -118,7 +118,7 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest { val schema = StructType(List(StructField("d", DecimalType(18, 0), false), StructField("time", TimestampType, false)).toArray) withTempPath { file => - val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data), schema) + val df = sqlContext.createDataFrame(sparkContext.parallelize(data), schema) df.write.parquet(file.getCanonicalPath) val df2 = sqlContext.read.parquet(file.getCanonicalPath) checkAnswer(df2, df.collect().toSeq) @@ -127,7 +127,7 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest { } class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { - private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi + val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") @@ -139,7 +139,7 @@ class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAnd } class ParquetDataSourceOffQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { - private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi + val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index 171a656f0e01e..caec2a6f25489 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -20,14 +20,15 @@ package org.apache.spark.sql.parquet import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.apache.parquet.schema.MessageTypeParser +import parquet.schema.MessageTypeParser import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext + val sqlContext = TestSQLContext /** * Checks whether the reflected Parquet message type for product type `T` conforms `messageType`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala index eb15a1609f1d0..516ba373f41d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala @@ -33,6 +33,8 @@ import org.apache.spark.sql.{DataFrame, SaveMode} * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. */ private[sql] trait ParquetTest extends SQLTestUtils { + import sqlContext.implicits.{localSeqToDataFrameHolder, rddToDataFrameHolder} + import sqlContext.sparkContext /** * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` @@ -42,7 +44,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath) + sparkContext.parallelize(data).toDF().write.parquet(file.getCanonicalPath) f(file.getCanonicalPath) } } @@ -73,7 +75,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) + data.toDF().write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index ac4a00a6f3dac..17a8b0cca09df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -25,9 +25,11 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils trait SQLTestUtils { - def sqlContext: SQLContext + val sqlContext: SQLContext - protected def configuration = sqlContext.sparkContext.hadoopConfiguration + import sqlContext.{conf, sparkContext} + + protected def configuration = sparkContext.hadoopConfiguration /** * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL @@ -37,12 +39,12 @@ trait SQLTestUtils { */ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(sqlContext.conf.getConf(key)).toOption) - (keys, values).zipped.foreach(sqlContext.conf.setConf) + val currentValues = keys.map(key => Try(conf.getConf(key)).toOption) + (keys, values).zipped.foreach(conf.setConf) try f finally { keys.zip(currentValues).foreach { - case (key, Some(value)) => sqlContext.conf.setConf(key, value) - case (key, None) => sqlContext.conf.unsetConf(key) + case (key, Some(value)) => conf.setConf(key, value) + case (key, None) => conf.unsetConf(key) } } } diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 73e6ccdb1eaf8..20d3c7d4c5959 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala similarity index 86% rename from sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala rename to sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala index 77272aecf2835..48ac9062af96a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.thriftserver -import java.util.{ArrayList => JArrayList, List => JList} +import scala.collection.JavaConversions._ import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} @@ -27,12 +27,8 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse import org.apache.spark.Logging import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -import scala.collection.JavaConversions._ - -private[hive] class SparkSQLDriver( - val context: HiveContext = SparkSQLEnv.hiveContext) - extends Driver - with Logging { +private[hive] abstract class AbstractSparkSQLDriver( + val context: HiveContext = SparkSQLEnv.hiveContext) extends Driver with Logging { private[hive] var tableSchema: Schema = _ private[hive] var hiveResponse: Seq[String] = _ @@ -75,16 +71,6 @@ private[hive] class SparkSQLDriver( 0 } - override def getResults(res: JList[_]): Boolean = { - if (hiveResponse == null) { - false - } else { - res.asInstanceOf[JArrayList[String]].addAll(hiveResponse) - hiveResponse = null - true - } - } - override def getSchema: Schema = tableSchema override def destroy() { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index c9da25253e13f..94687eeda4179 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -29,13 +29,12 @@ import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd, SparkListenerJobStart} import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab +import org.apache.spark.sql.hive.{HiveContext, HiveShim} import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkContext} - /** * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a * `HiveThriftServer2` thrift server. @@ -52,7 +51,7 @@ object HiveThriftServer2 extends Logging { @DeveloperApi def startWithContext(sqlContext: HiveContext): Unit = { val server = new HiveThriftServer2(sqlContext) - sqlContext.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) + sqlContext.setConf("spark.sql.hive.version", HiveShim.version) server.init(sqlContext.hiveconf) server.start() listener = new HiveThriftServer2Listener(server, sqlContext.conf) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 039cfa40d26b3..14f6f658d9b75 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -32,12 +32,12 @@ import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor} import org.apache.hadoop.hive.ql.session.SessionState import org.apache.thrift.transport.TSocket import org.apache.spark.Logging -import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.{HiveContext, HiveShim} import org.apache.spark.util.Utils private[hive] object SparkSQLCLIDriver { @@ -267,7 +267,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { } else { var ret = 0 val hconf = conf.asInstanceOf[HiveConf] - val proc: CommandProcessor = CommandProcessorFactory.get(Array(tokens(0)), hconf) + val proc: CommandProcessor = HiveShim.getCommandProcessor(Array(tokens(0)), hconf) if (proc != null) { if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor] || diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index 41f647d5f8c5a..499e077d7294a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -21,6 +21,8 @@ import java.io.IOException import java.util.{List => JList} import javax.security.auth.login.LoginException +import scala.collection.JavaConversions._ + import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.shims.ShimLoader @@ -32,8 +34,7 @@ import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ - -import scala.collection.JavaConversions._ +import org.apache.spark.util.Utils private[hive] class SparkSQLCLIService(hiveContext: HiveContext) extends CLIService @@ -51,7 +52,7 @@ private[hive] class SparkSQLCLIService(hiveContext: HiveContext) try { HiveAuthFactory.loginFromKeytab(hiveConf) sparkServiceUGI = ShimLoader.getHadoopShims.getUGIForConf(hiveConf) - setSuperField(this, "serviceUGI", sparkServiceUGI) + HiveThriftServerShim.setServerUserName(sparkServiceUGI, this) } catch { case e @ (_: IOException | _: LoginException) => throw new ServiceException("Unable to login to kerberos with given principal/keytab", e) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 79eda1f5123bf..7c0c505e2d61e 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -22,7 +22,7 @@ import java.io.PrintStream import scala.collection.JavaConversions._ import org.apache.spark.scheduler.StatsReportListener -import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.{HiveShim, HiveContext} import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.util.Utils @@ -56,7 +56,7 @@ private[hive] object SparkSQLEnv extends Logging { hiveContext.metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) hiveContext.metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) - hiveContext.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) + hiveContext.setConf("spark.sql.hive.version", HiveShim.version) if (log.isDebugEnabled) { hiveContext.hiveconf.getAllProperties.toSeq.sorted.foreach { case (k, v) => diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala deleted file mode 100644 index 2d5ee68002286..0000000000000 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.thriftserver - -import java.util.concurrent.Executors - -import org.apache.commons.logging.Log -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hive.service.cli.SessionHandle -import org.apache.hive.service.cli.session.SessionManager -import org.apache.hive.service.cli.thrift.TProtocolVersion - -import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ -import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager - - -private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) - extends SessionManager - with ReflectedCompositeService { - - private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) - - override def init(hiveConf: HiveConf) { - setSuperField(this, "hiveConf", hiveConf) - - val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) - setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) - getAncestorField[Log](this, 3, "LOG").info( - s"HiveServer2: Async execution pool size $backgroundPoolSize") - - setSuperField(this, "operationManager", sparkSqlOperationManager) - addService(sparkSqlOperationManager) - - initCompositeService(hiveConf) - } - - override def openSession( - protocol: TProtocolVersion, - username: String, - passwd: String, - sessionConf: java.util.Map[String, String], - withImpersonation: Boolean, - delegationToken: String): SessionHandle = { - hiveContext.openSession() - val sessionHandle = super.openSession( - protocol, username, passwd, sessionConf, withImpersonation, delegationToken) - val session = super.getSession(sessionHandle) - HiveThriftServer2.listener.onSessionCreated( - session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername) - sessionHandle - } - - override def closeSession(sessionHandle: SessionHandle) { - HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString) - super.closeSession(sessionHandle) - sparkSqlOperationManager.sessionToActivePool -= sessionHandle - - hiveContext.detachSession() - } -} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index c8031ed0f3437..9c0bf02391e0e 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -44,12 +44,9 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext) confOverlay: JMap[String, String], async: Boolean): ExecuteStatementOperation = synchronized { - val runInBackground = async && hiveContext.hiveThriftServerAsync - val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay, - runInBackground)(hiveContext, sessionToActivePool) + val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay)( + hiveContext, sessionToActivePool) handleToOperation.put(operation.getHandle, operation) - logDebug(s"Created Operation for $statement with session=$parentSession, " + - s"runInBackground=$runInBackground") operation } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 13b0c5951dddc..3732af7870b93 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -133,7 +133,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { } test("Single command with -e") { - runCliWithin(2.minute, Seq("-e", "SHOW DATABASES;"))("" -> "OK") + runCliWithin(1.minute, Seq("-e", "SHOW DATABASES;"))("" -> "OK") } test("Single command with --database") { @@ -165,7 +165,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { val dataFilePath = Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") - runCliWithin(3.minute, Seq("--jars", s"$jarFile"))( + runCliWithin(1.minute, Seq("--jars", s"$jarFile"))( """CREATE TABLE t1(key string, val string) |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'; """.stripMargin diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 178bd1f5cb164..a93a3dee43511 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -19,13 +19,11 @@ package org.apache.spark.sql.hive.thriftserver import java.io.File import java.net.URL -import java.nio.charset.StandardCharsets -import java.sql.{Date, DriverManager, SQLException, Statement} +import java.sql.{Date, DriverManager, Statement} import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ -import scala.concurrent.{Await, Promise, future} -import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.{Await, Promise} import scala.sys.process.{Process, ProcessLogger} import scala.util.{Random, Try} @@ -42,7 +40,7 @@ import org.apache.thrift.transport.TSocket import org.scalatest.BeforeAndAfterAll import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.HiveShim import org.apache.spark.util.Utils object TestData { @@ -113,8 +111,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { withJdbcStatement { statement => val resultSet = statement.executeQuery("SET spark.sql.hive.version") resultSet.next() - assert(resultSet.getString(1) === - s"spark.sql.hive.version=${HiveContext.hiveExecutionVersion}") + assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") } } @@ -340,42 +337,6 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } ) } - - test("test jdbc cancel") { - withJdbcStatement { statement => - val queries = Seq( - "DROP TABLE IF EXISTS test_map", - "CREATE TABLE test_map(key INT, value STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map") - - queries.foreach(statement.execute) - - val largeJoin = "SELECT COUNT(*) FROM test_map " + - List.fill(10)("join test_map").mkString(" ") - val f = future { Thread.sleep(100); statement.cancel(); } - val e = intercept[SQLException] { - statement.executeQuery(largeJoin) - } - assert(e.getMessage contains "cancelled") - Await.result(f, 3.minute) - - // cancel is a noop - statement.executeQuery("SET spark.sql.hive.thriftServer.async=false") - val sf = future { Thread.sleep(100); statement.cancel(); } - val smallJoin = "SELECT COUNT(*) FROM test_map " + - List.fill(4)("join test_map").mkString(" ") - val rs1 = statement.executeQuery(smallJoin) - Await.result(sf, 3.minute) - rs1.next() - assert(rs1.getInt(1) === math.pow(5, 5)) - rs1.close() - - val rs2 = statement.executeQuery("SELECT COUNT(*) FROM test_map") - rs2.next() - assert(rs2.getInt(1) === 5) - rs2.close() - } - } } class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { @@ -404,8 +365,7 @@ class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { withJdbcStatement { statement => val resultSet = statement.executeQuery("SET spark.sql.hive.version") resultSet.next() - assert(resultSet.getString(1) === - s"spark.sql.hive.version=${HiveContext.hiveExecutionVersion}") + assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") } } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala similarity index 52% rename from sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala rename to sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala index e071103df925c..b9d4f1c58c982 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala @@ -17,55 +17,78 @@ package org.apache.spark.sql.hive.thriftserver -import java.security.PrivilegedExceptionAction import java.sql.{Date, Timestamp} -import java.util.concurrent.RejectedExecutionException -import java.util.{Map => JMap, UUID} +import java.util.concurrent.Executors +import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, UUID} + +import org.apache.commons.logging.Log +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hive.service.cli.thrift.TProtocolVersion +import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, Map => SMap} -import scala.util.control.NonFatal -import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.hive.service.cli._ -import org.apache.hadoop.hive.ql.metadata.Hive -import org.apache.hadoop.hive.ql.metadata.HiveException -import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.shims.ShimLoader import org.apache.hadoop.security.UserGroupInformation +import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.ExecuteStatementOperation -import org.apache.hive.service.cli.session.HiveSession +import org.apache.hive.service.cli.session.{SessionManager, HiveSession} -import org.apache.spark.Logging +import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} import org.apache.spark.sql.execution.SetCommand +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} +/** + * A compatibility layer for interacting with Hive version 0.13.1. + */ +private[thriftserver] object HiveThriftServerShim { + val version = "0.13.1" + + def setServerUserName( + sparkServiceUGI: UserGroupInformation, + sparkCliService:SparkSQLCLIService) = { + setSuperField(sparkCliService, "serviceUGI", sparkServiceUGI) + } +} + +private[hive] class SparkSQLDriver(val _context: HiveContext = SparkSQLEnv.hiveContext) + extends AbstractSparkSQLDriver(_context) { + override def getResults(res: JList[_]): Boolean = { + if (hiveResponse == null) { + false + } else { + res.asInstanceOf[JArrayList[String]].addAll(hiveResponse) + hiveResponse = null + true + } + } +} private[hive] class SparkExecuteStatementOperation( parentSession: HiveSession, statement: String, confOverlay: JMap[String, String], - runInBackground: Boolean = true) - (hiveContext: HiveContext, sessionToActivePool: SMap[SessionHandle, String]) - extends ExecuteStatementOperation(parentSession, statement, confOverlay, runInBackground) - with Logging { + runInBackground: Boolean = true)( + hiveContext: HiveContext, + sessionToActivePool: SMap[SessionHandle, String]) + // NOTE: `runInBackground` is set to `false` intentionally to disable asynchronous execution + extends ExecuteStatementOperation(parentSession, statement, confOverlay, false) with Logging { private var result: DataFrame = _ private var iter: Iterator[SparkRow] = _ private var dataTypes: Array[DataType] = _ - private var statementId: String = _ def close(): Unit = { // RDDs will be cleaned automatically upon garbage collection. - hiveContext.sparkContext.clearJobGroup() - logDebug(s"CLOSING $statementId") - cleanup(OperationState.CLOSED) + logDebug("CLOSING") } - def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int) { + def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int) { dataTypes(ordinal) match { case StringType => to += from.getString(ordinal) @@ -126,10 +149,10 @@ private[hive] class SparkExecuteStatementOperation( } def getResultSetSchema: TableSchema = { - if (result == null || result.queryExecution.analyzed.output.size == 0) { + logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") + if (result.queryExecution.analyzed.output.size == 0) { new TableSchema(new FieldSchema("Result", "string", "") :: Nil) } else { - logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") val schema = result.queryExecution.analyzed.output.map { attr => new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") } @@ -137,73 +160,9 @@ private[hive] class SparkExecuteStatementOperation( } } - override def run(): Unit = { - setState(OperationState.PENDING) - setHasResultSet(true) // avoid no resultset for async run - - if (!runInBackground) { - runInternal() - } else { - val parentSessionState = SessionState.get() - val hiveConf = getConfigForOperation() - val sparkServiceUGI = ShimLoader.getHadoopShims.getUGIForConf(hiveConf) - val sessionHive = getCurrentHive() - val currentSqlSession = hiveContext.currentSession - - // Runnable impl to call runInternal asynchronously, - // from a different thread - val backgroundOperation = new Runnable() { - - override def run(): Unit = { - val doAsAction = new PrivilegedExceptionAction[Object]() { - override def run(): Object = { - - // User information is part of the metastore client member in Hive - hiveContext.setSession(currentSqlSession) - Hive.set(sessionHive) - SessionState.setCurrentSessionState(parentSessionState) - try { - runInternal() - } catch { - case e: HiveSQLException => - setOperationException(e) - log.error("Error running hive query: ", e) - } - return null - } - } - - try { - ShimLoader.getHadoopShims().doAs(sparkServiceUGI, doAsAction) - } catch { - case e: Exception => - setOperationException(new HiveSQLException(e)) - logError("Error running hive query as user : " + - sparkServiceUGI.getShortUserName(), e) - } - } - } - try { - // This submit blocks if no background threads are available to run this operation - val backgroundHandle = - getParentSession().getSessionManager().submitBackgroundOperation(backgroundOperation) - setBackgroundHandle(backgroundHandle) - } catch { - case rejected: RejectedExecutionException => - setState(OperationState.ERROR) - throw new HiveSQLException("The background threadpool cannot accept" + - " new task for execution, please retry the operation", rejected) - case NonFatal(e) => - logError(s"Error executing query in background", e) - setState(OperationState.ERROR) - throw e - } - } - } - - private def runInternal(): Unit = { - statementId = UUID.randomUUID().toString - logInfo(s"Running query '$statement' with $statementId") + def run(): Unit = { + val statementId = UUID.randomUUID().toString + logInfo(s"Running query '$statement'") setState(OperationState.RUNNING) HiveThriftServer2.listener.onStatementStart( statementId, @@ -235,82 +194,63 @@ private[hive] class SparkExecuteStatementOperation( } } dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray + setHasResultSet(true) } catch { - case e: HiveSQLException => - if (getStatus().getState() == OperationState.CANCELED) { - return - } else { - setState(OperationState.ERROR); - throw e - } // Actually do need to catch Throwable as some failures don't inherit from Exception and // HiveServer will silently swallow them. case e: Throwable => - val currentState = getStatus().getState() - logError(s"Error executing query, currentState $currentState, ", e) setState(OperationState.ERROR) HiveThriftServer2.listener.onStatementError( statementId, e.getMessage, e.getStackTraceString) + logError("Error executing query:", e) throw new HiveSQLException(e.toString) } setState(OperationState.FINISHED) HiveThriftServer2.listener.onStatementFinish(statementId) } +} - override def cancel(): Unit = { - logInfo(s"Cancel '$statement' with $statementId") - if (statementId != null) { - hiveContext.sparkContext.cancelJobGroup(statementId) - } - cleanup(OperationState.CANCELED) - } +private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) + extends SessionManager + with ReflectedCompositeService { - private def cleanup(state: OperationState) { - setState(state) - if (runInBackground) { - val backgroundHandle = getBackgroundHandle() - if (backgroundHandle != null) { - backgroundHandle.cancel(true) - } - } - } + private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) - /** - * If there are query specific settings to overlay, then create a copy of config - * There are two cases we need to clone the session config that's being passed to hive driver - * 1. Async query - - * If the client changes a config setting, that shouldn't reflect in the execution - * already underway - * 2. confOverlay - - * The query specific settings should only be applied to the query config and not session - * @return new configuration - * @throws HiveSQLException - */ - private def getConfigForOperation(): HiveConf = { - var sqlOperationConf = getParentSession().getHiveConf() - if (!getConfOverlay().isEmpty() || runInBackground) { - // clone the partent session config for this query - sqlOperationConf = new HiveConf(sqlOperationConf) + override def init(hiveConf: HiveConf) { + setSuperField(this, "hiveConf", hiveConf) - // apply overlay query specific settings, if any - getConfOverlay().foreach { case (k, v) => - try { - sqlOperationConf.verifyAndSet(k, v) - } catch { - case e: IllegalArgumentException => - throw new HiveSQLException("Error applying statement specific settings", e) - } - } - } - return sqlOperationConf + val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) + setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) + getAncestorField[Log](this, 3, "LOG").info( + s"HiveServer2: Async execution pool size $backgroundPoolSize") + + setSuperField(this, "operationManager", sparkSqlOperationManager) + addService(sparkSqlOperationManager) + + initCompositeService(hiveConf) } - private def getCurrentHive(): Hive = { - try { - return Hive.get() - } catch { - case e: HiveException => - throw new HiveSQLException("Failed to get current Hive object", e); - } + override def openSession( + protocol: TProtocolVersion, + username: String, + passwd: String, + sessionConf: java.util.Map[String, String], + withImpersonation: Boolean, + delegationToken: String): SessionHandle = { + hiveContext.openSession() + val sessionHandle = super.openSession( + protocol, username, passwd, sessionConf, withImpersonation, delegationToken) + val session = super.getSession(sessionHandle) + HiveThriftServer2.listener.onSessionCreated( + session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername) + sessionHandle + } + + override def closeSession(sessionHandle: SessionHandle) { + HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString) + super.closeSession(sessionHandle) + sparkSqlOperationManager.sessionToActivePool -= sessionHandle + + hiveContext.detachSession() } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 048f78b4daa8d..0b1917a392901 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -23,6 +23,7 @@ import java.util.{Locale, TimeZone} import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.hive.HiveShim import org.apache.spark.sql.hive.test.TestHive /** @@ -253,7 +254,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // the answer is sensitive for jdk version "udf_java_method" - ) + ) ++ HiveShim.compatibilityBlackList /** * The set of tests that are believed to be working in catalyst. Tests not on whiteList or diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index a17546d706248..923ffabb9b99e 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index b8f294c262af7..fbf2c7d8cbc06 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -17,34 +17,37 @@ package org.apache.spark.sql.hive -import java.io.File +import java.io.{BufferedReader, File, InputStreamReader, PrintStream} import java.net.{URL, URLClassLoader} import java.sql.Timestamp +import java.util.{ArrayList => JArrayList} -import org.apache.hadoop.hive.common.StatsSetupConst -import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.spark.sql.catalyst.ParserDialect import scala.collection.JavaConversions._ -import scala.collection.mutable.HashMap +import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.language.implicitConversions import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.parse.VariableSubstitution +import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.annotation.Experimental +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateSubQueries, OverrideCatalog, OverrideFunctionRegistry} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand} +import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, QueryExecutionException, SetCommand} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} -import org.apache.spark.sql.sources.DataSourceStrategy +import org.apache.spark.sql.sources.{DDLParser, DataSourceStrategy} import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -144,12 +147,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { getConf("spark.sql.hive.metastore.barrierPrefixes", "") .split(",").filterNot(_ == "") - /* - * hive thrift server use background spark sql thread pool to execute sql queries - */ - protected[hive] def hiveThriftServerAsync: Boolean = - getConf("spark.sql.hive.thriftServer.async", "true").toBoolean - @transient protected[sql] lazy val substitutor = new VariableSubstitution() @@ -334,7 +331,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val tableParameters = relation.hiveQlTable.getParameters val oldTotalSize = - Option(tableParameters.get(StatsSetupConst.TOTAL_SIZE)) + Option(tableParameters.get(HiveShim.getStatsSetupConstTotalSize)) .map(_.toLong) .getOrElse(0L) val newTotalSize = getFileSizeForTable(hiveconf, relation.hiveQlTable) @@ -345,7 +342,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.client.alterTable( relation.table.copy( properties = relation.table.properties + - (StatsSetupConst.TOTAL_SIZE -> newTotalSize.toString))) + (HiveShim.getStatsSetupConstTotalSize -> newTotalSize.toString))) } case otherRelation => throw new UnsupportedOperationException( @@ -567,7 +564,7 @@ private[hive] object HiveContext { case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8") case (decimal: java.math.BigDecimal, DecimalType()) => // Hive strips trailing zeros so use its toString - HiveDecimal.create(decimal).toString + HiveShim.createDecimal(decimal).toString case (other, tpe) if primitiveTypes contains tpe => other.toString } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index c466203cd0220..24cd335082639 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _} -import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfoFactory} import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.{io => hadoopIo} @@ -351,7 +350,7 @@ private[hive] trait HiveInspectors { new HiveVarchar(s, s.size) case _: JavaHiveDecimalObjectInspector => - (o: Any) => HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal) + (o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toJavaBigDecimal) case _: JavaDateObjectInspector => (o: Any) => DateUtils.toJavaDate(o.asInstanceOf[Int]) @@ -440,31 +439,31 @@ private[hive] trait HiveInspectors { case _ if a == null => null case x: PrimitiveObjectInspector => x match { // TODO we don't support the HiveVarcharObjectInspector yet. - case _: StringObjectInspector if x.preferWritable() => getStringWritable(a) + case _: StringObjectInspector if x.preferWritable() => HiveShim.getStringWritable(a) case _: StringObjectInspector => a.asInstanceOf[UTF8String].toString() - case _: IntObjectInspector if x.preferWritable() => getIntWritable(a) + case _: IntObjectInspector if x.preferWritable() => HiveShim.getIntWritable(a) case _: IntObjectInspector => a.asInstanceOf[java.lang.Integer] - case _: BooleanObjectInspector if x.preferWritable() => getBooleanWritable(a) + case _: BooleanObjectInspector if x.preferWritable() => HiveShim.getBooleanWritable(a) case _: BooleanObjectInspector => a.asInstanceOf[java.lang.Boolean] - case _: FloatObjectInspector if x.preferWritable() => getFloatWritable(a) + case _: FloatObjectInspector if x.preferWritable() => HiveShim.getFloatWritable(a) case _: FloatObjectInspector => a.asInstanceOf[java.lang.Float] - case _: DoubleObjectInspector if x.preferWritable() => getDoubleWritable(a) + case _: DoubleObjectInspector if x.preferWritable() => HiveShim.getDoubleWritable(a) case _: DoubleObjectInspector => a.asInstanceOf[java.lang.Double] - case _: LongObjectInspector if x.preferWritable() => getLongWritable(a) + case _: LongObjectInspector if x.preferWritable() => HiveShim.getLongWritable(a) case _: LongObjectInspector => a.asInstanceOf[java.lang.Long] - case _: ShortObjectInspector if x.preferWritable() => getShortWritable(a) + case _: ShortObjectInspector if x.preferWritable() => HiveShim.getShortWritable(a) case _: ShortObjectInspector => a.asInstanceOf[java.lang.Short] - case _: ByteObjectInspector if x.preferWritable() => getByteWritable(a) + case _: ByteObjectInspector if x.preferWritable() => HiveShim.getByteWritable(a) case _: ByteObjectInspector => a.asInstanceOf[java.lang.Byte] case _: HiveDecimalObjectInspector if x.preferWritable() => - getDecimalWritable(a.asInstanceOf[Decimal]) + HiveShim.getDecimalWritable(a.asInstanceOf[Decimal]) case _: HiveDecimalObjectInspector => - HiveDecimal.create(a.asInstanceOf[Decimal].toJavaBigDecimal) - case _: BinaryObjectInspector if x.preferWritable() => getBinaryWritable(a) + HiveShim.createDecimal(a.asInstanceOf[Decimal].toJavaBigDecimal) + case _: BinaryObjectInspector if x.preferWritable() => HiveShim.getBinaryWritable(a) case _: BinaryObjectInspector => a.asInstanceOf[Array[Byte]] - case _: DateObjectInspector if x.preferWritable() => getDateWritable(a) + case _: DateObjectInspector if x.preferWritable() => HiveShim.getDateWritable(a) case _: DateObjectInspector => DateUtils.toJavaDate(a.asInstanceOf[Int]) - case _: TimestampObjectInspector if x.preferWritable() => getTimestampWritable(a) + case _: TimestampObjectInspector if x.preferWritable() => HiveShim.getTimestampWritable(a) case _: TimestampObjectInspector => a.asInstanceOf[java.sql.Timestamp] } case x: SettableStructObjectInspector => @@ -575,31 +574,31 @@ private[hive] trait HiveInspectors { */ def toInspector(expr: Expression): ObjectInspector = expr match { case Literal(value, StringType) => - getStringWritableConstantObjectInspector(value) + HiveShim.getStringWritableConstantObjectInspector(value) case Literal(value, IntegerType) => - getIntWritableConstantObjectInspector(value) + HiveShim.getIntWritableConstantObjectInspector(value) case Literal(value, DoubleType) => - getDoubleWritableConstantObjectInspector(value) + HiveShim.getDoubleWritableConstantObjectInspector(value) case Literal(value, BooleanType) => - getBooleanWritableConstantObjectInspector(value) + HiveShim.getBooleanWritableConstantObjectInspector(value) case Literal(value, LongType) => - getLongWritableConstantObjectInspector(value) + HiveShim.getLongWritableConstantObjectInspector(value) case Literal(value, FloatType) => - getFloatWritableConstantObjectInspector(value) + HiveShim.getFloatWritableConstantObjectInspector(value) case Literal(value, ShortType) => - getShortWritableConstantObjectInspector(value) + HiveShim.getShortWritableConstantObjectInspector(value) case Literal(value, ByteType) => - getByteWritableConstantObjectInspector(value) + HiveShim.getByteWritableConstantObjectInspector(value) case Literal(value, BinaryType) => - getBinaryWritableConstantObjectInspector(value) + HiveShim.getBinaryWritableConstantObjectInspector(value) case Literal(value, DateType) => - getDateWritableConstantObjectInspector(value) + HiveShim.getDateWritableConstantObjectInspector(value) case Literal(value, TimestampType) => - getTimestampWritableConstantObjectInspector(value) + HiveShim.getTimestampWritableConstantObjectInspector(value) case Literal(value, DecimalType()) => - getDecimalWritableConstantObjectInspector(value) + HiveShim.getDecimalWritableConstantObjectInspector(value) case Literal(_, NullType) => - getPrimitiveNullWritableConstantObjectInspector + HiveShim.getPrimitiveNullWritableConstantObjectInspector case Literal(value, ArrayType(dt, _)) => val listObjectInspector = toInspector(dt) if (value == null) { @@ -659,8 +658,8 @@ private[hive] trait HiveInspectors { case _: JavaFloatObjectInspector => FloatType case _: WritableBinaryObjectInspector => BinaryType case _: JavaBinaryObjectInspector => BinaryType - case w: WritableHiveDecimalObjectInspector => decimalTypeInfoToCatalyst(w) - case j: JavaHiveDecimalObjectInspector => decimalTypeInfoToCatalyst(j) + case w: WritableHiveDecimalObjectInspector => HiveShim.decimalTypeInfoToCatalyst(w) + case j: JavaHiveDecimalObjectInspector => HiveShim.decimalTypeInfoToCatalyst(j) case _: WritableDateObjectInspector => DateType case _: JavaDateObjectInspector => DateType case _: WritableTimestampObjectInspector => TimestampType @@ -669,136 +668,10 @@ private[hive] trait HiveInspectors { case _: JavaVoidObjectInspector => NullType } - private def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = { - val info = inspector.getTypeInfo.asInstanceOf[DecimalTypeInfo] - DecimalType(info.precision(), info.scale()) - } - - private def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.stringTypeInfo, getStringWritable(value)) - - private def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.intTypeInfo, getIntWritable(value)) - - private def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.doubleTypeInfo, getDoubleWritable(value)) - - private def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.booleanTypeInfo, getBooleanWritable(value)) - - private def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.longTypeInfo, getLongWritable(value)) - - private def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.floatTypeInfo, getFloatWritable(value)) - - private def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.shortTypeInfo, getShortWritable(value)) - - private def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.byteTypeInfo, getByteWritable(value)) - - private def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.binaryTypeInfo, getBinaryWritable(value)) - - private def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.dateTypeInfo, getDateWritable(value)) - - private def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.timestampTypeInfo, getTimestampWritable(value)) - - private def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.decimalTypeInfo, getDecimalWritable(value)) - - private def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.voidTypeInfo, null) - - private def getStringWritable(value: Any): hadoopIo.Text = - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) - - private def getIntWritable(value: Any): hadoopIo.IntWritable = - if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) - - private def getDoubleWritable(value: Any): hiveIo.DoubleWritable = - if (value == null) { - null - } else { - new hiveIo.DoubleWritable(value.asInstanceOf[Double]) - } - - private def getBooleanWritable(value: Any): hadoopIo.BooleanWritable = - if (value == null) { - null - } else { - new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]) - } - - private def getLongWritable(value: Any): hadoopIo.LongWritable = - if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long]) - - private def getFloatWritable(value: Any): hadoopIo.FloatWritable = - if (value == null) { - null - } else { - new hadoopIo.FloatWritable(value.asInstanceOf[Float]) - } - - private def getShortWritable(value: Any): hiveIo.ShortWritable = - if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short]) - - private def getByteWritable(value: Any): hiveIo.ByteWritable = - if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte]) - - private def getBinaryWritable(value: Any): hadoopIo.BytesWritable = - if (value == null) { - null - } else { - new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) - } - - private def getDateWritable(value: Any): hiveIo.DateWritable = - if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[Int]) - - private def getTimestampWritable(value: Any): hiveIo.TimestampWritable = - if (value == null) { - null - } else { - new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) - } - - private def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable = - if (value == null) { - null - } else { - // TODO precise, scale? - new hiveIo.HiveDecimalWritable( - HiveDecimal.create(value.asInstanceOf[Decimal].toJavaBigDecimal)) - } - implicit class typeInfoConversions(dt: DataType) { import org.apache.hadoop.hive.serde2.typeinfo._ import TypeInfoFactory._ - private def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match { - case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale) - case _ => new DecimalTypeInfo( - HiveShim.UNLIMITED_DECIMAL_PRECISION, - HiveShim.UNLIMITED_DECIMAL_SCALE) - } - def toTypeInfo: TypeInfo = dt match { case ArrayType(elemType, _) => getListTypeInfo(elemType.toTypeInfo) @@ -817,7 +690,7 @@ private[hive] trait HiveInspectors { case LongType => longTypeInfo case ShortType => shortTypeInfo case StringType => stringTypeInfo - case d: DecimalType => decimalTypeInfo(d) + case d: DecimalType => HiveShim.decimalTypeInfo(d) case DateType => dateTypeInfo case TimestampType => timestampTypeInfo case NullType => voidTypeInfo diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 5a4651a887b7c..ca1f49b546bd7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -19,13 +19,11 @@ package org.apache.spark.sql.hive import com.google.common.base.Objects import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} - import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.metastore.Warehouse import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.ql.metadata._ -import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.hadoop.hive.serde2.Deserializer import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} @@ -39,6 +37,7 @@ import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode, sources} +import org.apache.spark.util.Utils /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -671,8 +670,8 @@ private[hive] case class MetastoreRelation @transient override lazy val statistics: Statistics = Statistics( sizeInBytes = { - val totalSize = hiveQlTable.getParameters.get(StatsSetupConst.TOTAL_SIZE) - val rawDataSize = hiveQlTable.getParameters.get(StatsSetupConst.RAW_DATA_SIZE) + val totalSize = hiveQlTable.getParameters.get(HiveShim.getStatsSetupConstTotalSize) + val rawDataSize = hiveQlTable.getParameters.get(HiveShim.getStatsSetupConstRawDataSize) // TODO: check if this estimate is valid for tables after partition pruning. // NOTE: getting `totalSize` directly from params is kind of hacky, but this should be // relatively cheap if parameters for the table are populated into the metastore. An @@ -698,7 +697,11 @@ private[hive] case class MetastoreRelation } } - val tableDesc = new TableDesc( + val tableDesc = HiveShim.getTableDesc( + Class.forName( + hiveQlTable.getSerializationLib, + true, + Utils.getContextOrSparkClassLoader).asInstanceOf[Class[Deserializer]], hiveQlTable.getInputFormatClass, // The class of table should be org.apache.hadoop.hive.ql.metadata.Table because // getOutputFormatClass will use HiveFileFormatUtils.getOutputFormatSubstitute to @@ -740,11 +743,6 @@ private[hive] case class MetastoreRelation private[hive] object HiveMetastoreTypes { def toDataType(metastoreType: String): DataType = DataTypeParser.parse(metastoreType) - def decimalMetastoreString(decimalType: DecimalType): String = decimalType match { - case DecimalType.Fixed(precision, scale) => s"decimal($precision,$scale)" - case _ => s"decimal($HiveShim.UNLIMITED_DECIMAL_PRECISION,$HiveShim.UNLIMITED_DECIMAL_SCALE)" - } - def toMetastoreType(dt: DataType): String = dt match { case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" case StructType(fields) => @@ -761,7 +759,7 @@ private[hive] object HiveMetastoreTypes { case BinaryType => "binary" case BooleanType => "boolean" case DateType => "date" - case d: DecimalType => decimalMetastoreString(d) + case d: DecimalType => HiveShim.decimalMetastoreString(d) case TimestampType => "timestamp" case NullType => "void" case udt: UserDefinedType[_] => toMetastoreType(udt.sqlType) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 9544d12c9053c..a5ca3613c5e00 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.hive import java.sql.Date +import scala.collection.mutable.ArrayBuffer + import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.ql.{ErrorMsg, Context} @@ -37,7 +39,6 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.execution.ExplainCommand import org.apache.spark.sql.sources.DescribeCommand -import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable, HiveScriptIOSchema} import org.apache.spark.sql.types._ @@ -45,7 +46,6 @@ import org.apache.spark.util.random.RandomSampler /* Implicit conversions */ import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer /** * Used when we need to start parsing the AST before deciding that we are going to pass the command diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala deleted file mode 100644 index fa5409f602444..0000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ /dev/null @@ -1,247 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.io.{InputStream, OutputStream} -import java.rmi.server.UID - -import com.esotericsoftware.kryo.Kryo -import com.esotericsoftware.kryo.io.{Input, Output} -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} -import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} -import org.apache.hadoop.hive.serde2.ColumnProjectionUtils -import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable -import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector -import org.apache.hadoop.io.Writable - -import org.apache.spark.Logging -import org.apache.spark.sql.types.Decimal -import org.apache.spark.util.Utils - -/* Implicit conversions */ -import scala.collection.JavaConversions._ -import scala.reflect.ClassTag - -private[hive] object HiveShim { - // Precision and scale to pass for unlimited decimals; these are the same as the precision and - // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs) - val UNLIMITED_DECIMAL_PRECISION = 38 - val UNLIMITED_DECIMAL_SCALE = 18 - - /* - * This function in hive-0.13 become private, but we have to do this to walkaround hive bug - */ - private def appendReadColumnNames(conf: Configuration, cols: Seq[String]) { - val old: String = conf.get(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, "") - val result: StringBuilder = new StringBuilder(old) - var first: Boolean = old.isEmpty - - for (col <- cols) { - if (first) { - first = false - } else { - result.append(',') - } - result.append(col) - } - conf.set(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, result.toString) - } - - /* - * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null or empty - */ - def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { - if (ids != null && ids.size > 0) { - ColumnProjectionUtils.appendReadColumns(conf, ids) - } - if (names != null && names.size > 0) { - appendReadColumnNames(conf, names) - } - } - - /* - * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that - * is needed to initialize before serialization. - */ - def prepareWritable(w: Writable): Writable = { - w match { - case w: AvroGenericRecordWritable => - w.setRecordReaderID(new UID()) - case _ => - } - w - } - - def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { - if (hdoi.preferWritable()) { - Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue, - hdoi.precision(), hdoi.scale()) - } else { - Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) - } - } - - /** - * This class provides the UDF creation and also the UDF instance serialization and - * de-serialization cross process boundary. - * - * Detail discussion can be found at https://github.com/apache/spark/pull/3640 - * - * @param functionClassName UDF class name - */ - private[hive] case class HiveFunctionWrapper(var functionClassName: String) - extends java.io.Externalizable { - - // for Serialization - def this() = this(null) - - @transient - def deserializeObjectByKryo[T: ClassTag]( - kryo: Kryo, - in: InputStream, - clazz: Class[_]): T = { - val inp = new Input(in) - val t: T = kryo.readObject(inp, clazz).asInstanceOf[T] - inp.close() - t - } - - @transient - def serializeObjectByKryo( - kryo: Kryo, - plan: Object, - out: OutputStream) { - val output: Output = new Output(out) - kryo.writeObject(output, plan) - output.close() - } - - def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = { - deserializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), is, clazz) - .asInstanceOf[UDFType] - } - - def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = { - serializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), function, out) - } - - private var instance: AnyRef = null - - def writeExternal(out: java.io.ObjectOutput) { - // output the function name - out.writeUTF(functionClassName) - - // Write a flag if instance is null or not - out.writeBoolean(instance != null) - if (instance != null) { - // Some of the UDF are serializable, but some others are not - // Hive Utilities can handle both cases - val baos = new java.io.ByteArrayOutputStream() - serializePlan(instance, baos) - val functionInBytes = baos.toByteArray - - // output the function bytes - out.writeInt(functionInBytes.length) - out.write(functionInBytes, 0, functionInBytes.length) - } - } - - def readExternal(in: java.io.ObjectInput) { - // read the function name - functionClassName = in.readUTF() - - if (in.readBoolean()) { - // if the instance is not null - // read the function in bytes - val functionInBytesLength = in.readInt() - val functionInBytes = new Array[Byte](functionInBytesLength) - in.read(functionInBytes, 0, functionInBytesLength) - - // deserialize the function object via Hive Utilities - instance = deserializePlan[AnyRef](new java.io.ByteArrayInputStream(functionInBytes), - Utils.getContextOrSparkClassLoader.loadClass(functionClassName)) - } - } - - def createFunction[UDFType <: AnyRef](): UDFType = { - if (instance != null) { - instance.asInstanceOf[UDFType] - } else { - val func = Utils.getContextOrSparkClassLoader - .loadClass(functionClassName).newInstance.asInstanceOf[UDFType] - if (!func.isInstanceOf[UDF]) { - // We cache the function if it's no the Simple UDF, - // as we always have to create new instance for Simple UDF - instance = func - } - func - } - } - } - - /* - * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not. - * Fix it through wrapper. - * */ - implicit def wrapperToFileSinkDesc(w: ShimFileSinkDesc): FileSinkDesc = { - var f = new FileSinkDesc(new Path(w.dir), w.tableInfo, w.compressed) - f.setCompressCodec(w.compressCodec) - f.setCompressType(w.compressType) - f.setTableInfo(w.tableInfo) - f.setDestTableId(w.destTableId) - f - } - - /* - * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not. - * Fix it through wrapper. - */ - private[hive] class ShimFileSinkDesc( - var dir: String, - var tableInfo: TableDesc, - var compressed: Boolean) - extends Serializable with Logging { - var compressCodec: String = _ - var compressType: String = _ - var destTableId: Int = _ - - def setCompressed(compressed: Boolean) { - this.compressed = compressed - } - - def getDirName(): String = dir - - def setDestTableId(destTableId: Int) { - this.destTableId = destTableId - } - - def setTableInfo(tableInfo: TableDesc) { - this.tableInfo = tableInfo - } - - def setCompressCodec(intermediateCompressorCodec: String) { - compressCodec = intermediateCompressorCodec - } - - def setCompressType(intermediateCompressType: String) { - compressType = intermediateCompressType - } - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 334bfccc9d200..294fc3bd7d5e9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -25,13 +25,14 @@ import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} import org.apache.hadoop.hive.serde2.Deserializer -import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, StructObjectInspector} +import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} -import org.apache.spark.{Logging, SerializableWritable} +import org.apache.spark.SerializableWritable import org.apache.spark.broadcast.Broadcast +import org.apache.spark.Logging import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DateUtils @@ -171,7 +172,7 @@ class HadoopTableReader( path.toString + tails } - val partPath = partition.getDataLocation + val partPath = HiveShim.getDataLocationPath(partition) val partNum = Utilities.getPartitionDesc(partition).getPartSpec.size(); var pathPatternStr = getPathPatternByPath(partNum, partPath) if (!pathPatternSet.contains(pathPatternStr)) { @@ -186,7 +187,7 @@ class HadoopTableReader( val hivePartitionRDDs = verifyPartitionPath(partitionToDeserializer) .map { case (partition, partDeserializer) => val partDesc = Utilities.getPartitionDesc(partition) - val partPath = partition.getDataLocation + val partPath = HiveShim.getDataLocationPath(partition) val inputPathStr = applyFilterIfNeeded(partPath, filterOpt) val ifc = partDesc.getInputFileFormatClass .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] @@ -324,7 +325,7 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { val soi = if (rawDeser.getObjectInspector.equals(tableDeser.getObjectInspector)) { rawDeser.getObjectInspector.asInstanceOf[StructObjectInspector] } else { - ObjectInspectorConverters.getConvertedOI( + HiveShim.getConvertedOI( rawDeser.getObjectInspector, tableDeser.getObjectInspector).asInstanceOf[StructObjectInspector] } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index eeb472602be3c..8613332186f28 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -19,25 +19,27 @@ package org.apache.spark.sql.hive.execution import java.util +import scala.collection.JavaConversions._ + import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.metastore.MetaStoreUtils +import org.apache.hadoop.hive.ql.metadata.Hive import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde2.Serializer import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.mapred.{FileOutputFormat, JobConf} +import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} -import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive._ +import org.apache.spark.sql.hive.{ ShimFileSinkDesc => FileSinkDesc} +import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.{SerializableWritable, SparkException, TaskContext} -import scala.collection.JavaConversions._ - private[hive] case class InsertIntoHiveTable( table: MetastoreRelation, @@ -124,7 +126,7 @@ case class InsertIntoHiveTable( // instances within the closure, since Serializer is not serializable while TableDesc is. val tableDesc = table.tableDesc val tableLocation = table.hiveQlTable.getDataLocation - val tmpLocation = hiveContext.getExternalTmpPath(tableLocation.toUri) + val tmpLocation = HiveShim.getExternalTmpPath(hiveContext, tableLocation) val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) val isCompressed = sc.hiveconf.getBoolean( ConfVars.COMPRESSRESULT.varname, ConfVars.COMPRESSRESULT.defaultBoolVal) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 01f47352b2313..1658bb93b0b79 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -37,7 +37,6 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.types._ /* Implicit conversions */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index ee440e304ec19..2bb526b14be34 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -35,7 +35,8 @@ import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.Row import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} +import org.apache.spark.sql.hive.{ShimFileSinkDesc => FileSinkDesc} +import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.types._ /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index af586712e3235..58e2d1fbfa73e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -561,28 +561,30 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA } } - test("scan a parquet table created through a CTAS statement") { - withSQLConf( - "spark.sql.hive.convertMetastoreParquet" -> "true", - SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { - - withTempTable("jt") { - (1 to 10).map(i => i -> s"str$i").toDF("a", "b").registerTempTable("jt") - - withTable("test_parquet_ctas") { - sql( - """CREATE TABLE test_parquet_ctas STORED AS PARQUET - |AS SELECT tmp.a FROM jt tmp WHERE tmp.a < 5 - """.stripMargin) - - checkAnswer( - sql(s"SELECT a FROM test_parquet_ctas WHERE a > 2 "), - Row(3) :: Row(4) :: Nil) - - table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(p: ParquetRelation2) => // OK - case _ => - fail(s"test_parquet_ctas should have be converted to ${classOf[ParquetRelation2]}") + if (HiveShim.version == "0.13.1") { + test("scan a parquet table created through a CTAS statement") { + withSQLConf( + "spark.sql.hive.convertMetastoreParquet" -> "true", + SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { + + withTempTable("jt") { + (1 to 10).map(i => i -> s"str$i").toDF("a", "b").registerTempTable("jt") + + withTable("test_parquet_ctas") { + sql( + """CREATE TABLE test_parquet_ctas STORED AS PARQUET + |AS SELECT tmp.a FROM jt tmp WHERE tmp.a < 5 + """.stripMargin) + + checkAnswer( + sql(s"SELECT a FROM test_parquet_ctas WHERE a > 2 "), + Row(3) :: Row(4) :: Nil) + + table("test_parquet_ctas").queryExecution.optimizedPlan match { + case LogicalRelation(p: ParquetRelation2) => // OK + case _ => + fail(s"test_parquet_ctas should have be converted to ${classOf[ParquetRelation2]}") + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index e16e530555aee..00a69de9e4262 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -79,6 +79,10 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() + // TODO: How does it works? needs to add it back for other hive version. + if (HiveShim.version =="0.12.0") { + assert(queryTotalSize("analyzeTable") === conf.defaultSizeInBytes) + } sql("ANALYZE TABLE analyzeTable COMPUTE STATISTICS noscan") assert(queryTotalSize("analyzeTable") === BigInt(11624)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 6d8d99ebc8164..440b7c87b0da2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -874,6 +874,15 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |WITH serdeproperties('s1'='9') """.stripMargin) } + // Now only verify 0.12.0, and ignore other versions due to binary compatibility + // current TestSerDe.jar is from 0.12.0 + if (HiveShim.version == "0.12.0") { + sql(s"ADD JAR $testJar") + sql( + """ALTER TABLE alter1 SET SERDE 'org.apache.hadoop.hive.serde2.TestSerDe' + |WITH serdeproperties('s1'='9') + """.stripMargin) + } sql("DROP TABLE alter1") } @@ -881,13 +890,15 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // this is a test case from mapjoin_addjar.q val testJar = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath val testData = TestHive.getHiveFile("data/files/sample.json").getCanonicalPath - sql(s"ADD JAR $testJar") - sql( - """CREATE TABLE t1(a string, b string) - |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'""".stripMargin) - sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE t1""") - sql("select * from src join t1 on src.key = t1.a") - sql("DROP TABLE t1") + if (HiveShim.version == "0.13.1") { + sql(s"ADD JAR $testJar") + sql( + """CREATE TABLE t1(a string, b string) + |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'""".stripMargin) + sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE t1""") + sql("select * from src join t1 on src.key = t1.a") + sql("DROP TABLE t1") + } } test("ADD FILE command") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 40a35674e4cb8..aba3becb1bce2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.hive.{HiveQLDialect, MetastoreRelation} +import org.apache.spark.sql.hive.{HiveQLDialect, HiveShim, MetastoreRelation} import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.types._ @@ -330,33 +330,35 @@ class SQLQuerySuite extends QueryTest { "serde_p1=p1", "serde_p2=p2", "tbl_p1=p11", "tbl_p2=p22", "MANAGED_TABLE" ) - val origUseParquetDataSource = conf.parquetUseDataSourceApi - try { - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") - sql( - """CREATE TABLE ctas5 - | STORED AS parquet AS - | SELECT key, value - | FROM src - | ORDER BY key, value""".stripMargin).collect() - - checkExistence(sql("DESC EXTENDED ctas5"), true, - "name:key", "type:string", "name:value", "ctas5", - "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", - "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", - "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", - "MANAGED_TABLE" - ) - - val default = getConf("spark.sql.hive.convertMetastoreParquet", "true") - // use the Hive SerDe for parquet tables - sql("set spark.sql.hive.convertMetastoreParquet = false") - checkAnswer( - sql("SELECT key, value FROM ctas5 ORDER BY key, value"), - sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) - sql(s"set spark.sql.hive.convertMetastoreParquet = $default") - } finally { - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, origUseParquetDataSource.toString) + if (HiveShim.version =="0.13.1") { + val origUseParquetDataSource = conf.parquetUseDataSourceApi + try { + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + sql( + """CREATE TABLE ctas5 + | STORED AS parquet AS + | SELECT key, value + | FROM src + | ORDER BY key, value""".stripMargin).collect() + + checkExistence(sql("DESC EXTENDED ctas5"), true, + "name:key", "type:string", "name:value", "ctas5", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", + "MANAGED_TABLE" + ) + + val default = getConf("spark.sql.hive.convertMetastoreParquet", "true") + // use the Hive SerDe for parquet tables + sql("set spark.sql.hive.convertMetastoreParquet = false") + checkAnswer( + sql("SELECT key, value FROM ctas5 ORDER BY key, value"), + sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) + sql(s"set spark.sql.hive.convertMetastoreParquet = $default") + } finally { + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, origUseParquetDataSource.toString) + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index b384fb39f3d66..57c23fe77f8b5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -52,6 +52,9 @@ case class Contact(name: String, phone: String) case class Person(name: String, age: Int, contacts: Seq[Contact]) class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { + override val sqlContext = TestHive + + import TestHive.read def getTempFilePath(prefix: String, suffix: String = ""): File = { val tempFile = File.createTempFile(prefix, suffix) @@ -66,7 +69,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withOrcFile(data) { file => checkAnswer( - sqlContext.read.format("orc").load(file), + read.format("orc").load(file), data.toDF().collect()) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 5daf691aa8c53..750f0b04aaa87 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -22,11 +22,13 @@ import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql._ private[sql] trait OrcTest extends SQLTestUtils { - lazy val sqlContext = org.apache.spark.sql.hive.test.TestHive + protected def hiveContext = sqlContext.asInstanceOf[HiveContext] import sqlContext.sparkContext import sqlContext.implicits._ @@ -51,7 +53,7 @@ private[sql] trait OrcTest extends SQLTestUtils { protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: DataFrame => Unit): Unit = { - withOrcFile(data)(path => f(sqlContext.read.format("orc").load(path))) + withOrcFile(data)(path => f(hiveContext.read.format("orc").load(path))) } /** @@ -63,7 +65,7 @@ private[sql] trait OrcTest extends SQLTestUtils { (data: Seq[T], tableName: String) (f: => Unit): Unit = { withOrcDataFrame(data) { df => - sqlContext.registerDataFrameAsTable(df, tableName) + hiveContext.registerDataFrameAsTable(df, tableName) withTempTable(tableName)(f) } } diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala new file mode 100644 index 0000000000000..dbc5e029e2047 --- /dev/null +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -0,0 +1,457 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.rmi.server.UID +import java.util.{Properties, ArrayList => JArrayList} +import java.io.{OutputStream, InputStream} + +import scala.collection.JavaConversions._ +import scala.language.implicitConversions +import scala.reflect.ClassTag + +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.io.Input +import com.esotericsoftware.kryo.io.Output +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.common.StatsSetupConst +import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.Context +import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} +import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} +import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} +import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory +import org.apache.hadoop.hive.serde.serdeConstants +import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable +import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorConverters, PrimitiveObjectInspector} +import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfo, TypeInfoFactory} +import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer, io => hiveIo} +import org.apache.hadoop.io.{NullWritable, Writable} +import org.apache.hadoop.mapred.InputFormat +import org.apache.hadoop.{io => hadoopIo} + +import org.apache.spark.Logging +import org.apache.spark.sql.types.{Decimal, DecimalType, UTF8String} +import org.apache.spark.util.Utils._ + +/** + * This class provides the UDF creation and also the UDF instance serialization and + * de-serialization cross process boundary. + * + * Detail discussion can be found at https://github.com/apache/spark/pull/3640 + * + * @param functionClassName UDF class name + */ +private[hive] case class HiveFunctionWrapper(var functionClassName: String) + extends java.io.Externalizable { + + // for Serialization + def this() = this(null) + + @transient + def deserializeObjectByKryo[T: ClassTag]( + kryo: Kryo, + in: InputStream, + clazz: Class[_]): T = { + val inp = new Input(in) + val t: T = kryo.readObject(inp,clazz).asInstanceOf[T] + inp.close() + t + } + + @transient + def serializeObjectByKryo( + kryo: Kryo, + plan: Object, + out: OutputStream ) { + val output: Output = new Output(out) + kryo.writeObject(output, plan) + output.close() + } + + def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = { + deserializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), is, clazz) + .asInstanceOf[UDFType] + } + + def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = { + serializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), function, out) + } + + private var instance: AnyRef = null + + def writeExternal(out: java.io.ObjectOutput) { + // output the function name + out.writeUTF(functionClassName) + + // Write a flag if instance is null or not + out.writeBoolean(instance != null) + if (instance != null) { + // Some of the UDF are serializable, but some others are not + // Hive Utilities can handle both cases + val baos = new java.io.ByteArrayOutputStream() + serializePlan(instance, baos) + val functionInBytes = baos.toByteArray + + // output the function bytes + out.writeInt(functionInBytes.length) + out.write(functionInBytes, 0, functionInBytes.length) + } + } + + def readExternal(in: java.io.ObjectInput) { + // read the function name + functionClassName = in.readUTF() + + if (in.readBoolean()) { + // if the instance is not null + // read the function in bytes + val functionInBytesLength = in.readInt() + val functionInBytes = new Array[Byte](functionInBytesLength) + in.read(functionInBytes, 0, functionInBytesLength) + + // deserialize the function object via Hive Utilities + instance = deserializePlan[AnyRef](new java.io.ByteArrayInputStream(functionInBytes), + getContextOrSparkClassLoader.loadClass(functionClassName)) + } + } + + def createFunction[UDFType <: AnyRef](): UDFType = { + if (instance != null) { + instance.asInstanceOf[UDFType] + } else { + val func = getContextOrSparkClassLoader + .loadClass(functionClassName).newInstance.asInstanceOf[UDFType] + if (!func.isInstanceOf[UDF]) { + // We cache the function if it's no the Simple UDF, + // as we always have to create new instance for Simple UDF + instance = func + } + func + } + } +} + +/** + * A compatibility layer for interacting with Hive version 0.13.1. + */ +private[hive] object HiveShim { + val version = "0.13.1" + + def getTableDesc( + serdeClass: Class[_ <: Deserializer], + inputFormatClass: Class[_ <: InputFormat[_, _]], + outputFormatClass: Class[_], + properties: Properties) = { + new TableDesc(inputFormatClass, outputFormatClass, properties) + } + + + def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.stringTypeInfo, getStringWritable(value)) + + def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.intTypeInfo, getIntWritable(value)) + + def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.doubleTypeInfo, getDoubleWritable(value)) + + def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.booleanTypeInfo, getBooleanWritable(value)) + + def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.longTypeInfo, getLongWritable(value)) + + def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.floatTypeInfo, getFloatWritable(value)) + + def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.shortTypeInfo, getShortWritable(value)) + + def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.byteTypeInfo, getByteWritable(value)) + + def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.binaryTypeInfo, getBinaryWritable(value)) + + def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.dateTypeInfo, getDateWritable(value)) + + def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.timestampTypeInfo, getTimestampWritable(value)) + + def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.decimalTypeInfo, getDecimalWritable(value)) + + def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.voidTypeInfo, null) + + def getStringWritable(value: Any): hadoopIo.Text = + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) + + def getIntWritable(value: Any): hadoopIo.IntWritable = + if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) + + def getDoubleWritable(value: Any): hiveIo.DoubleWritable = + if (value == null) { + null + } else { + new hiveIo.DoubleWritable(value.asInstanceOf[Double]) + } + + def getBooleanWritable(value: Any): hadoopIo.BooleanWritable = + if (value == null) { + null + } else { + new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]) + } + + def getLongWritable(value: Any): hadoopIo.LongWritable = + if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long]) + + def getFloatWritable(value: Any): hadoopIo.FloatWritable = + if (value == null) { + null + } else { + new hadoopIo.FloatWritable(value.asInstanceOf[Float]) + } + + def getShortWritable(value: Any): hiveIo.ShortWritable = + if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short]) + + def getByteWritable(value: Any): hiveIo.ByteWritable = + if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte]) + + def getBinaryWritable(value: Any): hadoopIo.BytesWritable = + if (value == null) { + null + } else { + new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) + } + + def getDateWritable(value: Any): hiveIo.DateWritable = + if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[Int]) + + def getTimestampWritable(value: Any): hiveIo.TimestampWritable = + if (value == null) { + null + } else { + new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) + } + + def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable = + if (value == null) { + null + } else { + // TODO precise, scale? + new hiveIo.HiveDecimalWritable( + HiveShim.createDecimal(value.asInstanceOf[Decimal].toJavaBigDecimal)) + } + + def getPrimitiveNullWritable: NullWritable = NullWritable.get() + + def createDriverResultsArray = new JArrayList[Object] + + def processResults(results: JArrayList[Object]) = { + results.map { r => + r match { + case s: String => s + case a: Array[Object] => a(0).asInstanceOf[String] + } + } + } + + def getStatsSetupConstTotalSize = StatsSetupConst.TOTAL_SIZE + + def getStatsSetupConstRawDataSize = StatsSetupConst.RAW_DATA_SIZE + + def createDefaultDBIfNeeded(context: HiveContext) = { + context.runSqlHive("CREATE DATABASE default") + context.runSqlHive("USE default") + } + + def getCommandProcessor(cmd: Array[String], conf: HiveConf) = { + CommandProcessorFactory.get(cmd, conf) + } + + def createDecimal(bd: java.math.BigDecimal): HiveDecimal = { + HiveDecimal.create(bd) + } + + /* + * This function in hive-0.13 become private, but we have to do this to walkaround hive bug + */ + private def appendReadColumnNames(conf: Configuration, cols: Seq[String]) { + val old: String = conf.get(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, "") + val result: StringBuilder = new StringBuilder(old) + var first: Boolean = old.isEmpty + + for (col <- cols) { + if (first) { + first = false + } else { + result.append(',') + } + result.append(col) + } + conf.set(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, result.toString) + } + + /* + * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null or empty + */ + def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { + if (ids != null && ids.size > 0) { + ColumnProjectionUtils.appendReadColumns(conf, ids) + } + if (names != null && names.size > 0) { + appendReadColumnNames(conf, names) + } + } + + def getExternalTmpPath(context: Context, path: Path) = { + context.getExternalTmpPath(path.toUri) + } + + def getDataLocationPath(p: Partition) = p.getDataLocation + + def getAllPartitionsOf(client: Hive, tbl: Table) = client.getAllPartitionsOf(tbl) + + def compatibilityBlackList = Seq() + + def setLocation(tbl: Table, crtTbl: CreateTableDesc): Unit = { + tbl.setDataLocation(new Path(crtTbl.getLocation())) + } + + /* + * Bug introdiced in hive-0.13. FileSinkDesc is serializable, but its member path is not. + * Fix it through wrapper. + * */ + implicit def wrapperToFileSinkDesc(w: ShimFileSinkDesc): FileSinkDesc = { + var f = new FileSinkDesc(new Path(w.dir), w.tableInfo, w.compressed) + f.setCompressCodec(w.compressCodec) + f.setCompressType(w.compressType) + f.setTableInfo(w.tableInfo) + f.setDestTableId(w.destTableId) + f + } + + // Precision and scale to pass for unlimited decimals; these are the same as the precision and + // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs) + private val UNLIMITED_DECIMAL_PRECISION = 38 + private val UNLIMITED_DECIMAL_SCALE = 18 + + def decimalMetastoreString(decimalType: DecimalType): String = decimalType match { + case DecimalType.Fixed(precision, scale) => s"decimal($precision,$scale)" + case _ => s"decimal($UNLIMITED_DECIMAL_PRECISION,$UNLIMITED_DECIMAL_SCALE)" + } + + def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match { + case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale) + case _ => new DecimalTypeInfo(UNLIMITED_DECIMAL_PRECISION, UNLIMITED_DECIMAL_SCALE) + } + + def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = { + val info = inspector.getTypeInfo.asInstanceOf[DecimalTypeInfo] + DecimalType(info.precision(), info.scale()) + } + + def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { + if (hdoi.preferWritable()) { + Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue, + hdoi.precision(), hdoi.scale()) + } else { + Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) + } + } + + def getConvertedOI(inputOI: ObjectInspector, outputOI: ObjectInspector): ObjectInspector = { + ObjectInspectorConverters.getConvertedOI(inputOI, outputOI) + } + + /* + * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that + * is needed to initialize before serialization. + */ + def prepareWritable(w: Writable): Writable = { + w match { + case w: AvroGenericRecordWritable => + w.setRecordReaderID(new UID()) + case _ => + } + w + } + + def setTblNullFormat(crtTbl: CreateTableDesc, tbl: Table) = { + if (crtTbl != null && crtTbl.getNullFormat() != null) { + tbl.setSerdeParam(serdeConstants.SERIALIZATION_NULL_FORMAT, crtTbl.getNullFormat()) + } + } +} + +/* + * Bug introduced in hive-0.13. FileSinkDesc is serilizable, but its member path is not. + * Fix it through wrapper. + */ +private[hive] class ShimFileSinkDesc( + var dir: String, + var tableInfo: TableDesc, + var compressed: Boolean) + extends Serializable with Logging { + var compressCodec: String = _ + var compressType: String = _ + var destTableId: Int = _ + + def setCompressed(compressed: Boolean) { + this.compressed = compressed + } + + def getDirName = dir + + def setDestTableId(destTableId: Int) { + this.destTableId = destTableId + } + + def setTableInfo(tableInfo: TableDesc) { + this.tableInfo = tableInfo + } + + def setCompressCodec(intermediateCompressorCodec: String) { + compressCodec = intermediateCompressorCodec + } + + def setCompressType(intermediateCompressType: String) { + compressType = intermediateCompressType + } +} diff --git a/streaming/pom.xml b/streaming/pom.xml index 697895e72fe5b..49d035a1e9696 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../pom.xml diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 207d64d9414ee..dd0af251b46bc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -32,7 +32,10 @@ import org.apache.spark.{Logging, SparkConf, SparkException} /** Trait that represents the metadata related to storage of blocks */ private[streaming] trait ReceivedBlockStoreResult { - def blockId: StreamBlockId // Any implementation of this trait will store a block id + // Any implementation of this trait will store a block id + def blockId: StreamBlockId + // Any implementation of this trait will have to return the number of records + def numRecords: Option[Long] } /** Trait that represents a class that handles the storage of blocks received by receiver */ @@ -51,7 +54,8 @@ private[streaming] trait ReceivedBlockHandler { * that stores the metadata related to storage of blocks using * [[org.apache.spark.streaming.receiver.BlockManagerBasedBlockHandler]] */ -private[streaming] case class BlockManagerBasedStoreResult(blockId: StreamBlockId) +private[streaming] case class BlockManagerBasedStoreResult(blockId: StreamBlockId, + numRecords: Option[Long]) extends ReceivedBlockStoreResult @@ -64,11 +68,17 @@ private[streaming] class BlockManagerBasedBlockHandler( extends ReceivedBlockHandler with Logging { def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = { + var numRecords = None: Option[Long] + val countIterator = block match { + case ArrayBufferBlock(arrayBuffer) => new CountingIterator(arrayBuffer.iterator) + case IteratorBlock(iterator) => new CountingIterator(iterator) + case _ => null + } val putResult: Seq[(BlockId, BlockStatus)] = block match { case ArrayBufferBlock(arrayBuffer) => - blockManager.putIterator(blockId, arrayBuffer.iterator, storageLevel, tellMaster = true) + blockManager.putIterator(blockId, countIterator, storageLevel, tellMaster = true) case IteratorBlock(iterator) => - blockManager.putIterator(blockId, iterator, storageLevel, tellMaster = true) + blockManager.putIterator(blockId, countIterator, storageLevel, tellMaster = true) case ByteBufferBlock(byteBuffer) => blockManager.putBytes(blockId, byteBuffer, storageLevel, tellMaster = true) case o => @@ -79,7 +89,10 @@ private[streaming] class BlockManagerBasedBlockHandler( throw new SparkException( s"Could not store $blockId to block manager with storage level $storageLevel") } - BlockManagerBasedStoreResult(blockId) + if(countIterator != null) { + numRecords = Some(countIterator.count) + } + BlockManagerBasedStoreResult(blockId, numRecords) } def cleanupOldBlocks(threshTime: Long) { @@ -96,6 +109,7 @@ private[streaming] class BlockManagerBasedBlockHandler( */ private[streaming] case class WriteAheadLogBasedStoreResult( blockId: StreamBlockId, + numRecords: Option[Long], walRecordHandle: WriteAheadLogRecordHandle ) extends ReceivedBlockStoreResult @@ -151,12 +165,18 @@ private[streaming] class WriteAheadLogBasedBlockHandler( */ def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = { + var numRecords = None: Option[Long] + val countIterator = block match { + case ArrayBufferBlock(arrayBuffer) => new CountingIterator(arrayBuffer.iterator) + case IteratorBlock(iterator) => new CountingIterator(iterator) + case _ => null + } // Serialize the block so that it can be inserted into both val serializedBlock = block match { case ArrayBufferBlock(arrayBuffer) => - blockManager.dataSerialize(blockId, arrayBuffer.iterator) + blockManager.dataSerialize(blockId, countIterator) case IteratorBlock(iterator) => - blockManager.dataSerialize(blockId, iterator) + blockManager.dataSerialize(blockId, countIterator) case ByteBufferBlock(byteBuffer) => byteBuffer case _ => @@ -181,7 +201,10 @@ private[streaming] class WriteAheadLogBasedBlockHandler( // Combine the futures, wait for both to complete, and return the write ahead log record handle val combinedFuture = storeInBlockManagerFuture.zip(storeInWriteAheadLogFuture).map(_._2) val walRecordHandle = Await.result(combinedFuture, blockStoreTimeout) - WriteAheadLogBasedStoreResult(blockId, walRecordHandle) + if(countIterator != null) { + numRecords = Some(countIterator.count) + } + WriteAheadLogBasedStoreResult(blockId, numRecords, walRecordHandle) } def cleanupOldBlocks(threshTime: Long) { @@ -199,3 +222,16 @@ private[streaming] object WriteAheadLogBasedBlockHandler { new Path(checkpointDir, new Path("receivedData", streamId.toString)).toString } } + +/** + * A utility that will wrap the Iterator to get the count + */ +private class CountingIterator[T: Manifest](iterator: Iterator[T]) extends Iterator[T] { + var count = 0 + def hasNext(): Boolean = iterator.hasNext + def isFullyConsumed: Boolean = !iterator.hasNext + def next(): T = { + count+=1 + iterator.next() + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 8be732b64e3a3..6078cdf8f8790 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -137,15 +137,10 @@ private[streaming] class ReceiverSupervisorImpl( blockIdOption: Option[StreamBlockId] ) { val blockId = blockIdOption.getOrElse(nextBlockId) - val numRecords = receivedBlock match { - case ArrayBufferBlock(arrayBuffer) => Some(arrayBuffer.size.toLong) - case _ => None - } - val time = System.currentTimeMillis val blockStoreResult = receivedBlockHandler.storeBlock(blockId, receivedBlock) logDebug(s"Pushed block $blockId in ${(System.currentTimeMillis - time)} ms") - + val numRecords = blockStoreResult.numRecords val blockInfo = ReceivedBlockInfo(streamId, numRecords, metadataOption, blockStoreResult) trackerEndpoint.askWithRetry[Boolean](AddBlock(blockInfo)) logDebug(s"Reported block $blockId") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index cca8cedb1d080..790d8ab06d3b4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -49,7 +49,6 @@ class ReceivedBlockHandlerSuite val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") val hadoopConf = new Configuration() - val storageLevel = StorageLevel.MEMORY_ONLY_SER val streamId = 1 val securityMgr = new SecurityManager(conf) val mapOutputTracker = new MapOutputTrackerMaster(conf) @@ -61,7 +60,21 @@ class ReceivedBlockHandlerSuite var rpcEnv: RpcEnv = null var blockManagerMaster: BlockManagerMaster = null var blockManager: BlockManager = null + var handler: ReceivedBlockHandler = null var tempDirectory: File = null + var storageLevel = StorageLevel.MEMORY_ONLY_SER + + private def makeBlockManager( + maxMem: Long, + name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { + conf.set("spark.storage.unrollMemoryThreshold", "512") + conf.set("spark.storage.unrollFraction", "0.4") + val transfer = new NioBlockTransferService(conf, securityMgr) + val manager = new BlockManager(name, rpcEnv, blockManagerMaster, serializer, maxMem, conf, + mapOutputTracker, shuffleManager, transfer, securityMgr, 0) + manager.initialize("app-id") + manager + } before { rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) @@ -174,6 +187,172 @@ class ReceivedBlockHandlerSuite } } + test("BlockManagerBasedBlockHandler-MEMORY_ONLY-ByteBufferBlock - count messages") { + storageLevel = StorageLevel.MEMORY_ONLY + // Create a non-trivial (not all zeros) byte array + var counter = 0.toByte + def incr: Byte = {counter = (counter + 1).toByte; counter;} + val bytes = Array.fill[Byte](100)(incr) + val byteBufferBlock = ByteBuffer.wrap(bytes) + withBlockManagerBasedBlockHandler { handler => + val blockStoreResult = storeBlock(handler, ByteBufferBlock(byteBufferBlock)) + assert(blockStoreResult.numRecords === None) + } + } + + test("BlockManagerBasedBlockHandler-MEMORY_ONLY-ArrayBufferBlock - count messages") { + storageLevel = StorageLevel.MEMORY_ONLY + val block = ArrayBuffer.fill(100)(0) + withBlockManagerBasedBlockHandler { handler => + val blockStoreResult = storeBlock(handler, ArrayBufferBlock(block)) + assert(blockStoreResult.numRecords === Some(100)) + } + } + + test("BlockManagerBasedBlockHandler-DISK_ONLY-ArrayBufferBlock - count messages") { + storageLevel = StorageLevel.DISK_ONLY + val block = ArrayBuffer.fill(100)(0) + withBlockManagerBasedBlockHandler { handler => + val blockStoreResult = storeBlock(handler, ArrayBufferBlock(block)) + assert(blockStoreResult.numRecords === Some(100)) + } + } + + test("BlockManagerBasedBlockHandler-MEMORY_AND_DISK-ArrayBufferBlock - count messages") { + storageLevel = StorageLevel.MEMORY_AND_DISK + val block = ArrayBuffer.fill(100)(0) + withBlockManagerBasedBlockHandler { handler => + val blockStoreResult = storeBlock(handler, ArrayBufferBlock(block)) + assert(blockStoreResult.numRecords === Some(100)) + } + } + + test("BlockManagerBasedBlockHandler-MEMORY_ONLY-IteratorBlock - count messages") { + storageLevel = StorageLevel.MEMORY_ONLY + val block = ArrayBuffer.fill(100)(0) + withBlockManagerBasedBlockHandler { handler => + val blockStoreResult = storeBlock(handler, IteratorBlock(block.iterator)) + assert(blockStoreResult.numRecords === Some(100)) + } + } + + test("BlockManagerBasedBlockHandler-DISK_ONLY-IteratorBlock - count messages") { + storageLevel = StorageLevel.DISK_ONLY + val block = ArrayBuffer.fill(100)(0) + withBlockManagerBasedBlockHandler { handler => + val blockStoreResult = storeBlock(handler, IteratorBlock(block.iterator)) + assert(blockStoreResult.numRecords === Some(100)) + } + } + + test("BlockManagerBasedBlockHandler-MEMORY_AND_DISK-IteratorBlock - count messages") { + storageLevel = StorageLevel.MEMORY_AND_DISK + val block = ArrayBuffer.fill(100)(0) + withBlockManagerBasedBlockHandler { handler => + val blockStoreResult = storeBlock(handler, IteratorBlock(block.iterator)) + assert(blockStoreResult.numRecords === Some(100)) + } + } + + test("WriteAheadLogBasedBlockHandler-MEMORY_ONLY-ArrayBufferBlock - count messages") { + storageLevel = StorageLevel.MEMORY_ONLY + val block = ArrayBuffer.fill(100)(0) + withWriteAheadLogBasedBlockHandler { handler => + val blockStoreResult = storeBlock(handler, ArrayBufferBlock(block)) + assert(blockStoreResult.numRecords === Some(100)) + } + } + + test("WriteAheadLogBasedBlockHandler-DISK_ONLY-ArrayBufferBlock - count messages") { + storageLevel = StorageLevel.DISK_ONLY + val block = ArrayBuffer.fill(100)(0) + withWriteAheadLogBasedBlockHandler { handler => + val blockStoreResult = storeBlock(handler, ArrayBufferBlock(block)) + assert(blockStoreResult.numRecords === Some(100)) + } + } + + test("WriteAheadLogBasedBlockHandler-MEMORY_AND_DISK-ArrayBufferBlock - count messages") { + storageLevel = StorageLevel.MEMORY_AND_DISK + val block = ArrayBuffer.fill(100)(0) + withWriteAheadLogBasedBlockHandler { handler => + val blockStoreResult = storeBlock(handler, ArrayBufferBlock(block)) + assert(blockStoreResult.numRecords === Some(100)) + } + } + + test("WriteAheadLogBasedBlockHandler-MEMORY_ONLY-IteratorBlock - count messages") { + storageLevel = StorageLevel.MEMORY_ONLY + val block = ArrayBuffer.fill(100)(0) + withWriteAheadLogBasedBlockHandler { handler => + val blockStoreResult = storeBlock(handler, IteratorBlock(block.iterator)) + assert(blockStoreResult.numRecords === Some(100)) + } + } + + test("WriteAheadLogBasedBlockHandler-DISK_ONLY-IteratorBlock - count messages ") { + storageLevel = StorageLevel.DISK_ONLY + val block = ArrayBuffer.fill(100)(0) + withWriteAheadLogBasedBlockHandler { handler => + val blockStoreResult = storeBlock(handler, IteratorBlock(block.iterator)) + assert(blockStoreResult.numRecords === Some(100)) + } + } + + test("WriteAheadLogBasedBlockHandler-MEMORY_AND_DISK-IteratorBlock - count messages") { + storageLevel = StorageLevel.MEMORY_AND_DISK + val block = ArrayBuffer.fill(100)(0) + withWriteAheadLogBasedBlockHandler { handler => + val blockStoreResult = storeBlock(handler, IteratorBlock(block.iterator)) + assert(blockStoreResult.numRecords === Some(100)) + } + } + + test("BlockManagerBasedBlockHandler - isFullyConsumed-MEMORY_ONLY") { + storageLevel = StorageLevel.MEMORY_ONLY + blockManager = makeBlockManager(12000) + val block = List.fill(70)(new Array[Byte](100)) + // spark.storage.unrollFraction set to 0.4 for BlockManager + // With 12000 * 0.4 = 4800 bytes of free space for unroll , there is not enough space to store + // this block With MEMORY_ONLY StorageLevel. BlockManager will not be able to unroll this block + // and hence it will not tryToPut this block , resulting the SparkException + withBlockManagerBasedBlockHandler { handler => + val thrown = intercept[SparkException] { + val blockStoreResult = storeBlock(handler, IteratorBlock(block.iterator)) + } + assert(thrown.getMessage === + "Could not store input-1-1000 to block manager with storage level " + storageLevel) + } + } + + test("BlockManagerBasedBlockHandler - isFullyConsumed-MEMORY_AND_DISK") { + storageLevel = StorageLevel.MEMORY_AND_DISK + blockManager = makeBlockManager(12000) + val block = List.fill(70)(new Array[Byte](100)) + // spark.storage.unrollFraction set to 0.4 for BlockManager + // With 12000 * 0.4 = 4800 bytes of free space for unroll , there is not enough space to store + // this block in MEMORY , But BlockManager will be able to sereliaze this block to DISK + // and hence count returns correct value. + withBlockManagerBasedBlockHandler { handler => + val blockStoreResult = storeBlock(handler, IteratorBlock(block.iterator)) + assert(blockStoreResult.numRecords === Some(70)) + } + } + + test("WriteAheadLogBasedBlockHandler - isFullyConsumed-MEMORY_ONLY") { + storageLevel = StorageLevel.MEMORY_ONLY + blockManager = makeBlockManager(12000) + val block = List.fill(70)(new Array[Byte](100)) + // spark.storage.unrollFraction set to 0.4 for BlockManager + // With 12000 * 0.4 = 4800 bytes of free space for unroll , there is not enough space to store + // this block in MEMORY , But BlockManager will be able to sereliaze this block to WAL + // and hence count returns correct value. + withWriteAheadLogBasedBlockHandler { handler => + val blockStoreResult = storeBlock(handler, IteratorBlock(block.iterator)) + assert(blockStoreResult.numRecords === Some(70)) + } + } + /** * Test storing of data using different forms of ReceivedBlocks and verify that they succeeded * using the given verification function @@ -251,9 +430,20 @@ class ReceivedBlockHandlerSuite (blockIds, storeResults) } + /** Store block using a handler */ + private def storeBlock( + handler: ReceivedBlockHandler, + block: ReceivedBlock + ): ReceivedBlockStoreResult = { + val blockId = StreamBlockId(streamId, 1000L) + val blockStoreResult = handler.storeBlock(blockId, block) + logDebug("Done inserting") + blockStoreResult + } private def getWriteAheadLogFiles(): Seq[String] = { getLogFilesInDirectory(checkpointDirToLogDir(tempDirectory.toString, streamId)) } private def generateBlockId(): StreamBlockId = StreamBlockId(streamId, scala.util.Random.nextLong) } + diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index be305b5e0dfea..f793a12843b2f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -225,7 +225,7 @@ class ReceivedBlockTrackerSuite /** Generate blocks infos using random ids */ def generateBlockInfos(): Seq[ReceivedBlockInfo] = { List.fill(5)(ReceivedBlockInfo(streamId, Some(0L), None, - BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt))))) + BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt)), Some(0L)))) } /** Get all the data written in the given write ahead log file. */ diff --git a/tools/pom.xml b/tools/pom.xml index feffde4c857eb..1c6f3e83a1819 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../pom.xml diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 62c6354f1e203..2fd17267ac427 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../pom.xml diff --git a/yarn/pom.xml b/yarn/pom.xml index 644def7501dc8..e207a46809684 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.4.0-SNAPSHOT ../pom.xml diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 002d7b6eaf498..760e458972d98 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -67,7 +67,6 @@ private[spark] class ApplicationMaster( @volatile private var reporterThread: Thread = _ @volatile private var allocator: YarnAllocator = _ - private val allocatorLock = new Object() // Fields used in client mode. private var rpcEnv: RpcEnv = null @@ -360,9 +359,7 @@ private[spark] class ApplicationMaster( } logDebug(s"Number of pending allocations is $numPendingAllocate. " + s"Sleeping for $sleepInterval.") - allocatorLock.synchronized { - allocatorLock.wait(sleepInterval) - } + Thread.sleep(sleepInterval) } catch { case e: InterruptedException => } @@ -549,15 +546,8 @@ private[spark] class ApplicationMaster( override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RequestExecutors(requestedTotal) => Option(allocator) match { - case Some(a) => - allocatorLock.synchronized { - if (a.requestTotalExecutors(requestedTotal)) { - allocatorLock.notifyAll() - } - } - - case None => - logWarning("Container allocator is not ready to request executors yet.") + case Some(a) => a.requestTotalExecutors(requestedTotal) + case None => logWarning("Container allocator is not ready to request executors yet.") } context.reply(true) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 940873fbd046c..21193e7c625e3 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -146,16 +146,11 @@ private[yarn] class YarnAllocator( * Request as many executors from the ResourceManager as needed to reach the desired total. If * the requested total is smaller than the current number of running executors, no executors will * be killed. - * - * @return Whether the new requested total is different than the old value. */ - def requestTotalExecutors(requestedTotal: Int): Boolean = synchronized { + def requestTotalExecutors(requestedTotal: Int): Unit = synchronized { if (requestedTotal != targetNumExecutors) { logInfo(s"Driver requested a total number of $requestedTotal executor(s).") targetNumExecutors = requestedTotal - true - } else { - false } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index bc42e12dfafd7..d8bc2534c1a6a 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -326,7 +326,7 @@ private object YarnClusterDriver extends Logging with Matchers { var result = "failure" try { val data = sc.parallelize(1 to 4, 4).collect().toSet - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) data should be (Set(1, 2, 3, 4)) result = "success" } finally {